edsl 0.1.33.dev3__py3-none-any.whl → 0.1.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. edsl/Base.py +15 -11
  2. edsl/__version__.py +1 -1
  3. edsl/agents/Invigilator.py +22 -3
  4. edsl/agents/PromptConstructor.py +80 -184
  5. edsl/agents/prompt_helpers.py +129 -0
  6. edsl/coop/coop.py +3 -2
  7. edsl/data_transfer_models.py +0 -1
  8. edsl/inference_services/AnthropicService.py +5 -2
  9. edsl/inference_services/AwsBedrock.py +5 -2
  10. edsl/inference_services/AzureAI.py +5 -2
  11. edsl/inference_services/GoogleService.py +108 -33
  12. edsl/inference_services/MistralAIService.py +5 -2
  13. edsl/inference_services/OpenAIService.py +3 -2
  14. edsl/inference_services/TestService.py +11 -2
  15. edsl/inference_services/TogetherAIService.py +1 -1
  16. edsl/jobs/Jobs.py +91 -10
  17. edsl/jobs/interviews/Interview.py +15 -2
  18. edsl/jobs/runners/JobsRunnerAsyncio.py +46 -25
  19. edsl/jobs/runners/JobsRunnerStatus.py +4 -3
  20. edsl/jobs/tasks/QuestionTaskCreator.py +1 -13
  21. edsl/language_models/LanguageModel.py +12 -9
  22. edsl/language_models/utilities.py +5 -2
  23. edsl/questions/QuestionBase.py +13 -3
  24. edsl/questions/QuestionBaseGenMixin.py +28 -0
  25. edsl/questions/QuestionCheckBox.py +1 -1
  26. edsl/questions/QuestionMultipleChoice.py +8 -4
  27. edsl/questions/ResponseValidatorABC.py +5 -1
  28. edsl/questions/descriptors.py +12 -11
  29. edsl/questions/templates/numerical/answering_instructions.jinja +0 -1
  30. edsl/questions/templates/yes_no/answering_instructions.jinja +2 -2
  31. edsl/scenarios/FileStore.py +159 -76
  32. edsl/scenarios/Scenario.py +23 -49
  33. edsl/scenarios/ScenarioList.py +6 -2
  34. edsl/surveys/DAG.py +62 -0
  35. edsl/surveys/MemoryPlan.py +26 -0
  36. edsl/surveys/Rule.py +24 -0
  37. edsl/surveys/RuleCollection.py +36 -2
  38. edsl/surveys/Survey.py +182 -10
  39. edsl/surveys/base.py +4 -0
  40. {edsl-0.1.33.dev3.dist-info → edsl-0.1.34.dist-info}/METADATA +2 -1
  41. {edsl-0.1.33.dev3.dist-info → edsl-0.1.34.dist-info}/RECORD +43 -43
  42. edsl/scenarios/ScenarioImageMixin.py +0 -100
  43. {edsl-0.1.33.dev3.dist-info → edsl-0.1.34.dist-info}/LICENSE +0 -0
  44. {edsl-0.1.33.dev3.dist-info → edsl-0.1.34.dist-info}/WHEEL +0 -0
edsl/Base.py CHANGED
@@ -115,23 +115,27 @@ class PersistenceMixin:
115
115
  if filename.endswith("json.gz"):
116
116
  import warnings
117
117
 
118
- warnings.warn(
119
- "Do not apply the file extensions. The filename should not end with 'json.gz'."
120
- )
118
+ # warnings.warn(
119
+ # "Do not apply the file extensions. The filename should not end with 'json.gz'."
120
+ # )
121
121
  filename = filename[:-7]
122
122
  if filename.endswith("json"):
123
123
  filename = filename[:-4]
124
- warnings.warn(
125
- "Do not apply the file extensions. The filename should not end with 'json'."
126
- )
124
+ # warnings.warn(
125
+ # "Do not apply the file extensions. The filename should not end with 'json'."
126
+ # )
127
127
 
128
128
  if compress:
