edsl 0.1.36.dev5__py3-none-any.whl → 0.1.37__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__init__.py +1 -0
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +92 -41
- edsl/agents/AgentList.py +15 -2
- edsl/agents/InvigilatorBase.py +15 -25
- edsl/agents/PromptConstructor.py +149 -108
- edsl/agents/descriptors.py +17 -4
- edsl/conjure/AgentConstructionMixin.py +11 -3
- edsl/conversation/Conversation.py +66 -14
- edsl/conversation/chips.py +95 -0
- edsl/coop/coop.py +148 -39
- edsl/data/Cache.py +1 -1
- edsl/data/RemoteCacheSync.py +25 -12
- edsl/exceptions/BaseException.py +21 -0
- edsl/exceptions/__init__.py +7 -3
- edsl/exceptions/agents.py +17 -19
- edsl/exceptions/results.py +11 -8
- edsl/exceptions/scenarios.py +22 -0
- edsl/exceptions/surveys.py +13 -10
- edsl/inference_services/AwsBedrock.py +7 -2
- edsl/inference_services/InferenceServicesCollection.py +42 -13
- edsl/inference_services/models_available_cache.py +25 -1
- edsl/jobs/Jobs.py +306 -71
- edsl/jobs/interviews/Interview.py +24 -14
- edsl/jobs/interviews/InterviewExceptionCollection.py +1 -1
- edsl/jobs/interviews/InterviewExceptionEntry.py +17 -13
- edsl/jobs/interviews/ReportErrors.py +2 -2
- edsl/jobs/runners/JobsRunnerAsyncio.py +10 -9
- edsl/jobs/tasks/TaskHistory.py +1 -0
- edsl/language_models/KeyLookup.py +30 -0
- edsl/language_models/LanguageModel.py +47 -59
- edsl/language_models/__init__.py +1 -0
- edsl/prompts/Prompt.py +11 -12
- edsl/questions/QuestionBase.py +53 -13
- edsl/questions/QuestionBasePromptsMixin.py +1 -33
- edsl/questions/QuestionFreeText.py +1 -0
- edsl/questions/QuestionFunctional.py +2 -2
- edsl/questions/descriptors.py +23 -28
- edsl/results/DatasetExportMixin.py +25 -1
- edsl/results/Result.py +27 -10
- edsl/results/Results.py +34 -121
- edsl/results/ResultsDBMixin.py +1 -1
- edsl/results/Selector.py +18 -1
- edsl/scenarios/FileStore.py +20 -5
- edsl/scenarios/Scenario.py +52 -13
- edsl/scenarios/ScenarioHtmlMixin.py +7 -2
- edsl/scenarios/ScenarioList.py +12 -1
- edsl/scenarios/__init__.py +2 -0
- edsl/surveys/Rule.py +10 -4
- edsl/surveys/Survey.py +100 -77
- edsl/utilities/utilities.py +18 -0
- {edsl-0.1.36.dev5.dist-info → edsl-0.1.37.dist-info}/METADATA +1 -1
- {edsl-0.1.36.dev5.dist-info → edsl-0.1.37.dist-info}/RECORD +55 -51
- {edsl-0.1.36.dev5.dist-info → edsl-0.1.37.dist-info}/LICENSE +0 -0
- {edsl-0.1.36.dev5.dist-info → edsl-0.1.37.dist-info}/WHEEL +0 -0
edsl/agents/PromptConstructor.py
CHANGED
@@ -1,17 +1,27 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
from typing import Dict, Any, Optional, Set
|
3
|
-
from collections import UserList
|
4
|
-
import pdb
|
5
3
|
|
6
4
|
from jinja2 import Environment, meta
|
7
5
|
|
8
6
|
from edsl.prompts.Prompt import Prompt
|
9
|
-
from edsl.
|
7
|
+
from edsl.agents.prompt_helpers import PromptPlan
|
10
8
|
|
11
|
-
# from edsl.prompts.registry import get_classes as prompt_lookup
|
12
|
-
from edsl.exceptions import QuestionScenarioRenderError
|
13
9
|
|
14
|
-
|
10
|
+
class PlaceholderAnswer:
|
11
|
+
"""A placeholder answer for when a question is not yet answered."""
|
12
|
+
|
13
|
+
def __init__(self):
|
14
|
+
self.answer = "N/A"
|
15
|
+
self.comment = "Will be populated by prior answer"
|
16
|
+
|
17
|
+
def __getitem__(self, index):
|
18
|
+
return ""
|
19
|
+
|
20
|
+
def __str__(self):
|
21
|
+
return "<<PlaceholderAnswer>>"
|
22
|
+
|
23
|
+
def __repr__(self):
|
24
|
+
return "<<PlaceholderAnswer>>"
|
15
25
|
|
16
26
|
|
17
27
|
def get_jinja2_variables(template_str: str) -> Set[str]:
|
@@ -40,7 +50,7 @@ class PromptConstructor:
|
|
40
50
|
This is mixed into the Invigilator class.
|
41
51
|
"""
|
42
52
|
|
43
|
-
def __init__(self, invigilator):
|
53
|
+
def __init__(self, invigilator, prompt_plan: Optional["PromptPlan"] = None):
|
44
54
|
self.invigilator = invigilator
|
45
55
|
self.agent = invigilator.agent
|
46
56
|
self.question = invigilator.question
|
@@ -49,7 +59,7 @@ class PromptConstructor:
|
|
49
59
|
self.model = invigilator.model
|
50
60
|
self.current_answers = invigilator.current_answers
|
51
61
|
self.memory_plan = invigilator.memory_plan
|
52
|
-
self.prompt_plan = PromptPlan()
|
62
|
+
self.prompt_plan = prompt_plan or PromptPlan()
|
53
63
|
|
54
64
|
@property
|
55
65
|
def scenario_file_keys(self) -> list:
|
@@ -95,15 +105,20 @@ class PromptConstructor:
|
|
95
105
|
return self.agent.prompt()
|
96
106
|
|
97
107
|
def prior_answers_dict(self) -> dict:
|
108
|
+
# this is all questions
|
98
109
|
d = self.survey.question_names_to_questions()
|
99
110
|
# This attaches the answer to the question
|
100
|
-
for question
|
101
|
-
if question in
|
102
|
-
d[question].answer =
|
111
|
+
for question in d:
|
112
|
+
if question in self.current_answers:
|
113
|
+
d[question].answer = self.current_answers[question]
|
103
114
|
else:
|
104
|
-
|
105
|
-
|
106
|
-
|
115
|
+
d[question].answer = PlaceholderAnswer()
|
116
|
+
|
117
|
+
# if (new_question := question.split("_comment")[0]) in d:
|
118
|
+
# d[new_question].comment = answer
|
119
|
+
# d[question].answer = PlaceholderAnswer()
|
120
|
+
|
121
|
+
# breakpoint()
|
107
122
|
return d
|
108
123
|
|
109
124
|
@property
|
@@ -116,6 +131,123 @@ class PromptConstructor:
|
|
116
131
|
question_file_keys.append(var)
|
117
132
|
return question_file_keys
|
118
133
|
|
134
|
+
def build_replacement_dict(self, question_data: dict):
|
135
|
+
"""
|
136
|
+
Builds a dictionary of replacement values by combining multiple data sources.
|
137
|
+
"""
|
138
|
+
# File references dictionary
|
139
|
+
file_refs = {key: f"<see file {key}>" for key in self.scenario_file_keys}
|
140
|
+
|
141
|
+
# Scenario items excluding file keys
|
142
|
+
scenario_items = {
|
143
|
+
k: v for k, v in self.scenario.items() if k not in self.scenario_file_keys
|
144
|
+
}
|
145
|
+
|
146
|
+
# Question settings with defaults
|
147
|
+
question_settings = {
|
148
|
+
"use_code": getattr(self.question, "_use_code", True),
|
149
|
+
"include_comment": getattr(self.question, "_include_comment", False),
|
150
|
+
}
|
151
|
+
|
152
|
+
# Combine all dictionaries using dict.update() for clarity
|
153
|
+
replacement_dict = {}
|
154
|
+
for d in [
|
155
|
+
file_refs,
|
156
|
+
question_data,
|
157
|
+
scenario_items,
|
158
|
+
self.prior_answers_dict(),
|
159
|
+
{"agent": self.agent},
|
160
|
+
question_settings,
|
161
|
+
]:
|
162
|
+
replacement_dict.update(d)
|
163
|
+
|
164
|
+
return replacement_dict
|
165
|
+
|
166
|
+
def _get_question_options(self, question_data):
|
167
|
+
question_options_entry = question_data.get("question_options", None)
|
168
|
+
question_options = question_options_entry
|
169
|
+
|
170
|
+
placeholder = ["<< Option 1 - Placholder >>", "<< Option 2 - Placholder >>"]
|
171
|
+
|
172
|
+
if isinstance(question_options_entry, str):
|
173
|
+
env = Environment()
|
174
|
+
parsed_content = env.parse(question_options_entry)
|
175
|
+
question_option_key = list(meta.find_undeclared_variables(parsed_content))[
|
176
|
+
0
|
177
|
+
]
|
178
|
+
if isinstance(self.scenario.get(question_option_key), list):
|
179
|
+
question_options = self.scenario.get(question_option_key)
|
180
|
+
|
181
|
+
# might be getting it from the prior answers
|
182
|
+
if self.prior_answers_dict().get(question_option_key) is not None:
|
183
|
+
prior_question = self.prior_answers_dict().get(question_option_key)
|
184
|
+
if hasattr(prior_question, "answer"):
|
185
|
+
if isinstance(prior_question.answer, list):
|
186
|
+
question_options = prior_question.answer
|
187
|
+
else:
|
188
|
+
question_options = placeholder
|
189
|
+
else:
|
190
|
+
question_options = placeholder
|
191
|
+
|
192
|
+
return question_options
|
193
|
+
|
194
|
+
def build_question_instructions_prompt(self):
|
195
|
+
"""Buils the question instructions prompt."""
|
196
|
+
|
197
|
+
question_prompt = Prompt(self.question.get_instructions(model=self.model.model))
|
198
|
+
|
199
|
+
# Get the data for the question - this is a dictionary of the question data
|
200
|
+
# e.g., {'question_text': 'Do you like school?', 'question_name': 'q0', 'question_options': ['yes', 'no']}
|
201
|
+
question_data = self.question.data.copy()
|
202
|
+
|
203
|
+
if "question_options" in question_data:
|
204
|
+
question_options = self._get_question_options(question_data)
|
205
|
+
question_data["question_options"] = question_options
|
206
|
+
|
207
|
+
# check to see if the question_options is actually a string
|
208
|
+
# This is used when the user is using the question_options as a variable from a scenario
|
209
|
+
# if "question_options" in question_data:
|
210
|
+
replacement_dict = self.build_replacement_dict(question_data)
|
211
|
+
rendered_instructions = question_prompt.render(replacement_dict)
|
212
|
+
|
213
|
+
# is there anything left to render?
|
214
|
+
undefined_template_variables = (
|
215
|
+
rendered_instructions.undefined_template_variables({})
|
216
|
+
)
|
217
|
+
|
218
|
+
# Check if it's the name of a question in the survey
|
219
|
+
for question_name in self.survey.question_names:
|
220
|
+
if question_name in undefined_template_variables:
|
221
|
+
print(
|
222
|
+
"Question name found in undefined_template_variables: ",
|
223
|
+
question_name,
|
224
|
+
)
|
225
|
+
|
226
|
+
if undefined_template_variables:
|
227
|
+
msg = f"Question instructions still has variables: {undefined_template_variables}."
|
228
|
+
import warnings
|
229
|
+
|
230
|
+
warnings.warn(msg)
|
231
|
+
# raise QuestionScenarioRenderError(
|
232
|
+
# f"Question instructions still has variables: {undefined_template_variables}."
|
233
|
+
# )
|
234
|
+
|
235
|
+
# Check if question has instructions - these are instructions in a Survey that can apply to multiple follow-on questions
|
236
|
+
relevant_instructions = self.survey.relevant_instructions(
|
237
|
+
self.question.question_name
|
238
|
+
)
|
239
|
+
|
240
|
+
if relevant_instructions != []:
|
241
|
+
# preamble_text = Prompt(
|
242
|
+
# text="You were given the following instructions: "
|
243
|
+
# )
|
244
|
+
preamble_text = Prompt(text="")
|
245
|
+
for instruction in relevant_instructions:
|
246
|
+
preamble_text += instruction.text
|
247
|
+
rendered_instructions = preamble_text + rendered_instructions
|
248
|
+
|
249
|
+
return rendered_instructions
|
250
|
+
|
119
251
|
@property
|
120
252
|
def question_instructions_prompt(self) -> Prompt:
|
121
253
|
"""
|
@@ -125,102 +257,11 @@ class PromptConstructor:
|
|
125
257
|
Prompt(text=\"""...
|
126
258
|
...
|
127
259
|
"""
|
128
|
-
# The user might have passed a custom prompt, which would be stored in _question_instructions_prompt
|
129
260
|
if not hasattr(self, "_question_instructions_prompt"):
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
# Get the data for the question - this is a dictionary of the question data
|
134
|
-
# e.g., {'question_text': 'Do you like school?', 'question_name': 'q0', 'question_options': ['yes', 'no']}
|
135
|
-
question_data = self.question.data.copy()
|
136
|
-
|
137
|
-
# check to see if the question_options is actually a string
|
138
|
-
# This is used when the user is using the question_options as a variable from a scenario
|
139
|
-
# if "question_options" in question_data:
|
140
|
-
if isinstance(self.question.data.get("question_options", None), str):
|
141
|
-
env = Environment()
|
142
|
-
parsed_content = env.parse(self.question.data["question_options"])
|
143
|
-
question_option_key = list(
|
144
|
-
meta.find_undeclared_variables(parsed_content)
|
145
|
-
)[0]
|
146
|
-
|
147
|
-
# look to see if the question_option_key is in the scenario
|
148
|
-
if isinstance(
|
149
|
-
question_options := self.scenario.get(question_option_key), list
|
150
|
-
):
|
151
|
-
question_data["question_options"] = question_options
|
152
|
-
self.question.question_options = question_options
|
153
|
-
|
154
|
-
# might be getting it from the prior answers
|
155
|
-
if self.prior_answers_dict().get(question_option_key) is not None:
|
156
|
-
if isinstance(
|
157
|
-
question_options := self.prior_answers_dict()
|
158
|
-
.get(question_option_key)
|
159
|
-
.answer,
|
160
|
-
list,
|
161
|
-
):
|
162
|
-
question_data["question_options"] = question_options
|
163
|
-
self.question.question_options = question_options
|
164
|
-
|
165
|
-
replacement_dict = (
|
166
|
-
{key: f"<see file {key}>" for key in self.scenario_file_keys}
|
167
|
-
| question_data
|
168
|
-
| {
|
169
|
-
k: v
|
170
|
-
for k, v in self.scenario.items()
|
171
|
-
if k not in self.scenario_file_keys
|
172
|
-
} # don't include images in the replacement dict
|
173
|
-
| self.prior_answers_dict()
|
174
|
-
| {"agent": self.agent}
|
175
|
-
| {
|
176
|
-
"use_code": getattr(self.question, "_use_code", True),
|
177
|
-
"include_comment": getattr(
|
178
|
-
self.question, "_include_comment", False
|
179
|
-
),
|
180
|
-
}
|
181
|
-
)
|
182
|
-
|
183
|
-
rendered_instructions = question_prompt.render(replacement_dict)
|
184
|
-
|
185
|
-
# is there anything left to render?
|
186
|
-
undefined_template_variables = (
|
187
|
-
rendered_instructions.undefined_template_variables({})
|
261
|
+
self._question_instructions_prompt = (
|
262
|
+
self.build_question_instructions_prompt()
|
188
263
|
)
|
189
264
|
|
190
|
-
# Check if it's the name of a question in the survey
|
191
|
-
for question_name in self.survey.question_names:
|
192
|
-
if question_name in undefined_template_variables:
|
193
|
-
print(
|
194
|
-
"Question name found in undefined_template_variables: ",
|
195
|
-
question_name,
|
196
|
-
)
|
197
|
-
|
198
|
-
if undefined_template_variables:
|
199
|
-
msg = f"Question instructions still has variables: {undefined_template_variables}."
|
200
|
-
import warnings
|
201
|
-
|
202
|
-
warnings.warn(msg)
|
203
|
-
# raise QuestionScenarioRenderError(
|
204
|
-
# f"Question instructions still has variables: {undefined_template_variables}."
|
205
|
-
# )
|
206
|
-
|
207
|
-
####################################
|
208
|
-
# Check if question has instructions - these are instructions in a Survey that can apply to multiple follow-on questions
|
209
|
-
####################################
|
210
|
-
relevant_instructions = self.survey.relevant_instructions(
|
211
|
-
self.question.question_name
|
212
|
-
)
|
213
|
-
|
214
|
-
if relevant_instructions != []:
|
215
|
-
# preamble_text = Prompt(
|
216
|
-
# text="You were given the following instructions: "
|
217
|
-
# )
|
218
|
-
preamble_text = Prompt(text="")
|
219
|
-
for instruction in relevant_instructions:
|
220
|
-
preamble_text += instruction.text
|
221
|
-
rendered_instructions = preamble_text + rendered_instructions
|
222
|
-
|
223
|
-
self._question_instructions_prompt = rendered_instructions
|
224
265
|
return self._question_instructions_prompt
|
225
266
|
|
226
267
|
@property
|
@@ -285,7 +326,7 @@ class PromptConstructor:
|
|
285
326
|
prompts = self.prompt_plan.get_prompts(
|
286
327
|
agent_instructions=self.agent_instructions_prompt,
|
287
328
|
agent_persona=self.agent_persona_prompt,
|
288
|
-
question_instructions=self.question_instructions_prompt,
|
329
|
+
question_instructions=Prompt(self.question_instructions_prompt),
|
289
330
|
prior_question_memory=self.prior_question_memory_prompt,
|
290
331
|
)
|
291
332
|
if self.question_file_keys:
|
edsl/agents/descriptors.py
CHANGED
@@ -4,6 +4,20 @@ from typing import Dict
|
|
4
4
|
from edsl.exceptions.agents import AgentNameError, AgentTraitKeyError
|
5
5
|
|
6
6
|
|
7
|
+
def convert_agent_name(x):
|
8
|
+
# potentially a numpy int64
|
9
|
+
import numpy as np
|
10
|
+
|
11
|
+
if isinstance(x, np.int64):
|
12
|
+
return int(x)
|
13
|
+
elif x is None:
|
14
|
+
return None
|
15
|
+
elif isinstance(x, int):
|
16
|
+
return x
|
17
|
+
else:
|
18
|
+
return str(x)
|
19
|
+
|
20
|
+
|
7
21
|
class NameDescriptor:
|
8
22
|
"""ABC for something."""
|
9
23
|
|
@@ -13,7 +27,7 @@ class NameDescriptor:
|
|
13
27
|
|
14
28
|
def __set__(self, instance, name: str) -> None:
|
15
29
|
"""Set the value of the attribute."""
|
16
|
-
instance.__dict__[self.name] = name
|
30
|
+
instance.__dict__[self.name] = convert_agent_name(name)
|
17
31
|
|
18
32
|
def __set_name__(self, owner, name: str) -> None:
|
19
33
|
"""Set the name of the attribute."""
|
@@ -34,9 +48,8 @@ class TraitsDescriptor:
|
|
34
48
|
for key, value in traits_dict.items():
|
35
49
|
if key == "name":
|
36
50
|
raise AgentNameError(
|
37
|
-
"
|
38
|
-
Agent(name="my_agent", traits={"trait1": "value1", "trait2": "value2"})
|
39
|
-
"""
|
51
|
+
"Trait keys cannot be 'name'. Instead, use the 'name' attribute directly e.g.,\n"
|
52
|
+
'Agent(name="my_agent", traits={"trait1": "value1", "trait2": "value2"})'
|
40
53
|
)
|
41
54
|
|
42
55
|
if not is_valid_variable_name(key):
|
@@ -99,6 +99,8 @@ class AgentConstructionMixin:
|
|
99
99
|
sample_size: int = None,
|
100
100
|
seed: str = "edsl",
|
101
101
|
dryrun=False,
|
102
|
+
disable_remote_cache: bool = False,
|
103
|
+
disable_remote_inference: bool = False,
|
102
104
|
) -> Union[Results, None]:
|
103
105
|
"""Return the results of the survey.
|
104
106
|
|
@@ -109,7 +111,7 @@ class AgentConstructionMixin:
|
|
109
111
|
|
110
112
|
>>> from edsl.conjure.InputData import InputDataABC
|
111
113
|
>>> id = InputDataABC.example()
|
112
|
-
>>> r = id.to_results()
|
114
|
+
>>> r = id.to_results(disable_remote_cache = True, disable_remote_inference = True)
|
113
115
|
>>> len(r) == id.num_observations
|
114
116
|
True
|
115
117
|
"""
|
@@ -125,7 +127,10 @@ class AgentConstructionMixin:
|
|
125
127
|
import time
|
126
128
|
|
127
129
|
start = time.time()
|
128
|
-
_ = survey.by(agent_list.sample(DRYRUN_SAMPLE)).run(
|
130
|
+
_ = survey.by(agent_list.sample(DRYRUN_SAMPLE)).run(
|
131
|
+
disable_remote_cache=disable_remote_cache,
|
132
|
+
disable_remote_inference=disable_remote_inference,
|
133
|
+
)
|
129
134
|
end = time.time()
|
130
135
|
print(f"Time to run {DRYRUN_SAMPLE} agents (s): {round(end - start, 2)}")
|
131
136
|
time_per_agent = (end - start) / DRYRUN_SAMPLE
|
@@ -143,7 +148,10 @@ class AgentConstructionMixin:
|
|
143
148
|
f"Full sample will take about {round(full_sample_time / 3600, 2)} hours."
|
144
149
|
)
|
145
150
|
return None
|
146
|
-
return survey.by(agent_list).run(
|
151
|
+
return survey.by(agent_list).run(
|
152
|
+
disable_remote_cache=disable_remote_cache,
|
153
|
+
disable_remote_inference=disable_remote_inference,
|
154
|
+
)
|
147
155
|
|
148
156
|
|
149
157
|
if __name__ == "__main__":
|
@@ -5,7 +5,7 @@ from typing import Optional, Callable
|
|
5
5
|
from edsl import Agent, QuestionFreeText, Results, AgentList, ScenarioList, Scenario
|
6
6
|
from edsl.questions import QuestionBase
|
7
7
|
from edsl.results.Result import Result
|
8
|
-
|
8
|
+
from jinja2 import Template
|
9
9
|
from edsl.data import Cache
|
10
10
|
|
11
11
|
from edsl.conversation.next_speaker_utilities import (
|
@@ -54,6 +54,9 @@ class Conversation:
|
|
54
54
|
"""A conversation between a list of agents. The first agent in the list is the first speaker.
|
55
55
|
After that, order is determined by the next_speaker function.
|
56
56
|
The question asked to each agent is determined by the next_statement_question.
|
57
|
+
|
58
|
+
If the user has passed in a "per_round_message_template", this will be displayed at the beginning of each round.
|
59
|
+
{{ round_message }} must be in the question_text.
|
57
60
|
"""
|
58
61
|
|
59
62
|
def __init__(
|
@@ -64,28 +67,62 @@ class Conversation:
|
|
64
67
|
next_statement_question: Optional[QuestionBase] = None,
|
65
68
|
next_speaker_generator: Optional[Callable] = None,
|
66
69
|
verbose: bool = False,
|
70
|
+
per_round_message_template: Optional[str] = None,
|
67
71
|
conversation_index: Optional[int] = None,
|
68
72
|
cache=None,
|
73
|
+
disable_remote_inference=False,
|
74
|
+
default_model: Optional["LanguageModel"] = None,
|
69
75
|
):
|
76
|
+
self.disable_remote_inference = disable_remote_inference
|
77
|
+
self.per_round_message_template = per_round_message_template
|
78
|
+
|
70
79
|
if cache is None:
|
71
80
|
self.cache = Cache()
|
72
81
|
else:
|
73
82
|
self.cache = cache
|
74
83
|
|
75
84
|
self.agent_list = agent_list
|
85
|
+
|
86
|
+
from edsl import Model
|
87
|
+
|
88
|
+
for agent in self.agent_list:
|
89
|
+
if not hasattr(agent, "model"):
|
90
|
+
if default_model is not None:
|
91
|
+
agent.model = default_model
|
92
|
+
else:
|
93
|
+
agent.model = Model()
|
94
|
+
|
76
95
|
self.verbose = verbose
|
77
96
|
self.agent_statements = []
|
78
97
|
self._conversation_index = conversation_index
|
79
|
-
|
80
98
|
self.agent_statements = AgentStatements()
|
81
99
|
|
82
100
|
self.max_turns = max_turns
|
83
101
|
|
84
102
|
if next_statement_question is None:
|
103
|
+
import textwrap
|
104
|
+
|
105
|
+
base_question = textwrap.dedent(
|
106
|
+
"""\
|
107
|
+
You are {{ agent_name }}. This is the conversation so far: {{ conversation }}
|
108
|
+
{% if round_message is not none %}
|
109
|
+
{{ round_message }}
|
110
|
+
{% endif %}
|
111
|
+
What do you say next?"""
|
112
|
+
)
|
85
113
|
self.next_statement_question = QuestionFreeText(
|
86
|
-
question_text=
|
114
|
+
question_text=base_question,
|
87
115
|
question_name="dialogue",
|
88
116
|
)
|
117
|
+
else:
|
118
|
+
self.next_statement_question = next_statement_question
|
119
|
+
if (
|
120
|
+
per_round_message_template
|
121
|
+
and "{{ round_message }}" not in next_statement_question.question_text
|
122
|
+
):
|
123
|
+
raise ValueError(
|
124
|
+
"If you pass in a per_round_message_template, you must include {{ round_message }} in the question_text."
|
125
|
+
)
|
89
126
|
|
90
127
|
# Determine how the next speaker is chosen
|
91
128
|
if next_speaker_generator is None:
|
@@ -93,6 +130,7 @@ class Conversation:
|
|
93
130
|
else:
|
94
131
|
func = next_speaker_generator
|
95
132
|
|
133
|
+
# Choose the next speaker
|
96
134
|
self.next_speaker = speaker_closure(
|
97
135
|
agent_list=self.agent_list, generator_function=func
|
98
136
|
)
|
@@ -158,17 +196,32 @@ class Conversation:
|
|
158
196
|
}
|
159
197
|
return Scenario(d)
|
160
198
|
|
161
|
-
async def get_next_statement(self, *, index, speaker, conversation):
|
199
|
+
async def get_next_statement(self, *, index, speaker, conversation) -> "Result":
|
200
|
+
"""Get the next statement from the speaker."""
|
162
201
|
q = self.next_statement_question
|
163
|
-
assert q.parameters == {"agent_name", "conversation"}, q.parameters
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
202
|
+
# assert q.parameters == {"agent_name", "conversation"}, q.parameters
|
203
|
+
from edsl import Scenario
|
204
|
+
|
205
|
+
if self.per_round_message_template is None:
|
206
|
+
round_message = None
|
207
|
+
else:
|
208
|
+
round_message = Template(self.per_round_message_template).render(
|
209
|
+
{"max_turns": self.max_turns, "current_turn": index}
|
210
|
+
)
|
211
|
+
|
212
|
+
s = Scenario(
|
213
|
+
{
|
214
|
+
"agent_name": speaker.name,
|
215
|
+
"conversation": conversation,
|
216
|
+
"conversation_index": self.conversation_index,
|
217
|
+
"index": index,
|
218
|
+
"round_message": round_message,
|
219
|
+
}
|
220
|
+
)
|
221
|
+
jobs = q.by(s).by(speaker).by(speaker.model)
|
222
|
+
jobs.show_prompts()
|
223
|
+
results = await jobs.run_async(
|
224
|
+
cache=self.cache, disable_remote_inference=self.disable_remote_inference
|
172
225
|
)
|
173
226
|
return results[0]
|
174
227
|
|
@@ -179,7 +232,6 @@ class Conversation:
|
|
179
232
|
i = 0
|
180
233
|
while await self.continue_conversation():
|
181
234
|
speaker = self.next_speaker()
|
182
|
-
# breakpoint()
|
183
235
|
|
184
236
|
next_statement = AgentStatement(
|
185
237
|
statement=await self.get_next_statement(
|
@@ -0,0 +1,95 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
|
3
|
+
from edsl import Agent, AgentList, QuestionFreeText
|
4
|
+
from edsl import Cache
|
5
|
+
from edsl import QuestionList
|
6
|
+
from edsl import Model
|
7
|
+
|
8
|
+
from edsl.conversation.Conversation import Conversation, ConversationList
|
9
|
+
|
10
|
+
m = Model("gemini-1.5-flash")
|
11
|
+
|
12
|
+
|
13
|
+
class ChipLover(Agent):
|
14
|
+
def __init__(self, name, chip_values, initial_chips, model: Optional[Model] = None):
|
15
|
+
self.chip_values = chip_values
|
16
|
+
self.initial_chips = initial_chips
|
17
|
+
self.current_chips = initial_chips
|
18
|
+
self.model = model or Model()
|
19
|
+
super().__init__(
|
20
|
+
name=name,
|
21
|
+
traits={
|
22
|
+
"motivation": f"""
|
23
|
+
You are {name}. You are negotiating the trading of colored 'chips' with other players. You want to maximize your score.
|
24
|
+
When you want to accept a deal, say "DEAL!"
|
25
|
+
Note that different players can have different values for the chips.
|
26
|
+
""",
|
27
|
+
"chip_values": chip_values,
|
28
|
+
"initial_chips": initial_chips,
|
29
|
+
},
|
30
|
+
)
|
31
|
+
|
32
|
+
def trade(self, chips_given_dict, chips_received_dict):
|
33
|
+
for color, amount in chips_given_dict.items():
|
34
|
+
self.current_chips[color] -= amount
|
35
|
+
for color, amount in chips_received_dict.items():
|
36
|
+
self.current_chips[color] += amount
|
37
|
+
|
38
|
+
def get_score(self):
|
39
|
+
return sum(
|
40
|
+
self.chip_values[color] * self.current_chips[color]
|
41
|
+
for color in self.chip_values
|
42
|
+
)
|
43
|
+
|
44
|
+
|
45
|
+
a1 = ChipLover(
|
46
|
+
name="Alice",
|
47
|
+
chip_values={"Green": 7, "Blue": 1, "Red": 0},
|
48
|
+
model=Model("gemini-1.5-flash"),
|
49
|
+
initial_chips={"Green": 1, "Blue": 2, "Red": 3},
|
50
|
+
)
|
51
|
+
a2 = ChipLover(
|
52
|
+
name="Bob",
|
53
|
+
chip_values={"Green": 7, "Blue": 1, "Red": 0},
|
54
|
+
initial_chips={"Green": 1, "Blue": 2, "Red": 3},
|
55
|
+
)
|
56
|
+
|
57
|
+
c1 = Conversation(agent_list=AgentList([a1, a2]), max_turns=10, verbose=True)
|
58
|
+
c2 = Conversation(agent_list=AgentList([a1, a2]), max_turns=10, verbose=True)
|
59
|
+
|
60
|
+
with Cache() as c:
|
61
|
+
combo = ConversationList([c1, c2], cache=c)
|
62
|
+
combo.run()
|
63
|
+
results = combo.to_results()
|
64
|
+
results.select("conversation_index", "index", "agent_name", "dialogue").print(
|
65
|
+
format="rich"
|
66
|
+
)
|
67
|
+
|
68
|
+
q = QuestionFreeText(
|
69
|
+
question_text="""This was a conversation/negotiation: {{ transcript }}.
|
70
|
+
What trades occurred in the conversation?
|
71
|
+
""",
|
72
|
+
question_name="trades",
|
73
|
+
)
|
74
|
+
|
75
|
+
q_actors = QuestionList(
|
76
|
+
question_text="""Here is a transcript: {{ transcript }}.
|
77
|
+
Who were the actors in the conversation?
|
78
|
+
""",
|
79
|
+
question_name="actors",
|
80
|
+
)
|
81
|
+
|
82
|
+
from edsl import QuestionList
|
83
|
+
|
84
|
+
q_transfers = QuestionList(
|
85
|
+
question_text="""This was a conversation/negotiation: {{ transcript }}.
|
86
|
+
Extract all offers and their outcomes.
|
87
|
+
Use this format: {'proposing_agent':"Alice": 'receiving_agent': "Bob", 'gives':{"Green": 1, "Blue": 2}, 'receives':{"Green": 2, "Blue": 1}, 'accepted':True}
|
88
|
+
""",
|
89
|
+
question_name="transfers",
|
90
|
+
)
|
91
|
+
|
92
|
+
transcript_analysis = (
|
93
|
+
q.add_question(q_actors).add_question(q_transfers).by(combo.summarize()).run()
|
94
|
+
)
|
95
|
+
transcript_analysis.select("trades", "actors", "transfers").print(format="rich")
|