edsl 0.1.38.dev4__py3-none-any.whl → 0.1.39__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +197 -116
- edsl/__init__.py +15 -7
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +351 -147
- edsl/agents/AgentList.py +211 -73
- edsl/agents/Invigilator.py +101 -50
- edsl/agents/InvigilatorBase.py +62 -70
- edsl/agents/PromptConstructor.py +143 -225
- edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
- edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
- edsl/agents/__init__.py +0 -1
- edsl/agents/prompt_helpers.py +3 -3
- edsl/agents/question_option_processor.py +172 -0
- edsl/auto/AutoStudy.py +18 -5
- edsl/auto/StageBase.py +53 -40
- edsl/auto/StageQuestions.py +2 -1
- edsl/auto/utilities.py +0 -6
- edsl/config.py +22 -2
- edsl/conversation/car_buying.py +2 -1
- edsl/coop/CoopFunctionsMixin.py +15 -0
- edsl/coop/ExpectedParrotKeyHandler.py +125 -0
- edsl/coop/PriceFetcher.py +1 -1
- edsl/coop/coop.py +125 -47
- edsl/coop/utils.py +14 -14
- edsl/data/Cache.py +45 -27
- edsl/data/CacheEntry.py +12 -15
- edsl/data/CacheHandler.py +31 -12
- edsl/data/RemoteCacheSync.py +154 -46
- edsl/data/__init__.py +4 -3
- edsl/data_transfer_models.py +2 -1
- edsl/enums.py +27 -0
- edsl/exceptions/__init__.py +50 -50
- edsl/exceptions/agents.py +12 -0
- edsl/exceptions/inference_services.py +5 -0
- edsl/exceptions/questions.py +24 -6
- edsl/exceptions/scenarios.py +7 -0
- edsl/inference_services/AnthropicService.py +38 -19
- edsl/inference_services/AvailableModelCacheHandler.py +184 -0
- edsl/inference_services/AvailableModelFetcher.py +215 -0
- edsl/inference_services/AwsBedrock.py +0 -2
- edsl/inference_services/AzureAI.py +0 -2
- edsl/inference_services/GoogleService.py +7 -12
- edsl/inference_services/InferenceServiceABC.py +18 -85
- edsl/inference_services/InferenceServicesCollection.py +120 -79
- edsl/inference_services/MistralAIService.py +0 -3
- edsl/inference_services/OpenAIService.py +47 -35
- edsl/inference_services/PerplexityService.py +0 -3
- edsl/inference_services/ServiceAvailability.py +135 -0
- edsl/inference_services/TestService.py +11 -10
- edsl/inference_services/TogetherAIService.py +5 -3
- edsl/inference_services/data_structures.py +134 -0
- edsl/jobs/AnswerQuestionFunctionConstructor.py +223 -0
- edsl/jobs/Answers.py +1 -14
- edsl/jobs/FetchInvigilator.py +47 -0
- edsl/jobs/InterviewTaskManager.py +98 -0
- edsl/jobs/InterviewsConstructor.py +50 -0
- edsl/jobs/Jobs.py +356 -431
- edsl/jobs/JobsChecks.py +35 -10
- edsl/jobs/JobsComponentConstructor.py +189 -0
- edsl/jobs/JobsPrompts.py +6 -4
- edsl/jobs/JobsRemoteInferenceHandler.py +205 -133
- edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
- edsl/jobs/RequestTokenEstimator.py +30 -0
- edsl/jobs/async_interview_runner.py +138 -0
- edsl/jobs/buckets/BucketCollection.py +44 -3
- edsl/jobs/buckets/TokenBucket.py +53 -21
- edsl/jobs/buckets/TokenBucketAPI.py +211 -0
- edsl/jobs/buckets/TokenBucketClient.py +191 -0
- edsl/jobs/check_survey_scenario_compatibility.py +85 -0
- edsl/jobs/data_structures.py +120 -0
- edsl/jobs/decorators.py +35 -0
- edsl/jobs/interviews/Interview.py +143 -408
- edsl/jobs/jobs_status_enums.py +9 -0
- edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
- edsl/jobs/results_exceptions_handler.py +98 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +88 -403
- edsl/jobs/runners/JobsRunnerStatus.py +133 -165
- edsl/jobs/tasks/QuestionTaskCreator.py +21 -19
- edsl/jobs/tasks/TaskHistory.py +38 -18
- edsl/jobs/tasks/task_status_enum.py +0 -2
- edsl/language_models/ComputeCost.py +63 -0
- edsl/language_models/LanguageModel.py +194 -236
- edsl/language_models/ModelList.py +28 -19
- edsl/language_models/PriceManager.py +127 -0
- edsl/language_models/RawResponseHandler.py +106 -0
- edsl/language_models/ServiceDataSources.py +0 -0
- edsl/language_models/__init__.py +1 -2
- edsl/language_models/key_management/KeyLookup.py +63 -0
- edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
- edsl/language_models/key_management/KeyLookupCollection.py +38 -0
- edsl/language_models/key_management/__init__.py +0 -0
- edsl/language_models/key_management/models.py +131 -0
- edsl/language_models/model.py +256 -0
- edsl/language_models/repair.py +2 -2
- edsl/language_models/utilities.py +5 -4
- edsl/notebooks/Notebook.py +19 -14
- edsl/notebooks/NotebookToLaTeX.py +142 -0
- edsl/prompts/Prompt.py +29 -39
- edsl/questions/ExceptionExplainer.py +77 -0
- edsl/questions/HTMLQuestion.py +103 -0
- edsl/questions/QuestionBase.py +68 -214
- edsl/questions/QuestionBasePromptsMixin.py +7 -3
- edsl/questions/QuestionBudget.py +1 -1
- edsl/questions/QuestionCheckBox.py +3 -3
- edsl/questions/QuestionExtract.py +5 -7
- edsl/questions/QuestionFreeText.py +2 -3
- edsl/questions/QuestionList.py +10 -18
- edsl/questions/QuestionMatrix.py +265 -0
- edsl/questions/QuestionMultipleChoice.py +67 -23
- edsl/questions/QuestionNumerical.py +2 -4
- edsl/questions/QuestionRank.py +7 -17
- edsl/questions/SimpleAskMixin.py +4 -3
- edsl/questions/__init__.py +2 -1
- edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +47 -2
- edsl/questions/data_structures.py +20 -0
- edsl/questions/derived/QuestionLinearScale.py +6 -3
- edsl/questions/derived/QuestionTopK.py +1 -1
- edsl/questions/descriptors.py +17 -3
- edsl/questions/loop_processor.py +149 -0
- edsl/questions/{QuestionBaseGenMixin.py → question_base_gen_mixin.py} +57 -50
- edsl/questions/question_registry.py +1 -1
- edsl/questions/{ResponseValidatorABC.py → response_validator_abc.py} +40 -26
- edsl/questions/response_validator_factory.py +34 -0
- edsl/questions/templates/matrix/__init__.py +1 -0
- edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
- edsl/questions/templates/matrix/question_presentation.jinja +20 -0
- edsl/results/CSSParameterizer.py +1 -1
- edsl/results/Dataset.py +170 -7
- edsl/results/DatasetExportMixin.py +168 -305
- edsl/results/DatasetTree.py +28 -8
- edsl/results/MarkdownToDocx.py +122 -0
- edsl/results/MarkdownToPDF.py +111 -0
- edsl/results/Result.py +298 -206
- edsl/results/Results.py +149 -131
- edsl/results/ResultsExportMixin.py +2 -0
- edsl/results/TableDisplay.py +98 -171
- edsl/results/TextEditor.py +50 -0
- edsl/results/__init__.py +1 -1
- edsl/results/file_exports.py +252 -0
- edsl/results/{Selector.py → results_selector.py} +23 -13
- edsl/results/smart_objects.py +96 -0
- edsl/results/table_data_class.py +12 -0
- edsl/results/table_renderers.py +118 -0
- edsl/scenarios/ConstructDownloadLink.py +109 -0
- edsl/scenarios/DocumentChunker.py +102 -0
- edsl/scenarios/DocxScenario.py +16 -0
- edsl/scenarios/FileStore.py +150 -239
- edsl/scenarios/PdfExtractor.py +40 -0
- edsl/scenarios/Scenario.py +90 -193
- edsl/scenarios/ScenarioHtmlMixin.py +4 -3
- edsl/scenarios/ScenarioList.py +415 -244
- edsl/scenarios/ScenarioListExportMixin.py +0 -7
- edsl/scenarios/ScenarioListPdfMixin.py +15 -37
- edsl/scenarios/__init__.py +1 -2
- edsl/scenarios/directory_scanner.py +96 -0
- edsl/scenarios/file_methods.py +85 -0
- edsl/scenarios/handlers/__init__.py +13 -0
- edsl/scenarios/handlers/csv.py +49 -0
- edsl/scenarios/handlers/docx.py +76 -0
- edsl/scenarios/handlers/html.py +37 -0
- edsl/scenarios/handlers/json.py +111 -0
- edsl/scenarios/handlers/latex.py +5 -0
- edsl/scenarios/handlers/md.py +51 -0
- edsl/scenarios/handlers/pdf.py +68 -0
- edsl/scenarios/handlers/png.py +39 -0
- edsl/scenarios/handlers/pptx.py +105 -0
- edsl/scenarios/handlers/py.py +294 -0
- edsl/scenarios/handlers/sql.py +313 -0
- edsl/scenarios/handlers/sqlite.py +149 -0
- edsl/scenarios/handlers/txt.py +33 -0
- edsl/scenarios/{ScenarioJoin.py → scenario_join.py} +10 -6
- edsl/scenarios/scenario_selector.py +156 -0
- edsl/study/ObjectEntry.py +1 -1
- edsl/study/SnapShot.py +1 -1
- edsl/study/Study.py +5 -12
- edsl/surveys/ConstructDAG.py +92 -0
- edsl/surveys/EditSurvey.py +221 -0
- edsl/surveys/InstructionHandler.py +100 -0
- edsl/surveys/MemoryManagement.py +72 -0
- edsl/surveys/Rule.py +5 -4
- edsl/surveys/RuleCollection.py +25 -27
- edsl/surveys/RuleManager.py +172 -0
- edsl/surveys/Simulator.py +75 -0
- edsl/surveys/Survey.py +270 -791
- edsl/surveys/SurveyCSS.py +20 -8
- edsl/surveys/{SurveyFlowVisualizationMixin.py → SurveyFlowVisualization.py} +11 -9
- edsl/surveys/SurveyToApp.py +141 -0
- edsl/surveys/__init__.py +4 -2
- edsl/surveys/descriptors.py +6 -2
- edsl/surveys/instructions/ChangeInstruction.py +1 -2
- edsl/surveys/instructions/Instruction.py +4 -13
- edsl/surveys/instructions/InstructionCollection.py +11 -6
- edsl/templates/error_reporting/interview_details.html +1 -1
- edsl/templates/error_reporting/report.html +1 -1
- edsl/tools/plotting.py +1 -1
- edsl/utilities/PrettyList.py +56 -0
- edsl/utilities/is_notebook.py +18 -0
- edsl/utilities/is_valid_variable_name.py +11 -0
- edsl/utilities/remove_edsl_version.py +24 -0
- edsl/utilities/utilities.py +35 -23
- {edsl-0.1.38.dev4.dist-info → edsl-0.1.39.dist-info}/METADATA +12 -10
- edsl-0.1.39.dist-info/RECORD +358 -0
- {edsl-0.1.38.dev4.dist-info → edsl-0.1.39.dist-info}/WHEEL +1 -1
- edsl/language_models/KeyLookup.py +0 -30
- edsl/language_models/registry.py +0 -190
- edsl/language_models/unused/ReplicateBase.py +0 -83
- edsl/results/ResultsDBMixin.py +0 -238
- edsl-0.1.38.dev4.dist-info/RECORD +0 -277
- /edsl/questions/{RegisterQuestionsMeta.py → register_questions_meta.py} +0 -0
- /edsl/results/{ResultsFetchMixin.py → results_fetch_mixin.py} +0 -0
- /edsl/results/{ResultsToolsMixin.py → results_tools_mixin.py} +0 -0
- {edsl-0.1.38.dev4.dist-info → edsl-0.1.39.dist-info}/LICENSE +0 -0
edsl/agents/AgentList.py
CHANGED
@@ -1,38 +1,41 @@
|
|
1
|
-
"""A list of
|
2
|
-
|
3
|
-
Example usage:
|
4
|
-
|
5
|
-
.. code-block:: python
|
6
|
-
|
7
|
-
al = AgentList([Agent.example(), Agent.example()])
|
8
|
-
len(al)
|
9
|
-
2
|
10
|
-
|
1
|
+
"""A list of Agents
|
11
2
|
"""
|
12
3
|
|
13
4
|
from __future__ import annotations
|
14
5
|
import csv
|
15
|
-
import
|
6
|
+
import sys
|
16
7
|
from collections import UserList
|
8
|
+
from collections.abc import Iterable
|
9
|
+
|
17
10
|
from typing import Any, List, Optional, Union, TYPE_CHECKING
|
18
|
-
from rich import print_json
|
19
|
-
from rich.table import Table
|
20
|
-
from simpleeval import EvalWithCompoundTypes
|
21
|
-
from edsl.Base import Base
|
22
|
-
from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
|
23
11
|
|
24
|
-
from
|
12
|
+
from simpleeval import EvalWithCompoundTypes, NameNotDefined
|
25
13
|
|
14
|
+
from edsl.Base import Base
|
15
|
+
from edsl.utilities.remove_edsl_version import remove_edsl_version
|
26
16
|
from edsl.exceptions.agents import AgentListError
|
17
|
+
from edsl.utilities.is_notebook import is_notebook
|
18
|
+
from edsl.results.ResultsExportMixin import ResultsExportMixin
|
19
|
+
import logging
|
20
|
+
|
21
|
+
logger = logging.getLogger(__name__)
|
27
22
|
|
28
23
|
if TYPE_CHECKING:
|
29
24
|
from edsl.scenarios.ScenarioList import ScenarioList
|
25
|
+
from edsl.agents.Agent import Agent
|
26
|
+
from pandas import DataFrame
|
30
27
|
|
31
28
|
|
32
29
|
def is_iterable(obj):
|
33
30
|
return isinstance(obj, Iterable)
|
34
31
|
|
35
32
|
|
33
|
+
class EmptyAgentList:
|
34
|
+
def __repr__(self):
|
35
|
+
return "Empty AgentList"
|
36
|
+
|
37
|
+
|
38
|
+
# ResultsExportMixin,
|
36
39
|
class AgentList(UserList, Base):
|
37
40
|
"""A list of Agents."""
|
38
41
|
|
@@ -50,14 +53,15 @@ class AgentList(UserList, Base):
|
|
50
53
|
else:
|
51
54
|
super().__init__()
|
52
55
|
|
53
|
-
def shuffle(self, seed: Optional[str] =
|
56
|
+
def shuffle(self, seed: Optional[str] = None) -> AgentList:
|
54
57
|
"""Shuffle the AgentList.
|
55
58
|
|
56
59
|
:param seed: The seed for the random number generator.
|
57
60
|
"""
|
58
61
|
import random
|
59
62
|
|
60
|
-
|
63
|
+
if seed is not None:
|
64
|
+
random.seed(seed)
|
61
65
|
random.shuffle(self.data)
|
62
66
|
return self
|
63
67
|
|
@@ -73,22 +77,60 @@ class AgentList(UserList, Base):
|
|
73
77
|
random.seed(seed)
|
74
78
|
return AgentList(random.sample(self.data, n))
|
75
79
|
|
76
|
-
def to_pandas(self):
|
77
|
-
"""Return a pandas DataFrame.
|
80
|
+
def to_pandas(self) -> "DataFrame":
|
81
|
+
"""Return a pandas DataFrame.
|
82
|
+
|
83
|
+
>>> from edsl.agents.Agent import Agent
|
84
|
+
>>> al = AgentList([Agent(traits = {'age': 22, 'hair': 'brown', 'height': 5.5}), Agent(traits = {'age': 22, 'hair': 'brown', 'height': 5.5})])
|
85
|
+
>>> al.to_pandas()
|
86
|
+
age hair height
|
87
|
+
0 22 brown 5.5
|
88
|
+
1 22 brown 5.5
|
89
|
+
"""
|
78
90
|
return self.to_scenario_list().to_pandas()
|
79
91
|
|
80
|
-
def tally(
|
81
|
-
|
92
|
+
def tally(
|
93
|
+
self, *fields: Optional[str], top_n: Optional[int] = None, output="Dataset"
|
94
|
+
) -> Union[dict, "Dataset"]:
|
95
|
+
"""Tally the values of a field or perform a cross-tab of multiple fields.
|
96
|
+
|
97
|
+
:param fields: The field(s) to tally, multiple fields for cross-tabulation.
|
82
98
|
|
83
|
-
|
99
|
+
>>> al = AgentList.example()
|
100
|
+
>>> al.tally('age')
|
101
|
+
Dataset([{'age': [22]}, {'count': [2]}])
|
102
|
+
"""
|
103
|
+
return self.to_scenario_list().tally(*fields, top_n=top_n, output=output)
|
104
|
+
|
105
|
+
def duplicate(self):
|
106
|
+
"""Duplicate the AgentList.
|
107
|
+
|
108
|
+
>>> al = AgentList.example()
|
109
|
+
>>> al2 = al.duplicate()
|
110
|
+
>>> al2 == al
|
111
|
+
True
|
112
|
+
>>> id(al2) == id(al)
|
113
|
+
False
|
114
|
+
"""
|
115
|
+
return AgentList([a.duplicate() for a in self.data])
|
116
|
+
|
117
|
+
def rename(self, old_name, new_name) -> AgentList:
|
84
118
|
"""Rename a trait in the AgentList.
|
85
119
|
|
86
120
|
:param old_name: The old name of the trait.
|
87
121
|
:param new_name: The new name of the trait.
|
122
|
+
:param inplace: Whether to rename the trait in place.
|
123
|
+
|
124
|
+
>>> from edsl.agents.Agent import Agent
|
125
|
+
>>> al = AgentList([Agent(traits = {'a': 1, 'b': 1}), Agent(traits = {'a': 1, 'b': 2})])
|
126
|
+
>>> al2 = al.rename('a', 'c')
|
127
|
+
>>> assert al2 == AgentList([Agent(traits = {'c': 1, 'b': 1}), Agent(traits = {'c': 1, 'b': 2})])
|
128
|
+
>>> assert al != al2
|
88
129
|
"""
|
89
|
-
|
90
|
-
|
91
|
-
|
130
|
+
newagents = []
|
131
|
+
for agent in self:
|
132
|
+
newagents.append(agent.rename(old_name, new_name))
|
133
|
+
return AgentList(newagents)
|
92
134
|
|
93
135
|
def select(self, *traits) -> AgentList:
|
94
136
|
"""Selects agents with only the references traits.
|
@@ -123,19 +165,36 @@ class AgentList(UserList, Base):
|
|
123
165
|
"""
|
124
166
|
return EvalWithCompoundTypes(names=agent.traits)
|
125
167
|
|
126
|
-
try:
|
127
168
|
# iterates through all the results and evaluates the expression
|
169
|
+
|
170
|
+
try:
|
128
171
|
new_data = [
|
129
172
|
agent for agent in self.data if create_evaluator(agent).eval(expression)
|
130
173
|
]
|
131
|
-
except
|
132
|
-
|
133
|
-
|
174
|
+
except NameNotDefined as e:
|
175
|
+
e = AgentListError(f"'{expression}' is not a valid expression.")
|
176
|
+
if is_notebook():
|
177
|
+
print(e, file=sys.stderr)
|
178
|
+
else:
|
179
|
+
raise e
|
180
|
+
|
181
|
+
return EmptyAgentList()
|
182
|
+
|
183
|
+
if len(new_data) == 0:
|
184
|
+
return EmptyAgentList()
|
134
185
|
|
135
186
|
return AgentList(new_data)
|
136
187
|
|
137
188
|
@property
|
138
|
-
def all_traits(self):
|
189
|
+
def all_traits(self) -> list[str]:
|
190
|
+
"""Return all traits in the AgentList.
|
191
|
+
>>> from edsl.agents.Agent import Agent
|
192
|
+
>>> agent_1 = Agent(traits = {'age': 22})
|
193
|
+
>>> agent_2 = Agent(traits = {'hair': 'brown'})
|
194
|
+
>>> al = AgentList([agent_1, agent_2])
|
195
|
+
>>> al.all_traits
|
196
|
+
['age', 'hair']
|
197
|
+
"""
|
139
198
|
d = {}
|
140
199
|
for agent in self:
|
141
200
|
d.update(agent.traits)
|
@@ -180,14 +239,20 @@ class AgentList(UserList, Base):
|
|
180
239
|
agent_list.append(Agent(traits=row))
|
181
240
|
return cls(agent_list)
|
182
241
|
|
183
|
-
def translate_traits(self,
|
242
|
+
def translate_traits(self, codebook: dict[str, str]):
|
184
243
|
"""Translate traits to a new codebook.
|
185
244
|
|
186
245
|
:param codebook: The new codebook.
|
246
|
+
|
247
|
+
>>> al = AgentList.example()
|
248
|
+
>>> codebook = {'hair': {'brown':'Secret word for green'}}
|
249
|
+
>>> al.translate_traits(codebook)
|
250
|
+
AgentList([Agent(traits = {'age': 22, 'hair': 'Secret word for green', 'height': 5.5}), Agent(traits = {'age': 22, 'hair': 'Secret word for green', 'height': 5.5})])
|
187
251
|
"""
|
252
|
+
new_agents = []
|
188
253
|
for agent in self.data:
|
189
|
-
agent.translate_traits(codebook)
|
190
|
-
return
|
254
|
+
new_agents.append(agent.translate_traits(codebook))
|
255
|
+
return AgentList(new_agents)
|
191
256
|
|
192
257
|
def remove_trait(self, trait: str):
|
193
258
|
"""Remove traits from the AgentList.
|
@@ -198,20 +263,21 @@ class AgentList(UserList, Base):
|
|
198
263
|
>>> al.remove_trait('age')
|
199
264
|
AgentList([Agent(traits = {'hair': 'brown', 'height': 5.5}), Agent(traits = {'hair': 'brown', 'height': 5.5})])
|
200
265
|
"""
|
201
|
-
|
202
|
-
|
203
|
-
|
266
|
+
agents = []
|
267
|
+
new_al = self.duplicate()
|
268
|
+
for agent in new_al.data:
|
269
|
+
agents.append(agent.remove_trait(trait))
|
270
|
+
return AgentList(agents)
|
204
271
|
|
205
|
-
def add_trait(self, trait, values):
|
272
|
+
def add_trait(self, trait: str, values: List[Any]) -> AgentList:
|
206
273
|
"""Adds a new trait to every agent, with values taken from values.
|
207
274
|
|
208
275
|
:param trait: The name of the trait.
|
209
276
|
:param values: The valeues(s) of the trait. If a single value is passed, it is used for all agents.
|
210
277
|
|
211
278
|
>>> al = AgentList.example()
|
212
|
-
>>> al.add_trait('new_trait', 1)
|
213
|
-
|
214
|
-
>>> al.select('new_trait').to_scenario_list().to_list()
|
279
|
+
>>> new_al = al.add_trait('new_trait', 1)
|
280
|
+
>>> new_al.select('new_trait').to_scenario_list().to_list()
|
215
281
|
[1, 1]
|
216
282
|
>>> al.add_trait('new_trait', [1, 2, 3])
|
217
283
|
Traceback (most recent call last):
|
@@ -220,18 +286,24 @@ class AgentList(UserList, Base):
|
|
220
286
|
...
|
221
287
|
"""
|
222
288
|
if not is_iterable(values):
|
289
|
+
new_agents = []
|
223
290
|
value = values
|
224
291
|
for agent in self.data:
|
225
|
-
agent.add_trait(trait, value)
|
226
|
-
return
|
292
|
+
new_agents.append(agent.add_trait(trait, value))
|
293
|
+
return AgentList(new_agents)
|
227
294
|
|
228
295
|
if len(values) != len(self):
|
229
|
-
|
296
|
+
e = AgentListError(
|
230
297
|
"The passed values have to be the same length as the agent list."
|
231
298
|
)
|
299
|
+
if is_notebook():
|
300
|
+
print(e, file=sys.stderr)
|
301
|
+
else:
|
302
|
+
raise e
|
303
|
+
new_agents = []
|
232
304
|
for agent, value in zip(self.data, values):
|
233
|
-
agent.add_trait(trait, value)
|
234
|
-
return
|
305
|
+
new_agents.append(agent.add_trait(trait, value))
|
306
|
+
return AgentList(new_agents)
|
235
307
|
|
236
308
|
@staticmethod
|
237
309
|
def get_codebook(file_path: str):
|
@@ -244,12 +316,23 @@ class AgentList(UserList, Base):
|
|
244
316
|
return {field: None for field in reader.fieldnames}
|
245
317
|
|
246
318
|
def __hash__(self) -> int:
|
319
|
+
"""Return the hash of the AgentList.
|
320
|
+
|
321
|
+
>>> al = AgentList.example()
|
322
|
+
>>> hash(al)
|
323
|
+
1681154913465662422
|
324
|
+
"""
|
247
325
|
from edsl.utilities.utilities import dict_hash
|
248
326
|
|
249
327
|
return dict_hash(self.to_dict(add_edsl_version=False, sorted=True))
|
250
328
|
|
251
329
|
def to_dict(self, sorted=False, add_edsl_version=True):
|
252
|
-
"""Serialize the AgentList to a dictionary.
|
330
|
+
"""Serialize the AgentList to a dictionary.
|
331
|
+
|
332
|
+
>>> AgentList.example().to_dict(add_edsl_version=False)
|
333
|
+
{'agent_list': [{'traits': {'age': 22, 'hair': 'brown', 'height': 5.5}}, {'traits': {'age': 22, 'hair': 'brown', 'height': 5.5}}]}
|
334
|
+
|
335
|
+
"""
|
253
336
|
if sorted:
|
254
337
|
data = self.data[:]
|
255
338
|
data.sort(key=lambda x: hash(x))
|
@@ -279,15 +362,26 @@ class AgentList(UserList, Base):
|
|
279
362
|
|
280
363
|
def _summary(self):
|
281
364
|
return {
|
282
|
-
"
|
283
|
-
"Number of agents": len(self),
|
284
|
-
"Agent trait fields": self.all_traits,
|
365
|
+
"agents": len(self),
|
285
366
|
}
|
286
367
|
|
287
|
-
def
|
288
|
-
"""
|
289
|
-
|
290
|
-
|
368
|
+
def set_codebook(self, codebook: dict[str, str]) -> AgentList:
|
369
|
+
"""Set the codebook for the AgentList.
|
370
|
+
|
371
|
+
>>> from edsl.agents.Agent import Agent
|
372
|
+
>>> a = Agent(traits = {'hair': 'brown'})
|
373
|
+
>>> al = AgentList([a, a])
|
374
|
+
>>> _ = al.set_codebook({'hair': "Color of hair on driver's license"})
|
375
|
+
>>> al[0].codebook
|
376
|
+
{'hair': "Color of hair on driver's license"}
|
377
|
+
|
378
|
+
|
379
|
+
:param codebook: The codebook.
|
380
|
+
"""
|
381
|
+
for agent in self.data:
|
382
|
+
agent.codebook = codebook
|
383
|
+
|
384
|
+
return self
|
291
385
|
|
292
386
|
def to_csv(self, file_path: str):
|
293
387
|
"""Save the AgentList to a CSV file.
|
@@ -300,19 +394,33 @@ class AgentList(UserList, Base):
|
|
300
394
|
"""Return a list of tuples."""
|
301
395
|
return self.to_scenario_list(include_agent_name).to_list()
|
302
396
|
|
303
|
-
def to_scenario_list(
|
304
|
-
|
397
|
+
def to_scenario_list(
|
398
|
+
self, include_agent_name: bool = False, include_instruction: bool = False
|
399
|
+
) -> ScenarioList:
|
400
|
+
"""Converts the agent to a scenario list."""
|
305
401
|
from edsl.scenarios.ScenarioList import ScenarioList
|
306
402
|
from edsl.scenarios.Scenario import Scenario
|
307
403
|
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
404
|
+
# raise NotImplementedError("This method is not implemented yet.")
|
405
|
+
|
406
|
+
scenario_list = ScenarioList()
|
407
|
+
for agent in self.data:
|
408
|
+
d = agent.traits
|
409
|
+
if include_agent_name:
|
410
|
+
d["agent_name"] = agent.name
|
411
|
+
if include_instruction:
|
412
|
+
d["instruction"] = agent.instruction
|
413
|
+
scenario_list.append(Scenario(d))
|
414
|
+
return scenario_list
|
415
|
+
|
416
|
+
# if include_agent_name:
|
417
|
+
# return ScenarioList(
|
418
|
+
# [
|
419
|
+
# Scenario(agent.traits | {"agent_name": agent.name} | })
|
420
|
+
# for agent in self.data
|
421
|
+
# ]
|
422
|
+
# )
|
423
|
+
# return ScenarioList([Scenario(agent.traits) for agent in self.data])
|
316
424
|
|
317
425
|
def table(
|
318
426
|
self,
|
@@ -320,12 +428,50 @@ class AgentList(UserList, Base):
|
|
320
428
|
tablefmt: Optional[str] = None,
|
321
429
|
pretty_labels: Optional[dict] = None,
|
322
430
|
) -> Table:
|
431
|
+
if len(self) == 0:
|
432
|
+
e = AgentListError("Cannot create a table from an empty AgentList.")
|
433
|
+
if is_notebook():
|
434
|
+
print(e, file=sys.stderr)
|
435
|
+
return None
|
436
|
+
else:
|
437
|
+
raise e
|
323
438
|
return (
|
324
439
|
self.to_scenario_list()
|
325
440
|
.to_dataset()
|
326
441
|
.table(*fields, tablefmt=tablefmt, pretty_labels=pretty_labels)
|
327
442
|
)
|
328
443
|
|
444
|
+
def to_dataset(self, traits_only: bool = True):
|
445
|
+
"""
|
446
|
+
Convert the AgentList to a Dataset.
|
447
|
+
|
448
|
+
>>> from edsl.agents.AgentList import AgentList
|
449
|
+
>>> al = AgentList.example()
|
450
|
+
>>> al.to_dataset()
|
451
|
+
Dataset([{'age': [22, 22]}, {'hair': ['brown', 'brown']}, {'height': [5.5, 5.5]}])
|
452
|
+
>>> al.to_dataset(traits_only = False)
|
453
|
+
Dataset([{'age': [22, 22]}, {'hair': ['brown', 'brown']}, {'height': [5.5, 5.5]}, {'agent_parameters': [{'instruction': 'You are answering questions as if you were a human. Do not break character.', 'agent_name': None}, {'instruction': 'You are answering questions as if you were a human. Do not break character.', 'agent_name': None}]}])
|
454
|
+
"""
|
455
|
+
from edsl.results.Dataset import Dataset
|
456
|
+
from collections import defaultdict
|
457
|
+
|
458
|
+
agent_trait_keys = []
|
459
|
+
for agent in self:
|
460
|
+
agent_keys = list(agent.traits.keys())
|
461
|
+
for key in agent_keys:
|
462
|
+
if key not in agent_trait_keys:
|
463
|
+
agent_trait_keys.append(key)
|
464
|
+
|
465
|
+
data = defaultdict(list)
|
466
|
+
for agent in self:
|
467
|
+
for trait_key in agent_trait_keys:
|
468
|
+
data[trait_key].append(agent.traits.get(trait_key, None))
|
469
|
+
if not traits_only:
|
470
|
+
data["agent_parameters"].append(
|
471
|
+
{"instruction": agent.instruction, "agent_name": agent.name}
|
472
|
+
)
|
473
|
+
return Dataset([{key: entry} for key, entry in data.items()])
|
474
|
+
|
329
475
|
def tree(self, node_order: Optional[List[str]] = None):
|
330
476
|
return self.to_scenario_list().tree(node_order)
|
331
477
|
|
@@ -398,14 +544,6 @@ class AgentList(UserList, Base):
|
|
398
544
|
return "\n".join(lines)
|
399
545
|
return lines
|
400
546
|
|
401
|
-
def rich_print(self) -> Table:
|
402
|
-
"""Display an object as a rich table."""
|
403
|
-
table = Table(title="AgentList")
|
404
|
-
table.add_column("Agents", style="bold")
|
405
|
-
for agent in self.data:
|
406
|
-
table.add_row(agent.rich_print())
|
407
|
-
return table
|
408
|
-
|
409
547
|
|
410
548
|
if __name__ == "__main__":
|
411
549
|
import doctest
|
edsl/agents/Invigilator.py
CHANGED
@@ -1,38 +1,29 @@
|
|
1
1
|
"""Module for creating Invigilators, which are objects to administer a question to an Agent."""
|
2
2
|
|
3
|
-
from typing import Dict, Any, Optional
|
3
|
+
from typing import Dict, Any, Optional, TYPE_CHECKING
|
4
4
|
|
5
|
-
from edsl.
|
6
|
-
from edsl.utilities.decorators import sync_wrapper, jupyter_nb_handler
|
7
|
-
|
8
|
-
# from edsl.prompts.registry import get_classes as prompt_lookup
|
5
|
+
from edsl.utilities.decorators import sync_wrapper
|
9
6
|
from edsl.exceptions.questions import QuestionAnswerValidationError
|
10
7
|
from edsl.agents.InvigilatorBase import InvigilatorBase
|
11
8
|
from edsl.data_transfer_models import AgentResponseDict, EDSLResultObjectInput
|
12
|
-
|
9
|
+
|
10
|
+
if TYPE_CHECKING:
|
11
|
+
from edsl.prompts.Prompt import Prompt
|
12
|
+
from edsl.scenarios.Scenario import Scenario
|
13
|
+
from edsl.surveys.Survey import Survey
|
13
14
|
|
14
15
|
|
15
|
-
|
16
|
-
def __new__(cls):
|
17
|
-
instance = super().__new__(cls, "Not Applicable")
|
18
|
-
instance.literal = "Not Applicable"
|
19
|
-
return instance
|
16
|
+
NA = "Not Applicable"
|
20
17
|
|
21
18
|
|
22
19
|
class InvigilatorAI(InvigilatorBase):
|
23
20
|
"""An invigilator that uses an AI model to answer questions."""
|
24
21
|
|
25
|
-
def get_prompts(self) -> Dict[str, Prompt]:
|
22
|
+
def get_prompts(self) -> Dict[str, "Prompt"]:
|
26
23
|
"""Return the prompts used."""
|
27
24
|
return self.prompt_constructor.get_prompts()
|
28
25
|
|
29
|
-
async def
|
30
|
-
"""Answer a question using the AI model.
|
31
|
-
|
32
|
-
>>> i = InvigilatorAI.example()
|
33
|
-
>>> i.answer_question()
|
34
|
-
{'message': [{'text': 'SPAM!'}], 'usage': {'prompt_tokens': 1, 'completion_tokens': 1}}
|
35
|
-
"""
|
26
|
+
async def async_get_agent_response(self) -> AgentResponseDict:
|
36
27
|
prompts = self.get_prompts()
|
37
28
|
params = {
|
38
29
|
"user_prompt": prompts["user_prompt"].text,
|
@@ -40,33 +31,95 @@ class InvigilatorAI(InvigilatorBase):
|
|
40
31
|
}
|
41
32
|
if "encoded_image" in prompts:
|
42
33
|
params["encoded_image"] = prompts["encoded_image"]
|
34
|
+
raise NotImplementedError("encoded_image not implemented")
|
35
|
+
|
43
36
|
if "files_list" in prompts:
|
44
37
|
params["files_list"] = prompts["files_list"]
|
45
38
|
|
46
39
|
params.update({"iteration": self.iteration, "cache": self.cache})
|
47
|
-
|
48
40
|
params.update({"invigilator": self})
|
49
|
-
# if hasattr(self.question, "answer_template"):
|
50
|
-
# breakpoint()
|
51
41
|
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
42
|
+
if self.key_lookup:
|
43
|
+
self.model.set_key_lookup(self.key_lookup)
|
44
|
+
|
45
|
+
return await self.model.async_get_response(**params)
|
46
|
+
|
47
|
+
def store_response(self, agent_response_dict: AgentResponseDict) -> None:
|
48
|
+
"""Store the response in the invigilator, in case it is needed later because of validation failure."""
|
56
49
|
self.raw_model_response = agent_response_dict.model_outputs.response
|
57
50
|
self.generated_tokens = agent_response_dict.edsl_dict.generated_tokens
|
58
51
|
|
59
|
-
|
52
|
+
async def async_answer_question(self) -> AgentResponseDict:
|
53
|
+
"""Answer a question using the AI model.
|
54
|
+
|
55
|
+
>>> i = InvigilatorAI.example()
|
56
|
+
"""
|
57
|
+
agent_response_dict = await self.async_get_agent_response()
|
58
|
+
self.store_response(agent_response_dict)
|
59
|
+
return self._extract_edsl_result_entry_and_validate(agent_response_dict)
|
60
60
|
|
61
61
|
def _remove_from_cache(self, cache_key) -> None:
|
62
62
|
"""Remove an entry from the cache."""
|
63
63
|
if cache_key:
|
64
64
|
del self.cache.data[cache_key]
|
65
65
|
|
66
|
-
def
|
67
|
-
|
66
|
+
def _determine_answer(self, raw_answer: str) -> Any:
|
67
|
+
"""Determine the answer from the raw answer.
|
68
|
+
|
69
|
+
>>> i = InvigilatorAI.example()
|
70
|
+
>>> i._determine_answer("SPAM!")
|
71
|
+
'SPAM!'
|
72
|
+
|
73
|
+
>>> from edsl.questions import QuestionMultipleChoice
|
74
|
+
>>> q = QuestionMultipleChoice(question_text = "How are you?", question_name = "how_are_you", question_options = ["Good", "Bad"], use_code = True)
|
75
|
+
>>> i = InvigilatorAI.example(question = q)
|
76
|
+
>>> i._determine_answer("1")
|
77
|
+
'Bad'
|
78
|
+
>>> i._determine_answer("0")
|
79
|
+
'Good'
|
80
|
+
|
81
|
+
This shows how the answer can depend on scenario details
|
82
|
+
|
83
|
+
>>> from edsl import Scenario
|
84
|
+
>>> s = Scenario({'feeling_options':['Good', 'Bad']})
|
85
|
+
>>> q = QuestionMultipleChoice(question_text = "How are you?", question_name = "how_are_you", question_options = "{{ feeling_options }}", use_code = True)
|
86
|
+
>>> i = InvigilatorAI.example(question = q, scenario = s)
|
87
|
+
>>> i._determine_answer("1")
|
88
|
+
'Bad'
|
89
|
+
|
90
|
+
>>> from edsl import QuestionList, QuestionMultipleChoice, Survey
|
91
|
+
>>> q1 = QuestionList(question_name = "favs", question_text = "What are your top 3 colors?")
|
92
|
+
>>> q2 = QuestionMultipleChoice(question_text = "What is your favorite color?", question_name = "best", question_options = "{{ favs.answer }}", use_code = True)
|
93
|
+
>>> survey = Survey([q1, q2])
|
94
|
+
>>> i = InvigilatorAI.example(question = q2, scenario = s, survey = survey)
|
95
|
+
>>> i.current_answers = {"favs": ["Green", "Blue", "Red"]}
|
96
|
+
>>> i._determine_answer("2")
|
97
|
+
'Red'
|
98
|
+
"""
|
99
|
+
substitution_dict = self._prepare_substitution_dict(
|
100
|
+
self.survey, self.current_answers, self.scenario
|
101
|
+
)
|
102
|
+
return self.question._translate_answer_code_to_answer(
|
103
|
+
raw_answer, substitution_dict
|
104
|
+
)
|
105
|
+
|
106
|
+
@staticmethod
|
107
|
+
def _prepare_substitution_dict(
|
108
|
+
survey: "Survey", current_answers: dict, scenario: "Scenario"
|
109
|
+
) -> Dict[str, Any]:
|
110
|
+
"""Prepares a substitution dictionary for the question based on the survey, current answers, and scenario.
|
111
|
+
|
112
|
+
This is necessary beause sometimes the model's answer to a question could depend on details in
|
113
|
+
the prompt that were provided by the answer to a previous question or a scenario detail.
|
114
|
+
|
115
|
+
Note that the question object is getting the answer & a the comment appended to it, as the
|
116
|
+
jinja2 template might be referencing these values with a dot notation.
|
117
|
+
|
118
|
+
"""
|
119
|
+
question_dict = survey.duplicate().question_names_to_questions()
|
120
|
+
|
68
121
|
# iterates through the current answers and updates the question_dict (which is all questions)
|
69
|
-
for other_question, answer in
|
122
|
+
for other_question, answer in current_answers.items():
|
70
123
|
if other_question in question_dict:
|
71
124
|
question_dict[other_question].answer = answer
|
72
125
|
else:
|
@@ -76,13 +129,12 @@ class InvigilatorAI(InvigilatorBase):
|
|
76
129
|
) in question_dict:
|
77
130
|
question_dict[new_question].comment = answer
|
78
131
|
|
79
|
-
|
80
|
-
# sometimes the answer is a code, so we need to translate it
|
81
|
-
return self.question._translate_answer_code_to_answer(raw_answer, combined_dict)
|
132
|
+
return {**question_dict, **scenario}
|
82
133
|
|
83
|
-
def
|
134
|
+
def _extract_edsl_result_entry_and_validate(
|
84
135
|
self, agent_response_dict: AgentResponseDict
|
85
136
|
) -> EDSLResultObjectInput:
|
137
|
+
"""Extract the EDSL result entry and validate it."""
|
86
138
|
edsl_dict = agent_response_dict.edsl_dict._asdict()
|
87
139
|
exception_occurred = None
|
88
140
|
validated = False
|
@@ -94,10 +146,8 @@ class InvigilatorAI(InvigilatorBase):
|
|
94
146
|
# question options have be treated differently because of dynamic question
|
95
147
|
# this logic is all in the prompt constructor
|
96
148
|
if "question_options" in self.question.data:
|
97
|
-
new_question_options = (
|
98
|
-
self.
|
99
|
-
self.question.data
|
100
|
-
)
|
149
|
+
new_question_options = self.prompt_constructor.get_question_options(
|
150
|
+
self.question.data
|
101
151
|
)
|
102
152
|
if new_question_options != self.question.data["question_options"]:
|
103
153
|
# I don't love this direct writing but it seems to work
|
@@ -110,9 +160,8 @@ class InvigilatorAI(InvigilatorBase):
|
|
110
160
|
else:
|
111
161
|
question_with_validators = self.question
|
112
162
|
|
113
|
-
# breakpoint()
|
114
163
|
validated_edsl_dict = question_with_validators._validate_answer(edsl_dict)
|
115
|
-
answer = self.
|
164
|
+
answer = self._determine_answer(validated_edsl_dict["answer"])
|
116
165
|
comment = validated_edsl_dict.get("comment", "")
|
117
166
|
validated = True
|
118
167
|
except QuestionAnswerValidationError as e:
|
@@ -182,13 +231,13 @@ class InvigilatorHuman(InvigilatorBase):
|
|
182
231
|
exception_occurred = e
|
183
232
|
finally:
|
184
233
|
data = {
|
185
|
-
"generated_tokens": NotApplicable(),
|
234
|
+
"generated_tokens": NA, # NotApplicable(),
|
186
235
|
"question_name": self.question.question_name,
|
187
236
|
"prompts": self.get_prompts(),
|
188
|
-
"cached_response":
|
189
|
-
"raw_model_response":
|
190
|
-
"cache_used":
|
191
|
-
"cache_key":
|
237
|
+
"cached_response": NA,
|
238
|
+
"raw_model_response": NA,
|
239
|
+
"cache_used": NA,
|
240
|
+
"cache_key": NA,
|
192
241
|
"answer": answer,
|
193
242
|
"comment": comment,
|
194
243
|
"validated": validated,
|
@@ -209,17 +258,19 @@ class InvigilatorFunctional(InvigilatorBase):
|
|
209
258
|
generated_tokens=str(answer),
|
210
259
|
question_name=self.question.question_name,
|
211
260
|
prompts=self.get_prompts(),
|
212
|
-
cached_response=
|
213
|
-
raw_model_response=
|
214
|
-
cache_used=
|
215
|
-
cache_key=
|
261
|
+
cached_response=NA,
|
262
|
+
raw_model_response=NA,
|
263
|
+
cache_used=NA,
|
264
|
+
cache_key=NA,
|
216
265
|
answer=answer["answer"],
|
217
266
|
comment="This is the result of a functional question.",
|
218
267
|
validated=True,
|
219
268
|
exception_occurred=None,
|
220
269
|
)
|
221
270
|
|
222
|
-
def get_prompts(self) -> Dict[str, Prompt]:
|
271
|
+
def get_prompts(self) -> Dict[str, "Prompt"]:
|
272
|
+
from edsl.prompts.Prompt import Prompt
|
273
|
+
|
223
274
|
"""Return the prompts used."""
|
224
275
|
return {
|
225
276
|
"user_prompt": Prompt("NA"),
|