edsl 0.1.44__py3-none-any.whl → 0.1.46__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. edsl/Base.py +7 -3
  2. edsl/__version__.py +1 -1
  3. edsl/agents/InvigilatorBase.py +3 -1
  4. edsl/agents/PromptConstructor.py +66 -91
  5. edsl/agents/QuestionInstructionPromptBuilder.py +160 -79
  6. edsl/agents/QuestionTemplateReplacementsBuilder.py +80 -17
  7. edsl/agents/question_option_processor.py +15 -6
  8. edsl/coop/CoopFunctionsMixin.py +3 -4
  9. edsl/coop/coop.py +171 -96
  10. edsl/data/RemoteCacheSync.py +10 -9
  11. edsl/enums.py +3 -3
  12. edsl/inference_services/AnthropicService.py +11 -9
  13. edsl/inference_services/AvailableModelFetcher.py +2 -0
  14. edsl/inference_services/AwsBedrock.py +1 -2
  15. edsl/inference_services/AzureAI.py +12 -9
  16. edsl/inference_services/GoogleService.py +9 -4
  17. edsl/inference_services/InferenceServicesCollection.py +2 -2
  18. edsl/inference_services/MistralAIService.py +1 -2
  19. edsl/inference_services/OpenAIService.py +9 -4
  20. edsl/inference_services/PerplexityService.py +2 -1
  21. edsl/inference_services/{GrokService.py → XAIService.py} +2 -2
  22. edsl/inference_services/registry.py +2 -2
  23. edsl/jobs/AnswerQuestionFunctionConstructor.py +12 -1
  24. edsl/jobs/Jobs.py +24 -17
  25. edsl/jobs/JobsChecks.py +10 -13
  26. edsl/jobs/JobsPrompts.py +49 -26
  27. edsl/jobs/JobsRemoteInferenceHandler.py +4 -5
  28. edsl/jobs/async_interview_runner.py +3 -1
  29. edsl/jobs/check_survey_scenario_compatibility.py +5 -5
  30. edsl/jobs/data_structures.py +3 -0
  31. edsl/jobs/interviews/Interview.py +6 -3
  32. edsl/jobs/interviews/InterviewExceptionEntry.py +12 -0
  33. edsl/jobs/tasks/TaskHistory.py +1 -1
  34. edsl/language_models/LanguageModel.py +6 -3
  35. edsl/language_models/PriceManager.py +45 -5
  36. edsl/language_models/model.py +47 -26
  37. edsl/questions/QuestionBase.py +21 -0
  38. edsl/questions/QuestionBasePromptsMixin.py +103 -0
  39. edsl/questions/QuestionFreeText.py +22 -5
  40. edsl/questions/descriptors.py +4 -0
  41. edsl/questions/question_base_gen_mixin.py +96 -29
  42. edsl/results/Dataset.py +65 -0
  43. edsl/results/DatasetExportMixin.py +320 -32
  44. edsl/results/Result.py +27 -0
  45. edsl/results/Results.py +22 -2
  46. edsl/results/ResultsGGMixin.py +7 -3
  47. edsl/scenarios/DocumentChunker.py +2 -0
  48. edsl/scenarios/FileStore.py +10 -0
  49. edsl/scenarios/PdfExtractor.py +21 -1
  50. edsl/scenarios/Scenario.py +25 -9
  51. edsl/scenarios/ScenarioList.py +226 -24
  52. edsl/scenarios/handlers/__init__.py +1 -0
  53. edsl/scenarios/handlers/docx.py +5 -1
  54. edsl/scenarios/handlers/jpeg.py +39 -0
  55. edsl/surveys/Survey.py +5 -4
  56. edsl/surveys/SurveyFlowVisualization.py +91 -43
  57. edsl/templates/error_reporting/exceptions_table.html +7 -8
  58. edsl/templates/error_reporting/interview_details.html +1 -1
  59. edsl/templates/error_reporting/interviews.html +0 -1
  60. edsl/templates/error_reporting/overview.html +2 -7
  61. edsl/templates/error_reporting/performance_plot.html +1 -1
  62. edsl/templates/error_reporting/report.css +1 -1
  63. edsl/utilities/PrettyList.py +14 -0
  64. edsl-0.1.46.dist-info/METADATA +246 -0
  65. {edsl-0.1.44.dist-info → edsl-0.1.46.dist-info}/RECORD +67 -66
  66. edsl-0.1.44.dist-info/METADATA +0 -110
  67. {edsl-0.1.44.dist-info → edsl-0.1.46.dist-info}/LICENSE +0 -0
  68. {edsl-0.1.44.dist-info → edsl-0.1.46.dist-info}/WHEEL +0 -0
