edsl 0.1.42__py3-none-any.whl → 0.1.44__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. edsl/Base.py +15 -6
  2. edsl/__version__.py +1 -1
  3. edsl/agents/Invigilator.py +1 -1
  4. edsl/agents/PromptConstructor.py +92 -21
  5. edsl/agents/QuestionInstructionPromptBuilder.py +68 -9
  6. edsl/agents/prompt_helpers.py +2 -2
  7. edsl/coop/coop.py +100 -22
  8. edsl/enums.py +3 -1
  9. edsl/exceptions/coop.py +4 -0
  10. edsl/inference_services/AnthropicService.py +2 -0
  11. edsl/inference_services/AvailableModelFetcher.py +4 -1
  12. edsl/inference_services/GoogleService.py +2 -0
  13. edsl/inference_services/GrokService.py +11 -0
  14. edsl/inference_services/InferenceServiceABC.py +1 -0
  15. edsl/inference_services/OpenAIService.py +1 -0
  16. edsl/inference_services/TestService.py +1 -0
  17. edsl/inference_services/registry.py +2 -0
  18. edsl/jobs/Jobs.py +54 -35
  19. edsl/jobs/JobsChecks.py +7 -7
  20. edsl/jobs/JobsPrompts.py +57 -6
  21. edsl/jobs/JobsRemoteInferenceHandler.py +41 -25
  22. edsl/jobs/buckets/BucketCollection.py +30 -0
  23. edsl/jobs/data_structures.py +1 -0
  24. edsl/language_models/LanguageModel.py +5 -2
  25. edsl/language_models/key_management/KeyLookupBuilder.py +47 -20
  26. edsl/language_models/key_management/models.py +10 -4
  27. edsl/language_models/model.py +43 -11
  28. edsl/prompts/Prompt.py +124 -61
  29. edsl/questions/descriptors.py +32 -18
  30. edsl/questions/question_base_gen_mixin.py +1 -0
  31. edsl/results/DatasetExportMixin.py +35 -6
  32. edsl/results/Results.py +180 -1
  33. edsl/results/ResultsGGMixin.py +117 -60
  34. edsl/scenarios/FileStore.py +19 -8
  35. edsl/scenarios/Scenario.py +33 -0
  36. edsl/scenarios/ScenarioList.py +22 -3
  37. edsl/scenarios/ScenarioListPdfMixin.py +9 -3
  38. edsl/surveys/Survey.py +27 -6
  39. {edsl-0.1.42.dist-info → edsl-0.1.44.dist-info}/METADATA +3 -4
  40. {edsl-0.1.42.dist-info → edsl-0.1.44.dist-info}/RECORD +42 -41
  41. {edsl-0.1.42.dist-info → edsl-0.1.44.dist-info}/LICENSE +0 -0
  42. {edsl-0.1.42.dist-info → edsl-0.1.44.dist-info}/WHEEL +0 -0
@@ -1,6 +1,6 @@
1
1
  import textwrap
2
2
  from random import random
3
- from typing import Optional, TYPE_CHECKING, List
3
+ from typing import Optional, TYPE_CHECKING, List, Callable
4
4
 
5
5
  from edsl.utilities.PrettyList import PrettyList
6
6
  from edsl.config import CONFIG
