edsl 0.1.51__py3-none-any.whl → 0.1.53__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
edsl/__init__.py CHANGED
@@ -15,39 +15,42 @@ from edsl import logger
15
15
  # Set up logger with configuration from environment/config
16
16
  # (We'll configure the logger after CONFIG is initialized below)
17
17
 
18
- __all__ = ['logger']
18
+ __all__ = ["logger"]
19
19
 
20
20
  # Define modules to import
21
21
  modules_to_import = [
22
- 'dataset',
23
- 'agents',
24
- 'surveys',
25
- 'questions',
26
- 'scenarios',
27
- 'language_models',
28
- 'results',
29
- 'caching',
30
- 'notebooks',
31
- 'coop',
32
- 'instructions',
33
- 'jobs'
22
+ "dataset",
23
+ "agents",
24
+ "surveys",
25
+ "questions",
26
+ "scenarios",
27
+ "language_models",
28
+ "results",
29
+ "caching",
30
+ "notebooks",
31
+ "coop",
32
+ "instructions",
33
+ "jobs",
34
+ "conversation",
34
35
  ]
35
36
 
36
37
  # Dynamically import modules and extend __all__
37
38
  for module_name in modules_to_import:
38
39
  try:
39
40
  # Import the module
40
- module = importlib.import_module(f'.{module_name}', package='edsl')
41
-
41
+ module = importlib.import_module(f".{module_name}", package="edsl")
42
+
42
43
  # Get the module's __all__ attribute
43
- module_all = getattr(module, '__all__', [])
44
-
44
+ module_all = getattr(module, "__all__", [])
45
+
45
46
  # Import all names from the module
46
47
  exec(f"from .{module_name} import *")
47
-
48
+
48
49
  # Extend __all__ with the module's __all__
49
50
  if module_all:
50
- logger.debug(f"Adding {len(module_all)} items from {module_name} to __all__")
51
+ logger.debug(
52
+ f"Adding {len(module_all)} items from {module_name} to __all__"
53
+ )
51
54
  __all__.extend(module_all)
52
55
  else:
53
56
  logger.warning(f"Module {module_name} does not have __all__ defined")
@@ -61,39 +64,43 @@ for module_name in modules_to_import:
61
64
  try:
62
65
  from edsl.load_plugins import load_plugins
63
66
  from edsl.plugins import get_plugin_manager, get_exports
64
-
67
+
65
68
  # Load all plugins
66
69
  plugins = load_plugins()
67
70
  logger.info(f"Loaded {len(plugins)} plugins")
68
-
71
+
69
72
  # Add plugins to globals and __all__
70
73
  for plugin_name, plugin in plugins.items():
71
74
  globals()[plugin_name] = plugin
72
75
  __all__.append(plugin_name)
73
76
  logger.info(f"Registered plugin {plugin_name} in global namespace")
74
-
77
+
75
78
  # Get exports from plugins and add them to globals
76
79
  exports = get_exports()
77
80
  logger.info(f"Found {len(exports)} exported objects from plugins")
78
-
81
+
79
82
  for name, obj in exports.items():
80
83
  globals()[name] = obj
81
84
  __all__.append(name)
82
85
  logger.info(f"Added plugin export: {name}")
83
-
86
+
84
87
  # Add placeholders for expected exports that are missing
85
88
  # This maintains backward compatibility for common plugins
86
89
  PLUGIN_PLACEHOLDERS = {
87
90
  # No placeholders - removed Conjure for cleaner namespace
88
91
  }
89
-
92
+
90
93
  for placeholder_name, github_url in PLUGIN_PLACEHOLDERS.items():
91
94
  if placeholder_name not in globals():
92
95
  # Create a placeholder class