@@ -24,7 +24,7 @@ from edsl.jobs.JobsRemoteInferenceLogger import JobLogger
24
24
  class RemoteJobConstants:
25
25
  """Constants for remote job handling."""
26
26
 
27
- REMOTE_JOB_POLL_INTERVAL = 1
27
+ REMOTE_JOB_POLL_INTERVAL = 4
28
28
  REMOTE_JOB_VERBOSE = False
29
29
  DISCORD_URL = "https://discord.com/invite/mxAYkjfy9m"
30
30
 
@@ -88,8 +88,8 @@ class JobsRemoteInferenceHandler:
88
88
  iterations: int = 1,
89
89
  remote_inference_description: Optional[str] = None,
90
90
  remote_inference_results_visibility: Optional[VisibilityType] = "unlisted",
91
+ fresh: Optional[bool] = False,
91
92
  ) -> RemoteJobInfo:
92
-
93
93
  from edsl.config import CONFIG
94
94
  from edsl.coop.coop import Coop
95
95
 
@@ -106,6 +106,7 @@ class JobsRemoteInferenceHandler:
106
106
  status="queued",
107
107
  iterations=iterations,
108
108
  initial_results_visibility=remote_inference_results_visibility,
109
+ fresh=fresh,
109
110
  )
110
111
  logger.update(
111
112
  "Your survey is running at the Expected Parrot server...",
@@ -277,9 +278,7 @@ class JobsRemoteInferenceHandler:
277
278
  job_in_queue = True
278
279
  while job_in_queue:
279
280
  result = self._attempt_fetch_job(
280
- job_info,
281
- remote_job_data_fetcher,
282
- object_fetcher
281
+ job_info, remote_job_data_fetcher, object_fetcher
283
282
  )
284
283
  if result != "continue":
285
284
  return result
@@ -7,6 +7,8 @@ from edsl.data_transfer_models import EDSLResultObjectInput
7
7
 
8
8
  from edsl.results.Result import Result
9
9
  from edsl.jobs.interviews.Interview import Interview
10
+ from edsl.config import Config
11
+ config = Config()
10
12
 
11
13
  if TYPE_CHECKING:
12
14
  from edsl.jobs.Jobs import Jobs
@@ -23,7 +25,7 @@ from edsl.jobs.data_structures import RunConfig
23
25
 
24
26
 
25
27
  class AsyncInterviewRunner:
26
- MAX_CONCURRENT = 5
28
+ MAX_CONCURRENT = int(config.EDSL_MAX_CONCURRENT_TASKS)
27
29
 
28
30
  def __init__(self, jobs: "Jobs", run_config: RunConfig):
29
31
  self.jobs = jobs
@@ -72,11 +72,11 @@ class CheckSurveyScenarioCompatibility:
72
72
  if warn:
73
73
  warnings.warn(message)
74
74
 
75
- if self.scenarios.has_jinja_braces:
76
- warnings.warn(
77
- "The scenarios have Jinja braces ({{ and }}). Converting to '<<' and '>>'. If you want a different conversion, use the convert_jinja_braces method first to modify the scenario."
78
- )
79
- self.scenarios = self.scenarios._convert_jinja_braces()
75
+ # if self.scenarios.has_jinja_braces:
76
+ # warnings.warn(
77
+ # "The scenarios have Jinja braces ({{ and }}). Converting to '<<' and '>>'. If you want a different conversion, use the convert_jinja_braces method first to modify the scenario."
78
+ # )
79
+ # self.scenarios = self.scenarios._convert_jinja_braces()
80
80
 
81
81
 
82
82
  if __name__ == "__main__":
@@ -36,6 +36,9 @@ class RunParameters(Base):
36
36
  disable_remote_cache: bool = False
37
37
  disable_remote_inference: bool = False
38
38
  job_uuid: Optional[str] = None
39
+ fresh: Optional[
40
+ bool
41
+ ] = False # if True, will not use cache and will save new results to cache
39
42
 
40
43
  def to_dict(self, add_edsl_version=False) -> dict:
41
44
  d = asdict(self)
@@ -238,9 +238,6 @@ class Interview:
238
238
  >>> run_config = RunConfig(parameters = RunParameters(), environment = RunEnvironment())
239
239
  >>> run_config.parameters.stop_on_exception = True
240
240
  >>> result, _ = asyncio.run(i.async_conduct_interview(run_config))
241
- Traceback (most recent call last):
242
- ...
243
- asyncio.exceptions.CancelledError
244
241
  """
245
242
  from edsl.jobs.Jobs import RunConfig, RunParameters, RunEnvironment
246
243
 
@@ -262,6 +259,8 @@ class Interview:
262
259
  if model_buckets is None or hasattr(self.agent, "answer_question_directly"):
263
260
  model_buckets = ModelBuckets.infinity_bucket()
264
261
 
262
+ self.skip_flags = {q.question_name: False for q in self.survey.questions}
263
+
265
264
  # was "self.tasks" - is that necessary?
266
265
  self.tasks = self.task_manager.build_question_tasks(
267
266
  answer_func=AnswerQuestionFunctionConstructor(
@@ -310,6 +309,10 @@ class Interview:
310
309
  def handle_task(task, invigilator):
311
310
  try:
312
311
  result: Answers = task.result()
312
+ if result == "skipped":
313
+ result = invigilator.get_failed_task_result(
314
+ failure_reason="Task was skipped."
315
+ )
313
316
  except asyncio.CancelledError as e: # task was cancelled
314
317
  result = invigilator.get_failed_task_result(
315
318
  failure_reason="Task was cancelled."
@@ -166,6 +166,9 @@ class InterviewExceptionEntry:
166
166
  >>> entry = InterviewExceptionEntry.example()
167
167
  >>> _ = entry.to_dict()
168
168
  """
169
+ import json
170
+ from edsl.exceptions.questions import QuestionAnswerValidationError
171
+
169
172
  invigilator = (
170
173
  self.invigilator.to_dict() if self.invigilator is not None else None
171
174
  )
@@ -174,7 +177,16 @@ class InterviewExceptionEntry:
174
177
  "time": self.time,
175
178
  "traceback": self.traceback,
176
179
  "invigilator": invigilator,
180
+ "additional_data": {},
177
181
  }