@@ -11,17 +11,21 @@ from edsl.inference_services.InferenceServicesCollection import (
11
11
  from edsl.inference_services.data_structures import AvailableModels
12
12
  from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
13
13
  from edsl.enums import InferenceServiceLiteral
14
+ from edsl.exceptions.inference_services import InferenceServiceError
14
15
 
15
16
  if TYPE_CHECKING:
16
17
  from edsl.results.Dataset import Dataset
17
18
 
18
19
 
19
- def get_model_class(model_name, registry: Optional[InferenceServicesCollection] = None):
20
+ def get_model_class(model_name, registry: Optional[InferenceServicesCollection] = None, service_name: Optional[InferenceServiceLiteral] = None):
20
21
  from edsl.inference_services.registry import default
21
22
 
22
23
  registry = registry or default
23
- factory = registry.create_model_factory(model_name)
24
- return factory
24
+ try:
25
+ factory = registry.create_model_factory(model_name, service_name=service_name)
26
+ return factory
27
+ except (InferenceServiceError, Exception) as e:
28
+ return Model._handle_model_error(model_name, e)
25
29
 
26
30
 
27
31
  class Meta(type):
@@ -58,6 +62,33 @@ class Model(metaclass=Meta):
58
62
  """Set a new registry"""
59
63
  cls._registry = registry
60
64
 
65
+ @classmethod
66
+ def _handle_model_error(cls, model_name: str, error: Exception):
67
+ """Handle errors from model creation and execution with notebook-aware behavior."""
68
+ if isinstance(error, InferenceServiceError):
69
+ services = [s._inference_service_ for s in cls.get_registry().services]
70
+ message = (
71
+ f"Model '{model_name}' not found in any services.\n"
72
+ "It is likely that our registry is just out of date.\n"
73
+ "Simply adding the service name to your model call should fix this.\n"
74
+ f"Available services are: {services}\n"
75
+ f"To specify a model with a service, use:\n"
76
+ f'Model("{model_name}", service_name="<service_name>")'
77
+ )
78
+ else:
79
+ message = f"An error occurred: {str(error)}"
80
+
81
+ # Check if we're in a notebook environment
82
+ try:
83
+ get_ipython()
84
+ print(message)
85
+ return None
86
+ except NameError:
87
+ # Not in a notebook, raise the exception
88
+ if isinstance(error, InferenceServiceError):
89
+ raise InferenceServiceError(message)
90
+ raise error
91
+
61
92
  def __new__(
62
93
  cls,
63
94
  model_name: Optional[str] = None,
@@ -69,9 +100,7 @@ class Model(metaclass=Meta):
69
100
  "Instantiate a new language model."
70
101
  # Map index to the respective subclass
71
102
  if model_name is None:
72
- model_name = (
73
- cls.default_model
74
- ) # when model_name is None, use the default model, set in the config file
103
+ model_name = cls.default_model
75
104
 
76
105
  if registry is not None:
77
106
  cls.set_registry(registry)
@@ -79,10 +108,13 @@ class Model(metaclass=Meta):
79
108
  if isinstance(model_name, int): # can refer to a model by index
80
109
  model_name = cls.available(name_only=True)[model_name]
81
110
 
82
- factory = cls.get_registry().create_model_factory(
83
- model_name, service_name=service_name
84
- )
85
- return factory(*args, **kwargs)
111
+ try:
112
+ factory = cls.get_registry().create_model_factory(
113
+ model_name, service_name=service_name
114
+ )
115
+ return factory(*args, **kwargs)
116
+ except (InferenceServiceError, Exception) as e:
117
+ return cls._handle_model_error(model_name, e)
86
118
 
87
119
  @classmethod
88
120
  def add_model(cls, service_name, model_name) -> None:
edsl/prompts/Prompt.py CHANGED
@@ -10,6 +10,48 @@ from edsl.Base import PersistenceMixin, RepresentationMixin
10
10
 
11
11
  MAX_NESTING = 100
12
12
 
13
+ from jinja2 import Environment, meta, TemplateSyntaxError, Undefined
14
+ from functools import lru_cache
15
+
16
+ class PreserveUndefined(Undefined):
17
+ def __str__(self):
18
+ return "{{ " + str(self._undefined_name) + " }}"
19
+
20
+ # Create environment once at module level
21
+ _env = Environment(undefined=PreserveUndefined)
22
+
23
+ @lru_cache(maxsize=1024)
24
+ def _compile_template(text: str):
25
+ return _env.from_string(text)
26
+
27
+ @lru_cache(maxsize=1024)
28
+ def _find_template_variables(template: str) -> list[str]:
29
+ """Find and return the template variables."""
30
+ ast = _env.parse(template)
31
+ return list(meta.find_undeclared_variables(ast))
32
+
33
+ def _make_hashable(value):
34
+ """Convert unhashable types to hashable ones."""
35
+ if isinstance(value, list):
36
+ return tuple(_make_hashable(item) for item in value)
37
+ if isinstance(value, dict):
38
+ return frozenset((k, _make_hashable(v)) for k, v in value.items())
39
+ return value
40
+
41
+ @lru_cache(maxsize=1024)
42
+ def _cached_render(text: str, frozen_replacements: frozenset) -> str:
43
+ """Cached version of template rendering with frozen replacements."""
44
+ # Print cache info on every call
45
+ cache_info = _cached_render.cache_info()
46
+ print(f"\t\t\t\t\t Cache status - hits: {cache_info.hits}, misses: {cache_info.misses}, current size: {cache_info.currsize}")
47
+
48
+ # Convert back to dict with original types for rendering
49
+ replacements = {k: v for k, v in frozen_replacements}
50
+
51
+ template = _compile_template(text)
52
+ result = template.render(replacements)
53
+
54
+ return result
13
55
 
14
56
  class Prompt(PersistenceMixin, RepresentationMixin):
15
57
  """Class for creating a prompt to be used in a survey."""
@@ -145,33 +187,8 @@ class Prompt(PersistenceMixin, RepresentationMixin):
145
187
  return f'Prompt(text="""{self.text}""")'
146
188
 
147
189
  def template_variables(self) -> list[str]:
148
- """Return the the variables in the template.
149
-
150
- Example:
151
-
152
- >>> p = Prompt("Hello, {{person}}")
153
- >>> p.template_variables()
154
- ['person']
155
-
156
- """
157
- return self._template_variables(self.text)
158
-
159
- @staticmethod
160
- def _template_variables(template: str) -> list[str]:
161
- """Find and return the template variables.
162
-
163
- :param template: The template to find the variables in.
164
-
165
- """
166
- from jinja2 import Environment, meta, Undefined
167
-
168
- class PreserveUndefined(Undefined):
169
- def __str__(self):
170
- return "{{ " + str(self._undefined_name) + " }}"
171
-
172
- env = Environment(undefined=PreserveUndefined)
173
- ast = env.parse(template)
174
- return list(meta.find_undeclared_variables(ast))
190
+ """Return the variables in the template."""
191
+ return _find_template_variables(self.text)
175
192
 
176
193
  def undefined_template_variables(self, replacement_dict: dict):
177
194
  """Return the variables in the template that are not in the replacement_dict.
@@ -239,45 +256,39 @@ class Prompt(PersistenceMixin, RepresentationMixin):
239
256
  return self
240
257
 
241
258
  @staticmethod
242
- def _render(
243
- text: str, primary_replacement, **additional_replacements
244
- ) -> "PromptBase":
245
- """Render the template text with variables replaced from the provided named dictionaries.
246
-
247
- :param text: The text to render.
248
- :param primary_replacement: The primary replacement dictionary.
249
- :param additional_replacements: Additional replacement dictionaries.
250
-
251
- Allows for nested variable resolution up to a specified maximum nesting depth.
252
-
253
- Example:
254
-
255
- >>> codebook = {"age": "Age"}
256
- >>> p = Prompt("You are an agent named {{ name }}. {{ codebook['age']}}: {{ age }}")
257
- >>> p.render({"name": "John", "age": 44}, codebook=codebook)
258
- Prompt(text=\"""You are an agent named John. Age: 44\""")
259
- """
260
- from jinja2 import Environment, meta, TemplateSyntaxError, Undefined
261
-
262
- class PreserveUndefined(Undefined):
263
- def __str__(self):
264
- return "{{ " + str(self._undefined_name) + " }}"
265
-
266
- env = Environment(undefined=PreserveUndefined)
259
+ def _render(text: str, primary_replacement, **additional_replacements) -> "PromptBase":
260
+ """Render the template text with variables replaced."""
261
+ import time
262
+
263
+ # if there are no replacements, return the text
264
+ if not primary_replacement and not additional_replacements:
265
+ return text
266
+
267
267
  try:
268
+ variables = _find_template_variables(text)
269
+
270
+ if not variables: # if there are no variables, return the text
271
+ return text
272
+
273
+ # Combine all replacements
274
+ all_replacements = {**primary_replacement, **additional_replacements}
275
+
268
276
  previous_text = None
277
+ current_text = text
278
+ iteration = 0
279
+
269
280
  for _ in range(MAX_NESTING):
270
- # breakpoint()
271
- rendered_text = env.from_string(text).render(
272
- primary_replacement, **additional_replacements
273
- )
274
- if rendered_text == previous_text:
275
- # No more changes, so return the rendered text
281
+ iteration += 1
282
+
283
+ template = _compile_template(current_text)
284
+ rendered_text = template.render(all_replacements)
285
+
286
+ if rendered_text == current_text:
276
287
  return rendered_text
277
- previous_text = text
278
- text = rendered_text
288
+
289
+ previous_text = current_text
290
+ current_text = rendered_text
279
291
 
280
- # If the loop exits without returning, it indicates too much nesting
281
292
  raise TemplateRenderError(
282
293
  "Too much nesting - you created an infinite loop here, pal"
283
294
  )
@@ -331,6 +342,58 @@ class Prompt(PersistenceMixin, RepresentationMixin):
331
342
  """Return an example of the prompt."""
332
343
  return cls(cls.default_instructions)
333
344
 
345
+ def get_prompts(self) -> Dict[str, Any]:
346
+ """Get the prompts for the question."""
347
+ start = time.time()
348
+
349
+ # Build all the components
350
+ instr_start = time.time()
351
+ agent_instructions = self.agent_instructions_prompt
352
+ instr_end = time.time()
353
+ logger.debug(f"Time taken for agent instructions: {instr_end - instr_start:.4f}s")
354
+
355
+ persona_start = time.time()
356
+ agent_persona = self.agent_persona_prompt
357
+ persona_end = time.time()
358
+ logger.debug(f"Time taken for agent persona: {persona_end - persona_start:.4f}s")
359
+
360
+ q_instr_start = time.time()
361
+ question_instructions = self.question_instructions_prompt
362
+ q_instr_end = time.time()
363
+ logger.debug(f"Time taken for question instructions: {q_instr_end - q_instr_start:.4f}s")
364
+
365
+ memory_start = time.time()
366
+ prior_question_memory = self.prior_question_memory_prompt
367
+ memory_end = time.time()
368
+ logger.debug(f"Time taken for prior question memory: {memory_end - memory_start:.4f}s")
369
+
370
+ # Get components dict
371
+ components = {
372
+ "agent_instructions": agent_instructions.text,
373
+ "agent_persona": agent_persona.text,
374
+ "question_instructions": question_instructions.text,
375
+ "prior_question_memory": prior_question_memory.text,
376
+ }
377
+
378
+ # Use PromptPlan's get_prompts method
379
+ plan_start = time.time()
380
+ prompts = self.prompt_plan.get_prompts(**components)
381
+ plan_end = time.time()
382
+ logger.debug(f"Time taken for prompt processing: {plan_end - plan_start:.4f}s")
383
+
384
+ # Handle file keys if present
385
+ if hasattr(self, 'question_file_keys') and self.question_file_keys:
386
+ files_start = time.time()
387
+ files_list = []
388
+ for key in self.question_file_keys:
389
+ files_list.append(self.scenario[key])
390
+ prompts["files_list"] = files_list
391
+ files_end = time.time()
392
+ logger.debug(f"Time taken for file key processing: {files_end - files_start:.4f}s")
393
+
394
+ end = time.time()
395
+ logger.debug(f"Total time in get_prompts: {end - start:.4f}s")
396
+ return prompts
334
397
 
335
398
  if __name__ == "__main__":
336
399
  print("Running doctests...")
@@ -249,7 +249,28 @@ class QuestionNameDescriptor(BaseDescriptor):
249
249
 
250
250
 
251
251
  class QuestionOptionsDescriptor(BaseDescriptor):
252
- """Validate that `question_options` is a list, does not exceed the min/max lengths, and has unique items."""
252
+ """Validate that `question_options` is a list, does not exceed the min/max lengths, and has unique items.
253
+
254
+ >>> import warnings
255
+ >>> q_class = QuestionOptionsDescriptor.example()
256
+ >>> with warnings.catch_warnings(record=True) as w:
257
+ ... _ = q_class(["a ", "b", "c"]) # Has trailing space
258
+ ... assert len(w) == 1
259
+ ... assert "trailing whitespace" in str(w[0].message)
260
+
261
+ >>> _ = q_class(["a", "b", "c", "d", "d"])
262
+ Traceback (most recent call last):
263
+ ...
264
+ edsl.exceptions.questions.QuestionCreationValidationError: Question options must be unique (got ['a', 'b', 'c', 'd', 'd']).
265
+
266
+ We allow dynamic question options, which are strings of the form '{{ question_options }}'.
267
+
268
+ >>> _ = q_class("{{dynamic_options}}")
269
+ >>> _ = q_class("dynamic_options")
270
+ Traceback (most recent call last):
271
+ ...
272
+ edsl.exceptions.questions.QuestionCreationValidationError: ...
273
+ """
253
274
 
254
275
  @classmethod
255
276
  def example(cls):
@@ -273,23 +294,7 @@ class QuestionOptionsDescriptor(BaseDescriptor):
273
294
  self.q_budget = q_budget
274
295
 
275
296
  def validate(self, value: Any, instance) -> None:
276
- """Validate the question options.
277
-
278
- >>> q_class = QuestionOptionsDescriptor.example()
279
- >>> _ = q_class(["a", "b", "c"])
280
- >>> _ = q_class(["a", "b", "c", "d", "d"])
281
- Traceback (most recent call last):
282
- ...
283
- edsl.exceptions.questions.QuestionCreationValidationError: Question options must be unique (got ['a', 'b', 'c', 'd', 'd']).
284
-
285
- We allow dynamic question options, which are strings of the form '{{ question_options }}'.
286
-
287
- >>> _ = q_class("{{dynamic_options}}")
288
- >>> _ = q_class("dynamic_options")
289
- Traceback (most recent call last):
290
- ...
291
- edsl.exceptions.questions.QuestionCreationValidationError: ...
292
- """
297
+ """Validate the question options."""
293
298
  if isinstance(value, str):
294
299
  # Check if the string is a dynamic question option
295
300
  if "{{" in value and "}}" in value:
@@ -343,6 +348,15 @@ class QuestionOptionsDescriptor(BaseDescriptor):
343
348
  f"All question options must be at least 1 character long but less than {Settings.MAX_OPTION_LENGTH} characters long (got {value})."
344
349
  )
345
350
 
351
+ # Check for trailing whitespace in string options
352
+ if any(isinstance(x, str) and (x != x.strip()) for x in value):
353
+ import warnings
354
+
355
+ warnings.warn(
356
+ "Some question options contain trailing whitespace. This may cause unexpected behavior.",
357
+ UserWarning,
358
+ )
359
+
346
360
  if hasattr(instance, "min_selections") and instance.min_selections != None:
347
361
  if instance.min_selections > len(value):
348
362
  raise QuestionCreationValidationError(
@@ -114,6 +114,7 @@ class QuestionBaseGenMixin:
114
114
  .render(strings_only_replacement_dict)
115
115
  )
116
116
  except Exception as e:
117
+ #breakpoint()
117
118
  import warnings
118
119
 
119
120
  warnings.warn("Failed to render string: " + value)
@@ -7,7 +7,6 @@ from typing import Optional, Tuple, Union, List
7
7
 
8
8
  from edsl.results.file_exports import CSVExport, ExcelExport, JSONLExport, SQLiteExport
9
9
 
10
-
11
10
  class DatasetExportMixin:
12
11
  """Mixin class for exporting Dataset objects."""
13
12
 
@@ -220,23 +219,45 @@ class DatasetExportMixin:
220
219
  )
221
220
  return exporter.export()
222
221
 
223
- def _db(self, remove_prefix: bool = True):
222
+ def _db(self, remove_prefix: bool = True, shape: str = "wide") -> "sqlalchemy.engine.Engine":
224
223
  """Create a SQLite database in memory and return the connection.
225
224
 
226
225
  Args:
227
- shape: The shape of the data in the database (wide or long)
228
226
  remove_prefix: Whether to remove the prefix from the column names
227
+ shape: The shape of the data in the database ("wide" or "long")
229
228
 
230
229
  Returns:
231
230
  A database connection
231
+ >>> from sqlalchemy import text
232
+ >>> from edsl import Results
233
+ >>> engine = Results.example()._db()
234
+ >>> len(engine.execute(text("SELECT * FROM self")).fetchall())
235
+ 4
236
+ >>> engine = Results.example()._db(shape = "long")
237
+ >>> len(engine.execute(text("SELECT * FROM self")).fetchall())
238
+ 172
232
239
  """
233
- from sqlalchemy import create_engine
240
+ from sqlalchemy import create_engine, text
234
241
 
235
242
  engine = create_engine("sqlite:///:memory:")
236
- if remove_prefix:
243
+ if remove_prefix and shape == "wide":
237
244
  df = self.remove_prefix().to_pandas(lists_as_strings=True)
238
245
  else:
239
246
  df = self.to_pandas(lists_as_strings=True)
247
+
248
+ if shape == "long":
249
+ # Melt the dataframe to convert it to long format
250
+ df = df.melt(
251
+ var_name='key',
252
+ value_name='value'
253
+ )
254
+ # Add a row number column for reference
255
+ df.insert(0, 'row_number', range(1, len(df) + 1))
256
+
257
+ # Split the key into data_type and key
258
+ df['data_type'] = df['key'].apply(lambda x: x.split('.')[0] if '.' in x else None)
259
+ df['key'] = df['key'].apply(lambda x: '.'.join(x.split('.')[1:]) if '.' in x else x)
260
+
240
261
  df.to_sql(
241
262
  "self",
242
263
  engine,
@@ -251,6 +272,7 @@ class DatasetExportMixin:
251
272
  transpose: bool = None,
252
273
  transpose_by: str = None,
253
274
  remove_prefix: bool = True,
275
+ shape: str = "wide",
254
276
  ) -> Union["pd.DataFrame", str]:
255
277
  """Execute a SQL query and return the results as a DataFrame.
256
278
 
@@ -268,10 +290,17 @@ class DatasetExportMixin:
268
290
  Returns:
269
291
  DataFrame, CSV string, list, or LaTeX string depending on parameters
270
292
 
293
+ Examples:
294
+ >>> from edsl import Results
295
+ >>> r = Results.example();
296
+ >>> len(r.sql("SELECT * FROM self", shape = "wide"))
297
+ 4
298
+ >>> len(r.sql("SELECT * FROM self", shape = "long"))
299
+ 172
271
300
  """
272
301
  import pandas as pd
273
302
 
274
- conn = self._db(remove_prefix=remove_prefix)
303
+ conn = self._db(remove_prefix=remove_prefix, shape=shape)
275
304
  df = pd.read_sql_query(query, conn)
276
305
 
277
306
  # Transpose the DataFrame if transpose is True