edsl 0.1.36.dev5__py3-none-any.whl → 0.1.37__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. edsl/__init__.py +1 -0
  2. edsl/__version__.py +1 -1
  3. edsl/agents/Agent.py +92 -41
  4. edsl/agents/AgentList.py +15 -2
  5. edsl/agents/InvigilatorBase.py +15 -25
  6. edsl/agents/PromptConstructor.py +149 -108
  7. edsl/agents/descriptors.py +17 -4
  8. edsl/conjure/AgentConstructionMixin.py +11 -3
  9. edsl/conversation/Conversation.py +66 -14
  10. edsl/conversation/chips.py +95 -0
  11. edsl/coop/coop.py +148 -39
  12. edsl/data/Cache.py +1 -1
  13. edsl/data/RemoteCacheSync.py +25 -12
  14. edsl/exceptions/BaseException.py +21 -0
  15. edsl/exceptions/__init__.py +7 -3
  16. edsl/exceptions/agents.py +17 -19
  17. edsl/exceptions/results.py +11 -8
  18. edsl/exceptions/scenarios.py +22 -0
  19. edsl/exceptions/surveys.py +13 -10
  20. edsl/inference_services/AwsBedrock.py +7 -2
  21. edsl/inference_services/InferenceServicesCollection.py +42 -13
  22. edsl/inference_services/models_available_cache.py +25 -1
  23. edsl/jobs/Jobs.py +306 -71
  24. edsl/jobs/interviews/Interview.py +24 -14
  25. edsl/jobs/interviews/InterviewExceptionCollection.py +1 -1
  26. edsl/jobs/interviews/InterviewExceptionEntry.py +17 -13
  27. edsl/jobs/interviews/ReportErrors.py +2 -2
  28. edsl/jobs/runners/JobsRunnerAsyncio.py +10 -9
  29. edsl/jobs/tasks/TaskHistory.py +1 -0
  30. edsl/language_models/KeyLookup.py +30 -0
  31. edsl/language_models/LanguageModel.py +47 -59
  32. edsl/language_models/__init__.py +1 -0
  33. edsl/prompts/Prompt.py +11 -12
  34. edsl/questions/QuestionBase.py +53 -13
  35. edsl/questions/QuestionBasePromptsMixin.py +1 -33
  36. edsl/questions/QuestionFreeText.py +1 -0
  37. edsl/questions/QuestionFunctional.py +2 -2
  38. edsl/questions/descriptors.py +23 -28
  39. edsl/results/DatasetExportMixin.py +25 -1
  40. edsl/results/Result.py +27 -10
  41. edsl/results/Results.py +34 -121
  42. edsl/results/ResultsDBMixin.py +1 -1
  43. edsl/results/Selector.py +18 -1
  44. edsl/scenarios/FileStore.py +20 -5
  45. edsl/scenarios/Scenario.py +52 -13
  46. edsl/scenarios/ScenarioHtmlMixin.py +7 -2
  47. edsl/scenarios/ScenarioList.py +12 -1
  48. edsl/scenarios/__init__.py +2 -0
  49. edsl/surveys/Rule.py +10 -4
  50. edsl/surveys/Survey.py +100 -77
  51. edsl/utilities/utilities.py +18 -0
  52. {edsl-0.1.36.dev5.dist-info → edsl-0.1.37.dist-info}/METADATA +1 -1
  53. {edsl-0.1.36.dev5.dist-info → edsl-0.1.37.dist-info}/RECORD +55 -51
  54. {edsl-0.1.36.dev5.dist-info → edsl-0.1.37.dist-info}/LICENSE +0 -0
  55. {edsl-0.1.36.dev5.dist-info → edsl-0.1.37.dist-info}/WHEEL +0 -0
@@ -1,17 +1,27 @@
1
1
  from __future__ import annotations
2
2
  from typing import Dict, Any, Optional, Set
3
- from collections import UserList
4
- import pdb
5
3
 
6
4
  from jinja2 import Environment, meta
7
5
 
8
6
  from edsl.prompts.Prompt import Prompt
9
- from edsl.data_transfer_models import ImageInfo
7
+ from edsl.agents.prompt_helpers import PromptPlan
10
8
 