182
+
183
+ if isinstance(self.exception, QuestionAnswerValidationError):
184
+ d["additional_data"]["edsl_response"] = json.dumps(self.exception.data)
185
+ d["additional_data"]["validating_model"] = json.dumps(
186
+ self.exception.model.model_json_schema()
187
+ )
188
+ d["additional_data"]["error_message"] = str(self.exception.message)
189
+
178
190
  return d
179
191
 
180
192
  @classmethod
@@ -419,7 +419,7 @@ class TaskHistory(RepresentationMixin):
419
419
  filename: Optional[str] = None,
420
420
  return_link=False,
421
421
  css=None,
422
- cta="\nClick to open the report in a new tab\n",
422
+ cta="<br><span style='font-size: 18px; font-weight: medium-bold; text-decoration: underline;'>Click to open the report in a new tab</span><br><br>",
423
423
  open_in_browser=False,
424
424
  ):
425
425
  """Return an HTML report."""
@@ -379,8 +379,10 @@ class LanguageModel(
379
379
  cached_response, cache_key = cache.fetch(**cache_call_params)
380
380
 
381
381
  if cache_used := cached_response is not None:
382
+ # print("cache used")
382
383
  response = json.loads(cached_response)
383
384
  else:
385
+ # print("cache not used")
384
386
  f = (
385
387
  self.remote_async_execute_model_call
386
388
  if hasattr(self, "remote") and self.remote
@@ -394,14 +396,16 @@ class LanguageModel(
394
396
  from edsl.config import CONFIG
395
397
 
396
398
  TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
397
-
398
399
  response = await asyncio.wait_for(f(**params), timeout=TIMEOUT)
399
400
  new_cache_key = cache.store(
400
401
  **cache_call_params, response=response
401
402
  ) # store the response in the cache
402
403
  assert new_cache_key == cache_key # should be the same
403
404
 
405
+ #breakpoint()
406
+
404
407
  cost = self.cost(response)
408
+ #breakpoint()
405
409
  return ModelResponse(
406
410
  response=response,
407
411
  cache_used=cache_used,
@@ -466,6 +470,7 @@ class LanguageModel(
466
470
  model_outputs=model_outputs,
467
471
  edsl_dict=edsl_dict,
468
472
  )
473
+ #breakpoint()
469
474
  return agent_response_dict
470
475
 
471
476
  get_response = sync_wrapper(async_get_response)
@@ -518,8 +523,6 @@ class LanguageModel(
518
523
  """
519
524
  from edsl.language_models.model import get_model_class
520
525
 
521
- # breakpoint()
522
-
523
526
  model_class = get_model_class(
524
527
  data["model"], service_name=data.get("inference_service", None)
525
528
  )
@@ -30,19 +30,22 @@ class PriceManager:
30
30
  except Exception as e:
31
31
  print(f"Error fetching prices: {str(e)}")
32
32
 
33
- def get_price(self, inference_service: str, model: str) -> Optional[Dict]:
33
+ def get_price(self, inference_service: str, model: str) -> Dict:
34
34
  """
35
35
  Get the price information for a specific service and model combination.
36
+ If no specific price is found, returns a fallback price.
36
37
 
37
38
  Args:
38
39
  inference_service (str): The name of the inference service
39
40
  model (str): The model identifier
40
41
 
41
42
  Returns:
42
- Optional[Dict]: Price information if found, None otherwise
43
+ Dict: Price information (either actual or fallback prices)
43
44
  """
44
45
  key = (inference_service, model)
45
- return self._price_lookup.get(key)
46
+ return self._price_lookup.get(key) or self._get_fallback_price(
47
+ inference_service
48
+ )
46
49
 
47
50
  def get_all_prices(self) -> Dict[Tuple[str, str], Dict]:
48
51
  """
@@ -53,6 +56,45 @@ class PriceManager:
53
56
  """
54
57
  return self._price_lookup.copy()
55
58
 
59
+ def _get_fallback_price(self, inference_service: str) -> Dict:
60
+ """
61
+ Get fallback prices for a service.
62
+ - First fallback: The highest input and output prices for that service from the price lookup.
63
+ - Second fallback: $1.00 per million tokens (for both input and output).
64
+
65
+ Args:
66
+ inference_service (str): The inference service name
67
+
68
+ Returns:
69
+ Dict: Price information
70
+ """
71
+ service_prices = [
72
+ prices
73
+ for (service, _), prices in self._price_lookup.items()
74
+ if service == inference_service
75
+ ]
76
+
77
+ input_tokens_per_usd = [
78
+ float(p["input"]["one_usd_buys"]) for p in service_prices if "input" in p
79
+ ]
80
+ if input_tokens_per_usd:
81
+ min_input_tokens = min(input_tokens_per_usd)
82
+ else:
83
+ min_input_tokens = 1_000_000
84
+
85
+ output_tokens_per_usd = [
86
+ float(p["output"]["one_usd_buys"]) for p in service_prices if "output" in p
87
+ ]
88
+ if output_tokens_per_usd:
89
+ min_output_tokens = min(output_tokens_per_usd)
90
+ else:
91
+ min_output_tokens = 1_000_000
92
+
93
+ return {
94
+ "input": {"one_usd_buys": min_input_tokens},
95
+ "output": {"one_usd_buys": min_output_tokens},
96
+ }
97
+
56
98
  def calculate_cost(
57
99
  self,
58
100
  inference_service: str,
@@ -75,8 +117,6 @@ class PriceManager:
75
117
  Union[float, str]: Total cost if calculation successful, error message string if not
76
118
  """
77
119
  relevant_prices = self.get_price(inference_service, model)
78
- if relevant_prices is None:
79
- return f"Could not find price for model {model} in the price lookup."
80
120
 
81
121
  # Extract token counts
82
122
  try:
@@ -17,7 +17,11 @@ if TYPE_CHECKING:
17
17
  from edsl.results.Dataset import Dataset
18
18
 
19
19
 
20
- def get_model_class(model_name, registry: Optional[InferenceServicesCollection] = None, service_name: Optional[InferenceServiceLiteral] = None):
20
+ def get_model_class(
21
+ model_name,
22
+ registry: Optional[InferenceServicesCollection] = None,
23
+ service_name: Optional[InferenceServiceLiteral] = None,
24
+ ):
21
25
  from edsl.inference_services.registry import default
22
26
 
23
27
  registry = registry or default
@@ -40,6 +44,9 @@ class Meta(type):
40
44
  To get the default model, you can leave out the model name.
41
45
  To see the available models, you can do:
42
46
  >>> Model.available()
47
+
48
+ Or to see the models for a specific service, you can do:
49
+ >>> Model.available(service='openai')
43
50
  """
44
51
  )
45
52
 
@@ -97,7 +104,10 @@ class Model(metaclass=Meta):
97
104
  *args,
98
105
  **kwargs,
99
106
  ):
100
- "Instantiate a new language model."
107
+ """Instantiate a new language model.
108
+ >>> Model()
109
+ Model(...)
110
+ """
101
111
  # Map index to the respective subclass
102
112
  if model_name is None:
103
113
  model_name = cls.default_model
@@ -127,28 +137,25 @@ class Model(metaclass=Meta):
127
137
  >>> Model.service_classes()
128
138
  [...]
129
139
  """
130
- return [r for r in cls.services(name_only=True)]
140
+ return [r for r in cls.services()]
131
141
 
132
142
  @classmethod
133
143
  def services(cls, name_only: bool = False) -> List[str]:
134
- """Returns a list of services, annotated with whether the user has local keys for them."""
135
- services_with_local_keys = set(cls.key_info().select("service").to_list())
136
- f = lambda service_name: (
137
- "yes" if service_name in services_with_local_keys else " "
138
- )
139
- if name_only:
140
- return PrettyList(
141
- [r._inference_service_ for r in cls.get_registry().services],
142
- columns=["Service Name"],
143
- )
144
- else:
145
- return PrettyList(
144
+ """Returns a list of services excluding 'test', sorted alphabetically.
145
+
146
+ >>> Model.services()
147
+ [...]
148
+ """
149
+ return PrettyList(
150
+ sorted(
146
151
  [
147
- (r._inference_service_, f(r._inference_service_))
152
+ [r._inference_service_]
148
153
  for r in cls.get_registry().services
149
- ],
150
- columns=["Service Name", "Local key?"],
151
- )
154
+ if r._inference_service_.lower() != "test"
155
+ ]
156
+ ),
157
+ columns=["Service Name"],
158
+ )
152
159
 
153
160
  @classmethod
154
161
  def services_with_local_keys(cls) -> set:
@@ -198,7 +205,15 @@ class Model(metaclass=Meta):
198
205
  search_term: str = None,
199
206
  name_only: bool = False,
200
207
  service: Optional[str] = None,
208
+ force_refresh: bool = False,
201
209
  ):
210
+ """Get available models
211
+
212
+ >>> Model.available()
213
+ [...]
214
+ >>> Model.available(service='openai')
215
+ [...]
216
+ """
202
217
  # if search_term is None and service is None:
203
218
  # print("Getting available models...")
204
219
  # print("You have local keys for the following services:")
@@ -209,13 +224,16 @@ class Model(metaclass=Meta):
209
224
  # return None
210
225
 
211
226
  if service is not None:
212
- if service not in cls.services(name_only=True):
227
+ known_services = [x[0] for x in cls.services(name_only=True)]
228
+ if service not in known_services:
213
229
  raise ValueError(
214
230
  f"Service {service} not found in available services.",
215
- f"Available services are: {cls.services()}",
231
+ f"Available services are: {known_services}",
216
232
  )
217
233
 
218
- full_list = cls.get_registry().available(service=service)
234
+ full_list = cls.get_registry().available(
235
+ service=service, force_refresh=force_refresh
236
+ )
219
237
 
220
238
  if search_term is None:
221
239
  if name_only:
@@ -319,6 +337,9 @@ class Model(metaclass=Meta):
319
337
  """
320
338
  Returns an example Model instance.
321
339
 
340
+ >>> Model.example()
341
+ Model(...)
342
+
322
343
  :param randomize: If True, the temperature is set to a random decimal between 0 and 1.
323
344
  """
324
345
  temperature = 0.5 if not randomize else round(random(), 2)
@@ -331,7 +352,7 @@ if __name__ == "__main__":
331
352
 
332
353
  doctest.testmod(optionflags=doctest.ELLIPSIS)
333
354
 
334
- available = Model.available()
335
- m = Model("gpt-4-1106-preview")
336
- results = m.execute_model_call("Hello world")
337
- print(results)
355
+ # available = Model.available()
356
+ # m = Model("gpt-4-1106-preview")
357
+ # results = m.execute_model_call("Hello world")
358
+ # print(results)
@@ -85,6 +85,9 @@ class QuestionBase(
85
85
  >>> Q.example()._simulate_answer()
86
86
  {'answer': '...', 'generated_tokens': ...}
87
87
  """
88
+ if self.question_type == "free_text":
89
+ return {"answer": "Hello, how are you?", 'generated_tokens': "Hello, how are you?"}
90
+
88
91
  simulated_answer = self.fake_data_factory.build().dict()
89
92
  if human_readable and hasattr(self, "question_options") and self.use_code:
90
93
  simulated_answer["answer"] = [
@@ -432,6 +435,24 @@ class QuestionBase(
432
435
 
433
436
  return Survey([self])
434
437
 
438
+ def humanize(
439
+ self,
440
+ project_name: str = "Project",
441
+ survey_description: Optional[str] = None,
442
+ survey_alias: Optional[str] = None,
443
+ survey_visibility: Optional["VisibilityType"] = "unlisted",
444
+ ) -> dict:
445
+ """
446
+ Turn a single question into a survey and send the survey to Coop.
447
+
448
+ Then, create a project on Coop so you can share the survey with human respondents.
449
+ """
450
+ s = self.to_survey()
451
+ project_details = s.humanize(
452
+ project_name, survey_description, survey_alias, survey_visibility
453
+ )
454
+ return project_details
455
+
435
456
  def by(self, *args) -> "Jobs":
436
457
  """Turn a single question into a survey and then a Job."""
437
458
  from edsl.surveys.Survey import Survey
@@ -187,6 +187,73 @@ class QuestionBasePromptsMixin:
187
187
  from edsl.prompts import Prompt
188
188
 
189
189
  return Prompt(self.question_presentation) + Prompt(self.answering_instructions)
190
+
191
+
192
+ def detailed_parameters_by_key(self) -> dict[str, set[tuple[str, ...]]]:
193
+ """
194
+ Return a dictionary of parameters by key.
195
+
196
+ >>> from edsl import QuestionMultipleChoice
197
+ >>> QuestionMultipleChoice.example().detailed_parameters_by_key()
198
+ {'question_name': set(), 'question_text': set()}
199
+
200
+ >>> from edsl import QuestionFreeText
201
+ >>> q = QuestionFreeText(question_name = "example", question_text = "What is your name, {{ nickname }}, based on {{ q0.answer }}?")
202
+ >>> r = q.detailed_parameters_by_key()
203
+ >>> r == {'question_name': set(), 'question_text': {('q0', 'answer'), ('nickname',)}}
204
+ True
205
+ """
206
+ params_by_key = {}
207
+ for key, value in self.data.items():
208
+ if isinstance(value, str):
209
+ params_by_key[key] = self.extract_parameters(value)
210
+ return params_by_key
211
+
212
+ @staticmethod
213
+ def extract_parameters(txt: str) -> set[tuple[str, ...]]:
214
+ """Return all parameters of the question as tuples representing their full paths.
215
+
216
+ :param txt: The text to extract parameters from.
217
+ :return: A set of tuples representing the parameters.
218
+
219
+ >>> from edsl.questions import QuestionMultipleChoice
220
+ >>> d = QuestionMultipleChoice.example().extract_parameters("What is your name, {{ nickname }}, based on {{ q0.answer }}?")
221
+ >>> d =={('nickname',), ('q0', 'answer')}
222
+ True
223
+ """
224
+ from jinja2 import Environment, nodes
225
+
226
+ env = Environment()
227
+ #txt = self._all_text()
228
+ ast = env.parse(txt)
229
+
230
+ variables = set()
231
+ processed_nodes = set() # Keep track of nodes we've processed
232
+
233
+ def visit_node(node, path=()):
234
+ if id(node) in processed_nodes:
235
+ return
236
+ processed_nodes.add(id(node))
237
+
238
+ if isinstance(node, nodes.Name):
239
+ # Only add the name if we're not in the middle of building a longer path
240
+ if not path:
241
+ variables.add((node.name,))
242
+ else:
243
+ variables.add((node.name,) + path)
244
+ elif isinstance(node, nodes.Getattr):
245
+ # Build path from bottom up
246
+ new_path = (node.attr,) + path
247
+ visit_node(node.node, new_path)
248
+
249
+ for node in ast.find_all((nodes.Name, nodes.Getattr)):
250
+ visit_node(node)
251
+
252
+ return variables
253
+
254
+ @property
255
+ def detailed_parameters(self):
256
+ return [".".join(p) for p in self.extract_parameters(self._all_text())]
190
257
 
191
258
  @property
192
259
  def parameters(self) -> set[str]:
@@ -219,3 +286,39 @@ class QuestionBasePromptsMixin:
219
286
  return self.new_default_instructions
220
287
  else:
221
288
  return self.applicable_prompts(model)[0]()
289
+
290
+ @staticmethod
291
+ def sequence_in_dict(d: dict, path: tuple[str, ...]) -> tuple[bool, any]:
292
+ """Check if a sequence of nested keys exists in a dictionary and return the value.
293
+
294
+ Args:
295
+ d: The dictionary to check
296
+ path: Tuple of keys representing the nested path
297
+
298
+ Returns:
299
+ tuple[bool, any]: (True, value) if the path exists, (False, None) otherwise
300
+
301
+ Example:
302
+ >>> sequence_in_dict = QuestionBasePromptsMixin.sequence_in_dict
303
+ >>> d = {'a': {'b': {'c': 1}}}
304
+ >>> sequence_in_dict(d, ('a', 'b', 'c'))
305
+ (True, 1)
306
+ >>> sequence_in_dict(d, ('a', 'b', 'd'))
307
+ (False, None)
308
+ >>> sequence_in_dict(d, ('x',))
309
+ (False, None)
310
+ """
311
+ try:
312
+ current = d
313
+ for key in path:
314
+ current = current.get(key)
315
+ if current is None:
316
+ return (False, None)
317
+ return (True, current)
318
+ except (AttributeError, TypeError):
319
+ return (False, None)
320
+
321
+
322
+ if __name__ == "__main__":
323
+ import doctest
324
+ doctest.testmod()
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
  from typing import Any, Optional
3
3
  from uuid import uuid4
4
4
 
5
- from pydantic import field_validator
5
+ from pydantic import field_validator, model_validator
6
6
 
7
7
  from edsl.questions.QuestionBase import QuestionBase
8
8
  from edsl.questions.response_validator_abc import ResponseValidatorABC
@@ -24,6 +24,17 @@ class FreeTextResponse(BaseModel):
24
24
  answer: str
25
25
  generated_tokens: Optional[str] = None
26
26
 
27
+ @model_validator(mode='after')
28
+ def validate_tokens_match_answer(self):
29
+ if self.generated_tokens is not None: # If generated_tokens exists
30
+ # Ensure exact string equality
31
+ if self.answer.strip() != self.generated_tokens.strip(): # They MUST match exactly
32
+ raise ValueError(
33
+ f"answer '{self.answer}' must exactly match generated_tokens '{self.generated_tokens}'. "
34
+ f"Type of answer: {type(self.answer)}, Type of tokens: {type(self.generated_tokens)}"
35
+ )
36
+ return self
37
+
27
38
 
28
39
  class FreeTextResponseValidator(ResponseValidatorABC):
29
40
  required_params = []
@@ -37,10 +48,16 @@ class FreeTextResponseValidator(ResponseValidatorABC):
37
48
  ]
38
49
 
39
50
  def fix(self, response, verbose=False):
40
- return {
41
- "answer": str(response.get("generated_tokens")),
42
- "generated_tokens": str(response.get("generated_tokens")),
43
- }
51
+ if response.get("generated_tokens") != response.get("answer"):
52
+ return {
53
+ "answer": str(response.get("generated_tokens")),
54
+ "generated_tokens": str(response.get("generated_tokens")),
55
+ }
56
+ else:
57
+ return {
58
+ "answer": str(response.get("generated_tokens")),
59
+ "generated_tokens": str(response.get("generated_tokens")),
60
+ }
44
61
 
45
62
 
46
63
  class QuestionFreeText(QuestionBase):
@@ -2,6 +2,7 @@
2
2
 
3
3
  from abc import ABC, abstractmethod
4
4
  import re
5
+ import textwrap
5
6
  from typing import Any, Callable, List, Optional
6
7
  from edsl.exceptions.questions import (
7
8
  QuestionCreationValidationError,
@@ -404,6 +405,9 @@ class QuestionTextDescriptor(BaseDescriptor):
404
405
  raise Exception("Question is too short!")
405
406
  if not isinstance(value, str):
406
407
  raise Exception("Question must be a string!")
408
+
409
+ #value = textwrap.dedent(value).strip()
410
+
407
411
  if contains_single_braced_substring(value):
408
412
  import warnings
409
413