129
- with gzip.open(filename + ".json.gz", "wb") as f:
129
+ full_file_name = filename + ".json.gz"
130
+ with gzip.open(full_file_name, "wb") as f:
130
131
  f.write(json.dumps(self.to_dict()).encode("utf-8"))
131
132
  else:
133
+ full_file_name = filename + ".json"
132
134
  with open(filename + ".json", "w") as f:
133
135
  f.write(json.dumps(self.to_dict()))
134
136
 
137
+ print("Saved to", full_file_name)
138
+
135
139
  @staticmethod
136
140
  def open_compressed_file(filename):
137
141
  with gzip.open(filename, "rb") as f:
@@ -160,11 +164,11 @@ class PersistenceMixin:
160
164
  d = cls.open_regular_file(filename)
161
165
  else:
162
166
  try:
163
- d = cls.open_compressed_file(filename)
167
+ d = cls.open_compressed_file(filename + ".json.gz")
164
168
  except:
165
- d = cls.open_regular_file(filename)
166
- finally:
167
- raise ValueError("File must be a json or json.gz file")
169
+ d = cls.open_regular_file(filename + ".json")
170
+ # finally:
171
+ # raise ValueError("File must be a json or json.gz file")
168
172
 
169
173
  return cls.from_dict(d)
170
174
 
edsl/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.1.33.dev3"
1
+ __version__ = "0.1.34"
@@ -39,6 +39,8 @@ class InvigilatorAI(InvigilatorBase):
39
39
  }
40
40
  if "encoded_image" in prompts:
41
41
  params["encoded_image"] = prompts["encoded_image"]
42
+ if "files_list" in prompts:
43
+ params["files_list"] = prompts["files_list"]
42
44
 
43
45
  params.update({"iteration": self.iteration, "cache": self.cache})
44
46
 
@@ -80,15 +82,32 @@ class InvigilatorAI(InvigilatorBase):
80
82
  exception_occurred = None
81
83
  validated = False
82
84
  try:
83
- validated_edsl_dict = self.question._validate_answer(edsl_dict)
85
+ # if the question has jinja parameters, it might be easier to make a new question
86
+ # with those all filled in & then validate that
87
+ # breakpoint()
88
+ if self.question.parameters:
89
+ prior_answers_dict = self.prompt_constructor.prior_answers_dict()
90
+ question_with_validators = self.question.render(
91
+ self.scenario | prior_answers_dict
92
+ )
93
+ question_with_validators.use_code = self.question.use_code
94
+ # if question_with_validators.parameters:
95
+ # raise ValueError(
96
+ # f"The question still has parameters after rendering: {question_with_validators}"
97
+ # )
98
+ else:
99
+ question_with_validators = self.question
100
+
101
+ # breakpoint()
102
+ validated_edsl_dict = question_with_validators._validate_answer(edsl_dict)
84
103
  answer = self.determine_answer(validated_edsl_dict["answer"])
85
104
  comment = validated_edsl_dict.get("comment", "")
86
105
  validated = True
87
106
  except QuestionAnswerValidationError as e:
88
107
  answer = None
89
108
  comment = "The response was not valid."
90
- if self.raise_validation_errors:
91
- exception_occurred = e
109
+ # if self.raise_validation_errors:
110
+ exception_occurred = e
92
111
  except Exception as non_validation_error:
93
112
  answer = None
94
113
  comment = "Some other error occurred."
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
  from typing import Dict, Any, Optional, Set
3
3
  from collections import UserList
4
- import enum
4
+ import pdb
5
5
 
6
6
  from jinja2 import Environment, meta
7
7
 
@@ -10,12 +10,7 @@ from edsl.data_transfer_models import ImageInfo
10
10
  from edsl.prompts.registry import get_classes as prompt_lookup
11
11
  from edsl.exceptions import QuestionScenarioRenderError
12
12
 