11
- # from edsl.prompts.registry import get_classes as prompt_lookup
12
- from edsl.exceptions import QuestionScenarioRenderError
13
9
 
14
- from edsl.agents.prompt_helpers import PromptComponent, PromptList, PromptPlan
10
+ class PlaceholderAnswer:
11
+ """A placeholder answer for when a question is not yet answered."""
12
+
13
+ def __init__(self):
14
+ self.answer = "N/A"
15
+ self.comment = "Will be populated by prior answer"
16
+
17
+ def __getitem__(self, index):
18
+ return ""
19
+
20
+ def __str__(self):
21
+ return "<<PlaceholderAnswer>>"
22
+
23
+ def __repr__(self):
24
+ return "<<PlaceholderAnswer>>"
15
25
 
16
26
 
17
27
  def get_jinja2_variables(template_str: str) -> Set[str]:
@@ -40,7 +50,7 @@ class PromptConstructor:
40
50
  This is mixed into the Invigilator class.
41
51
  """
42
52
 
43
- def __init__(self, invigilator):
53
+ def __init__(self, invigilator, prompt_plan: Optional["PromptPlan"] = None):
44
54
  self.invigilator = invigilator
45
55
  self.agent = invigilator.agent
46
56
  self.question = invigilator.question
@@ -49,7 +59,7 @@ class PromptConstructor:
49
59
  self.model = invigilator.model
50
60
  self.current_answers = invigilator.current_answers
51
61
  self.memory_plan = invigilator.memory_plan
52
- self.prompt_plan = PromptPlan()
62
+ self.prompt_plan = prompt_plan or PromptPlan()
53
63
 
54
64
  @property
55
65
  def scenario_file_keys(self) -> list:
@@ -95,15 +105,20 @@ class PromptConstructor:
95
105
  return self.agent.prompt()
96
106
 
97
107
  def prior_answers_dict(self) -> dict:
108
+ # this is all questions
98
109
  d = self.survey.question_names_to_questions()
99
110
  # This attaches the answer to the question
100
- for question, answer in self.current_answers.items():
101
- if question in d:
102
- d[question].answer = answer
111
+ for question in d:
112
+ if question in self.current_answers:
113
+ d[question].answer = self.current_answers[question]
103
114
  else:
104
- # adds a comment to the question
105
- if (new_question := question.split("_comment")[0]) in d:
106
- d[new_question].comment = answer
115
+ d[question].answer = PlaceholderAnswer()
116
+
117
+ # if (new_question := question.split("_comment")[0]) in d:
118
+ # d[new_question].comment = answer
119
+ # d[question].answer = PlaceholderAnswer()
120
+
121
+ # breakpoint()
107
122
  return d
108
123
 
109
124
  @property
@@ -116,6 +131,123 @@ class PromptConstructor:
116
131
  question_file_keys.append(var)
117
132
  return question_file_keys
118
133
 
134
+ def build_replacement_dict(self, question_data: dict):
135
+ """
136
+ Builds a dictionary of replacement values by combining multiple data sources.
137
+ """
138
+ # File references dictionary
139
+ file_refs = {key: f"<see file {key}>" for key in self.scenario_file_keys}
140
+
141
+ # Scenario items excluding file keys
142
+ scenario_items = {
143
+ k: v for k, v in self.scenario.items() if k not in self.scenario_file_keys
144
+ }
145
+
146
+ # Question settings with defaults
147
+ question_settings = {
148
+ "use_code": getattr(self.question, "_use_code", True),
149
+ "include_comment": getattr(self.question, "_include_comment", False),
150
+ }
151
+
152
+ # Combine all dictionaries using dict.update() for clarity
153
+ replacement_dict = {}
154
+ for d in [
155
+ file_refs,
156
+ question_data,
157
+ scenario_items,
158
+ self.prior_answers_dict(),
159
+ {"agent": self.agent},
160
+ question_settings,
161
+ ]:
162
+ replacement_dict.update(d)
163
+
164
+ return replacement_dict
165
+
166
+ def _get_question_options(self, question_data):
167
+ question_options_entry = question_data.get("question_options", None)
168
+ question_options = question_options_entry
169
+
170
+ placeholder = ["<< Option 1 - Placholder >>", "<< Option 2 - Placholder >>"]
171
+
172
+ if isinstance(question_options_entry, str):
173
+ env = Environment()
174
+ parsed_content = env.parse(question_options_entry)
175
+ question_option_key = list(meta.find_undeclared_variables(parsed_content))[
176
+ 0
177
+ ]
178
+ if isinstance(self.scenario.get(question_option_key), list):
179
+ question_options = self.scenario.get(question_option_key)
180
+
181
+ # might be getting it from the prior answers
182
+ if self.prior_answers_dict().get(question_option_key) is not None:
183
+ prior_question = self.prior_answers_dict().get(question_option_key)
184
+ if hasattr(prior_question, "answer"):
185
+ if isinstance(prior_question.answer, list):
186
+ question_options = prior_question.answer
187
+ else:
188
+ question_options = placeholder
189
+ else:
190
+ question_options = placeholder
191
+
192
+ return question_options
193
+
194
+ def build_question_instructions_prompt(self):
195
+ """Buils the question instructions prompt."""
196
+
197
+ question_prompt = Prompt(self.question.get_instructions(model=self.model.model))
198
+
199
+ # Get the data for the question - this is a dictionary of the question data
200
+ # e.g., {'question_text': 'Do you like school?', 'question_name': 'q0', 'question_options': ['yes', 'no']}
201
+ question_data = self.question.data.copy()
202
+
203
+ if "question_options" in question_data:
204
+ question_options = self._get_question_options(question_data)
205
+ question_data["question_options"] = question_options
206
+
207
+ # check to see if the question_options is actually a string
208
+ # This is used when the user is using the question_options as a variable from a scenario
209
+ # if "question_options" in question_data:
210
+ replacement_dict = self.build_replacement_dict(question_data)
211
+ rendered_instructions = question_prompt.render(replacement_dict)
212
+
213
+ # is there anything left to render?
214
+ undefined_template_variables = (
215
+ rendered_instructions.undefined_template_variables({})
216
+ )
217
+
218
+ # Check if it's the name of a question in the survey
219
+ for question_name in self.survey.question_names:
220
+ if question_name in undefined_template_variables:
221
+ print(
222
+ "Question name found in undefined_template_variables: ",
223
+ question_name,
224
+ )
225
+
226
+ if undefined_template_variables:
227
+ msg = f"Question instructions still has variables: {undefined_template_variables}."
228
+ import warnings
229
+
230
+ warnings.warn(msg)
231
+ # raise QuestionScenarioRenderError(
232
+ # f"Question instructions still has variables: {undefined_template_variables}."
233
+ # )
234
+
235
+ # Check if question has instructions - these are instructions in a Survey that can apply to multiple follow-on questions
236
+ relevant_instructions = self.survey.relevant_instructions(
237
+ self.question.question_name
238
+ )
239
+
240
+ if relevant_instructions != []:
241
+ # preamble_text = Prompt(
242
+ # text="You were given the following instructions: "
243
+ # )
244
+ preamble_text = Prompt(text="")
245
+ for instruction in relevant_instructions:
246
+ preamble_text += instruction.text
247
+ rendered_instructions = preamble_text + rendered_instructions
248
+
249
+ return rendered_instructions
250
+
119
251
  @property
120
252
  def question_instructions_prompt(self) -> Prompt:
121
253
  """
