edsl 0.1.51__py3-none-any.whl → 0.1.52__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
edsl/__init__.py CHANGED
@@ -15,39 +15,42 @@ from edsl import logger
15
15
  # Set up logger with configuration from environment/config
16
16
  # (We'll configure the logger after CONFIG is initialized below)
17
17
 
18
- __all__ = ['logger']
18
+ __all__ = ["logger"]
19
19
 
20
20
  # Define modules to import
21
21
  modules_to_import = [
22
- 'dataset',
23
- 'agents',
24
- 'surveys',
25
- 'questions',
26
- 'scenarios',
27
- 'language_models',
28
- 'results',
29
- 'caching',
30
- 'notebooks',
31
- 'coop',
32
- 'instructions',
33
- 'jobs'
22
+ "dataset",
23
+ "agents",
24
+ "surveys",
25
+ "questions",
26
+ "scenarios",
27
+ "language_models",
28
+ "results",
29
+ "caching",
30
+ "notebooks",
31
+ "coop",
32
+ "instructions",
33
+ "jobs",
34
+ "conversation",
34
35
  ]
35
36
 
36
37
  # Dynamically import modules and extend __all__
37
38
  for module_name in modules_to_import:
38
39
  try:
39
40
  # Import the module
40
- module = importlib.import_module(f'.{module_name}', package='edsl')
41
-
41
+ module = importlib.import_module(f".{module_name}", package="edsl")
42
+
42
43
  # Get the module's __all__ attribute
43
- module_all = getattr(module, '__all__', [])
44
-
44
+ module_all = getattr(module, "__all__", [])
45
+
45
46
  # Import all names from the module
46
47
  exec(f"from .{module_name} import *")
47
-
48
+
48
49
  # Extend __all__ with the module's __all__
49
50
  if module_all:
50
- logger.debug(f"Adding {len(module_all)} items from {module_name} to __all__")
51
+ logger.debug(
52
+ f"Adding {len(module_all)} items from {module_name} to __all__"
53
+ )
51
54
  __all__.extend(module_all)
52
55
  else:
53
56
  logger.warning(f"Module {module_name} does not have __all__ defined")
@@ -61,39 +64,43 @@ for module_name in modules_to_import:
61
64
  try:
62
65
  from edsl.load_plugins import load_plugins
63
66
  from edsl.plugins import get_plugin_manager, get_exports
64
-
67
+
65
68
  # Load all plugins
66
69
  plugins = load_plugins()
67
70
  logger.info(f"Loaded {len(plugins)} plugins")
68
-
71
+
69
72
  # Add plugins to globals and __all__
70
73
  for plugin_name, plugin in plugins.items():
71
74
  globals()[plugin_name] = plugin
72
75
  __all__.append(plugin_name)
73
76
  logger.info(f"Registered plugin {plugin_name} in global namespace")
74
-
77
+
75
78
  # Get exports from plugins and add them to globals
76
79
  exports = get_exports()
77
80
  logger.info(f"Found {len(exports)} exported objects from plugins")
78
-
81
+
79
82
  for name, obj in exports.items():
80
83
  globals()[name] = obj
81
84
  __all__.append(name)
82
85
  logger.info(f"Added plugin export: {name}")
83
-
86
+
84
87
  # Add placeholders for expected exports that are missing
85
88
  # This maintains backward compatibility for common plugins
86
89
  PLUGIN_PLACEHOLDERS = {
87
90
  # No placeholders - removed Conjure for cleaner namespace
88
91
  }
89
-
92
+
90
93
  for placeholder_name, github_url in PLUGIN_PLACEHOLDERS.items():
91
94
  if placeholder_name not in globals():
92
95
  # Create a placeholder class