13
-
14
- class PromptComponent(enum.Enum):
15
- AGENT_INSTRUCTIONS = "agent_instructions"
16
- AGENT_PERSONA = "agent_persona"
17
- QUESTION_INSTRUCTIONS = "question_instructions"
18
- PRIOR_QUESTION_MEMORY = "prior_question_memory"
13
+ from edsl.agents.prompt_helpers import PromptComponent, PromptList, PromptPlan
19
14
 
20
15
 
21
16
  def get_jinja2_variables(template_str: str) -> Set[str]:
@@ -33,127 +28,8 @@ def get_jinja2_variables(template_str: str) -> Set[str]:
33
28
  return meta.find_undeclared_variables(ast)
34
29
 
35
30
 
36
- class PromptList(UserList):
37
- separator = Prompt(" ")
38
-
39
- def reduce(self):
40
- """Reduce the list of prompts to a single prompt.
41
-
42
- >>> p = PromptList([Prompt("You are a happy-go lucky agent."), Prompt("You are an agent with the following persona: {'age': 22, 'hair': 'brown', 'height': 5.5}")])
43
- >>> p.reduce()
44
- Prompt(text=\"""You are a happy-go lucky agent. You are an agent with the following persona: {'age': 22, 'hair': 'brown', 'height': 5.5}\""")
45
-
46
- """
47
- p = self[0]
48
- for prompt in self[1:]:
49
- if len(prompt) > 0:
50
- p = p + self.separator + prompt
51
- return p
52
-
53
-
54
- class PromptPlan:
55
- """A plan for constructing prompts for the LLM call.
56
- Every prompt plan has a user prompt order and a system prompt order.
57
- It must contain each of the values in the PromptComponent enum.
58
-
59
-
60
- >>> p = PromptPlan(user_prompt_order=(PromptComponent.AGENT_INSTRUCTIONS, PromptComponent.AGENT_PERSONA),system_prompt_order=(PromptComponent.QUESTION_INSTRUCTIONS, PromptComponent.PRIOR_QUESTION_MEMORY))
61
- >>> p._is_valid_plan()
62
- True
63
-
64
- >>> p.arrange_components(agent_instructions=1, agent_persona=2, question_instructions=3, prior_question_memory=4)
65
- {'user_prompt': ..., 'system_prompt': ...}
66
-
67
- >>> p = PromptPlan(user_prompt_order=("agent_instructions", ), system_prompt_order=("question_instructions", "prior_question_memory"))
68
- Traceback (most recent call last):
69
- ...
70
- ValueError: Invalid plan: must contain each value of PromptComponent exactly once.
71
-
72
- """
73
-
74
- def __init__(
75
- self,
76
- user_prompt_order: Optional[tuple] = None,
77
- system_prompt_order: Optional[tuple] = None,
78
- ):
79
- """Initialize the PromptPlan."""
80
-
81
- if user_prompt_order is None:
82
- user_prompt_order = (
83
- PromptComponent.QUESTION_INSTRUCTIONS,
84
- PromptComponent.PRIOR_QUESTION_MEMORY,
85
- )
86
- if system_prompt_order is None:
87
- system_prompt_order = (
88
- PromptComponent.AGENT_INSTRUCTIONS,
89
- PromptComponent.AGENT_PERSONA,
90
- )
91
-
92
- # very commmon way to screw this up given how python treats single strings as iterables
93
- if isinstance(user_prompt_order, str):
94
- user_prompt_order = (user_prompt_order,)
95
-
96
- if isinstance(system_prompt_order, str):
97
- system_prompt_order = (system_prompt_order,)
98
-
99
- if not isinstance(user_prompt_order, tuple):
100
- raise TypeError(
101
- f"Expected a tuple, but got {type(user_prompt_order).__name__}"
102
- )
103
-
104
- if not isinstance(system_prompt_order, tuple):
105
- raise TypeError(
106
- f"Expected a tuple, but got {type(system_prompt_order).__name__}"
107
- )
108
-
109
- self.user_prompt_order = self._convert_to_enum(user_prompt_order)
110
- self.system_prompt_order = self._convert_to_enum(system_prompt_order)
111
- if not self._is_valid_plan():
112
- raise ValueError(
113
- "Invalid plan: must contain each value of PromptComponent exactly once."
114
- )
115
-
116
- def _convert_to_enum(self, prompt_order: tuple):
117
- """Convert string names to PromptComponent enum values."""
118
- return tuple(
119
- PromptComponent(component) if isinstance(component, str) else component
120
- for component in prompt_order
121
- )
122
-
123
- def _is_valid_plan(self):
124
- """Check if the plan is valid."""
125
- combined = self.user_prompt_order + self.system_prompt_order
126
- return set(combined) == set(PromptComponent)
127
-
128
- def arrange_components(self, **kwargs) -> Dict[PromptComponent, Prompt]:
129
- """Arrange the components in the order specified by the plan."""
130
- # check is valid components passed
131
- component_strings = set([pc.value for pc in PromptComponent])
132
- if not set(kwargs.keys()) == component_strings:
133
- raise ValueError(
134
- f"Invalid components passed: {set(kwargs.keys())} but expected {PromptComponent}"
135
- )
136
-
137
- user_prompt = PromptList(
138
- [kwargs[component.value] for component in self.user_prompt_order]
139
- )
140
- system_prompt = PromptList(
141
- [kwargs[component.value] for component in self.system_prompt_order]
142
- )
143
- return {"user_prompt": user_prompt, "system_prompt": system_prompt}
144
-
145
- def get_prompts(self, **kwargs) -> Dict[str, Prompt]:
146
- """Get both prompts for the LLM call."""
147
- prompts = self.arrange_components(**kwargs)
148
- return {
149
- "user_prompt": prompts["user_prompt"].reduce(),
150
- "system_prompt": prompts["system_prompt"].reduce(),
151
- }
152
-
153
-
154
31
  class PromptConstructor:
155
- """Mixin for constructing prompts for the LLM call.
156
-
32
+ """
157
33
  The pieces of a prompt are:
158
34
  - The agent instructions - "You are answering questions as if you were a human. Do not break character."
159
35
  - The persona prompt - "You are an agent with the following persona: {'age': 22, 'hair': 'brown', 'height': 5.5}"
@@ -172,18 +48,20 @@ class PromptConstructor:
172
48
  self.model = invigilator.model
173
49
  self.current_answers = invigilator.current_answers
174
50
  self.memory_plan = invigilator.memory_plan
175
- self.prompt_plan = PromptPlan() # Assuming PromptPlan is defined elsewhere
176
-
177
- # prompt_plan = PromptPlan()
51
+ self.prompt_plan = PromptPlan()
178
52
 
179
53
  @property
180
- def scenario_image_keys(self):
181
- image_entries = []
54
+ def scenario_file_keys(self) -> list:
55
+ """We need to find all the keys in the scenario that refer to FileStore objects.
56
+ These will be used to append to the prompt a list of files that are part of the scenario.
57
+ """
58
+ from edsl.scenarios.FileStore import FileStore
182
59
 
60
+ file_entries = []
183
61
  for key, value in self.scenario.items():
184
- if isinstance(value, ImageInfo):
185
- image_entries.append(key)
186
- return image_entries
62
+ if isinstance(value, FileStore):
63
+ file_entries.append(key)
64
+ return file_entries
187
65
 
188
66
  @property
189
67
  def agent_instructions_prompt(self) -> Prompt:
@@ -219,47 +97,51 @@ class PromptConstructor:
219
97
  {'age': 22, 'hair': 'brown', 'height': 5.5}\""")
220
98
 
221
99
  """
222
- if not hasattr(self, "_agent_persona_prompt"):
223
- from edsl import Agent
100
+ from edsl import Agent
224
101
 
225
- if self.agent == Agent(): # if agent is empty, then return an empty prompt
226
- return Prompt(text="")
102
+ if hasattr(self, "_agent_persona_prompt"):
103
+ return self._agent_persona_prompt
227
104
 