@@ -125,102 +257,11 @@ class PromptConstructor:
125
257
  Prompt(text=\"""...
126
258
  ...
127
259
  """
128
- # The user might have passed a custom prompt, which would be stored in _question_instructions_prompt
129
260
  if not hasattr(self, "_question_instructions_prompt"):
130
- # Gets the instructions for the question - this is how the question should be answered
131
- question_prompt = self.question.get_instructions(model=self.model.model)
132
-
133
- # Get the data for the question - this is a dictionary of the question data
134
- # e.g., {'question_text': 'Do you like school?', 'question_name': 'q0', 'question_options': ['yes', 'no']}
135
- question_data = self.question.data.copy()
136
-
137
- # check to see if the question_options is actually a string
138
- # This is used when the user is using the question_options as a variable from a scenario
139
- # if "question_options" in question_data:
140
- if isinstance(self.question.data.get("question_options", None), str):
141
- env = Environment()
142
- parsed_content = env.parse(self.question.data["question_options"])
143
- question_option_key = list(
144
- meta.find_undeclared_variables(parsed_content)
145
- )[0]
146
-
147
- # look to see if the question_option_key is in the scenario
148
- if isinstance(
149
- question_options := self.scenario.get(question_option_key), list
150
- ):
151
- question_data["question_options"] = question_options
152
- self.question.question_options = question_options
153
-
154
- # might be getting it from the prior answers
155
- if self.prior_answers_dict().get(question_option_key) is not None:
156
- if isinstance(
157
- question_options := self.prior_answers_dict()
158
- .get(question_option_key)
159
- .answer,
160
- list,
161
- ):
162
- question_data["question_options"] = question_options
163
- self.question.question_options = question_options
164
-
165
- replacement_dict = (
166
- {key: f"<see file {key}>" for key in self.scenario_file_keys}
167
- | question_data
168
- | {
169
- k: v
170
- for k, v in self.scenario.items()
171
- if k not in self.scenario_file_keys
172
- } # don't include images in the replacement dict
173
- | self.prior_answers_dict()
174
- | {"agent": self.agent}
175
- | {
176
- "use_code": getattr(self.question, "_use_code", True),
177
- "include_comment": getattr(
178
- self.question, "_include_comment", False
179
- ),
180
- }
181
- )
182
-
183
- rendered_instructions = question_prompt.render(replacement_dict)
184
-
185
- # is there anything left to render?
186
- undefined_template_variables = (
187
- rendered_instructions.undefined_template_variables({})
261
+ self._question_instructions_prompt = (
262
+ self.build_question_instructions_prompt()
188
263
  )
189
264
 
190
- # Check if it's the name of a question in the survey
191
- for question_name in self.survey.question_names:
192
- if question_name in undefined_template_variables:
193
- print(
194
- "Question name found in undefined_template_variables: ",
195
- question_name,
196
- )
197
-
198
- if undefined_template_variables:
199
- msg = f"Question instructions still has variables: {undefined_template_variables}."
200
- import warnings
201
-
202
- warnings.warn(msg)
203
- # raise QuestionScenarioRenderError(
204
- # f"Question instructions still has variables: {undefined_template_variables}."
205
- # )
206
-
207
- ####################################
208
- # Check if question has instructions - these are instructions in a Survey that can apply to multiple follow-on questions
209
- ####################################
210
- relevant_instructions = self.survey.relevant_instructions(
211
- self.question.question_name
212
- )
213
-
214
- if relevant_instructions != []:
215
- # preamble_text = Prompt(
216
- # text="You were given the following instructions: "
217
- # )
218
- preamble_text = Prompt(text="")
219
- for instruction in relevant_instructions:
220
- preamble_text += instruction.text
221
- rendered_instructions = preamble_text + rendered_instructions
222
-
223
- self._question_instructions_prompt = rendered_instructions
224
265
  return self._question_instructions_prompt
225
266
 
226
267
  @property
@@ -285,7 +326,7 @@ class PromptConstructor:
285
326
  prompts = self.prompt_plan.get_prompts(
286
327
  agent_instructions=self.agent_instructions_prompt,
287
328
  agent_persona=self.agent_persona_prompt,
288
- question_instructions=self.question_instructions_prompt,
329
+ question_instructions=Prompt(self.question_instructions_prompt),
289
330
  prior_question_memory=self.prior_question_memory_prompt,
290
331
  )
291
332
  if self.question_file_keys:
@@ -4,6 +4,20 @@ from typing import Dict
4
4
  from edsl.exceptions.agents import AgentNameError, AgentTraitKeyError
5
5
 
6
6
 
7
+ def convert_agent_name(x):
8
+ # potentially a numpy int64
9
+ import numpy as np
10
+
11
+ if isinstance(x, np.int64):
12
+ return int(x)
13
+ elif x is None:
14
+ return None
15
+ elif isinstance(x, int):
16
+ return x
17
+ else:
18
+ return str(x)
19
+
20
+
7
21
  class NameDescriptor:
8
22
  """ABC for something."""
9
23
 
@@ -13,7 +27,7 @@ class NameDescriptor:
13
27
 
14
28
  def __set__(self, instance, name: str) -> None:
15
29
  """Set the value of the attribute."""
16
- instance.__dict__[self.name] = name
30
+ instance.__dict__[self.name] = convert_agent_name(name)
17
31
 
18
32
  def __set_name__(self, owner, name: str) -> None:
19
33
  """Set the name of the attribute."""
@@ -34,9 +48,8 @@ class TraitsDescriptor:
34
48
  for key, value in traits_dict.items():
35
49
  if key == "name":
36
50
  raise AgentNameError(
37
- """Trait keys cannot be 'name'. Instead, use the 'name' attribute directly e.g.,
38
- Agent(name="my_agent", traits={"trait1": "value1", "trait2": "value2"})
39
- """
51
+ "Trait keys cannot be 'name'. Instead, use the 'name' attribute directly e.g.,\n"
52
+ 'Agent(name="my_agent", traits={"trait1": "value1", "trait2": "value2"})'
40
53
  )
41
54
 
42
55
  if not is_valid_variable_name(key):
@@ -99,6 +99,8 @@ class AgentConstructionMixin:
99
99
  sample_size: int = None,
100
100
  seed: str = "edsl",
101
101
  dryrun=False,
102
+ disable_remote_cache: bool = False,
103
+ disable_remote_inference: bool = False,
102
104
  ) -> Union[Results, None]:
103
105
  """Return the results of the survey.
104
106
 
@@ -109,7 +111,7 @@ class AgentConstructionMixin:
109
111
 
110
112
  >>> from edsl.conjure.InputData import InputDataABC
111
113
  >>> id = InputDataABC.example()
112
- >>> r = id.to_results()
114
+ >>> r = id.to_results(disable_remote_cache = True, disable_remote_inference = True)
113
115
  >>> len(r) == id.num_observations
114
116
  True
115
117
  """
@@ -125,7 +127,10 @@ class AgentConstructionMixin:
125
127
  import time
126
128
 
127
129
  start = time.time()
128
- _ = survey.by(agent_list.sample(DRYRUN_SAMPLE)).run()
130
+ _ = survey.by(agent_list.sample(DRYRUN_SAMPLE)).run(
131
+ disable_remote_cache=disable_remote_cache,
132
+ disable_remote_inference=disable_remote_inference,
133
+ )
129
134
  end = time.time()
130
135
  print(f"Time to run {DRYRUN_SAMPLE} agents (s): {round(end - start, 2)}")
131
136
  time_per_agent = (end - start) / DRYRUN_SAMPLE
@@ -143,7 +148,10 @@ class AgentConstructionMixin:
143
148
  f"Full sample will take about {round(full_sample_time / 3600, 2)} hours."
144
149
  )