93
- placeholder_class = type(placeholder_name, (), {
94
- "__getattr__": lambda self, name: self._not_installed(name),
95
- "_not_installed": lambda self, name: self._raise_import_error(),
96
- "_raise_import_error": lambda self: exec(f"""
96
+ placeholder_class = type(
97
+ placeholder_name,
98
+ (),
99
+ {
100
+ "__getattr__": lambda self, name: self._not_installed(name),
101
+ "_not_installed": lambda self, name: self._raise_import_error(),
102
+ "_raise_import_error": lambda self: exec(
103
+ f"""
97
104
  msg = (
98
105
  "The {placeholder_name} plugin is not installed. "
99
106
  "To use {placeholder_name} with EDSL, install it using:\\n"
@@ -104,13 +111,17 @@ msg = (
104
111
  )
105
112
  logger.warning(msg)
106
113
  raise ImportError(msg)
107
- """)
108
- })
109
-
114
+ """
115
+ ),
116
+ },
117
+ )
118
+
110
119
  # Register the placeholder
111
120
  globals()[placeholder_name] = placeholder_class()
112
121
  __all__.append(placeholder_name)
113
- logger.info(f"Added placeholder for {placeholder_name} with installation instructions")
122
+ logger.info(
123
+ f"Added placeholder for {placeholder_name} with installation instructions"
124
+ )
114
125
 
115
126
  except ImportError as e:
116
127
  # Modules not available
@@ -127,8 +138,8 @@ logger.configure_from_config()
127
138
 
128
139
  # Installs a custom exception handling routine for edsl exceptions
129
140
  from .base.base_exception import BaseException
141
+
130
142
  BaseException.install_exception_hook()
131
143
 
132
144
  # Log the total number of items in __all__ for debugging
133
145
  logger.debug(f"EDSL initialization complete with {len(__all__)} items in __all__")
134
-
edsl/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.1.51"
1
+ __version__ = "0.1.52"
@@ -4,7 +4,7 @@ import inspect
4
4
  from typing import Optional, Callable, TYPE_CHECKING
5
5
  from .. import QuestionFreeText, Results, AgentList, ScenarioList, Scenario, Model
6
6
  from ..questions import QuestionBase
7
- from ..results.Result import Result
7
+ from ..results.result import Result
8
8
  from jinja2 import Template
9
9
  from ..caching import Cache
10
10
 
@@ -124,6 +124,7 @@ What do you say next?"""
124
124
  and "{{ round_message }}" not in next_statement_question.question_text
125
125
  ):
126
126
  from .exceptions import ConversationValueError
127
+
127
128
  raise ConversationValueError(
128
129
  "If you pass in a per_round_message_template, you must include {{ round_message }} in the question_text."
129
130
  )
edsl/coop/coop.py CHANGED
@@ -666,6 +666,8 @@ class Coop(CoopFunctionsMixin):
666
666
  )
667
667
  edsl_class = ObjectRegistry.object_type_to_edsl_class.get(object_type)
668
668
  object = edsl_class.from_dict(json.loads(json_string))
669
+ if object_type == "results":
670
+ object.initialize_cache_from_results()
669
671
  return object
670
672
 
671
673
  def get_all(self, object_type: ObjectType) -> list[dict[str, Any]]:
@@ -26,7 +26,6 @@ class RetryConfig:
26
26
 
27
27
 
28
28
  class SkipHandler:
29
-
30
29
  def __init__(self, interview: "Interview"):
31
30
  self.interview = interview
32
31
  self.question_index = self.interview.to_index
@@ -47,7 +46,7 @@ class SkipHandler:
47
46
 
48
47
  def _current_info_env(self) -> dict[str, Any]:
49
48
  """
50
- - The current answers are "generated_tokens" and "comment"
49
+ - The current answers are "generated_tokens" and "comment"
51
50
  - The scenario should have "scenario." added to the keys
52
51
  - The agent traits should have "agent." added to the keys
53
52
  """
@@ -65,10 +64,14 @@ class SkipHandler:
65
64
  processed_answers[f"{key}.answer"] = value
66
65
 
67
66
  # Process scenario dictionary
68
- processed_scenario = {f"scenario.{k}": v for k, v in self.interview.scenario.items()}
67
+ processed_scenario = {
68
+ f"scenario.{k}": v for k, v in self.interview.scenario.items()
69
+ }
69
70
 
70
71
  # Process agent traits
71
- processed_agent = {f"agent.{k}": v for k, v in self.interview.agent["traits"].items()}
72
+ processed_agent = {
73
+ f"agent.{k}": v for k, v in self.interview.agent["traits"].items()
74
+ }
72
75
 
73
76
  return processed_answers | processed_scenario | processed_agent
74
77
 
@@ -85,21 +88,22 @@ class SkipHandler:
85
88
  # )
86
89
 
87
90
  # Get the index of the next question, which could also be the end of the survey
88
- next_question: Union[int, EndOfSurvey] = (
89
- self.interview.survey.rule_collection.next_question(
90
- q_now=current_question_index,
91
- answers=answers,
92
- )
91
+ next_question: Union[
92
+ int, EndOfSurvey
93
+ ] = self.interview.survey.rule_collection.next_question(
94
+ q_now=current_question_index,
95
+ answers=answers,
93
96
  )
94
97
 
95
-
96
98
  def cancel_between(start, end):
97
99
  """Cancel the tasks for questions between the start and end indices."""
98
100
  for i in range(start, end):
99
- #print(f"Cancelling task {i}")
100
- #self.interview.tasks[i].cancel()
101
- #self.interview.tasks[i].set_result("skipped")
102
- self.interview.skip_flags[self.interview.survey.questions[i].question_name] = True
101
+ # print(f"Cancelling task {i}")
102
+ # self.interview.tasks[i].cancel()
103
+ # self.interview.tasks[i].set_result("skipped")
104
+ self.interview.skip_flags[
105
+ self.interview.survey.questions[i].question_name
106
+ ] = True
103
107
 
104
108
  if (next_question_index := next_question.next_q) == EndOfSurvey:
105
109
  cancel_between(
@@ -111,8 +115,6 @@ class SkipHandler:
111
115
  cancel_between(current_question_index + 1, next_question_index)
112
116
 
113
117
 
114
-
115
-
116
118
  class AnswerQuestionFunctionConstructor:
117
119
  """Constructs a function that answers a question and records the answer."""
118
120
 
@@ -137,7 +139,6 @@ class AnswerQuestionFunctionConstructor:
137
139
  ):
138
140
  """Handle an exception that occurred while answering a question."""
139
141
 
140
-
141
142
  answers = copy.copy(
142
143
  self.interview.answers
143
144
  ) # copy to freeze the answers here for logging
@@ -171,7 +172,6 @@ class AnswerQuestionFunctionConstructor:
171
172
  question: "QuestionBase",
172
173
  task=None,
173
174
  ) -> "EDSLResultObjectInput":
174
-
175
175
  from tenacity import (
176
176
  RetryError,
177
177
  retry,
@@ -196,7 +196,6 @@ class AnswerQuestionFunctionConstructor:
196
196
  return invigilator.get_failed_task_result(
197
197
  failure_reason="Question skipped."
198
198
  )
199
-
200
199
  if self.skip_handler.should_skip(question):
201
200
  return invigilator.get_failed_task_result(
202
201
  failure_reason="Question skipped."
@@ -240,7 +239,6 @@ class AnswerQuestionFunctionConstructor:
240
239
  raise LanguageModelNoResponseError(
241
240
  f"Language model did not return a response for question '{question.question_name}.'"
242
241
  )
243
-
244
242
  if (
245
243
  question.question_name in self.interview.exceptions
246
244
  and had_language_model_no_response_error
@@ -250,7 +248,8 @@ class AnswerQuestionFunctionConstructor:
250
248
  return response
251
249
 
252
250
  try:
253
- return await attempt_answer()
251
+ out = await attempt_answer()
252
+ return out
254
253
  except RetryError as retry_error:
255
254
  original_error = retry_error.last_attempt.exception()
256
255
  self._handle_exception(
@@ -81,6 +81,7 @@ class InterviewExceptionEntry:
81
81
  raise_validation_errors=True,
82
82
  disable_remote_cache=True,
83
83
  disable_remote_inference=True,
84
+ cache=False,
84
85
  )
85
86
  return results.task_history.exceptions[0]["how_are_you"][0]
86
87
 
@@ -92,13 +93,13 @@ class InterviewExceptionEntry:
92
93
  def code(self, run=True):
93
94
  """Return the code to reproduce the exception."""
94
95
  lines = []
95
- lines.append("from .. import Question, Model, Scenario, Agent")
96
+ lines.append("from edsl import Question, Model, Scenario, Agent")
96
97
 
97
98
  lines.append(f"q = {repr(self.invigilator.question)}")
98
99
  lines.append(f"scenario = {repr(self.invigilator.scenario)}")
99
100
  lines.append(f"agent = {repr(self.invigilator.agent)}")
100
- lines.append(f"m = Model('{self.invigilator.model.model}')")
101
- lines.append("results = q.by(m).by(agent).by(scenario).run()")
101
+ lines.append(f"model = {repr(self.invigilator.model)}")
102
+ lines.append("results = q.by(model).by(agent).by(scenario).run()")
102
103
  code_str = "\n".join(lines)
103
104
 
104
105
  if run:
@@ -24,6 +24,7 @@ class InterviewTaskManager:
24
24
  for index, question_name in enumerate(self.survey.question_names)
25
25
  }
26
26
  self._task_status_log_dict = InterviewStatusLog()
27
+ self.survey_dag = None
27
28
 
28
29
  def build_question_tasks(
29
30
  self, answer_func, token_estimator, model_buckets
@@ -46,8 +47,9 @@ class InterviewTaskManager:
46
47
  self, existing_tasks: list[asyncio.Task], question: "QuestionBase"
47
48
  ) -> list[asyncio.Task]:
48
49
  """Get tasks that must be completed before the given question."""
49
- dag = self.survey.dag(textify=True)
50
- parents = dag.get(question.question_name, [])
50
+ if self.survey_dag is None:
51
+ self.survey_dag = self.survey.dag(textify=True)
52
+ parents = self.survey_dag.get(question.question_name, [])
51
53
  return [existing_tasks[self.to_index[parent_name]] for parent_name in parents]
52
54
 
53
55
  def _create_single_task(
@@ -100,4 +102,5 @@ class InterviewTaskManager:
100
102
 
101
103
  if __name__ == "__main__":
102
104
  import doctest
105
+
103
106
  doctest.testmod()
@@ -24,11 +24,11 @@ if TYPE_CHECKING:
24
24
  from ..key_management import KeyLookup
25
25
 
26
26
 
27
-
28
27
  PromptType = Literal["user_prompt", "system_prompt", "encoded_image", "files_list"]
29
28
 
30
29
  NA = "Not Applicable"
31
30
 
31
+
32
32
  class InvigilatorBase(ABC):
33
33
  """An invigiator (someone who administers an exam) is a class that is responsible for administering a question to an agent.
34
34
 
@@ -261,13 +261,14 @@ class InvigilatorBase(ABC):
261
261
  current_answers=current_answers,
262
262
  )
