edsl 0.1.38.dev1__py3-none-any.whl → 0.1.38.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +3 -3
- edsl/BaseDiff.py +7 -7
- edsl/__init__.py +2 -1
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +17 -14
- edsl/agents/AgentList.py +29 -17
- edsl/auto/SurveyCreatorPipeline.py +1 -1
- edsl/auto/utilities.py +1 -1
- edsl/base/Base.py +3 -13
- edsl/coop/coop.py +3 -0
- edsl/data/Cache.py +18 -15
- edsl/exceptions/agents.py +4 -0
- edsl/exceptions/cache.py +5 -0
- edsl/jobs/Jobs.py +22 -11
- edsl/jobs/buckets/TokenBucket.py +3 -0
- edsl/jobs/interviews/Interview.py +18 -18
- edsl/jobs/runners/JobsRunnerAsyncio.py +38 -15
- edsl/jobs/runners/JobsRunnerStatus.py +196 -196
- edsl/jobs/tasks/TaskHistory.py +12 -3
- edsl/language_models/LanguageModel.py +9 -7
- edsl/language_models/ModelList.py +20 -13
- edsl/notebooks/Notebook.py +7 -8
- edsl/questions/QuestionBase.py +21 -17
- edsl/questions/QuestionBaseGenMixin.py +1 -1
- edsl/questions/QuestionBasePromptsMixin.py +0 -17
- edsl/questions/QuestionFunctional.py +10 -3
- edsl/questions/derived/QuestionTopK.py +2 -0
- edsl/results/Result.py +31 -25
- edsl/results/Results.py +22 -22
- edsl/scenarios/Scenario.py +12 -14
- edsl/scenarios/ScenarioList.py +16 -16
- edsl/surveys/MemoryPlan.py +1 -1
- edsl/surveys/Rule.py +1 -5
- edsl/surveys/RuleCollection.py +1 -1
- edsl/surveys/Survey.py +9 -17
- edsl/surveys/instructions/ChangeInstruction.py +9 -7
- edsl/surveys/instructions/Instruction.py +9 -7
- edsl/{conjure → utilities}/naming_utilities.py +1 -1
- {edsl-0.1.38.dev1.dist-info → edsl-0.1.38.dev2.dist-info}/METADATA +1 -1
- {edsl-0.1.38.dev1.dist-info → edsl-0.1.38.dev2.dist-info}/RECORD +42 -56
- edsl/conjure/AgentConstructionMixin.py +0 -160
- edsl/conjure/Conjure.py +0 -62
- edsl/conjure/InputData.py +0 -659
- edsl/conjure/InputDataCSV.py +0 -48
- edsl/conjure/InputDataMixinQuestionStats.py +0 -182
- edsl/conjure/InputDataPyRead.py +0 -91
- edsl/conjure/InputDataSPSS.py +0 -8
- edsl/conjure/InputDataStata.py +0 -8
- edsl/conjure/QuestionOptionMixin.py +0 -76
- edsl/conjure/QuestionTypeMixin.py +0 -23
- edsl/conjure/RawQuestion.py +0 -65
- edsl/conjure/SurveyResponses.py +0 -7
- edsl/conjure/__init__.py +0 -9
- edsl/conjure/examples/placeholder.txt +0 -0
- edsl/conjure/utilities.py +0 -201
- {edsl-0.1.38.dev1.dist-info → edsl-0.1.38.dev2.dist-info}/LICENSE +0 -0
- {edsl-0.1.38.dev1.dist-info → edsl-0.1.38.dev2.dist-info}/WHEEL +0 -0
edsl/Base.py
CHANGED
@@ -256,10 +256,10 @@ class Base(
|
|
256
256
|
|
257
257
|
if not isinstance(other, self.__class__):
|
258
258
|
return False
|
259
|
-
if "sort" in inspect.signature(self.
|
260
|
-
return self.
|
259
|
+
if "sort" in inspect.signature(self.to_dict).parameters:
|
260
|
+
return self.to_dict(sort=True) == other.to_dict(sort=True)
|
261
261
|
else:
|
262
|
-
return self.
|
262
|
+
return self.to_dict() == other.to_dict()
|
263
263
|
|
264
264
|
@abstractmethod
|
265
265
|
def example():
|
edsl/BaseDiff.py
CHANGED
@@ -25,7 +25,7 @@ class DummyObject:
|
|
25
25
|
def __init__(self, object_dict):
|
26
26
|
self.object_dict = object_dict
|
27
27
|
|
28
|
-
def
|
28
|
+
def to_dict(self):
|
29
29
|
return self.object_dict
|
30
30
|
|
31
31
|
|
@@ -38,12 +38,12 @@ class BaseDiff:
|
|
38
38
|
self.obj1 = obj1
|
39
39
|
self.obj2 = obj2
|
40
40
|
|
41
|
-
if "sort" in inspect.signature(obj1.
|
42
|
-
self._dict1 = obj1.
|
43
|
-
self._dict2 = obj2.
|
41
|
+
if "sort" in inspect.signature(obj1.to_dict).parameters:
|
42
|
+
self._dict1 = obj1.to_dict(sort=True)
|
43
|
+
self._dict2 = obj2.to_dict(sort=True)
|
44
44
|
else:
|
45
|
-
self._dict1 = obj1.
|
46
|
-
self._dict2 = obj2.
|
45
|
+
self._dict1 = obj1.to_dict()
|
46
|
+
self._dict2 = obj2.to_dict()
|
47
47
|
self._obj_class = type(obj1)
|
48
48
|
|
49
49
|
self.added = added
|
@@ -139,7 +139,7 @@ class BaseDiff:
|
|
139
139
|
def apply(self, obj: Any):
|
140
140
|
"""Apply the diff to the object."""
|
141
141
|
|
142
|
-
new_obj_dict = obj.
|
142
|
+
new_obj_dict = obj.to_dict()
|
143
143
|
for k, v in self.added.items():
|
144
144
|
new_obj_dict[k] = v
|
145
145
|
for k in self.removed.keys():
|
edsl/__init__.py
CHANGED
@@ -41,7 +41,8 @@ from edsl.shared import shared_globals
|
|
41
41
|
from edsl.jobs.Jobs import Jobs
|
42
42
|
from edsl.notebooks.Notebook import Notebook
|
43
43
|
from edsl.study.Study import Study
|
44
|
-
|
44
|
+
|
45
|
+
# from edsl.conjure.Conjure import Conjure
|
45
46
|
from edsl.coop.coop import Coop
|
46
47
|
|
47
48
|
from edsl.surveys.instructions.Instruction import Instruction
|
edsl/__version__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.1.38.
|
1
|
+
__version__ = "0.1.38.dev2"
|
edsl/agents/Agent.py
CHANGED
@@ -669,9 +669,9 @@ class Agent(Base):
|
|
669
669
|
if dynamic_traits_func:
|
670
670
|
func = inspect.getsource(dynamic_traits_func)
|
671
671
|
raw_data["dynamic_traits_function_source_code"] = func
|
672
|
-
raw_data[
|
673
|
-
|
674
|
-
|
672
|
+
raw_data["dynamic_traits_function_name"] = (
|
673
|
+
self.dynamic_traits_function_name
|
674
|
+
)
|
675
675
|
if hasattr(self, "answer_question_directly"):
|
676
676
|
raw_data.pop(
|
677
677
|
"answer_question_directly", None
|
@@ -685,23 +685,19 @@ class Agent(Base):
|
|
685
685
|
raw_data["answer_question_directly_source_code"] = inspect.getsource(
|
686
686
|
answer_question_directly_func
|
687
687
|
)
|
688
|
-
raw_data[
|
689
|
-
|
690
|
-
|
688
|
+
raw_data["answer_question_directly_function_name"] = (
|
689
|
+
self.answer_question_directly_function_name
|
690
|
+
)
|
691
691
|
|
692
692
|
return raw_data
|
693
693
|
|
694
694
|
def __hash__(self) -> int:
|
695
695
|
from edsl.utilities.utilities import dict_hash
|
696
696
|
|
697
|
-
return dict_hash(self.
|
698
|
-
|
699
|
-
def _to_dict(self) -> dict[str, Union[dict, bool]]:
|
700
|
-
"""Serialize to a dictionary without EDSL info"""
|
701
|
-
return self.data
|
697
|
+
return dict_hash(self.to_dict(add_edsl_version=False))
|
702
698
|
|
703
|
-
@add_edsl_version
|
704
|
-
def to_dict(self) -> dict[str, Union[dict, bool]]:
|
699
|
+
# @add_edsl_version
|
700
|
+
def to_dict(self, add_edsl_version=True) -> dict[str, Union[dict, bool]]:
|
705
701
|
"""Serialize to a dictionary with EDSL info.
|
706
702
|
|
707
703
|
Example usage:
|
@@ -710,7 +706,14 @@ class Agent(Base):
|
|
710
706
|
>>> a.to_dict()
|
711
707
|
{'name': 'Steve', 'traits': {'age': 10, 'hair': 'brown', 'height': 5.5}, 'edsl_version': '...', 'edsl_class_name': 'Agent'}
|
712
708
|
"""
|
713
|
-
|
709
|
+
d = copy.deepcopy(self.data)
|
710
|
+
if add_edsl_version:
|
711
|
+
from edsl import __version__
|
712
|
+
|
713
|
+
d["edsl_version"] = __version__
|
714
|
+
d["edsl_class_name"] = self.__class__.__name__
|
715
|
+
|
716
|
+
return d
|
714
717
|
|
715
718
|
@classmethod
|
716
719
|
@remove_edsl_version
|
edsl/agents/AgentList.py
CHANGED
@@ -14,7 +14,7 @@ from __future__ import annotations
|
|
14
14
|
import csv
|
15
15
|
import json
|
16
16
|
from collections import UserList
|
17
|
-
from typing import Any, List, Optional, Union
|
17
|
+
from typing import Any, List, Optional, Union, TYPE_CHECKING
|
18
18
|
from rich import print_json
|
19
19
|
from rich.table import Table
|
20
20
|
from simpleeval import EvalWithCompoundTypes
|
@@ -23,6 +23,11 @@ from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
|
|
23
23
|
|
24
24
|
from collections.abc import Iterable
|
25
25
|
|
26
|
+
from edsl.exceptions.agents import AgentListError
|
27
|
+
|
28
|
+
if TYPE_CHECKING:
|
29
|
+
from edsl.scenarios.ScenarioList import ScenarioList
|
30
|
+
|
26
31
|
|
27
32
|
def is_iterable(obj):
|
28
33
|
return isinstance(obj, Iterable)
|
@@ -113,7 +118,7 @@ class AgentList(UserList, Base):
|
|
113
118
|
]
|
114
119
|
except Exception as e:
|
115
120
|
print(f"Exception:{e}")
|
116
|
-
raise
|
121
|
+
raise AgentListError(f"Error in filter. Exception:{e}")
|
117
122
|
|
118
123
|
return AgentList(new_data)
|
119
124
|
|
@@ -199,7 +204,8 @@ class AgentList(UserList, Base):
|
|
199
204
|
>>> al.add_trait('new_trait', [1, 2, 3])
|
200
205
|
Traceback (most recent call last):
|
201
206
|
...
|
202
|
-
|
207
|
+
edsl.exceptions.agents.AgentListError: The passed values have to be the same length as the agent list.
|
208
|
+
...
|
203
209
|
"""
|
204
210
|
if not is_iterable(values):
|
205
211
|
value = values
|
@@ -208,7 +214,7 @@ class AgentList(UserList, Base):
|
|
208
214
|
return self
|
209
215
|
|
210
216
|
if len(values) != len(self):
|
211
|
-
raise
|
217
|
+
raise AgentListError(
|
212
218
|
"The passed values have to be the same length as the agent list."
|
213
219
|
)
|
214
220
|
for agent, value in zip(self.data, values):
|
@@ -228,33 +234,39 @@ class AgentList(UserList, Base):
|
|
228
234
|
def __hash__(self) -> int:
|
229
235
|
from edsl.utilities.utilities import dict_hash
|
230
236
|
|
231
|
-
|
232
|
-
# data['agent_list'] = sorted(data['agent_list'], key=lambda x: dict_hash(x)
|
233
|
-
return dict_hash(self._to_dict(sorted=True))
|
237
|
+
return dict_hash(self.to_dict(add_edsl_version=False, sorted=True))
|
234
238
|
|
235
|
-
def
|
239
|
+
def to_dict(self, sorted=False, add_edsl_version=True):
|
236
240
|
if sorted:
|
237
241
|
data = self.data[:]
|
238
242
|
data.sort(key=lambda x: hash(x))
|
239
243
|
else:
|
240
244
|
data = self.data
|
241
245
|
|
242
|
-
|
246
|
+
d = {
|
247
|
+
"agent_list": [
|
248
|
+
agent.to_dict(add_edsl_version=add_edsl_version) for agent in data
|
249
|
+
]
|
250
|
+
}
|
251
|
+
if add_edsl_version:
|
252
|
+
from edsl import __version__
|
243
253
|
|
244
|
-
|
245
|
-
|
254
|
+
d["edsl_version"] = __version__
|
255
|
+
d["edsl_class_name"] = "AgentList"
|
256
|
+
|
257
|
+
return d
|
246
258
|
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
259
|
+
def __eq__(self, other: AgentList) -> bool:
|
260
|
+
return self.to_dict(sorted=True, add_edsl_version=False) == other.to_dict(
|
261
|
+
sorted=True, add_edsl_version=False
|
262
|
+
)
|
251
263
|
|
252
264
|
def __repr__(self):
|
253
265
|
return f"AgentList({self.data})"
|
254
266
|
|
255
267
|
def print(self, format: Optional[str] = None):
|
256
268
|
"""Print the AgentList."""
|
257
|
-
print_json(json.dumps(self.
|
269
|
+
print_json(json.dumps(self.to_dict(add_edsl_version=False)))
|
258
270
|
|
259
271
|
def _repr_html_(self):
|
260
272
|
"""Return an HTML representation of the AgentList."""
|
@@ -262,7 +274,7 @@ class AgentList(UserList, Base):
|
|
262
274
|
|
263
275
|
return data_to_html(self.to_dict()["agent_list"])
|
264
276
|
|
265
|
-
def to_scenario_list(self) ->
|
277
|
+
def to_scenario_list(self) -> ScenarioList:
|
266
278
|
"""Return a list of scenarios."""
|
267
279
|
from edsl.scenarios.ScenarioList import ScenarioList
|
268
280
|
from edsl.scenarios.Scenario import Scenario
|
@@ -15,7 +15,7 @@ from edsl.surveys.Survey import Survey
|
|
15
15
|
from edsl.questions.QuestionMultipleChoice import QuestionMultipleChoice
|
16
16
|
from edsl.questions.QuestionFreeText import QuestionFreeText
|
17
17
|
from edsl.auto.utilities import gen_pipeline
|
18
|
-
from edsl.
|
18
|
+
from edsl.utilities.naming_utilities import sanitize_string
|
19
19
|
|
20
20
|
|
21
21
|
m = Model()
|
edsl/auto/utilities.py
CHANGED
@@ -2,7 +2,7 @@ from textwrap import dedent
|
|
2
2
|
import random
|
3
3
|
from typing import List, TypeVar, Generator, Optional
|
4
4
|
from edsl.auto.StageBase import StageBase
|
5
|
-
from edsl.
|
5
|
+
from edsl.utilities.naming_utilities import sanitize_string
|
6
6
|
from edsl import Agent, Survey, Model, Cache, AgentList
|
7
7
|
from edsl import QuestionFreeText, Scenario
|
8
8
|
from edsl import QuestionMultipleChoice, Scenario, Agent, ScenarioList
|
edsl/base/Base.py
CHANGED
@@ -229,26 +229,16 @@ class Base(
|
|
229
229
|
|
230
230
|
return data_to_html(self.to_dict())
|
231
231
|
|
232
|
-
# def html(self):
|
233
|
-
# html_string = self._repr_html_()
|
234
|
-
# import tempfile
|
235
|
-
# import webbrowser
|
236
|
-
|
237
|
-
# with tempfile.NamedTemporaryFile("w", delete=False, suffix=".html") as f:
|
238
|
-
# # print("Writing HTML to", f.name)
|
239
|
-
# f.write(html_string)
|
240
|
-
# webbrowser.open(f.name)
|
241
|
-
|
242
232
|
def __eq__(self, other):
|
243
233
|
"""Return whether two objects are equal."""
|
244
234
|
import inspect
|
245
235
|
|
246
236
|
if not isinstance(other, self.__class__):
|
247
237
|
return False
|
248
|
-
if "sort" in inspect.signature(self.
|
249
|
-
return self.
|
238
|
+
if "sort" in inspect.signature(self.to_dict).parameters:
|
239
|
+
return self.to_dict(sort=True) == other.to_dict(sort=True)
|
250
240
|
else:
|
251
|
-
return self.
|
241
|
+
return self.to_dict() == other.to_dict()
|
252
242
|
|
253
243
|
@abstractmethod
|
254
244
|
def example():
|
edsl/coop/coop.py
CHANGED
edsl/data/Cache.py
CHANGED
@@ -11,7 +11,8 @@ from typing import Optional, Union
|
|
11
11
|
from edsl.Base import Base
|
12
12
|
from edsl.data.CacheEntry import CacheEntry
|
13
13
|
from edsl.utilities.utilities import dict_hash
|
14
|
-
from edsl.utilities.decorators import
|
14
|
+
from edsl.utilities.decorators import remove_edsl_version
|
15
|
+
from edsl.exceptions.cache import CacheError
|
15
16
|
|
16
17
|
|
17
18
|
class Cache(Base):
|
@@ -58,7 +59,7 @@ class Cache(Base):
|
|
58
59
|
|
59
60
|
self.filename = filename
|
60
61
|
if filename and data:
|
61
|
-
raise
|
62
|
+
raise CacheError("Cannot provide both filename and data")
|
62
63
|
if filename is None and data is None:
|
63
64
|
data = {}
|
64
65
|
if data is not None:
|
@@ -76,7 +77,7 @@ class Cache(Base):
|
|
76
77
|
if os.path.exists(filename):
|
77
78
|
self.add_from_sqlite(filename)
|
78
79
|
else:
|
79
|
-
raise
|
80
|
+
raise CacheError("Invalid file extension. Must be .jsonl or .db")
|
80
81
|
|
81
82
|
self._perform_checks()
|
82
83
|
|
@@ -116,7 +117,7 @@ class Cache(Base):
|
|
116
117
|
from edsl.data.CacheEntry import CacheEntry
|
117
118
|
|
118
119
|
if any(not isinstance(value, CacheEntry) for value in self.data.values()):
|
119
|
-
raise
|
120
|
+
raise CacheError("Not all values are CacheEntry instances")
|
120
121
|
if self.method is not None:
|
121
122
|
warnings.warn("Argument `method` is deprecated", DeprecationWarning)
|
122
123
|
|
@@ -227,9 +228,9 @@ class Cache(Base):
|
|
227
228
|
for key, value in new_data.items():
|
228
229
|
if key in self.data:
|
229
230
|
if value != self.data[key]:
|
230
|
-
raise
|
231
|
+
raise CacheError("Mismatch in values")
|
231
232
|
if not isinstance(value, CacheEntry):
|
232
|
-
raise
|
233
|
+
raise CacheError(f"Wrong type - the observed type is {type(value)}")
|
233
234
|
|
234
235
|
self.new_entries.update(new_data)
|
235
236
|
if write_now:
|
@@ -338,7 +339,7 @@ class Cache(Base):
|
|
338
339
|
elif filename.endswith(".db"):
|
339
340
|
self.write_sqlite_db(filename)
|
340
341
|
else:
|
341
|
-
raise
|
342
|
+
raise CacheError("Invalid file extension. Must be .jsonl or .db")
|
342
343
|
|
343
344
|
def write_jsonl(self, filename: str) -> None:
|
344
345
|
"""
|
@@ -396,15 +397,17 @@ class Cache(Base):
|
|
396
397
|
####################
|
397
398
|
def __hash__(self):
|
398
399
|
"""Return the hash of the Cache."""
|
399
|
-
return dict_hash(self.
|
400
|
+
return dict_hash(self.to_dict(add_edsl_version=False))
|
400
401
|
|
401
|
-
def
|
402
|
-
|
402
|
+
def to_dict(self, add_edsl_version=True) -> dict:
|
403
|
+
d = {k: v.to_dict() for k, v in self.data.items()}
|
404
|
+
if add_edsl_version:
|
405
|
+
from edsl import __version__
|
406
|
+
|
407
|
+
d["edsl_version"] = __version__
|
408
|
+
d["edsl_class_name"] = "Cache"
|
403
409
|
|
404
|
-
|
405
|
-
def to_dict(self) -> dict:
|
406
|
-
"""Return the Cache as a dictionary."""
|
407
|
-
return self._to_dict()
|
410
|
+
return d
|
408
411
|
|
409
412
|
def _repr_html_(self):
|
410
413
|
from edsl.utilities.utilities import data_to_html
|
@@ -438,7 +441,7 @@ class Cache(Base):
|
|
438
441
|
Combine two caches.
|
439
442
|
"""
|
440
443
|
if not isinstance(other, Cache):
|
441
|
-
raise
|
444
|
+
raise CacheError("Can only add two caches together")
|
442
445
|
self.data.update(other.data)
|
443
446
|
return self
|
444
447
|
|
edsl/exceptions/agents.py
CHANGED
@@ -1,6 +1,10 @@
|
|
1
1
|
from edsl.exceptions.BaseException import BaseException
|
2
2
|
|
3
3
|
|
4
|
+
class AgentListError(BaseException):
|
5
|
+
relevant_doc = "https://docs.expectedparrot.com/en/latest/agents.html#agent-lists"
|
6
|
+
|
7
|
+
|
4
8
|
class AgentErrors(BaseException):
|
5
9
|
relevant_doc = "https://docs.expectedparrot.com/en/latest/agents.html"
|
6
10
|
|
edsl/exceptions/cache.py
ADDED
edsl/jobs/Jobs.py
CHANGED
@@ -641,7 +641,7 @@ class Jobs(Base):
|
|
641
641
|
"""
|
642
642
|
from edsl.utilities.utilities import dict_hash
|
643
643
|
|
644
|
-
return dict_hash(self.
|
644
|
+
return dict_hash(self.to_dict(add_edsl_version=False))
|
645
645
|
|
646
646
|
def _output(self, message) -> None:
|
647
647
|
"""Check if a Job is verbose. If so, print the message."""
|
@@ -1188,18 +1188,29 @@ class Jobs(Base):
|
|
1188
1188
|
# Serialization methods
|
1189
1189
|
#######################
|
1190
1190
|
|
1191
|
-
def
|
1192
|
-
|
1193
|
-
"survey": self.survey.
|
1194
|
-
"agents": [
|
1195
|
-
|
1196
|
-
|
1191
|
+
def to_dict(self, add_edsl_version=True):
|
1192
|
+
d = {
|
1193
|
+
"survey": self.survey.to_dict(add_edsl_version=add_edsl_version),
|
1194
|
+
"agents": [
|
1195
|
+
agent.to_dict(add_edsl_version=add_edsl_version)
|
1196
|
+
for agent in self.agents
|
1197
|
+
],
|
1198
|
+
"models": [
|
1199
|
+
model.to_dict(add_edsl_version=add_edsl_version)
|
1200
|
+
for model in self.models
|
1201
|
+
],
|
1202
|
+
"scenarios": [
|
1203
|
+
scenario.to_dict(add_edsl_version=add_edsl_version)
|
1204
|
+
for scenario in self.scenarios
|
1205
|
+
],
|
1197
1206
|
}
|
1207
|
+
if add_edsl_version:
|
1208
|
+
from edsl import __version__
|
1209
|
+
|
1210
|
+
d["edsl_version"] = __version__
|
1211
|
+
d["edsl_class_name"] = "Jobs"
|
1198
1212
|
|
1199
|
-
|
1200
|
-
def to_dict(self) -> dict:
|
1201
|
-
"""Convert the Jobs instance to a dictionary."""
|
1202
|
-
return self._to_dict()
|
1213
|
+
return d
|
1203
1214
|
|
1204
1215
|
@classmethod
|
1205
1216
|
@remove_edsl_version
|
edsl/jobs/buckets/TokenBucket.py
CHANGED
@@ -110,9 +110,9 @@ class Interview:
|
|
110
110
|
self.debug = debug
|
111
111
|
self.iteration = iteration
|
112
112
|
self.cache = cache
|
113
|
-
self.answers: dict[
|
114
|
-
|
115
|
-
|
113
|
+
self.answers: dict[str, str] = (
|
114
|
+
Answers()
|
115
|
+
) # will get filled in as interview progresses
|
116
116
|
self.sidecar_model = sidecar_model
|
117
117
|
|
118
118
|
# Trackers
|
@@ -143,9 +143,9 @@ class Interview:
|
|
143
143
|
The keys are the question names; the values are the lists of status log changes for each task.
|
144
144
|
"""
|
145
145
|
for task_creator in self.task_creators.values():
|
146
|
-
self._task_status_log_dict[
|
147
|
-
task_creator.
|
148
|
-
|
146
|
+
self._task_status_log_dict[task_creator.question.question_name] = (
|
147
|
+
task_creator.status_log
|
148
|
+
)
|
149
149
|
return self._task_status_log_dict
|
150
150
|
|
151
151
|
@property
|
@@ -159,7 +159,7 @@ class Interview:
|
|
159
159
|
return self.task_creators.interview_status
|
160
160
|
|
161
161
|
# region: Serialization
|
162
|
-
def
|
162
|
+
def to_dict(self, include_exceptions=True, add_edsl_version=True) -> dict[str, Any]:
|
163
163
|
"""Return a dictionary representation of the Interview instance.
|
164
164
|
This is just for hashing purposes.
|
165
165
|
|
@@ -168,10 +168,10 @@ class Interview:
|
|
168
168
|
1217840301076717434
|
169
169
|
"""
|
170
170
|
d = {
|
171
|
-
"agent": self.agent.
|
172
|
-
"survey": self.survey.
|
173
|
-
"scenario": self.scenario.
|
174
|
-
"model": self.model.
|
171
|
+
"agent": self.agent.to_dict(add_edsl_version=add_edsl_version),
|
172
|
+
"survey": self.survey.to_dict(add_edsl_version=add_edsl_version),
|
173
|
+
"scenario": self.scenario.to_dict(add_edsl_version=add_edsl_version),
|
174
|
+
"model": self.model.to_dict(add_edsl_version=add_edsl_version),
|
175
175
|
"iteration": self.iteration,
|
176
176
|
"exceptions": {},
|
177
177
|
}
|
@@ -202,11 +202,11 @@ class Interview:
|
|
202
202
|
def __hash__(self) -> int:
|
203
203
|
from edsl.utilities.utilities import dict_hash
|
204
204
|
|
205
|
-
return dict_hash(self.
|
205
|
+
return dict_hash(self.to_dict(include_exceptions=False, add_edsl_version=False))
|
206
206
|
|
207
207
|
def __eq__(self, other: "Interview") -> bool:
|
208
208
|
"""
|
209
|
-
>>> from edsl.jobs.interviews.Interview import Interview; i = Interview.example(); d = i.
|
209
|
+
>>> from edsl.jobs.interviews.Interview import Interview; i = Interview.example(); d = i.to_dict(); i2 = Interview.from_dict(d); i == i2
|
210
210
|
True
|
211
211
|
"""
|
212
212
|
return hash(self) == hash(other)
|
@@ -486,11 +486,11 @@ class Interview:
|
|
486
486
|
"""
|
487
487
|
current_question_index: int = self.to_index[current_question.question_name]
|
488
488
|
|
489
|
-
next_question: Union[
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
489
|
+
next_question: Union[int, EndOfSurvey] = (
|
490
|
+
self.survey.rule_collection.next_question(
|
491
|
+
q_now=current_question_index,
|
492
|
+
answers=self.answers | self.scenario | self.agent["traits"],
|
493
|
+
)
|
494
494
|
)
|
495
495
|
|
496
496
|
next_question_index = next_question.next_q
|
@@ -2,13 +2,13 @@ from __future__ import annotations
|
|
2
2
|
import time
|
3
3
|
import asyncio
|
4
4
|
import threading
|
5
|
-
|
6
|
-
from
|
5
|
+
import warnings
|
6
|
+
from typing import Coroutine, List, AsyncGenerator, Optional, Union, Generator, Type
|
7
7
|
from collections import UserList
|
8
8
|
|
9
9
|
from edsl.results.Results import Results
|
10
10
|
from edsl.jobs.interviews.Interview import Interview
|
11
|
-
from edsl.jobs.runners.JobsRunnerStatus import JobsRunnerStatus
|
11
|
+
from edsl.jobs.runners.JobsRunnerStatus import JobsRunnerStatus, JobsRunnerStatusBase
|
12
12
|
|
13
13
|
from edsl.jobs.tasks.TaskHistory import TaskHistory
|
14
14
|
from edsl.jobs.buckets.BucketCollection import BucketCollection
|
@@ -41,6 +41,7 @@ class JobsRunnerAsyncio:
|
|
41
41
|
self.interviews: List["Interview"] = jobs.interviews()
|
42
42
|
self.bucket_collection: "BucketCollection" = jobs.bucket_collection
|
43
43
|
self.total_interviews: List["Interview"] = []
|
44
|
+
self._initialized = threading.Event()
|
44
45
|
|
45
46
|
async def run_async_generator(
|
46
47
|
self,
|
@@ -69,6 +70,8 @@ class JobsRunnerAsyncio:
|
|
69
70
|
self._populate_total_interviews(n=n)
|
70
71
|
) # Populate self.total_interviews before creating tasks
|
71
72
|
|
73
|
+
self._initialized.set() # Signal that we're ready
|
74
|
+
|
72
75
|
for interview in self.total_interviews:
|
73
76
|
interviewing_task = self._build_interview_task(
|
74
77
|
interview=interview,
|
@@ -165,20 +168,20 @@ class JobsRunnerAsyncio:
|
|
165
168
|
|
166
169
|
prompt_dictionary = {}
|
167
170
|
for answer_key_name in answer_key_names:
|
168
|
-
prompt_dictionary[
|
169
|
-
answer_key_name
|
170
|
-
|
171
|
-
prompt_dictionary[
|
172
|
-
answer_key_name
|
173
|
-
|
171
|
+
prompt_dictionary[answer_key_name + "_user_prompt"] = (
|
172
|
+
question_name_to_prompts[answer_key_name]["user_prompt"]
|
173
|
+
)
|
174
|
+
prompt_dictionary[answer_key_name + "_system_prompt"] = (
|
175
|
+
question_name_to_prompts[answer_key_name]["system_prompt"]
|
176
|
+
)
|
174
177
|
|
175
178
|
raw_model_results_dictionary = {}
|
176
179
|
cache_used_dictionary = {}
|
177
180
|
for result in valid_results:
|
178
181
|
question_name = result.question_name
|
179
|
-
raw_model_results_dictionary[
|
180
|
-
|
181
|
-
|
182
|
+
raw_model_results_dictionary[question_name + "_raw_model_response"] = (
|
183
|
+
result.raw_model_response
|
184
|
+
)
|
182
185
|
raw_model_results_dictionary[question_name + "_cost"] = result.cost
|
183
186
|
one_use_buys = (
|
184
187
|
"NA"
|
@@ -275,6 +278,7 @@ class JobsRunnerAsyncio:
|
|
275
278
|
stop_on_exception: bool = False,
|
276
279
|
progress_bar: bool = False,
|
277
280
|
sidecar_model: Optional[LanguageModel] = None,
|
281
|
+
jobs_runner_status: Optional[Type[JobsRunnerStatusBase]] = None,
|
278
282
|
print_exceptions: bool = True,
|
279
283
|
raise_validation_errors: bool = False,
|
280
284
|
) -> "Coroutine":
|
@@ -286,7 +290,21 @@ class JobsRunnerAsyncio:
|
|
286
290
|
self.cache = cache
|
287
291
|
self.sidecar_model = sidecar_model
|
288
292
|
|
289
|
-
|
293
|
+
from edsl.coop import Coop
|
294
|
+
|
295
|
+
coop = Coop()
|
296
|
+
endpoint_url = coop.get_progress_bar_url()
|
297
|
+
|
298
|
+
if jobs_runner_status is not None:
|
299
|
+
self.jobs_runner_status = jobs_runner_status(
|
300
|
+
self, n=n, endpoint_url=endpoint_url
|
301
|
+
)
|
302
|
+
else:
|
303
|
+
self.jobs_runner_status = JobsRunnerStatus(
|
304
|
+
self,
|
305
|
+
n=n,
|
306
|
+
endpoint_url=endpoint_url,
|
307
|
+
)
|
290
308
|
|
291
309
|
stop_event = threading.Event()
|
292
310
|
|
@@ -306,11 +324,16 @@ class JobsRunnerAsyncio:
|
|
306
324
|
"""Runs the progress bar in a separate thread."""
|
307
325
|
self.jobs_runner_status.update_progress(stop_event)
|
308
326
|
|
309
|
-
if progress_bar:
|
327
|
+
if progress_bar and self.jobs_runner_status.has_ep_api_key():
|
328
|
+
self.jobs_runner_status.setup()
|
310
329
|
progress_thread = threading.Thread(
|
311
330
|
target=run_progress_bar, args=(stop_event,)
|
312
331
|
)
|
313
332
|
progress_thread.start()
|
333
|
+
elif progress_bar:
|
334
|
+
warnings.warn(
|
335
|
+
"You need an Expected Parrot API key to view job progress bars."
|
336
|
+
)
|
314
337
|
|
315
338
|
exception_to_raise = None
|
316
339
|
try:
|
@@ -325,7 +348,7 @@ class JobsRunnerAsyncio:
|
|
325
348
|
stop_event.set()
|
326
349
|
finally:
|
327
350
|
stop_event.set()
|
328
|
-
if progress_bar:
|
351
|
+
if progress_bar and self.jobs_runner_status.has_ep_api_key():
|
329
352
|
# self.jobs_runner_status.stop_event.set()
|
330
353
|
if progress_thread:
|
331
354
|
progress_thread.join()
|