145
150
  return None
146
- return survey.by(agent_list).run()
151
+ return survey.by(agent_list).run(
152
+ disable_remote_cache=disable_remote_cache,
153
+ disable_remote_inference=disable_remote_inference,
154
+ )
147
155
 
148
156
 
149
157
  if __name__ == "__main__":
@@ -5,7 +5,7 @@ from typing import Optional, Callable
5
5
  from edsl import Agent, QuestionFreeText, Results, AgentList, ScenarioList, Scenario
6
6
  from edsl.questions import QuestionBase
7
7
  from edsl.results.Result import Result
8
-
8
+ from jinja2 import Template
9
9
  from edsl.data import Cache
10
10
 
11
11
  from edsl.conversation.next_speaker_utilities import (
@@ -54,6 +54,9 @@ class Conversation:
54
54
  """A conversation between a list of agents. The first agent in the list is the first speaker.
55
55
  After that, order is determined by the next_speaker function.
56
56
  The question asked to each agent is determined by the next_statement_question.
57
+
58
+ If the user has passed in a "per_round_message_template", this will be displayed at the beginning of each round.
59
+ {{ round_message }} must be in the question_text.
57
60
  """
58
61
 
59
62
  def __init__(
@@ -64,28 +67,62 @@ class Conversation:
64
67
  next_statement_question: Optional[QuestionBase] = None,
65
68
  next_speaker_generator: Optional[Callable] = None,
66
69
  verbose: bool = False,
70
+ per_round_message_template: Optional[str] = None,
67
71
  conversation_index: Optional[int] = None,
68
72
  cache=None,
73
+ disable_remote_inference=False,
74
+ default_model: Optional["LanguageModel"] = None,
69
75
  ):
76
+ self.disable_remote_inference = disable_remote_inference
77
+ self.per_round_message_template = per_round_message_template
78
+
70
79
  if cache is None:
71
80
  self.cache = Cache()
72
81
  else:
73
82
  self.cache = cache
74
83
 
75
84
  self.agent_list = agent_list
85
+
86
+ from edsl import Model
87
+
88
+ for agent in self.agent_list:
89
+ if not hasattr(agent, "model"):
90
+ if default_model is not None:
91
+ agent.model = default_model
92
+ else:
93
+ agent.model = Model()
94
+
76
95
  self.verbose = verbose
77
96
  self.agent_statements = []
78
97
  self._conversation_index = conversation_index
79
-
80
98
  self.agent_statements = AgentStatements()
81
99
 
82
100
  self.max_turns = max_turns
83
101
 
84
102
  if next_statement_question is None:
103
+ import textwrap
104
+
105
+ base_question = textwrap.dedent(
106
+ """\
107
+ You are {{ agent_name }}. This is the conversation so far: {{ conversation }}
108
+ {% if round_message is not none %}
109
+ {{ round_message }}
110
+ {% endif %}
111
+ What do you say next?"""
112
+ )
85
113
  self.next_statement_question = QuestionFreeText(
86
- question_text="You are {{ agent_name }}. This is the converstaion so far: {{ conversation }}. What do you say next?",
114
+ question_text=base_question,
87
115
  question_name="dialogue",
88
116
  )
117
+ else:
118
+ self.next_statement_question = next_statement_question
119
+ if (
120
+ per_round_message_template
121
+ and "{{ round_message }}" not in next_statement_question.question_text
122
+ ):
123
+ raise ValueError(
124
+ "If you pass in a per_round_message_template, you must include {{ round_message }} in the question_text."
125
+ )
89
126
 
90
127
  # Determine how the next speaker is chosen
91
128
  if next_speaker_generator is None:
@@ -93,6 +130,7 @@ class Conversation:
93
130
  else:
94
131
  func = next_speaker_generator
95
132
 
133
+ # Choose the next speaker
96
134
  self.next_speaker = speaker_closure(
97
135
  agent_list=self.agent_list, generator_function=func
98
136
  )
@@ -158,17 +196,32 @@ class Conversation:
158
196
  }