93
- placeholder_class = type(placeholder_name, (), {
94
- "__getattr__": lambda self, name: self._not_installed(name),
95
- "_not_installed": lambda self, name: self._raise_import_error(),
96
- "_raise_import_error": lambda self: exec(f"""
96
+ placeholder_class = type(
97
+ placeholder_name,
98
+ (),
99
+ {
100
+ "__getattr__": lambda self, name: self._not_installed(name),
101
+ "_not_installed": lambda self, name: self._raise_import_error(),
102
+ "_raise_import_error": lambda self: exec(
103
+ f"""
97
104
  msg = (
98
105
  "The {placeholder_name} plugin is not installed. "
99
106
  "To use {placeholder_name} with EDSL, install it using:\\n"
@@ -104,13 +111,17 @@ msg = (
104
111
  )
105
112
  logger.warning(msg)
106
113
  raise ImportError(msg)
107
- """)
108
- })
109
-
114
+ """
115
+ ),
116
+ },
117
+ )
118
+
110
119
  # Register the placeholder
111
120
  globals()[placeholder_name] = placeholder_class()
112
121
  __all__.append(placeholder_name)
113
- logger.info(f"Added placeholder for {placeholder_name} with installation instructions")
122
+ logger.info(
123
+ f"Added placeholder for {placeholder_name} with installation instructions"
124
+ )
114
125
 
115
126
  except ImportError as e:
116
127
  # Modules not available
@@ -127,8 +138,8 @@ logger.configure_from_config()
127
138
 
128
139
  # Installs a custom exception handling routine for edsl exceptions
129
140
  from .base.base_exception import BaseException
141
+
130
142
  BaseException.install_exception_hook()
131
143
 
132
144
  # Log the total number of items in __all__ for debugging
133
145
  logger.debug(f"EDSL initialization complete with {len(__all__)} items in __all__")
134
-
edsl/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.1.51"
1
+ __version__ = "0.1.53"
@@ -4,7 +4,7 @@ import inspect
4
4
  from typing import Optional, Callable, TYPE_CHECKING
5
5
  from .. import QuestionFreeText, Results, AgentList, ScenarioList, Scenario, Model
6
6
  from ..questions import QuestionBase
7
- from ..results.Result import Result
7
+ from ..results.result import Result
8
8
  from jinja2 import Template
9
9
  from ..caching import Cache
10
10
 
@@ -124,6 +124,7 @@ What do you say next?"""
124
124
  and "{{ round_message }}" not in next_statement_question.question_text
125
125
  ):
126
126
  from .exceptions import ConversationValueError
127
+
127
128
  raise ConversationValueError(
128
129
  "If you pass in a per_round_message_template, you must include {{ round_message }} in the question_text."
129
130
  )
edsl/coop/coop.py CHANGED
@@ -666,6 +666,8 @@ class Coop(CoopFunctionsMixin):
666
666
  )
667
667
  edsl_class = ObjectRegistry.object_type_to_edsl_class.get(object_type)
668
668
  object = edsl_class.from_dict(json.loads(json_string))
669
+ if object_type == "results":
670
+ object.initialize_cache_from_results()
669
671
  return object
670
672
 
671
673
  def get_all(self, object_type: ObjectType) -> list[dict[str, Any]]:
@@ -26,7 +26,6 @@ class RetryConfig:
26
26
 
27
27
 
28
28
  class SkipHandler:
29
-
30
29
  def __init__(self, interview: "Interview"):
31
30
  self.interview = interview
32
31
  self.question_index = self.interview.to_index
@@ -47,7 +46,7 @@ class SkipHandler:
47
46
 
48
47
  def _current_info_env(self) -> dict[str, Any]:
49
48
  """
50
- - The current answers are "generated_tokens" and "comment"
49
+ - The current answers are "generated_tokens" and "comment"
51
50
  - The scenario should have "scenario." added to the keys
52
51
  - The agent traits should have "agent." added to the keys
53
52
  """
@@ -65,10 +64,14 @@ class SkipHandler:
65
64
  processed_answers[f"{key}.answer"] = value
66
65
 
67
66
  # Process scenario dictionary
68
- processed_scenario = {f"scenario.{k}": v for k, v in self.interview.scenario.items()}
67
+ processed_scenario = {
68
+ f"scenario.{k}": v for k, v in self.interview.scenario.items()
69
+ }
69
70
 
70
71
  # Process agent traits
71
- processed_agent = {f"agent.{k}": v for k, v in self.interview.agent["traits"].items()}
72
+ processed_agent = {
73
+ f"agent.{k}": v for k, v in self.interview.agent["traits"].items()
74
+ }
72
75
 
73
76
  return processed_answers | processed_scenario | processed_agent
74
77
 
@@ -85,21 +88,22 @@ class SkipHandler:
85
88
  # )
86
89
 
87
90
  # Get the index of the next question, which could also be the end of the survey
88
- next_question: Union[int, EndOfSurvey] = (
89
- self.interview.survey.rule_collection.next_question(
90
- q_now=current_question_index,
91
- answers=answers,
92
- )
91
+ next_question: Union[
92
+ int, EndOfSurvey
93
+ ] = self.interview.survey.rule_collection.next_question(
94
+ q_now=current_question_index,
95
+ answers=answers,
93
96
  )
94
97
 
95
-
96
98
  def cancel_between(start, end):
97
99
  """Cancel the tasks for questions between the start and end indices."""
98
100
  for i in range(start, end):
99
- #print(f"Cancelling task {i}")
100
- #self.interview.tasks[i].cancel()
101
- #self.interview.tasks[i].set_result("skipped")
102
- self.interview.skip_flags[self.interview.survey.questions[i].question_name] = True
101
+ # print(f"Cancelling task {i}")
102
+ # self.interview.tasks[i].cancel()
103
+ # self.interview.tasks[i].set_result("skipped")
104
+ self.interview.skip_flags[
105
+ self.interview.survey.questions[i].question_name
106
+ ] = True
103
107
 
104
108
  if (next_question_index := next_question.next_q) == EndOfSurvey:
105
109
  cancel_between(
@@ -111,8 +115,6 @@ class SkipHandler:
111
115
  cancel_between(current_question_index + 1, next_question_index)
112
116
 
113
117
 
114
-
115
-
116
118
  class AnswerQuestionFunctionConstructor:
117
119
  """Constructs a function that answers a question and records the answer."""
118
120
 
@@ -137,7 +139,6 @@ class AnswerQuestionFunctionConstructor:
137
139
  ):
138
140
  """Handle an exception that occurred while answering a question."""
139
141
 
140
-
141
142
  answers = copy.copy(
142
143
  self.interview.answers
143
144
  ) # copy to freeze the answers here for logging
@@ -171,7 +172,6 @@ class AnswerQuestionFunctionConstructor:
171
172
  question: "QuestionBase",
172
173
  task=None,
173
174
  ) -> "EDSLResultObjectInput":
174
-
175
175
  from tenacity import (
176
176
  RetryError,
177
177
  retry,
@@ -196,7 +196,6 @@ class AnswerQuestionFunctionConstructor:
196
196
  return invigilator.get_failed_task_result(
197
197
  failure_reason="Question skipped."
198
198
  )
199
-
200
199
  if self.skip_handler.should_skip(question):
201
200
  return invigilator.get_failed_task_result(
202
201
  failure_reason="Question skipped."
@@ -240,7 +239,6 @@ class AnswerQuestionFunctionConstructor:
240
239
  raise LanguageModelNoResponseError(
241
240
  f"Language model did not return a response for question '{question.question_name}.'"
242
241
  )
243
-
244
242
  if (
245
243
  question.question_name in self.interview.exceptions
246
244
  and had_language_model_no_response_error
@@ -250,7 +248,8 @@ class AnswerQuestionFunctionConstructor:
250
248
  return response
251
249
 
252
250
  try:
253
- return await attempt_answer()
251
+ out = await attempt_answer()
252
+ return out
254
253
  except RetryError as retry_error:
255
254
  original_error = retry_error.last_attempt.exception()
256
255
  self._handle_exception(
@@ -81,6 +81,7 @@ class InterviewExceptionEntry:
81
81
  raise_validation_errors=True,
82
82
  disable_remote_cache=True,
83
83
  disable_remote_inference=True,
84
+ cache=False,
84
85
  )
85
86
  return results.task_history.exceptions[0]["how_are_you"][0]
86
87
 
@@ -92,13 +93,13 @@ class InterviewExceptionEntry:
92
93
  def code(self, run=True):
93
94
  """Return the code to reproduce the exception."""
94
95
  lines = []
95
- lines.append("from .. import Question, Model, Scenario, Agent")
96
+ lines.append("from edsl import Question, Model, Scenario, Agent")
96
97
 
97
98
  lines.append(f"q = {repr(self.invigilator.question)}")
98
99
  lines.append(f"scenario = {repr(self.invigilator.scenario)}")
99
100
  lines.append(f"agent = {repr(self.invigilator.agent)}")
100
- lines.append(f"m = Model('{self.invigilator.model.model}')")
101
- lines.append("results = q.by(m).by(agent).by(scenario).run()")
101
+ lines.append(f"model = {repr(self.invigilator.model)}")
102
+ lines.append("results = q.by(model).by(agent).by(scenario).run()")
102
103
  code_str = "\n".join(lines)
103
104
 
104
105
  if run:
@@ -24,6 +24,7 @@ class InterviewTaskManager:
24
24
  for index, question_name in enumerate(self.survey.question_names)
25
25
  }
26
26
  self._task_status_log_dict = InterviewStatusLog()
27
+ self.survey_dag = None
27
28
 
28
29
  def build_question_tasks(
29
30
  self, answer_func, token_estimator, model_buckets
@@ -46,8 +47,9 @@ class InterviewTaskManager:
46
47
  self, existing_tasks: list[asyncio.Task], question: "QuestionBase"
47
48
  ) -> list[asyncio.Task]:
48
49
  """Get tasks that must be completed before the given question."""
49
- dag = self.survey.dag(textify=True)
50
- parents = dag.get(question.question_name, [])
50
+ if self.survey_dag is None:
51
+ self.survey_dag = self.survey.dag(textify=True)
52
+ parents = self.survey_dag.get(question.question_name, [])
51
53
  return [existing_tasks[self.to_index[parent_name]] for parent_name in parents]
52
54
 
53
55
  def _create_single_task(
@@ -100,4 +102,5 @@ class InterviewTaskManager:
100
102
 
101
103
  if __name__ == "__main__":
102
104
  import doctest
105
+
103
106
  doctest.testmod()
@@ -1,6 +1,101 @@
1
1
  from ..jobs.fetch_invigilator import FetchInvigilator
2
2
  from ..scenarios import FileStore
3
3
 
4
+ import math
5
+
6
+ # Model configs: base tokens and tile tokens only
7
+ VISION_MODELS = {
8
+ "gpt-4o": {
9
+ "base_tokens": 85,
10
+ "tile_tokens": 170,
11
+ },
12
+ "gpt-4o-mini": {
13
+ "base_tokens": 2833,
14
+ "tile_tokens": 5667,
15
+ },
16
+ "o1": {
17
+ "base_tokens": 75,
18
+ "tile_tokens": 150,
19
+ },
20
+ }
21
+
22
+
23
+ def approximate_image_tokens_google(width: int, height: int) -> int:
24
+ """
25
+ Approximates the token usage for an image based on its dimensions.
26
+
27
+ This calculation is based on the rules described for Gemini 2.0 models
28
+ in the provided text:
29
+ - Images with both dimensions <= 384px cost 258 tokens.
30
+ - Larger images are processed in 768x768 tiles, each costing 258 tokens.
31
+
32
+ Note: This is an *approximation*. The exact cropping, scaling, and tiling
33
+ strategy used by the actual Gemini API might differ slightly.
34
+
35
+ Args:
36
+ width: The width of the image in pixels.
37
+ height: The height of the image in pixels.
38
+
39
+ Returns:
40
+ An estimated integer token count for the image.
41
+
42
+ Raises:
43
+ ValueError: If width or height are not positive integers.
44
+ """
45
+ SMALL_IMAGE_THRESHOLD = 384 # Max dimension for fixed token count
46
+ FIXED_TOKEN_COST_SMALL = 258 # Token cost for small images (<= 384x384)
47
+ TILE_SIZE = 768 # Dimension of tiles for larger images
48
+ TOKEN_COST_PER_TILE = 258 # Token cost per 768x768 tile
49
+ if (
50
+ not isinstance(width, int)
51
+ or not isinstance(height, int)
52
+ or width <= 0
53
+ or height <= 0
54
+ ):
55
+ raise ValueError("Image width and height must be positive integers.")
56
+
57
+ # Case 1: Small image (both dimensions <= threshold)
58
+ if width <= SMALL_IMAGE_THRESHOLD and height <= SMALL_IMAGE_THRESHOLD:
59
+ return FIXED_TOKEN_COST_SMALL
60
+
61
+ # Case 2: Larger image (at least one dimension > threshold)
62
+ else:
63
+ # Calculate how many tiles are needed to cover the width and height
64
+ # Use ceiling division to ensure full coverage
65
+ tiles_wide = math.ceil(width / TILE_SIZE)
66
+ tiles_high = math.ceil(height / TILE_SIZE)
67
+
68
+ # Total number of tiles is the product of tiles needed in each dimension
69
+ total_tiles = tiles_wide * tiles_high
70
+
71
+ # Total token cost is the number of tiles times the cost per tile
72
+ estimated_tokens = total_tiles * TOKEN_COST_PER_TILE
73
+ return estimated_tokens
74
+
75
+
76
+ def estimate_tokens(model_name, width, height):
77
+ if model_name == "test":
78
+ return 10 # for testing purposes
79
+ if "gemini" in model_name:
80
+ out = approximate_image_tokens_google(width, height)
81
+ return out
82
+ if "claude" in model_name:
83
+ total_tokens = width * height / 750
84
+ return total_tokens
85
+ if model_name not in VISION_MODELS:
86
+ total_tokens = width * height / 750
87
+ return total_tokens
88
+
89
+ config = VISION_MODELS[model_name]
90
+ TILE_SIZE = 512
91
+
92
+ tiles_x = math.ceil(width / TILE_SIZE)
93
+ tiles_y = math.ceil(height / TILE_SIZE)
94
+ total_tiles = tiles_x * tiles_y
95
+
96
+ total_tokens = config["base_tokens"] + config["tile_tokens"] * total_tiles
97
+ return total_tokens
98
+
4
99
 
5
100
  class RequestTokenEstimator:
6
101
  """Estimate the number of tokens that will be required to run the focal task."""
@@ -24,15 +119,22 @@ class RequestTokenEstimator:
24
119
  elif isinstance(prompt, list):
25
120
  for file in prompt:
26
121
  if isinstance(file, FileStore):
27
- file_tokens += file.size * 0.25
122
+ if file.is_image():
123
+ model_name = self.interview.model.model
124
+ width, height = file.get_image_dimensions()
125
+ token_usage = estimate_tokens(model_name, width, height)
126
+ file_tokens += token_usage
127
+ else:
128
+ file_tokens += file.size * 0.25
28
129
  else:
29
130
  from .exceptions import InterviewTokenError
131
+
30
132
  raise InterviewTokenError(f"Prompt is of type {type(prompt)}")
31
133
  result: float = len(combined_text) / 4.0 + file_tokens
32
134
  return result
33
135
 
34
136
 
35
-
36
137
  if __name__ == "__main__":
37
138
  import doctest
139
+
38
140
  doctest.testmod(optionflags=doctest.ELLIPSIS)
@@ -24,11 +24,11 @@ if TYPE_CHECKING:
24
24
  from ..key_management import KeyLookup
25
25
 
26
26
 
27
-
28
27
  PromptType = Literal["user_prompt", "system_prompt", "encoded_image", "files_list"]
29
28
 
30
29
  NA = "Not Applicable"
31
30
 
31
+
32
32
  class InvigilatorBase(ABC):
33
33
  """An invigiator (someone who administers an exam) is a class that is responsible for administering a question to an agent.
34
34
 
@@ -261,13 +261,14 @@ class InvigilatorBase(ABC):
261
261
  current_answers=current_answers,
262
262
  )