263
263
 
264
+
264
265
  class InvigilatorAI(InvigilatorBase):
265
266
  """An invigilator that uses an AI model to answer questions."""
266
267
 
267
268
  def get_prompts(self) -> Dict[PromptType, "Prompt"]:
268
269
  """Return the prompts used."""
269
270
  return self.prompt_constructor.get_prompts()
270
-
271
+
271
272
  def get_captured_variables(self) -> dict:
272
273
  """Get the captured variables."""
273
274
  return self.prompt_constructor.get_captured_variables()
@@ -281,6 +282,7 @@ class InvigilatorAI(InvigilatorBase):
281
282
  if "encoded_image" in prompts:
282
283
  params["encoded_image"] = prompts["encoded_image"]
283
284
  from .exceptions import InvigilatorNotImplementedError
285
+
284
286
  raise InvigilatorNotImplementedError("encoded_image not implemented")
285
287
 
286
288
  if "files_list" in prompts:
@@ -307,7 +309,8 @@ class InvigilatorAI(InvigilatorBase):
307
309
  """
308
310
  agent_response_dict: AgentResponseDict = await self.async_get_agent_response()
309
311
  self.store_response(agent_response_dict)
310
- return self._extract_edsl_result_entry_and_validate(agent_response_dict)
312
+ out = self._extract_edsl_result_entry_and_validate(agent_response_dict)
313
+ return out
311
314
 
312
315
  def _remove_from_cache(self, cache_key) -> None:
313
316
  """Remove an entry from the cache."""
@@ -389,6 +392,30 @@ class InvigilatorAI(InvigilatorBase):
389
392
  edsl_dict = agent_response_dict.edsl_dict._asdict()
390
393
  exception_occurred = None
391
394
  validated = False
395
+
396
+ if agent_response_dict.model_outputs.cache_used:
397
+ data = {
398
+ "answer": agent_response_dict.edsl_dict.answer
399
+ if type(agent_response_dict.edsl_dict.answer) is str
400
+ else "",
401
+ "comment": agent_response_dict.edsl_dict.comment
402
+ if agent_response_dict.edsl_dict.comment
403
+ else "",
404
+ "generated_tokens": agent_response_dict.edsl_dict.generated_tokens,
405
+ "question_name": self.question.question_name,
406
+ "prompts": self.get_prompts(),
407
+ "cached_response": agent_response_dict.model_outputs.cached_response,
408
+ "raw_model_response": agent_response_dict.model_outputs.response,
409
+ "cache_used": agent_response_dict.model_outputs.cache_used,
410
+ "cache_key": agent_response_dict.model_outputs.cache_key,
411
+ "validated": True,
412
+ "exception_occurred": exception_occurred,
413
+ "cost": agent_response_dict.model_outputs.cost,
414
+ }
415
+
416
+ result = EDSLResultObjectInput(**data)
417
+ return result
418
+
392
419
  try:
393
420
  # if the question has jinja parameters, it is easier to make a new question with the parameters
394
421
  if self.question.parameters:
@@ -405,7 +432,7 @@ class InvigilatorAI(InvigilatorBase):
405
432
  self.question.question_options = new_question_options
406
433
 
407
434
  question_with_validators = self.question.render(
408
- self.scenario | prior_answers_dict | {'agent':self.agent.traits}
435
+ self.scenario | prior_answers_dict | {"agent": self.agent.traits}
409
436
  )
410
437
  question_with_validators.use_code = self.question.use_code
411
438
  else:
@@ -426,6 +453,7 @@ class InvigilatorAI(InvigilatorBase):
426
453
  exception_occurred = non_validation_error
427
454
  finally:
428
455
  # even if validation failes, we still return the result
456
+
429
457
  data = {
430
458
  "answer": answer,
431
459
  "comment": comment,