159
197
  return Scenario(d)
160
198
 
161
- async def get_next_statement(self, *, index, speaker, conversation):
199
+ async def get_next_statement(self, *, index, speaker, conversation) -> "Result":
200
+ """Get the next statement from the speaker."""
162
201
  q = self.next_statement_question
163
- assert q.parameters == {"agent_name", "conversation"}, q.parameters
164
- results = await q.run_async(
165
- index=index,
166
- conversation=conversation,
167
- conversation_index=self.conversation_index,
168
- agent_name=speaker.name,
169
- agent=speaker,
170
- just_answer=False,
171
- cache=self.cache,
202
+ # assert q.parameters == {"agent_name", "conversation"}, q.parameters
203
+ from edsl import Scenario
204
+
205
+ if self.per_round_message_template is None:
206
+ round_message = None
207
+ else:
208
+ round_message = Template(self.per_round_message_template).render(
209
+ {"max_turns": self.max_turns, "current_turn": index}
210
+ )
211
+
212
+ s = Scenario(
213
+ {
214
+ "agent_name": speaker.name,
215
+ "conversation": conversation,
216
+ "conversation_index": self.conversation_index,
217
+ "index": index,
218
+ "round_message": round_message,
219
+ }
220
+ )
221
+ jobs = q.by(s).by(speaker).by(speaker.model)
222
+ jobs.show_prompts()
223
+ results = await jobs.run_async(
224
+ cache=self.cache, disable_remote_inference=self.disable_remote_inference
172
225
  )
173
226
  return results[0]
174
227
 
@@ -179,7 +232,6 @@ class Conversation:
179
232
  i = 0
180
233
  while await self.continue_conversation():
181
234
  speaker = self.next_speaker()
182
- # breakpoint()
183
235
 
184
236
  next_statement = AgentStatement(
185
237
  statement=await self.get_next_statement(
@@ -0,0 +1,95 @@
1
+ from typing import Optional
2
+
3
+ from edsl import Agent, AgentList, QuestionFreeText
4
+ from edsl import Cache
5
+ from edsl import QuestionList
6
+ from edsl import Model
7
+
8
+ from edsl.conversation.Conversation import Conversation, ConversationList
9
+
10
+ m = Model("gemini-1.5-flash")
11
+
12
+
13
+ class ChipLover(Agent):
14
+ def __init__(self, name, chip_values, initial_chips, model: Optional[Model] = None):
15
+ self.chip_values = chip_values
16
+ self.initial_chips = initial_chips
17
+ self.current_chips = initial_chips
18
+ self.model = model or Model()
19
+ super().__init__(
20
+ name=name,
21
+ traits={
22
+ "motivation": f"""
23
+ You are {name}. You are negotiating the trading of colored 'chips' with other players. You want to maximize your score.
24
+ When you want to accept a deal, say "DEAL!"
25
+ Note that different players can have different values for the chips.
26
+ """,
27
+ "chip_values": chip_values,
28
+ "initial_chips": initial_chips,
29
+ },
30
+ )
31
+
32
+ def trade(self, chips_given_dict, chips_received_dict):
33
+ for color, amount in chips_given_dict.items():
34
+ self.current_chips[color] -= amount
35
+ for color, amount in chips_received_dict.items():
36
+ self.current_chips[color] += amount
37
+
38
+ def get_score(self):
39
+ return sum(
40
+ self.chip_values[color] * self.current_chips[color]
41
+ for color in self.chip_values
42
+ )
43
+
44
+
45
+ a1 = ChipLover(
46
+ name="Alice",
47
+ chip_values={"Green": 7, "Blue": 1, "Red": 0},
48
+ model=Model("gemini-1.5-flash"),
49
+ initial_chips={"Green": 1, "Blue": 2, "Red": 3},
50
+ )
51
+ a2 = ChipLover(
52
+ name="Bob",
53
+ chip_values={"Green": 7, "Blue": 1, "Red": 0},
54
+ initial_chips={"Green": 1, "Blue": 2, "Red": 3},
55
+ )
56
+
57
+ c1 = Conversation(agent_list=AgentList([a1, a2]), max_turns=10, verbose=True)
58
+ c2 = Conversation(agent_list=AgentList([a1, a2]), max_turns=10, verbose=True)
59
+
60
+ with Cache() as c:
61
+ combo = ConversationList([c1, c2], cache=c)
62
+ combo.run()
63
+ results = combo.to_results()
64
+ results.select("conversation_index", "index", "agent_name", "dialogue").print(
65
+ format="rich"
66
+ )
67
+
68
+ q = QuestionFreeText(
69
+ question_text="""This was a conversation/negotiation: {{ transcript }}.
70
+ What trades occurred in the conversation?
71
+ """,
72
+ question_name="trades",
73
+ )
74
+
75
+ q_actors = QuestionList(
76
+ question_text="""Here is a transcript: {{ transcript }}.
77
+ Who were the actors in the conversation?
78
+ """,
79
+ question_name="actors",
80
+ )
81
+
82
+ from edsl import QuestionList
83
+
84
+ q_transfers = QuestionList(
85
+ question_text="""This was a conversation/negotiation: {{ transcript }}.
86
+ Extract all offers and their outcomes.
87
+ Use this format: {'proposing_agent':"Alice": 'receiving_agent': "Bob", 'gives':{"Green": 1, "Blue": 2}, 'receives':{"Green": 2, "Blue": 1}, 'accepted':True}
88
+ """,
89
+ question_name="transfers",
90
+ )
91
+
92
+ transcript_analysis = (
93
+ q.add_question(q_actors).add_question(q_transfers).by(combo.summarize()).run()
94
+ )
95
+ transcript_analysis.select("trades", "actors", "transfers").print(format="rich")