228
- if not hasattr(self.agent, "agent_persona"):
229
- applicable_prompts = prompt_lookup(
230
- component_type="agent_persona",
231
- model=self.model.model,
232
- )
233
- persona_prompt_template = applicable_prompts[0]()
234
- else:
235
- persona_prompt_template = self.agent.agent_persona
236
-
237
- # TODO: This multiple passing of agent traits - not sure if it is necessary. Not harmful.
238
- if undefined := persona_prompt_template.undefined_template_variables(
239
- self.agent.traits
240
- | {"traits": self.agent.traits}
241
- | {"codebook": self.agent.codebook}
242
- | {"traits": self.agent.traits}
243
- ):
244
- raise QuestionScenarioRenderError(
245
- f"Agent persona still has variables that were not rendered: {undefined}"
246
- )
105
+ if self.agent == Agent(): # if agent is empty, then return an empty prompt
106
+ return Prompt(text="")
247
107
 
248
- persona_prompt = persona_prompt_template.render(
249
- self.agent.traits | {"traits": self.agent.traits},
250
- codebook=self.agent.codebook,
251
- traits=self.agent.traits,
108
+ if not hasattr(self.agent, "agent_persona"):
109
+ applicable_prompts = prompt_lookup(
110
+ component_type="agent_persona",
111
+ model=self.model.model,
252
112
  )
253
- if persona_prompt.has_variables:
254
- raise QuestionScenarioRenderError(
255
- "Agent persona still has variables that were not rendered."
256
- )
257
- self._agent_persona_prompt = persona_prompt
113
+ persona_prompt_template = applicable_prompts[0]()
114
+ else:
115
+ persona_prompt_template = self.agent.agent_persona
116
+
117
+ # TODO: This multiple passing of agent traits - not sure if it is necessary. Not harmful.
118
+ template_parameter_dictionary = (
119
+ self.agent.traits
120
+ | {"traits": self.agent.traits}
121
+ | {"codebook": self.agent.codebook}
122
+ | {"traits": self.agent.traits}
123
+ )
124
+
125
+ if undefined := persona_prompt_template.undefined_template_variables(
126
+ template_parameter_dictionary
127
+ ):
128
+ raise QuestionScenarioRenderError(
129
+ f"Agent persona still has variables that were not rendered: {undefined}"
130
+ )
131
+
132
+ persona_prompt = persona_prompt_template.render(template_parameter_dictionary)
133
+ if persona_prompt.has_variables:
134
+ raise QuestionScenarioRenderError(
135
+ "Agent persona still has variables that were not rendered."
136
+ )
137
+
138
+ self._agent_persona_prompt = persona_prompt
258
139
 
259
140
  return self._agent_persona_prompt
260
141
 
261
142
  def prior_answers_dict(self) -> dict:
262
143
  d = self.survey.question_names_to_questions()
144
+ # This attaches the answer to the question
263
145
  for question, answer in self.current_answers.items():
264
146
  if question in d:
265
147
  d[question].answer = answer
@@ -270,14 +152,14 @@ class PromptConstructor:
270
152
  return d
271
153
 
272
154
  @property
273
- def question_image_keys(self):
155
+ def question_file_keys(self):
274
156
  raw_question_text = self.question.question_text
275
157
  variables = get_jinja2_variables(raw_question_text)
276
- question_image_keys = []
158
+ question_file_keys = []
277
159
  for var in variables:
278
- if var in self.scenario_image_keys:
279
- question_image_keys.append(var)
280
- return question_image_keys
160
+ if var in self.scenario_file_keys:
161
+ question_file_keys.append(var)
162
+ return question_file_keys
281
163
 
282
164
  @property
283
165
  def question_instructions_prompt(self) -> Prompt:
@@ -288,15 +170,17 @@ class PromptConstructor:
288
170
  Prompt(text=\"""...
289
171
  ...
290
172
  """
173
+ # The user might have passed a custom prompt, which would be stored in _question_instructions_prompt
291
174
  if not hasattr(self, "_question_instructions_prompt"):
175
+ # Gets the instructions for the question - this is how the question should be answered
292
176
  question_prompt = self.question.get_instructions(model=self.model.model)
293
177
 
294
- # Are any of the scenario values ImageInfo
295
-
178
+ # Get the data for the question - this is a dictionary of the question data
179
+ # e.g., {'question_text': 'Do you like school?', 'question_name': 'q0', 'question_options': ['yes', 'no']}
296
180
  question_data = self.question.data.copy()
297
181
 
298
182
  # check to see if the question_options is actually a string
299
- # This is used when the user is using the question_options as a variable from a sceario
183
+ # This is used when the user is using the question_options as a variable from a scenario
300
184
  # if "question_options" in question_data:
301
185
  if isinstance(self.question.data.get("question_options", None), str):
302
186
  env = Environment()
@@ -305,19 +189,31 @@ class PromptConstructor:
305
189
  meta.find_undeclared_variables(parsed_content)
306
190
  )[0]
307
191
 
192
+ # look to see if the question_option_key is in the scenario
308
193
  if isinstance(
309
194
  question_options := self.scenario.get(question_option_key), list
310
195
  ):
311
196
  question_data["question_options"] = question_options
312
197
  self.question.question_options = question_options
313
198
 
199
+ # might be getting it from the prior answers
200
+ if self.prior_answers_dict().get(question_option_key) is not None:
201
+ if isinstance(
202
+ question_options := self.prior_answers_dict()
203
+ .get(question_option_key)
204
+ .answer,
205
+ list,
206
+ ):
207
+ question_data["question_options"] = question_options
208
+ self.question.question_options = question_options
209
+
314
210
  replacement_dict = (
315
- {key: "<see image>" for key in self.scenario_image_keys}
211
+ {key: f"<see file {key}>" for key in self.scenario_file_keys}
316
212
  | question_data
317
213
  | {
318
214
  k: v
319
215
  for k, v in self.scenario.items()
320
- if k not in self.scenario_image_keys
216
+ if k not in self.scenario_file_keys
321
217
  } # don't include images in the replacement dict
322
218
  | self.prior_answers_dict()
323
219
  | {"agent": self.agent}
@@ -345,6 +241,7 @@ class PromptConstructor:
345
241
  )
346
242
 
347
243
  if undefined_template_variables:
244
+ # breakpoint()
348
245
  raise QuestionScenarioRenderError(
349
246
  f"Question instructions still has variables: {undefined_template_variables}."
350
247
  )
@@ -425,19 +322,18 @@ class PromptConstructor:
425
322
  >>> i.get_prompts()
426
323
  {'user_prompt': ..., 'system_prompt': ...}
427
324
  """
325
+ # breakpoint()
428
326
  prompts = self.prompt_plan.get_prompts(
429
327
  agent_instructions=self.agent_instructions_prompt,
430
328
  agent_persona=self.agent_persona_prompt,
431
329
  question_instructions=self.question_instructions_prompt,
432
330
  prior_question_memory=self.prior_question_memory_prompt,
433
331
  )
434
- if len(self.question_image_keys) > 1:
435
- raise ValueError("We can only handle one image per question.")
436
- elif len(self.question_image_keys) == 1:
437
- prompts["encoded_image"] = self.scenario[
438
- self.question_image_keys[0]
439
- ].encoded_image
440
-
332
+ if self.question_file_keys:
333
+ files_list = []
334
+ for key in self.question_file_keys:
335
+ files_list.append(self.scenario[key])
336
+ prompts["files_list"] = files_list
441
337
  return prompts
442
338
 
443
339
  def _get_scenario_with_image(self) -> Scenario:
@@ -0,0 +1,129 @@
1
+ import enum
2
+ from typing import Dict, Optional
3
+ from collections import UserList
4
+ from edsl.prompts import Prompt
5
+
6
+
7
+ class PromptComponent(enum.Enum):
8
+ AGENT_INSTRUCTIONS = "agent_instructions"
9
+ AGENT_PERSONA = "agent_persona"
10
+ QUESTION_INSTRUCTIONS = "question_instructions"
11
+ PRIOR_QUESTION_MEMORY = "prior_question_memory"
12
+
13
+
14
+ class PromptList(UserList):
15
+ separator = Prompt(" ")
16
+
17
+ def reduce(self):
18
+ """Reduce the list of prompts to a single prompt.
19
+
20
+ >>> p = PromptList([Prompt("You are a happy-go lucky agent."), Prompt("You are an agent with the following persona: {'age': 22, 'hair': 'brown', 'height': 5.5}")])
21
+ >>> p.reduce()
22
+ Prompt(text=\"""You are a happy-go lucky agent. You are an agent with the following persona: {'age': 22, 'hair': 'brown', 'height': 5.5}\""")
23
+
24
+ """
25
+ p = self[0]
26
+ for prompt in self[1:]:
27
+ if len(prompt) > 0:
28
+ p = p + self.separator + prompt
29
+ return p
30
+
31
+
32
+ class PromptPlan:
33
+ """A plan for constructing prompts for the LLM call.
34
+ Every prompt plan has a user prompt order and a system prompt order.
35
+ It must contain each of the values in the PromptComponent enum.
36
+
37
+
38
+ >>> p = PromptPlan(user_prompt_order=(PromptComponent.AGENT_INSTRUCTIONS, PromptComponent.AGENT_PERSONA),system_prompt_order=(PromptComponent.QUESTION_INSTRUCTIONS, PromptComponent.PRIOR_QUESTION_MEMORY))
39
+ >>> p._is_valid_plan()
40
+ True
41
+
42
+ >>> p.arrange_components(agent_instructions=1, agent_persona=2, question_instructions=3, prior_question_memory=4)
43
+ {'user_prompt': ..., 'system_prompt': ...}
44
+
45
+ >>> p = PromptPlan(user_prompt_order=("agent_instructions", ), system_prompt_order=("question_instructions", "prior_question_memory"))
46
+ Traceback (most recent call last):
47
+ ...
48
+ ValueError: Invalid plan: must contain each value of PromptComponent exactly once.
49
+
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ user_prompt_order: Optional[tuple] = None,
55
+ system_prompt_order: Optional[tuple] = None,
56
+ ):
57
+ """Initialize the PromptPlan."""
58
+
59
+ if user_prompt_order is None:
60
+ user_prompt_order = (
61
+ PromptComponent.QUESTION_INSTRUCTIONS,
62
+ PromptComponent.PRIOR_QUESTION_MEMORY,
63
+ )
64
+ if system_prompt_order is None:
65
+ system_prompt_order = (
66
+ PromptComponent.AGENT_INSTRUCTIONS,
67
+ PromptComponent.AGENT_PERSONA,
68
+ )
69
+
70
+ # very commmon way to screw this up given how python treats single strings as iterables
71
+ if isinstance(user_prompt_order, str):
72
+ user_prompt_order = (user_prompt_order,)
73
+
74
+ if isinstance(system_prompt_order, str):
75
+ system_prompt_order = (system_prompt_order,)
76
+
77
+ if not isinstance(user_prompt_order, tuple):
78
+ raise TypeError(
79
+ f"Expected a tuple, but got {type(user_prompt_order).__name__}"
80
+ )
81
+
82
+ if not isinstance(system_prompt_order, tuple):
83
+ raise TypeError(
84
+ f"Expected a tuple, but got {type(system_prompt_order).__name__}"
85
+ )
86
+
87
+ self.user_prompt_order = self._convert_to_enum(user_prompt_order)
88
+ self.system_prompt_order = self._convert_to_enum(system_prompt_order)
89
+ if not self._is_valid_plan():
90
+ raise ValueError(
91
+ "Invalid plan: must contain each value of PromptComponent exactly once."
92
+ )
93
+
94
+ def _convert_to_enum(self, prompt_order: tuple):
95
+ """Convert string names to PromptComponent enum values."""
96
+ return tuple(
97
+ PromptComponent(component) if isinstance(component, str) else component
98
+ for component in prompt_order
99
+ )
100
+
101
+ def _is_valid_plan(self):
102
+ """Check if the plan is valid."""
103
+ combined = self.user_prompt_order + self.system_prompt_order
104
+ return set(combined) == set(PromptComponent)
105
+
106
+ def arrange_components(self, **kwargs) -> Dict[PromptComponent, Prompt]:
107
+ """Arrange the components in the order specified by the plan."""
108
+ # check is valid components passed
109
+ component_strings = set([pc.value for pc in PromptComponent])
110
+ if not set(kwargs.keys()) == component_strings:
111
+ raise ValueError(
112
+ f"Invalid components passed: {set(kwargs.keys())} but expected {PromptComponent}"
113
+ )
114
+
115
+ user_prompt = PromptList(
116
+ [kwargs[component.value] for component in self.user_prompt_order]
117
+ )
118
+ system_prompt = PromptList(
119
+ [kwargs[component.value] for component in self.system_prompt_order]
120
+ )
121
+ return {"user_prompt": user_prompt, "system_prompt": system_prompt}
122
+
123
+ def get_prompts(self, **kwargs) -> Dict[str, Prompt]:
124
+ """Get both prompts for the LLM call."""
125
+ prompts = self.arrange_components(**kwargs)
126
+ return {
127
+ "user_prompt": prompts["user_prompt"].reduce(),
128
+ "system_prompt": prompts["system_prompt"].reduce(),
129
+ }
edsl/coop/coop.py CHANGED
@@ -803,8 +803,9 @@ def main():
803
803
  ##############
804
804
  job = Jobs.example()
805
805
  coop.remote_inference_cost(job)
806
- results = coop.remote_inference_create(job)
807
- coop.remote_inference_get(results.get("uuid"))
806
+ job_coop_object = coop.remote_inference_create(job)
807
+ job_coop_results = coop.remote_inference_get(job_coop_object.get("uuid"))
808
+ coop.get(uuid=job_coop_results.get("results_uuid"))
808
809
 
809
810
  ##############
810
811
  # E. Errors
@@ -3,7 +3,6 @@ from dataclasses import dataclass, fields
3
3
  import reprlib
4
4
 
5
5
 
6
-
7
6
  class ModelInputs(NamedTuple):
8
7
  "This is what was send by the agent to the model"
9
8
  user_prompt: str
@@ -1,5 +1,5 @@
1
1
  import os
2
- from typing import Any
2
+ from typing import Any, Optional, List
3
3
  import re
4
4
  from anthropic import AsyncAnthropic
5
5
  from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
@@ -60,7 +60,10 @@ class AnthropicService(InferenceServiceABC):
60
60
  _rpm = cls.get_rpm(cls)
61
61
 
62
62
  async def async_execute_model_call(
63
- self, user_prompt: str, system_prompt: str = ""
63
+ self,
64
+ user_prompt: str,
65
+ system_prompt: str = "",
66
+ files_list: Optional[List["Files"]] = None,
64
67
  ) -> dict[str, Any]:
65
68
  """Calls the OpenAI API and returns the API response."""
66
69
 
@@ -1,5 +1,5 @@
1
1
  import os
2
- from typing import Any
2
+ from typing import Any, List, Optional
3
3
  import re
4
4
  import boto3
5
5
  from botocore.exceptions import ClientError
@@ -69,7 +69,10 @@ class AwsBedrockService(InferenceServiceABC):
69
69
  _tpm = cls.get_tpm(cls)
70
70
 
71
71
  async def async_execute_model_call(
72
- self, user_prompt: str, system_prompt: str = ""
72
+ self,
73
+ user_prompt: str,
74
+ system_prompt: str = "",
75
+ files_list: Optional[List["FileStore"]] = None,
73
76
  ) -> dict[str, Any]:
74
77
  """Calls the AWS Bedrock API and returns the API response."""
75
78