263
263
 
264
+
264
265
  class InvigilatorAI(InvigilatorBase):
265
266
  """An invigilator that uses an AI model to answer questions."""
266
267
 
267
268
  def get_prompts(self) -> Dict[PromptType, "Prompt"]:
268
269
  """Return the prompts used."""
269
270
  return self.prompt_constructor.get_prompts()
270
-
271
+
271
272
  def get_captured_variables(self) -> dict:
272
273
  """Get the captured variables."""
273
274
  return self.prompt_constructor.get_captured_variables()
@@ -281,6 +282,7 @@ class InvigilatorAI(InvigilatorBase):
281
282
  if "encoded_image" in prompts:
282
283
  params["encoded_image"] = prompts["encoded_image"]
283
284
  from .exceptions import InvigilatorNotImplementedError
285
+
284
286
  raise InvigilatorNotImplementedError("encoded_image not implemented")
285
287
 
286
288
  if "files_list" in prompts:
@@ -307,7 +309,8 @@ class InvigilatorAI(InvigilatorBase):
307
309
  """
308
310
  agent_response_dict: AgentResponseDict = await self.async_get_agent_response()
309
311
  self.store_response(agent_response_dict)
310
- return self._extract_edsl_result_entry_and_validate(agent_response_dict)
312
+ out = self._extract_edsl_result_entry_and_validate(agent_response_dict)
313
+ return out
311
314
 
312
315
  def _remove_from_cache(self, cache_key) -> None:
313
316
  """Remove an entry from the cache."""
@@ -389,6 +392,35 @@ class InvigilatorAI(InvigilatorBase):
389
392
  edsl_dict = agent_response_dict.edsl_dict._asdict()
390
393
  exception_occurred = None
391
394
  validated = False
395
+
396
+ if agent_response_dict.model_outputs.cache_used:
397
+ data = {
398
+ "answer": agent_response_dict.edsl_dict.answer
399
+ if type(agent_response_dict.edsl_dict.answer) is str
400
+ or type(agent_response_dict.edsl_dict.answer) is dict
401
+ or type(agent_response_dict.edsl_dict.answer) is list
402
+ or type(agent_response_dict.edsl_dict.answer) is int
403
+ or type(agent_response_dict.edsl_dict.answer) is float
404
+ or type(agent_response_dict.edsl_dict.answer) is bool
405
+ else "",
406
+ "comment": agent_response_dict.edsl_dict.comment
407
+ if agent_response_dict.edsl_dict.comment
408
+ else "",
409
+ "generated_tokens": agent_response_dict.edsl_dict.generated_tokens,
410
+ "question_name": self.question.question_name,
411
+ "prompts": self.get_prompts(),
412
+ "cached_response": agent_response_dict.model_outputs.cached_response,
413
+ "raw_model_response": agent_response_dict.model_outputs.response,
414
+ "cache_used": agent_response_dict.model_outputs.cache_used,
415
+ "cache_key": agent_response_dict.model_outputs.cache_key,
416
+ "validated": True,
417
+ "exception_occurred": exception_occurred,
418
+ "cost": agent_response_dict.model_outputs.cost,
419
+ }
420
+
421
+ result = EDSLResultObjectInput(**data)
422
+ return result
423
+
392
424
  try:
393
425
  # if the question has jinja parameters, it is easier to make a new question with the parameters
394
426
  if self.question.parameters:
@@ -405,7 +437,7 @@ class InvigilatorAI(InvigilatorBase):
405
437
  self.question.question_options = new_question_options
406
438
 
407
439
  question_with_validators = self.question.render(
408
- self.scenario | prior_answers_dict | {'agent':self.agent.traits}
440
+ self.scenario | prior_answers_dict | {"agent": self.agent.traits}
409
441
  )
410
442
  question_with_validators.use_code = self.question.use_code
411
443
  else:
@@ -426,6 +458,7 @@ class InvigilatorAI(InvigilatorBase):
426
458
  exception_occurred = non_validation_error
427
459
  finally:
428
460
  # even if validation failes, we still return the result
461
+
429
462
  data = {
430
463
  "answer": answer,
431
464
  "comment": comment,