edsl 0.1.36.dev5__py3-none-any.whl → 0.1.36.dev6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +303 -303
- edsl/BaseDiff.py +260 -260
- edsl/TemplateLoader.py +24 -24
- edsl/__init__.py +47 -47
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +804 -804
- edsl/agents/AgentList.py +337 -337
- edsl/agents/Invigilator.py +222 -222
- edsl/agents/InvigilatorBase.py +294 -294
- edsl/agents/PromptConstructor.py +312 -312
- edsl/agents/__init__.py +3 -3
- edsl/agents/descriptors.py +86 -86
- edsl/agents/prompt_helpers.py +129 -129
- edsl/auto/AutoStudy.py +117 -117
- edsl/auto/StageBase.py +230 -230
- edsl/auto/StageGenerateSurvey.py +178 -178
- edsl/auto/StageLabelQuestions.py +125 -125
- edsl/auto/StagePersona.py +61 -61
- edsl/auto/StagePersonaDimensionValueRanges.py +88 -88
- edsl/auto/StagePersonaDimensionValues.py +74 -74
- edsl/auto/StagePersonaDimensions.py +69 -69
- edsl/auto/StageQuestions.py +73 -73
- edsl/auto/SurveyCreatorPipeline.py +21 -21
- edsl/auto/utilities.py +224 -224
- edsl/base/Base.py +289 -289
- edsl/config.py +149 -149
- edsl/conjure/AgentConstructionMixin.py +152 -152
- edsl/conjure/Conjure.py +62 -62
- edsl/conjure/InputData.py +659 -659
- edsl/conjure/InputDataCSV.py +48 -48
- edsl/conjure/InputDataMixinQuestionStats.py +182 -182
- edsl/conjure/InputDataPyRead.py +91 -91
- edsl/conjure/InputDataSPSS.py +8 -8
- edsl/conjure/InputDataStata.py +8 -8
- edsl/conjure/QuestionOptionMixin.py +76 -76
- edsl/conjure/QuestionTypeMixin.py +23 -23
- edsl/conjure/RawQuestion.py +65 -65
- edsl/conjure/SurveyResponses.py +7 -7
- edsl/conjure/__init__.py +9 -9
- edsl/conjure/naming_utilities.py +263 -263
- edsl/conjure/utilities.py +201 -201
- edsl/conversation/Conversation.py +238 -238
- edsl/conversation/car_buying.py +58 -58
- edsl/conversation/mug_negotiation.py +81 -81
- edsl/conversation/next_speaker_utilities.py +93 -93
- edsl/coop/PriceFetcher.py +54 -54
- edsl/coop/__init__.py +2 -2
- edsl/coop/coop.py +849 -849
- edsl/coop/utils.py +131 -131
- edsl/data/Cache.py +527 -527
- edsl/data/CacheEntry.py +228 -228
- edsl/data/CacheHandler.py +149 -149
- edsl/data/RemoteCacheSync.py +83 -83
- edsl/data/SQLiteDict.py +292 -292
- edsl/data/__init__.py +4 -4
- edsl/data/orm.py +10 -10
- edsl/data_transfer_models.py +73 -73
- edsl/enums.py +173 -173
- edsl/exceptions/__init__.py +50 -50
- edsl/exceptions/agents.py +40 -40
- edsl/exceptions/configuration.py +16 -16
- edsl/exceptions/coop.py +10 -10
- edsl/exceptions/data.py +14 -14
- edsl/exceptions/general.py +34 -34
- edsl/exceptions/jobs.py +33 -33
- edsl/exceptions/language_models.py +63 -63
- edsl/exceptions/prompts.py +15 -15
- edsl/exceptions/questions.py +91 -91
- edsl/exceptions/results.py +26 -26
- edsl/exceptions/surveys.py +34 -34
- edsl/inference_services/AnthropicService.py +87 -87
- edsl/inference_services/AwsBedrock.py +115 -115
- edsl/inference_services/AzureAI.py +217 -217
- edsl/inference_services/DeepInfraService.py +18 -18
- edsl/inference_services/GoogleService.py +156 -156
- edsl/inference_services/GroqService.py +20 -20
- edsl/inference_services/InferenceServiceABC.py +147 -147
- edsl/inference_services/InferenceServicesCollection.py +72 -68
- edsl/inference_services/MistralAIService.py +123 -123
- edsl/inference_services/OllamaService.py +18 -18
- edsl/inference_services/OpenAIService.py +224 -224
- edsl/inference_services/TestService.py +89 -89
- edsl/inference_services/TogetherAIService.py +170 -170
- edsl/inference_services/models_available_cache.py +118 -94
- edsl/inference_services/rate_limits_cache.py +25 -25
- edsl/inference_services/registry.py +39 -39
- edsl/inference_services/write_available.py +10 -10
- edsl/jobs/Answers.py +56 -56
- edsl/jobs/Jobs.py +1112 -1112
- edsl/jobs/__init__.py +1 -1
- edsl/jobs/buckets/BucketCollection.py +63 -63
- edsl/jobs/buckets/ModelBuckets.py +65 -65
- edsl/jobs/buckets/TokenBucket.py +248 -248
- edsl/jobs/interviews/Interview.py +651 -651
- edsl/jobs/interviews/InterviewExceptionCollection.py +99 -99
- edsl/jobs/interviews/InterviewExceptionEntry.py +182 -182
- edsl/jobs/interviews/InterviewStatistic.py +63 -63
- edsl/jobs/interviews/InterviewStatisticsCollection.py +25 -25
- edsl/jobs/interviews/InterviewStatusDictionary.py +78 -78
- edsl/jobs/interviews/InterviewStatusLog.py +92 -92
- edsl/jobs/interviews/ReportErrors.py +66 -66
- edsl/jobs/interviews/interview_status_enum.py +9 -9
- edsl/jobs/runners/JobsRunnerAsyncio.py +337 -337
- edsl/jobs/runners/JobsRunnerStatus.py +332 -332
- edsl/jobs/tasks/QuestionTaskCreator.py +242 -242
- edsl/jobs/tasks/TaskCreators.py +64 -64
- edsl/jobs/tasks/TaskHistory.py +441 -441
- edsl/jobs/tasks/TaskStatusLog.py +23 -23
- edsl/jobs/tasks/task_status_enum.py +163 -163
- edsl/jobs/tokens/InterviewTokenUsage.py +27 -27
- edsl/jobs/tokens/TokenUsage.py +34 -34
- edsl/language_models/LanguageModel.py +718 -718
- edsl/language_models/ModelList.py +102 -102
- edsl/language_models/RegisterLanguageModelsMeta.py +184 -184
- edsl/language_models/__init__.py +2 -2
- edsl/language_models/fake_openai_call.py +15 -15
- edsl/language_models/fake_openai_service.py +61 -61
- edsl/language_models/registry.py +137 -137
- edsl/language_models/repair.py +156 -156
- edsl/language_models/unused/ReplicateBase.py +83 -83
- edsl/language_models/utilities.py +64 -64
- edsl/notebooks/Notebook.py +259 -259
- edsl/notebooks/__init__.py +1 -1
- edsl/prompts/Prompt.py +358 -358
- edsl/prompts/__init__.py +2 -2
- edsl/questions/AnswerValidatorMixin.py +289 -289
- edsl/questions/QuestionBase.py +616 -616
- edsl/questions/QuestionBaseGenMixin.py +161 -161
- edsl/questions/QuestionBasePromptsMixin.py +266 -266
- edsl/questions/QuestionBudget.py +227 -227
- edsl/questions/QuestionCheckBox.py +359 -359
- edsl/questions/QuestionExtract.py +183 -183
- edsl/questions/QuestionFreeText.py +113 -113
- edsl/questions/QuestionFunctional.py +159 -159
- edsl/questions/QuestionList.py +231 -231
- edsl/questions/QuestionMultipleChoice.py +286 -286
- edsl/questions/QuestionNumerical.py +153 -153
- edsl/questions/QuestionRank.py +324 -324
- edsl/questions/Quick.py +41 -41
- edsl/questions/RegisterQuestionsMeta.py +71 -71
- edsl/questions/ResponseValidatorABC.py +174 -174
- edsl/questions/SimpleAskMixin.py +73 -73
- edsl/questions/__init__.py +26 -26
- edsl/questions/compose_questions.py +98 -98
- edsl/questions/decorators.py +21 -21
- edsl/questions/derived/QuestionLikertFive.py +76 -76
- edsl/questions/derived/QuestionLinearScale.py +87 -87
- edsl/questions/derived/QuestionTopK.py +91 -91
- edsl/questions/derived/QuestionYesNo.py +82 -82
- edsl/questions/descriptors.py +418 -418
- edsl/questions/prompt_templates/question_budget.jinja +13 -13
- edsl/questions/prompt_templates/question_checkbox.jinja +32 -32
- edsl/questions/prompt_templates/question_extract.jinja +11 -11
- edsl/questions/prompt_templates/question_free_text.jinja +3 -3
- edsl/questions/prompt_templates/question_linear_scale.jinja +11 -11
- edsl/questions/prompt_templates/question_list.jinja +17 -17
- edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -33
- edsl/questions/prompt_templates/question_numerical.jinja +36 -36
- edsl/questions/question_registry.py +147 -147
- edsl/questions/settings.py +12 -12
- edsl/questions/templates/budget/answering_instructions.jinja +7 -7
- edsl/questions/templates/budget/question_presentation.jinja +7 -7
- edsl/questions/templates/checkbox/answering_instructions.jinja +10 -10
- edsl/questions/templates/checkbox/question_presentation.jinja +22 -22
- edsl/questions/templates/extract/answering_instructions.jinja +7 -7
- edsl/questions/templates/likert_five/answering_instructions.jinja +10 -10
- edsl/questions/templates/likert_five/question_presentation.jinja +11 -11
- edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -5
- edsl/questions/templates/linear_scale/question_presentation.jinja +5 -5
- edsl/questions/templates/list/answering_instructions.jinja +3 -3
- edsl/questions/templates/list/question_presentation.jinja +5 -5
- edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -9
- edsl/questions/templates/multiple_choice/question_presentation.jinja +11 -11
- edsl/questions/templates/numerical/answering_instructions.jinja +6 -6
- edsl/questions/templates/numerical/question_presentation.jinja +6 -6
- edsl/questions/templates/rank/answering_instructions.jinja +11 -11
- edsl/questions/templates/rank/question_presentation.jinja +15 -15
- edsl/questions/templates/top_k/answering_instructions.jinja +8 -8
- edsl/questions/templates/top_k/question_presentation.jinja +22 -22
- edsl/questions/templates/yes_no/answering_instructions.jinja +6 -6
- edsl/questions/templates/yes_no/question_presentation.jinja +11 -11
- edsl/results/Dataset.py +293 -293
- edsl/results/DatasetExportMixin.py +693 -693
- edsl/results/DatasetTree.py +145 -145
- edsl/results/Result.py +433 -433
- edsl/results/Results.py +1158 -1158
- edsl/results/ResultsDBMixin.py +238 -238
- edsl/results/ResultsExportMixin.py +43 -43
- edsl/results/ResultsFetchMixin.py +33 -33
- edsl/results/ResultsGGMixin.py +121 -121
- edsl/results/ResultsToolsMixin.py +98 -98
- edsl/results/Selector.py +118 -118
- edsl/results/__init__.py +2 -2
- edsl/results/tree_explore.py +115 -115
- edsl/scenarios/FileStore.py +443 -443
- edsl/scenarios/Scenario.py +507 -507
- edsl/scenarios/ScenarioHtmlMixin.py +59 -59
- edsl/scenarios/ScenarioList.py +1101 -1101
- edsl/scenarios/ScenarioListExportMixin.py +52 -52
- edsl/scenarios/ScenarioListPdfMixin.py +261 -261
- edsl/scenarios/__init__.py +2 -2
- edsl/shared.py +1 -1
- edsl/study/ObjectEntry.py +173 -173
- edsl/study/ProofOfWork.py +113 -113
- edsl/study/SnapShot.py +80 -80
- edsl/study/Study.py +528 -528
- edsl/study/__init__.py +4 -4
- edsl/surveys/DAG.py +148 -148
- edsl/surveys/Memory.py +31 -31
- edsl/surveys/MemoryPlan.py +244 -244
- edsl/surveys/Rule.py +324 -324
- edsl/surveys/RuleCollection.py +387 -387
- edsl/surveys/Survey.py +1772 -1772
- edsl/surveys/SurveyCSS.py +261 -261
- edsl/surveys/SurveyExportMixin.py +259 -259
- edsl/surveys/SurveyFlowVisualizationMixin.py +121 -121
- edsl/surveys/SurveyQualtricsImport.py +284 -284
- edsl/surveys/__init__.py +3 -3
- edsl/surveys/base.py +53 -53
- edsl/surveys/descriptors.py +56 -56
- edsl/surveys/instructions/ChangeInstruction.py +47 -47
- edsl/surveys/instructions/Instruction.py +51 -51
- edsl/surveys/instructions/InstructionCollection.py +77 -77
- edsl/templates/error_reporting/base.html +23 -23
- edsl/templates/error_reporting/exceptions_by_model.html +34 -34
- edsl/templates/error_reporting/exceptions_by_question_name.html +16 -16
- edsl/templates/error_reporting/exceptions_by_type.html +16 -16
- edsl/templates/error_reporting/interview_details.html +115 -115
- edsl/templates/error_reporting/interviews.html +9 -9
- edsl/templates/error_reporting/overview.html +4 -4
- edsl/templates/error_reporting/performance_plot.html +1 -1
- edsl/templates/error_reporting/report.css +73 -73
- edsl/templates/error_reporting/report.html +117 -117
- edsl/templates/error_reporting/report.js +25 -25
- edsl/tools/__init__.py +1 -1
- edsl/tools/clusters.py +192 -192
- edsl/tools/embeddings.py +27 -27
- edsl/tools/embeddings_plotting.py +118 -118
- edsl/tools/plotting.py +112 -112
- edsl/tools/summarize.py +18 -18
- edsl/utilities/SystemInfo.py +28 -28
- edsl/utilities/__init__.py +22 -22
- edsl/utilities/ast_utilities.py +25 -25
- edsl/utilities/data/Registry.py +6 -6
- edsl/utilities/data/__init__.py +1 -1
- edsl/utilities/data/scooter_results.json +1 -1
- edsl/utilities/decorators.py +77 -77
- edsl/utilities/gcp_bucket/cloud_storage.py +96 -96
- edsl/utilities/interface.py +627 -627
- edsl/utilities/repair_functions.py +28 -28
- edsl/utilities/restricted_python.py +70 -70
- edsl/utilities/utilities.py +391 -391
- {edsl-0.1.36.dev5.dist-info → edsl-0.1.36.dev6.dist-info}/LICENSE +21 -21
- {edsl-0.1.36.dev5.dist-info → edsl-0.1.36.dev6.dist-info}/METADATA +1 -1
- edsl-0.1.36.dev6.dist-info/RECORD +279 -0
- edsl-0.1.36.dev5.dist-info/RECORD +0 -279
- {edsl-0.1.36.dev5.dist-info → edsl-0.1.36.dev6.dist-info}/WHEEL +0 -0
@@ -1,238 +1,238 @@
|
|
1
|
-
from collections import UserList
|
2
|
-
import asyncio
|
3
|
-
import inspect
|
4
|
-
from typing import Optional, Callable
|
5
|
-
from edsl import Agent, QuestionFreeText, Results, AgentList, ScenarioList, Scenario
|
6
|
-
from edsl.questions import QuestionBase
|
7
|
-
from edsl.results.Result import Result
|
8
|
-
|
9
|
-
from edsl.data import Cache
|
10
|
-
|
11
|
-
from edsl.conversation.next_speaker_utilities import (
|
12
|
-
default_turn_taking_generator,
|
13
|
-
speaker_closure,
|
14
|
-
)
|
15
|
-
|
16
|
-
|
17
|
-
class AgentStatement:
|
18
|
-
def __init__(self, statement: Result):
|
19
|
-
self.statement = statement
|
20
|
-
|
21
|
-
@property
|
22
|
-
def agent_name(self):
|
23
|
-
return self.statement["agent"]["name"]
|
24
|
-
|
25
|
-
def to_dict(self):
|
26
|
-
return self.statement.to_dict()
|
27
|
-
|
28
|
-
@classmethod
|
29
|
-
def from_dict(cls, data):
|
30
|
-
return cls(Result.from_dict(data))
|
31
|
-
|
32
|
-
@property
|
33
|
-
def text(self):
|
34
|
-
return self.statement["answer"]["dialogue"]
|
35
|
-
|
36
|
-
|
37
|
-
class AgentStatements(UserList):
|
38
|
-
def __init__(self, data=None):
|
39
|
-
super().__init__(data)
|
40
|
-
|
41
|
-
@property
|
42
|
-
def transcript(self):
|
43
|
-
return [{s.agent_name: s.text} for s in self.data]
|
44
|
-
|
45
|
-
def to_dict(self):
|
46
|
-
return [d.to_dict() for d in self.data]
|
47
|
-
|
48
|
-
@classmethod
|
49
|
-
def from_dict(cls, data):
|
50
|
-
return cls([AgentStatement.from_dict(d) for d in data])
|
51
|
-
|
52
|
-
|
53
|
-
class Conversation:
|
54
|
-
"""A conversation between a list of agents. The first agent in the list is the first speaker.
|
55
|
-
After that, order is determined by the next_speaker function.
|
56
|
-
The question asked to each agent is determined by the next_statement_question.
|
57
|
-
"""
|
58
|
-
|
59
|
-
def __init__(
|
60
|
-
self,
|
61
|
-
agent_list: AgentList,
|
62
|
-
max_turns: int = 20,
|
63
|
-
stopping_function: Optional[Callable] = None,
|
64
|
-
next_statement_question: Optional[QuestionBase] = None,
|
65
|
-
next_speaker_generator: Optional[Callable] = None,
|
66
|
-
verbose: bool = False,
|
67
|
-
conversation_index: Optional[int] = None,
|
68
|
-
cache=None,
|
69
|
-
):
|
70
|
-
if cache is None:
|
71
|
-
self.cache = Cache()
|
72
|
-
else:
|
73
|
-
self.cache = cache
|
74
|
-
|
75
|
-
self.agent_list = agent_list
|
76
|
-
self.verbose = verbose
|
77
|
-
self.agent_statements = []
|
78
|
-
self._conversation_index = conversation_index
|
79
|
-
|
80
|
-
self.agent_statements = AgentStatements()
|
81
|
-
|
82
|
-
self.max_turns = max_turns
|
83
|
-
|
84
|
-
if next_statement_question is None:
|
85
|
-
self.next_statement_question = QuestionFreeText(
|
86
|
-
question_text="You are {{ agent_name }}. This is the converstaion so far: {{ conversation }}. What do you say next?",
|
87
|
-
question_name="dialogue",
|
88
|
-
)
|
89
|
-
|
90
|
-
# Determine how the next speaker is chosen
|
91
|
-
if next_speaker_generator is None:
|
92
|
-
func = default_turn_taking_generator
|
93
|
-
else:
|
94
|
-
func = next_speaker_generator
|
95
|
-
|
96
|
-
self.next_speaker = speaker_closure(
|
97
|
-
agent_list=self.agent_list, generator_function=func
|
98
|
-
)
|
99
|
-
|
100
|
-
# Determine when the conversation ends
|
101
|
-
if stopping_function is None:
|
102
|
-
self.stopping_function = lambda agent_statements: False
|
103
|
-
else:
|
104
|
-
self.stopping_function = stopping_function
|
105
|
-
|
106
|
-
async def continue_conversation(self, **kwargs) -> bool:
|
107
|
-
if len(self.agent_statements) >= self.max_turns:
|
108
|
-
return False
|
109
|
-
|
110
|
-
if inspect.iscoroutinefunction(self.stopping_function):
|
111
|
-
should_stop = await self.stopping_function(self.agent_statements, **kwargs)
|
112
|
-
else:
|
113
|
-
should_stop = self.stopping_function(self.agent_statements, **kwargs)
|
114
|
-
|
115
|
-
return not should_stop
|
116
|
-
|
117
|
-
def add_index(self, index) -> None:
|
118
|
-
self._conversation_index = index
|
119
|
-
|
120
|
-
@property
|
121
|
-
def conversation_index(self):
|
122
|
-
return self._conversation_index
|
123
|
-
|
124
|
-
def to_dict(self):
|
125
|
-
return {
|
126
|
-
"agent_list": self.agent_list.to_dict(),
|
127
|
-
"max_turns": self.max_turns,
|
128
|
-
"verbose": self.verbose,
|
129
|
-
"agent_statements": [d.to_dict() for d in self.agent_statements],
|
130
|
-
"conversation_index": self.conversation_index,
|
131
|
-
}
|
132
|
-
|
133
|
-
@classmethod
|
134
|
-
def from_dict(cls, data):
|
135
|
-
agent_list = AgentList.from_dict(data["agent_list"])
|
136
|
-
max_turns = data["max_turns"]
|
137
|
-
verbose = data["verbose"]
|
138
|
-
agent_statements = (AgentStatements.from_dict(data["agent_statements"]),)
|
139
|
-
conversation_index = data["conversation_index"]
|
140
|
-
return cls(
|
141
|
-
agent_list=agent_list,
|
142
|
-
max_turns=max_turns,
|
143
|
-
verbose=verbose,
|
144
|
-
results_data=agent_statements,
|
145
|
-
conversation_index=conversation_index,
|
146
|
-
)
|
147
|
-
|
148
|
-
def to_results(self):
|
149
|
-
return Results(data=[s.statement for s in self.agent_statements])
|
150
|
-
|
151
|
-
def summarize(self):
|
152
|
-
d = {
|
153
|
-
"num_agents": len(self.agent_list),
|
154
|
-
"max_turns": self.max_turns,
|
155
|
-
"conversation_index": self.conversation_index,
|
156
|
-
"transcript": self.to_results().select("agent_name", "dialogue").to_list(),
|
157
|
-
"number_of_agent_statements": len(self.agent_statements),
|
158
|
-
}
|
159
|
-
return Scenario(d)
|
160
|
-
|
161
|
-
async def get_next_statement(self, *, index, speaker, conversation):
|
162
|
-
q = self.next_statement_question
|
163
|
-
assert q.parameters == {"agent_name", "conversation"}, q.parameters
|
164
|
-
results = await q.run_async(
|
165
|
-
index=index,
|
166
|
-
conversation=conversation,
|
167
|
-
conversation_index=self.conversation_index,
|
168
|
-
agent_name=speaker.name,
|
169
|
-
agent=speaker,
|
170
|
-
just_answer=False,
|
171
|
-
cache=self.cache,
|
172
|
-
)
|
173
|
-
return results[0]
|
174
|
-
|
175
|
-
def converse(self):
|
176
|
-
return asyncio.run(self._converse())
|
177
|
-
|
178
|
-
async def _converse(self):
|
179
|
-
i = 0
|
180
|
-
while await self.continue_conversation():
|
181
|
-
speaker = self.next_speaker()
|
182
|
-
# breakpoint()
|
183
|
-
|
184
|
-
next_statement = AgentStatement(
|
185
|
-
statement=await self.get_next_statement(
|
186
|
-
index=i,
|
187
|
-
speaker=speaker,
|
188
|
-
conversation=self.agent_statements.transcript,
|
189
|
-
)
|
190
|
-
)
|
191
|
-
self.agent_statements.append(next_statement)
|
192
|
-
if self.verbose:
|
193
|
-
print(f"'{speaker.name}':{next_statement.text}")
|
194
|
-
print("\n")
|
195
|
-
i += 1
|
196
|
-
|
197
|
-
|
198
|
-
class ConversationList:
|
199
|
-
"""A collection of conversations to be run in parallel."""
|
200
|
-
|
201
|
-
def __init__(self, conversations: list[Conversation], cache=None):
|
202
|
-
self.conversations = conversations
|
203
|
-
for i, conversation in enumerate(self.conversations):
|
204
|
-
conversation.add_index(i)
|
205
|
-
|
206
|
-
if cache is None:
|
207
|
-
self.cache = Cache()
|
208
|
-
else:
|
209
|
-
self.cache = cache
|
210
|
-
|
211
|
-
for c in self.conversations:
|
212
|
-
c.cache = self.cache
|
213
|
-
|
214
|
-
async def run_conversations(self):
|
215
|
-
await asyncio.gather(*[c._converse() for c in self.conversations])
|
216
|
-
|
217
|
-
def run(self) -> None:
|
218
|
-
"""Run all conversations in parallel"""
|
219
|
-
asyncio.run(self.run_conversations())
|
220
|
-
|
221
|
-
def to_dict(self) -> dict:
|
222
|
-
return {"conversations": c.to_dict() for c in self.conversations}
|
223
|
-
|
224
|
-
@classmethod
|
225
|
-
def from_dict(cls, data):
|
226
|
-
conversations = [Conversation.from_dict(d) for d in data["conversations"]]
|
227
|
-
return cls(conversations)
|
228
|
-
|
229
|
-
def to_results(self) -> Results:
|
230
|
-
"""Return the results of all conversations as a single Results"""
|
231
|
-
first_convo = self.conversations[0]
|
232
|
-
results = first_convo.to_results()
|
233
|
-
for conv in self.conversations[1:]:
|
234
|
-
results += conv.to_results()
|
235
|
-
return results
|
236
|
-
|
237
|
-
def summarize(self) -> ScenarioList:
|
238
|
-
return ScenarioList([c.summarize() for c in self.conversations])
|
1
|
+
from collections import UserList
|
2
|
+
import asyncio
|
3
|
+
import inspect
|
4
|
+
from typing import Optional, Callable
|
5
|
+
from edsl import Agent, QuestionFreeText, Results, AgentList, ScenarioList, Scenario
|
6
|
+
from edsl.questions import QuestionBase
|
7
|
+
from edsl.results.Result import Result
|
8
|
+
|
9
|
+
from edsl.data import Cache
|
10
|
+
|
11
|
+
from edsl.conversation.next_speaker_utilities import (
|
12
|
+
default_turn_taking_generator,
|
13
|
+
speaker_closure,
|
14
|
+
)
|
15
|
+
|
16
|
+
|
17
|
+
class AgentStatement:
|
18
|
+
def __init__(self, statement: Result):
|
19
|
+
self.statement = statement
|
20
|
+
|
21
|
+
@property
|
22
|
+
def agent_name(self):
|
23
|
+
return self.statement["agent"]["name"]
|
24
|
+
|
25
|
+
def to_dict(self):
|
26
|
+
return self.statement.to_dict()
|
27
|
+
|
28
|
+
@classmethod
|
29
|
+
def from_dict(cls, data):
|
30
|
+
return cls(Result.from_dict(data))
|
31
|
+
|
32
|
+
@property
|
33
|
+
def text(self):
|
34
|
+
return self.statement["answer"]["dialogue"]
|
35
|
+
|
36
|
+
|
37
|
+
class AgentStatements(UserList):
|
38
|
+
def __init__(self, data=None):
|
39
|
+
super().__init__(data)
|
40
|
+
|
41
|
+
@property
|
42
|
+
def transcript(self):
|
43
|
+
return [{s.agent_name: s.text} for s in self.data]
|
44
|
+
|
45
|
+
def to_dict(self):
|
46
|
+
return [d.to_dict() for d in self.data]
|
47
|
+
|
48
|
+
@classmethod
|
49
|
+
def from_dict(cls, data):
|
50
|
+
return cls([AgentStatement.from_dict(d) for d in data])
|
51
|
+
|
52
|
+
|
53
|
+
class Conversation:
|
54
|
+
"""A conversation between a list of agents. The first agent in the list is the first speaker.
|
55
|
+
After that, order is determined by the next_speaker function.
|
56
|
+
The question asked to each agent is determined by the next_statement_question.
|
57
|
+
"""
|
58
|
+
|
59
|
+
def __init__(
|
60
|
+
self,
|
61
|
+
agent_list: AgentList,
|
62
|
+
max_turns: int = 20,
|
63
|
+
stopping_function: Optional[Callable] = None,
|
64
|
+
next_statement_question: Optional[QuestionBase] = None,
|
65
|
+
next_speaker_generator: Optional[Callable] = None,
|
66
|
+
verbose: bool = False,
|
67
|
+
conversation_index: Optional[int] = None,
|
68
|
+
cache=None,
|
69
|
+
):
|
70
|
+
if cache is None:
|
71
|
+
self.cache = Cache()
|
72
|
+
else:
|
73
|
+
self.cache = cache
|
74
|
+
|
75
|
+
self.agent_list = agent_list
|
76
|
+
self.verbose = verbose
|
77
|
+
self.agent_statements = []
|
78
|
+
self._conversation_index = conversation_index
|
79
|
+
|
80
|
+
self.agent_statements = AgentStatements()
|
81
|
+
|
82
|
+
self.max_turns = max_turns
|
83
|
+
|
84
|
+
if next_statement_question is None:
|
85
|
+
self.next_statement_question = QuestionFreeText(
|
86
|
+
question_text="You are {{ agent_name }}. This is the converstaion so far: {{ conversation }}. What do you say next?",
|
87
|
+
question_name="dialogue",
|
88
|
+
)
|
89
|
+
|
90
|
+
# Determine how the next speaker is chosen
|
91
|
+
if next_speaker_generator is None:
|
92
|
+
func = default_turn_taking_generator
|
93
|
+
else:
|
94
|
+
func = next_speaker_generator
|
95
|
+
|
96
|
+
self.next_speaker = speaker_closure(
|
97
|
+
agent_list=self.agent_list, generator_function=func
|
98
|
+
)
|
99
|
+
|
100
|
+
# Determine when the conversation ends
|
101
|
+
if stopping_function is None:
|
102
|
+
self.stopping_function = lambda agent_statements: False
|
103
|
+
else:
|
104
|
+
self.stopping_function = stopping_function
|
105
|
+
|
106
|
+
async def continue_conversation(self, **kwargs) -> bool:
|
107
|
+
if len(self.agent_statements) >= self.max_turns:
|
108
|
+
return False
|
109
|
+
|
110
|
+
if inspect.iscoroutinefunction(self.stopping_function):
|
111
|
+
should_stop = await self.stopping_function(self.agent_statements, **kwargs)
|
112
|
+
else:
|
113
|
+
should_stop = self.stopping_function(self.agent_statements, **kwargs)
|
114
|
+
|
115
|
+
return not should_stop
|
116
|
+
|
117
|
+
def add_index(self, index) -> None:
|
118
|
+
self._conversation_index = index
|
119
|
+
|
120
|
+
@property
|
121
|
+
def conversation_index(self):
|
122
|
+
return self._conversation_index
|
123
|
+
|
124
|
+
def to_dict(self):
|
125
|
+
return {
|
126
|
+
"agent_list": self.agent_list.to_dict(),
|
127
|
+
"max_turns": self.max_turns,
|
128
|
+
"verbose": self.verbose,
|
129
|
+
"agent_statements": [d.to_dict() for d in self.agent_statements],
|
130
|
+
"conversation_index": self.conversation_index,
|
131
|
+
}
|
132
|
+
|
133
|
+
@classmethod
|
134
|
+
def from_dict(cls, data):
|
135
|
+
agent_list = AgentList.from_dict(data["agent_list"])
|
136
|
+
max_turns = data["max_turns"]
|
137
|
+
verbose = data["verbose"]
|
138
|
+
agent_statements = (AgentStatements.from_dict(data["agent_statements"]),)
|
139
|
+
conversation_index = data["conversation_index"]
|
140
|
+
return cls(
|
141
|
+
agent_list=agent_list,
|
142
|
+
max_turns=max_turns,
|
143
|
+
verbose=verbose,
|
144
|
+
results_data=agent_statements,
|
145
|
+
conversation_index=conversation_index,
|
146
|
+
)
|
147
|
+
|
148
|
+
def to_results(self):
|
149
|
+
return Results(data=[s.statement for s in self.agent_statements])
|
150
|
+
|
151
|
+
def summarize(self):
|
152
|
+
d = {
|
153
|
+
"num_agents": len(self.agent_list),
|
154
|
+
"max_turns": self.max_turns,
|
155
|
+
"conversation_index": self.conversation_index,
|
156
|
+
"transcript": self.to_results().select("agent_name", "dialogue").to_list(),
|
157
|
+
"number_of_agent_statements": len(self.agent_statements),
|
158
|
+
}
|
159
|
+
return Scenario(d)
|
160
|
+
|
161
|
+
async def get_next_statement(self, *, index, speaker, conversation):
|
162
|
+
q = self.next_statement_question
|
163
|
+
assert q.parameters == {"agent_name", "conversation"}, q.parameters
|
164
|
+
results = await q.run_async(
|
165
|
+
index=index,
|
166
|
+
conversation=conversation,
|
167
|
+
conversation_index=self.conversation_index,
|
168
|
+
agent_name=speaker.name,
|
169
|
+
agent=speaker,
|
170
|
+
just_answer=False,
|
171
|
+
cache=self.cache,
|
172
|
+
)
|
173
|
+
return results[0]
|
174
|
+
|
175
|
+
def converse(self):
|
176
|
+
return asyncio.run(self._converse())
|
177
|
+
|
178
|
+
async def _converse(self):
|
179
|
+
i = 0
|
180
|
+
while await self.continue_conversation():
|
181
|
+
speaker = self.next_speaker()
|
182
|
+
# breakpoint()
|
183
|
+
|
184
|
+
next_statement = AgentStatement(
|
185
|
+
statement=await self.get_next_statement(
|
186
|
+
index=i,
|
187
|
+
speaker=speaker,
|
188
|
+
conversation=self.agent_statements.transcript,
|
189
|
+
)
|
190
|
+
)
|
191
|
+
self.agent_statements.append(next_statement)
|
192
|
+
if self.verbose:
|
193
|
+
print(f"'{speaker.name}':{next_statement.text}")
|
194
|
+
print("\n")
|
195
|
+
i += 1
|
196
|
+
|
197
|
+
|
198
|
+
class ConversationList:
|
199
|
+
"""A collection of conversations to be run in parallel."""
|
200
|
+
|
201
|
+
def __init__(self, conversations: list[Conversation], cache=None):
|
202
|
+
self.conversations = conversations
|
203
|
+
for i, conversation in enumerate(self.conversations):
|
204
|
+
conversation.add_index(i)
|
205
|
+
|
206
|
+
if cache is None:
|
207
|
+
self.cache = Cache()
|
208
|
+
else:
|
209
|
+
self.cache = cache
|
210
|
+
|
211
|
+
for c in self.conversations:
|
212
|
+
c.cache = self.cache
|
213
|
+
|
214
|
+
async def run_conversations(self):
|
215
|
+
await asyncio.gather(*[c._converse() for c in self.conversations])
|
216
|
+
|
217
|
+
def run(self) -> None:
|
218
|
+
"""Run all conversations in parallel"""
|
219
|
+
asyncio.run(self.run_conversations())
|
220
|
+
|
221
|
+
def to_dict(self) -> dict:
|
222
|
+
return {"conversations": c.to_dict() for c in self.conversations}
|
223
|
+
|
224
|
+
@classmethod
|
225
|
+
def from_dict(cls, data):
|
226
|
+
conversations = [Conversation.from_dict(d) for d in data["conversations"]]
|
227
|
+
return cls(conversations)
|
228
|
+
|
229
|
+
def to_results(self) -> Results:
|
230
|
+
"""Return the results of all conversations as a single Results"""
|
231
|
+
first_convo = self.conversations[0]
|
232
|
+
results = first_convo.to_results()
|
233
|
+
for conv in self.conversations[1:]:
|
234
|
+
results += conv.to_results()
|
235
|
+
return results
|
236
|
+
|
237
|
+
def summarize(self) -> ScenarioList:
|
238
|
+
return ScenarioList([c.summarize() for c in self.conversations])
|
edsl/conversation/car_buying.py
CHANGED
@@ -1,58 +1,58 @@
|
|
1
|
-
from edsl import Agent, AgentList, QuestionFreeText
|
2
|
-
from edsl import Cache
|
3
|
-
from edsl.conversation.Conversation import Conversation, ConversationList
|
4
|
-
|
5
|
-
a1 = Agent(
|
6
|
-
name="Alice",
|
7
|
-
traits={
|
8
|
-
"motivation": """
|
9
|
-
You are Alice. You want to buy a car. You are talking to Bob, a car salesman.
|
10
|
-
It is very important to you that the steering wheel not whiff out of the window while you are driving.
|
11
|
-
Also, the car can have no space for mother-in-law.
|
12
|
-
You brought your brother-in-law along, Paul, who you have an antagonistic relationship with.
|
13
|
-
"""
|
14
|
-
},
|
15
|
-
)
|
16
|
-
a2 = Agent(
|
17
|
-
name="Bob",
|
18
|
-
traits={
|
19
|
-
"motivation": "You are Bob, a car salesman. You are talking to Alice, who wants to buy a car."
|
20
|
-
},
|
21
|
-
)
|
22
|
-
a3 = Agent(
|
23
|
-
name="Paul",
|
24
|
-
traits={
|
25
|
-
"motivation": "You are Paul, you are Alice's brother. You think her concerns are foolish and you are critical of her."
|
26
|
-
},
|
27
|
-
)
|
28
|
-
|
29
|
-
c1 = Conversation(agent_list=AgentList([a1, a3, a2]), max_turns=5, verbose=True)
|
30
|
-
c2 = Conversation(agent_list=AgentList([a1, a2]), max_turns=5, verbose=True)
|
31
|
-
|
32
|
-
c = Cache.load("car_talk.json.gz")
|
33
|
-
# breakpoint()
|
34
|
-
combo = ConversationList([c1, c2], cache=c)
|
35
|
-
combo.run()
|
36
|
-
results = combo.to_results()
|
37
|
-
results.select("conversation_index", "index", "agent_name", "dialogue").print(
|
38
|
-
format="rich"
|
39
|
-
)
|
40
|
-
|
41
|
-
q = QuestionFreeText(
|
42
|
-
question_text="""This was a conversation about buying a car: {{ transcript }}.
|
43
|
-
Was a brand or style of car mentioned? If so, what was it?
|
44
|
-
""",
|
45
|
-
question_name="car_brand",
|
46
|
-
)
|
47
|
-
|
48
|
-
from edsl import QuestionList
|
49
|
-
|
50
|
-
q_actors = QuestionList(
|
51
|
-
question_text="""This was a conversation about buying a car: {{ transcript }}.
|
52
|
-
Who were the actors in the conversation?
|
53
|
-
""",
|
54
|
-
question_name="actors",
|
55
|
-
)
|
56
|
-
|
57
|
-
transcript_analysis = q.add_question(q_actors).by(combo.summarize()).run()
|
58
|
-
transcript_analysis.select("car_brand", "actors").print(format="rich")
|
1
|
+
from edsl import Agent, AgentList, QuestionFreeText
|
2
|
+
from edsl import Cache
|
3
|
+
from edsl.conversation.Conversation import Conversation, ConversationList
|
4
|
+
|
5
|
+
a1 = Agent(
|
6
|
+
name="Alice",
|
7
|
+
traits={
|
8
|
+
"motivation": """
|
9
|
+
You are Alice. You want to buy a car. You are talking to Bob, a car salesman.
|
10
|
+
It is very important to you that the steering wheel not whiff out of the window while you are driving.
|
11
|
+
Also, the car can have no space for mother-in-law.
|
12
|
+
You brought your brother-in-law along, Paul, who you have an antagonistic relationship with.
|
13
|
+
"""
|
14
|
+
},
|
15
|
+
)
|
16
|
+
a2 = Agent(
|
17
|
+
name="Bob",
|
18
|
+
traits={
|
19
|
+
"motivation": "You are Bob, a car salesman. You are talking to Alice, who wants to buy a car."
|
20
|
+
},
|
21
|
+
)
|
22
|
+
a3 = Agent(
|
23
|
+
name="Paul",
|
24
|
+
traits={
|
25
|
+
"motivation": "You are Paul, you are Alice's brother. You think her concerns are foolish and you are critical of her."
|
26
|
+
},
|
27
|
+
)
|
28
|
+
|
29
|
+
c1 = Conversation(agent_list=AgentList([a1, a3, a2]), max_turns=5, verbose=True)
|
30
|
+
c2 = Conversation(agent_list=AgentList([a1, a2]), max_turns=5, verbose=True)
|
31
|
+
|
32
|
+
c = Cache.load("car_talk.json.gz")
|
33
|
+
# breakpoint()
|
34
|
+
combo = ConversationList([c1, c2], cache=c)
|
35
|
+
combo.run()
|
36
|
+
results = combo.to_results()
|
37
|
+
results.select("conversation_index", "index", "agent_name", "dialogue").print(
|
38
|
+
format="rich"
|
39
|
+
)
|
40
|
+
|
41
|
+
q = QuestionFreeText(
|
42
|
+
question_text="""This was a conversation about buying a car: {{ transcript }}.
|
43
|
+
Was a brand or style of car mentioned? If so, what was it?
|
44
|
+
""",
|
45
|
+
question_name="car_brand",
|
46
|
+
)
|
47
|
+
|
48
|
+
from edsl import QuestionList
|
49
|
+
|
50
|
+
q_actors = QuestionList(
|
51
|
+
question_text="""This was a conversation about buying a car: {{ transcript }}.
|
52
|
+
Who were the actors in the conversation?
|
53
|
+
""",
|
54
|
+
question_name="actors",
|
55
|
+
)
|
56
|
+
|
57
|
+
transcript_analysis = q.add_question(q_actors).by(combo.summarize()).run()
|
58
|
+
transcript_analysis.select("car_brand", "actors").print(format="rich")
|