edsl 0.1.33.dev3__py3-none-any.whl → 0.1.34.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. edsl/Base.py +15 -11
  2. edsl/__version__.py +1 -1
  3. edsl/agents/Invigilator.py +22 -3
  4. edsl/agents/PromptConstructor.py +79 -183
  5. edsl/agents/prompt_helpers.py +129 -0
  6. edsl/coop/coop.py +3 -2
  7. edsl/data_transfer_models.py +0 -1
  8. edsl/inference_services/AnthropicService.py +5 -2
  9. edsl/inference_services/AwsBedrock.py +5 -2
  10. edsl/inference_services/AzureAI.py +5 -2
  11. edsl/inference_services/GoogleService.py +108 -33
  12. edsl/inference_services/MistralAIService.py +5 -2
  13. edsl/inference_services/OpenAIService.py +3 -2
  14. edsl/inference_services/TestService.py +11 -2
  15. edsl/inference_services/TogetherAIService.py +1 -1
  16. edsl/jobs/interviews/Interview.py +19 -9
  17. edsl/jobs/runners/JobsRunnerAsyncio.py +37 -16
  18. edsl/jobs/runners/JobsRunnerStatus.py +4 -3
  19. edsl/jobs/tasks/QuestionTaskCreator.py +1 -13
  20. edsl/language_models/LanguageModel.py +12 -9
  21. edsl/language_models/utilities.py +3 -2
  22. edsl/questions/QuestionBase.py +11 -2
  23. edsl/questions/QuestionBaseGenMixin.py +28 -0
  24. edsl/questions/QuestionCheckBox.py +1 -1
  25. edsl/questions/QuestionMultipleChoice.py +5 -1
  26. edsl/questions/ResponseValidatorABC.py +5 -1
  27. edsl/questions/descriptors.py +12 -11
  28. edsl/questions/templates/yes_no/answering_instructions.jinja +2 -2
  29. edsl/scenarios/FileStore.py +159 -71
  30. edsl/scenarios/Scenario.py +23 -49
  31. edsl/scenarios/ScenarioList.py +6 -2
  32. edsl/surveys/DAG.py +62 -0
  33. edsl/surveys/MemoryPlan.py +26 -0
  34. edsl/surveys/Rule.py +24 -0
  35. edsl/surveys/RuleCollection.py +36 -2
  36. edsl/surveys/Survey.py +182 -10
  37. {edsl-0.1.33.dev3.dist-info → edsl-0.1.34.dev1.dist-info}/METADATA +2 -1
  38. {edsl-0.1.33.dev3.dist-info → edsl-0.1.34.dev1.dist-info}/RECORD +40 -40
  39. edsl/scenarios/ScenarioImageMixin.py +0 -100
  40. {edsl-0.1.33.dev3.dist-info → edsl-0.1.34.dev1.dist-info}/LICENSE +0 -0
  41. {edsl-0.1.33.dev3.dist-info → edsl-0.1.34.dev1.dist-info}/WHEEL +0 -0
@@ -1,25 +1,54 @@
1
1
  import os
2
- import aiohttp
3
- import json
4
- from typing import Any
2
+ from typing import Any, Dict, List, Optional
3
+ import google
4
+ import google.generativeai as genai
5
+ from google.generativeai.types import GenerationConfig
6
+ from google.api_core.exceptions import InvalidArgument
7
+
5
8
  from edsl.exceptions import MissingAPIKeyError
6
9
  from edsl.language_models.LanguageModel import LanguageModel
7
-
8
10
  from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
9
11
 
12
+ safety_settings = [
13
+ {
14
+ "category": "HARM_CATEGORY_HARASSMENT",
15
+ "threshold": "BLOCK_NONE",
16
+ },
17
+ {
18
+ "category": "HARM_CATEGORY_HATE_SPEECH",
19
+ "threshold": "BLOCK_NONE",
20
+ },
21
+ {
22
+ "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
23
+ "threshold": "BLOCK_NONE",
24
+ },
25
+ {
26
+ "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
27
+ "threshold": "BLOCK_NONE",
28
+ },
29
+ ]
30
+
10
31
 
11
32
  class GoogleService(InferenceServiceABC):
12
33
  _inference_service_ = "google"
13
34
  key_sequence = ["candidates", 0, "content", "parts", 0, "text"]
14
- usage_sequence = ["usageMetadata"]
15
- input_token_name = "promptTokenCount"
16
- output_token_name = "candidatesTokenCount"
35
+ usage_sequence = ["usage_metadata"]
36
+ input_token_name = "prompt_token_count"
37
+ output_token_name = "candidates_token_count"
17
38
 
18
39
  model_exclude_list = []
19
40
 
41
+ # @classmethod
42
+ # def available(cls) -> List[str]:
43
+ # return ["gemini-pro", "gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.0-pro"]
44
+
20
45
  @classmethod
21
- def available(cls):
22
- return ["gemini-pro", "gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.0-pro"]
46
+ def available(cls) -> List[str]:
47
+ model_list = []
48
+ for m in genai.list_models():
49
+ if "generateContent" in m.supported_generation_methods:
50
+ model_list.append(m.name.split("/")[-1])
51
+ return model_list
23
52
 
24
53
  @classmethod
25
54
  def create_model(
@@ -47,33 +76,79 @@ class GoogleService(InferenceServiceABC):
47
76
  "stopSequences": [],
48
77
  }
49
78
 
79
+ api_token = None
80
+ model = None
81
+
82
+ @classmethod
83
+ def initialize(cls):
84
+ if cls.api_token is None:
85
+ cls.api_token = os.getenv("GOOGLE_API_KEY")
86
+ if not cls.api_token:
87
+ raise MissingAPIKeyError(
88
+ "GOOGLE_API_KEY environment variable is not set"
89
+ )
90
+ genai.configure(api_key=cls.api_token)
91
+ cls.generative_model = genai.GenerativeModel(
92
+ cls._model_, safety_settings=safety_settings
93
+ )
94
+
95
+ def __init__(self, *args, **kwargs):
96
+ super().__init__(*args, **kwargs)
97
+ self.initialize()
98
+
99
+ def get_generation_config(self) -> GenerationConfig:
100
+ return GenerationConfig(
101
+ temperature=self.temperature,
102
+ top_p=self.topP,
103
+ top_k=self.topK,
104
+ max_output_tokens=self.maxOutputTokens,
105
+ stop_sequences=self.stopSequences,
106
+ )
107
+
50
108
  async def async_execute_model_call(
51
- self, user_prompt: str, system_prompt: str = ""
52
- ) -> dict[str, Any]:
53
- # self.api_token = os.getenv("GOOGLE_API_KEY")
54
- combined_prompt = user_prompt + system_prompt
55
- url = f"https://generativelanguage.googleapis.com/v1beta/models/{model_name}:generateContent?key={self.api_token}"
56
- headers = {"Content-Type": "application/json"}
57
- data = {
58
- "contents": [{"parts": [{"text": combined_prompt}]}],
59
- "generationConfig": {
60
- "temperature": self.temperature,
61
- "topK": self.topK,
62
- "topP": self.topP,
63
- "maxOutputTokens": self.maxOutputTokens,
64
- "stopSequences": self.stopSequences,
65
- },
66
- }
67
- # print(combined_prompt)
68
- async with aiohttp.ClientSession() as session:
69
- async with session.post(
70
- url, headers=headers, data=json.dumps(data)
71
- ) as response:
72
- raw_response_text = await response.text()
73
- return json.loads(raw_response_text)
109
+ self,
110
+ user_prompt: str,
111
+ system_prompt: str = "",
112
+ files_list: Optional["Files"] = None,
113
+ ) -> Dict[str, Any]:
114
+ generation_config = self.get_generation_config()
74
115
 
75
- LLM.__name__ = model_name
116
+ if files_list is None:
117
+ files_list = []
118
+
119
+ if (
120
+ system_prompt is not None
121
+ and system_prompt != ""
122
+ and self._model_ != "gemini-pro"
123
+ ):
124
+ try:
125
+ self.generative_model = genai.GenerativeModel(
126
+ self._model_,
127
+ safety_settings=safety_settings,
128
+ system_instruction=system_prompt,
129
+ )
130
+ except InvalidArgument as e:
131
+ print(
132
+ f"This model, {self._model_}, does not support system_instruction"
133
+ )
134
+ print("Will add system_prompt to user_prompt")
135
+ user_prompt = f"{system_prompt}\n{user_prompt}"
76
136
 
137
+ combined_prompt = [user_prompt]
138
+ for file in files_list:
139
+ if "google" not in file.external_locations:
140
+ _ = file.upload_google()
141
+ gen_ai_file = google.generativeai.types.file_types.File(
142
+ file.external_locations["google"]
143
+ )
144
+ combined_prompt.append(gen_ai_file)
145
+
146
+ response = await self.generative_model.generate_content_async(
147
+ combined_prompt, generation_config=generation_config
148
+ )
149
+ return response.to_dict()
150
+
151
+ LLM.__name__ = model_name
77
152
  return LLM
78
153
 
79
154
 
@@ -1,5 +1,5 @@
1
1
  import os
2
- from typing import Any, List
2
+ from typing import Any, List, Optional
3
3
  from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
4
4
  from edsl.language_models.LanguageModel import LanguageModel
5
5
  import asyncio
@@ -95,7 +95,10 @@ class MistralAIService(InferenceServiceABC):
95
95
  return cls.async_client()
96
96
 
97
97
  async def async_execute_model_call(
98
- self, user_prompt: str, system_prompt: str = ""
98
+ self,
99
+ user_prompt: str,
100
+ system_prompt: str = "",
101
+ files_list: Optional[List["FileStore"]] = None,
99
102
  ) -> dict[str, Any]:
100
103
  """Calls the Mistral API and returns the API response."""
101
104
  s = self.async_client()
@@ -168,13 +168,14 @@ class OpenAIService(InferenceServiceABC):
168
168
  self,
169
169
  user_prompt: str,
170
170
  system_prompt: str = "",
171
- encoded_image=None,
171
+ files_list: Optional[List["Files"]] = None,
172
172
  invigilator: Optional[
173
173
  "InvigilatorAI"
174
174
  ] = None, # TBD - can eventually be used for function-calling
175
175
  ) -> dict[str, Any]:
176
176
  """Calls the OpenAI API and returns the API response."""
177
- if encoded_image:
177
+ if files_list:
178
+ encoded_image = files_list[0].base64_string
178
179
  content = [{"type": "text", "text": user_prompt}]
179
180
  content.append(
180
181
  {
@@ -1,4 +1,4 @@
1
- from typing import Any, List
1
+ from typing import Any, List, Optional
2
2
  import os
3
3
  import asyncio
4
4
  from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
@@ -59,11 +59,20 @@ class TestService(InferenceServiceABC):
59
59
  self,
60
60
  user_prompt: str,
61
61
  system_prompt: str,
62
- encoded_image=None,
62
+ # func: Optional[callable] = None,
63
+ files_list: Optional[List["File"]] = None,
63
64
  ) -> dict[str, Any]:
64
65
  await asyncio.sleep(0.1)
65
66
  # return {"message": """{"answer": "Hello, world"}"""}
66
67
 
68
+ if hasattr(self, "func"):
69
+ return {
70
+ "message": [
71
+ {"text": self.func(user_prompt, system_prompt, files_list)}
72
+ ],
73
+ "usage": {"prompt_tokens": 1, "completion_tokens": 1},
74
+ }
75
+
67
76
  if hasattr(self, "throw_exception") and self.throw_exception:
68
77
  if hasattr(self, "exception_probability"):
69
78
  p = self.exception_probability
@@ -1,7 +1,7 @@
1
1
  import aiohttp
2
2
  import json
3
3
  import requests
4
- from typing import Any, List
4
+ from typing import Any, List, Optional
5
5
 
6
6
  # from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
7
7
  from edsl.language_models import LanguageModel
@@ -105,9 +105,9 @@ class Interview(InterviewStatusMixin):
105
105
  self.debug = debug
106
106
  self.iteration = iteration
107
107
  self.cache = cache
108
- self.answers: dict[
109
- str, str
110
- ] = Answers() # will get filled in as interview progresses
108
+ self.answers: dict[str, str] = (
109
+ Answers()
110
+ ) # will get filled in as interview progresses
111
111
  self.sidecar_model = sidecar_model
112
112
 
113
113
  # self.stop_on_exception = False
@@ -248,17 +248,24 @@ class Interview(InterviewStatusMixin):
248
248
 
249
249
  def _get_estimated_request_tokens(self, question) -> float:
250
250
  """Estimate the number of tokens that will be required to run the focal task."""
251
+ from edsl.scenarios.FileStore import FileStore
252
+
251
253
  invigilator = self._get_invigilator(question=question)
252
254
  # TODO: There should be a way to get a more accurate estimate.
253
255
  combined_text = ""
256
+ file_tokens = 0
254
257
  for prompt in invigilator.get_prompts().values():
255
258
  if hasattr(prompt, "text"):
256
259
  combined_text += prompt.text
257
260
  elif isinstance(prompt, str):
258
261
  combined_text += prompt
262
+ elif isinstance(prompt, list):
263
+ for file in prompt:
264
+ if isinstance(file, FileStore):
265
+ file_tokens += file.size * 0.25
259
266
  else:
260
267
  raise ValueError(f"Prompt is of type {type(prompt)}")
261
- return len(combined_text) / 4.0
268
+ return len(combined_text) / 4.0 + file_tokens
262
269
 
263
270
  async def _answer_question_and_record_task(
264
271
  self,
@@ -296,6 +303,9 @@ class Interview(InterviewStatusMixin):
296
303
  self.answers.add_answer(response=response, question=question)
297
304
  self._cancel_skipped_questions(question)
298
305
  else:
306
+ # When a question is not validated, it is not added to the answers.
307
+ # this should also cancel and dependent children questions.
308
+ # Is that happening now?
299
309
  if (
300
310
  hasattr(response, "exception_occurred")
301
311
  and response.exception_occurred
@@ -418,11 +428,11 @@ class Interview(InterviewStatusMixin):
418
428
  """
419
429
  current_question_index: int = self.to_index[current_question.question_name]
420
430
 
421
- next_question: Union[
422
- int, EndOfSurvey
423
- ] = self.survey.rule_collection.next_question(
424
- q_now=current_question_index,
425
- answers=self.answers | self.scenario | self.agent["traits"],
431
+ next_question: Union[int, EndOfSurvey] = (
432
+ self.survey.rule_collection.next_question(
433
+ q_now=current_question_index,
434
+ answers=self.answers | self.scenario | self.agent["traits"],
435
+ )
426
436
  )
427
437
 
428
438
  next_question_index = next_question.next_q
@@ -8,10 +8,10 @@ from typing import Coroutine, List, AsyncGenerator, Optional, Union, Generator
8
8
  from contextlib import contextmanager
9
9
  from collections import UserList
10
10
 
11
- from edsl.results.Results import Results
12
11
  from rich.live import Live
13
12
  from rich.console import Console
14
13
 
14
+ from edsl.results.Results import Results
15
15
  from edsl import shared_globals
16
16
  from edsl.jobs.interviews.Interview import Interview
17
17
  from edsl.jobs.runners.JobsRunnerStatus import JobsRunnerStatus
@@ -22,6 +22,8 @@ from edsl.utilities.decorators import jupyter_nb_handler
22
22
  from edsl.data.Cache import Cache
23
23
  from edsl.results.Result import Result
24
24
  from edsl.results.Results import Results
25
+ from edsl.language_models.LanguageModel import LanguageModel
26
+ from edsl.data.Cache import Cache
25
27
 
26
28
 
27
29
  class StatusTracker(UserList):
@@ -50,10 +52,10 @@ class JobsRunnerAsyncio:
50
52
 
51
53
  async def run_async_generator(
52
54
  self,
53
- cache: "Cache",
55
+ cache: Cache,
54
56
  n: int = 1,
55
57
  stop_on_exception: bool = False,
56
- sidecar_model: Optional["LanguageModel"] = None,
58
+ sidecar_model: Optional[LanguageModel] = None,
57
59
  total_interviews: Optional[List["Interview"]] = None,
58
60
  raise_validation_errors: bool = False,
59
61
  ) -> AsyncGenerator["Result", None]:
@@ -104,7 +106,7 @@ class JobsRunnerAsyncio:
104
106
  interview.cache = self.cache
105
107
  yield interview
106
108
 
107
- async def run_async(self, cache: Optional["Cache"] = None, n: int = 1) -> Results:
109
+ async def run_async(self, cache: Optional[Cache] = None, n: int = 1) -> Results:
108
110
  """Used for some other modules that have a non-standard way of running interviews."""
109
111
  self.jobs_runner_status = JobsRunnerStatus(self, n=n)
110
112
  self.cache = Cache() if cache is None else cache
@@ -291,6 +293,8 @@ class JobsRunnerAsyncio:
291
293
 
292
294
  self.jobs_runner_status = JobsRunnerStatus(self, n=n)
293
295
 
296
+ stop_event = threading.Event()
297
+
294
298
  async def process_results(cache):
295
299
  """Processes results from interviews."""
296
300
  async for result in self.run_async_generator(
@@ -303,20 +307,37 @@ class JobsRunnerAsyncio:
303
307
  self.results.append(result)
304
308
  self.completed = True
305
309
 
306
- def run_progress_bar():
310
+ def run_progress_bar(stop_event):
307
311
  """Runs the progress bar in a separate thread."""
308
- self.jobs_runner_status.update_progress()
312
+ self.jobs_runner_status.update_progress(stop_event)
309
313
 
310
314
  if progress_bar:
311
- progress_thread = threading.Thread(target=run_progress_bar)
315
+ progress_thread = threading.Thread(
316
+ target=run_progress_bar, args=(stop_event,)
317
+ )
312
318
  progress_thread.start()
313
319
 
314
- with cache as c:
315
- await process_results(cache=c)
316
-
317
- if progress_bar:
318
- progress_thread.join()
319
-
320
- return self.process_results(
321
- raw_results=self.results, cache=cache, print_exceptions=print_exceptions
322
- )
320
+ exception_to_raise = None
321
+ try:
322
+ with cache as c:
323
+ await process_results(cache=c)
324
+ except KeyboardInterrupt:
325
+ print("Keyboard interrupt received. Stopping gracefully...")
326
+ stop_event.set()
327
+ except Exception as e:
328
+ if stop_on_exception:
329
+ exception_to_raise = e
330
+ stop_event.set()
331
+ finally:
332
+ stop_event.set()
333
+ if progress_bar:
334
+ # self.jobs_runner_status.stop_event.set()
335
+ if progress_thread:
336
+ progress_thread.join()
337
+
338
+ if exception_to_raise:
339
+ raise exception_to_raise
340
+
341
+ return self.process_results(
342
+ raw_results=self.results, cache=cache, print_exceptions=print_exceptions
343
+ )
@@ -265,14 +265,15 @@ class JobsRunnerStatus:
265
265
  table.add_row(pretty_name, value)
266
266
  return table
267
267
 
268
- def update_progress(self):
268
+ def update_progress(self, stop_event):
269
269
  layout, progress, task_ids = self.generate_layout()
270
270
 
271
271
  with Live(
272
272
  layout, refresh_per_second=int(1 / self.refresh_rate), transient=True
273
273
  ) as live:
274
- while len(self.completed_interviews) < len(
275
- self.jobs_runner.total_interviews
274
+ while (
275
+ len(self.completed_interviews) < len(self.jobs_runner.total_interviews)
276
+ and not stop_event.is_set()
276
277
  ):
277
278
  completed_tasks = len(self.completed_interviews)
278
279
  total_tasks = len(self.jobs_runner.total_interviews)
@@ -156,19 +156,6 @@ class QuestionTaskCreator(UserList):
156
156
  self.tokens_bucket.turbo_mode_off()
157
157
  self.requests_bucket.turbo_mode_off()
158
158
 
159
- # breakpoint()
160
- # _ = results.pop("cached_response", None)
161
-
162
- # tracker = self.cached_token_usage if self.from_cache else self.new_token_usage
163
-
164
- # TODO: This is hacky. The 'func' call should return an object that definitely has a 'usage' key.
165
- # usage = results.get("usage", {"prompt_tokens": 0, "completion_tokens": 0})
166
- # prompt_tokens = usage.get("prompt_tokens", 0)
167
- # completion_tokens = usage.get("completion_tokens", 0)
168
- # tracker.add_tokens(
169
- # prompt_tokens=prompt_tokens, completion_tokens=completion_tokens
170
- # )
171
-
172
159
  return results
173
160
 
174
161
  @classmethod
@@ -249,6 +236,7 @@ class QuestionTaskCreator(UserList):
249
236
  f"Required tasks failed for {self.question.question_name}"
250
237
  ) from e
251
238
 
239
+ # this only runs if all the dependencies are successful
252
240
  return await self._run_focal_task()
253
241
 
254
242
 
@@ -440,7 +440,7 @@ class LanguageModel(
440
440
  system_prompt: str,
441
441
  cache: "Cache",
442
442
  iteration: int = 0,
443
- encoded_image=None,
443
+ files_list=None,
444
444
  ) -> ModelResponse:
445
445
  """Handle caching of responses.
446
446
 
@@ -462,16 +462,18 @@ class LanguageModel(
462
462
  >>> m._get_intended_model_call_outcome(user_prompt = "Hello", system_prompt = "hello", cache = Cache())
463
463
  ModelResponse(...)"""
464
464
 
465
- if encoded_image:
466
- # the image has is appended to the user_prompt for hash-lookup purposes
467
- image_hash = hashlib.md5(encoded_image.encode()).hexdigest()
468
- user_prompt += f" {image_hash}"
465
+ if files_list:
466
+ files_hash = "+".join([str(hash(file)) for file in files_list])
467
+ # print(f"Files hash: {files_hash}")
468
+ user_prompt_with_hashes = user_prompt + f" {files_hash}"
469
+ else:
470
+ user_prompt_with_hashes = user_prompt
469
471
 
470
472
  cache_call_params = {
471
473
  "model": str(self.model),
472
474
  "parameters": self.parameters,
473
475
  "system_prompt": system_prompt,
474
- "user_prompt": user_prompt,
476
+ "user_prompt": user_prompt_with_hashes,
475
477
  "iteration": iteration,
476
478
  }
477
479
  cached_response, cache_key = cache.fetch(**cache_call_params)
@@ -487,7 +489,8 @@ class LanguageModel(
487
489
  params = {
488
490
  "user_prompt": user_prompt,
489
491
  "system_prompt": system_prompt,
490
- **({"encoded_image": encoded_image} if encoded_image else {}),
492
+ "files_list": files_list
493
+ #**({"encoded_image": encoded_image} if encoded_image else {}),
491
494
  }
492
495
  # response = await f(**params)
493
496
  response = await asyncio.wait_for(f(**params), timeout=TIMEOUT)
@@ -531,7 +534,7 @@ class LanguageModel(
531
534
  system_prompt: str,
532
535
  cache: "Cache",
533
536
  iteration: int = 1,
534
- encoded_image=None,
537
+ files_list: Optional[List['File']] = None,
535
538
  ) -> dict:
536
539
  """Get response, parse, and return as string.
537
540
 
@@ -547,7 +550,7 @@ class LanguageModel(
547
550
  "system_prompt": system_prompt,
548
551
  "iteration": iteration,
549
552
  "cache": cache,
550
- **({"encoded_image": encoded_image} if encoded_image else {}),
553
+ "files_list": files_list,
551
554
  }
552
555
  model_inputs = ModelInputs(user_prompt=user_prompt, system_prompt=system_prompt)
553
556
  model_outputs = await self._async_get_intended_model_call_outcome(**params)
@@ -1,5 +1,5 @@
1
1
  import asyncio
2
- from typing import Any
2
+ from typing import Any, Optional, List
3
3
  from edsl import Survey
4
4
  from edsl.config import CONFIG
5
5
  from edsl.enums import InferenceServiceType
@@ -40,7 +40,8 @@ def create_language_model(
40
40
  _tpm = 1000000000000
41
41
 
42
42
  async def async_execute_model_call(
43
- self, user_prompt: str, system_prompt: str
43
+ self, user_prompt: str, system_prompt: str,
44
+ files_list: Optional[List[Any]] = None
44
45
  ) -> dict[str, Any]:
45
46
  question_number = int(
46
47
  user_prompt.split("XX")[1]
@@ -44,6 +44,13 @@ class QuestionBase(
44
44
  _answering_instructions = None
45
45
  _question_presentation = None
46
46
 
47
+ @property
48
+ def response_model(self) -> type["BaseModel"]:
49
+ if self._response_model is not None:
50
+ return self._response_model
51
+ else:
52
+ return self.create_response_model()
53
+
47
54
  # region: Validation and simulation methods
48
55
  @property
49
56
  def response_validator(self) -> "ResponseValidatorBase":
@@ -98,7 +105,9 @@ class QuestionBase(
98
105
  comment: Optional[str]
99
106
  generated_tokens: Optional[str]
100
107
 
101
- def _validate_answer(self, answer: dict) -> ValidatedAnswer:
108
+ def _validate_answer(
109
+ self, answer: dict, replacement_dict: dict = None
110
+ ) -> ValidatedAnswer:
102
111
  """Validate the answer.
103
112
  >>> from edsl.exceptions import QuestionAnswerValidationError
104
113
  >>> from edsl import QuestionFreeText as Q
@@ -106,7 +115,7 @@ class QuestionBase(
106
115
  {'answer': 'Hello', 'generated_tokens': 'Hello'}
107
116
  """
108
117
 
109
- return self.response_validator.validate(answer)
118
+ return self.response_validator.validate(answer, replacement_dict)
110
119
 
111
120
  # endregion
112
121
 
@@ -95,6 +95,34 @@ class QuestionBaseGenMixin:
95
95
  questions.append(QuestionBase.from_dict(new_data))
96
96
  return questions
97
97
 
98
+ def render(self, replacement_dict: dict) -> "QuestionBase":
99
+ """Render the question components as jinja2 templates with the replacement dictionary."""
100
+ from jinja2 import Environment
101
+ from edsl import Scenario
102
+
103
+ strings_only_replacement_dict = {
104
+ k: v for k, v in replacement_dict.items() if not isinstance(v, Scenario)
105
+ }
106
+
107
+ def render_string(value: str) -> str:
108
+ if value is None or not isinstance(value, str):
109
+ return value
110
+ else:
111
+ try:
112
+ return (
113
+ Environment()
114
+ .from_string(value)
115
+ .render(strings_only_replacement_dict)
116
+ )
117
+ except Exception as e:
118
+ import warnings
119
+
120
+ warnings.warn("Failed to render string: " + value)
121
+ # breakpoint()
122
+ return value
123
+
124
+ return self.apply_function(render_string)
125
+
98
126
  def apply_function(self, func: Callable, exclude_components=None) -> QuestionBase:
99
127
  """Apply a function to the question parts
100
128
 
@@ -245,7 +245,7 @@ class QuestionCheckBox(QuestionBase):
245
245
 
246
246
  scenario = scenario or Scenario()
247
247
  translated_options = [
248
- Template(option).render(scenario) for option in self.question_options
248
+ Template(str(option)).render(scenario) for option in self.question_options
249
249
  ]
250
250
  translated_codes = []
251
251
  for answer_code in answer_codes:
@@ -163,7 +163,11 @@ class QuestionMultipleChoice(QuestionBase):
163
163
  # Answer methods
164
164
  ################
165
165
 
166
- def create_response_model(self):
166
+ def create_response_model(self, replacement_dict: dict = None):
167
+ if replacement_dict is None:
168
+ replacement_dict = {}
169
+ # The replacement dict that could be from scenario, current answers, etc. to populate the response model
170
+
167
171
  if self.use_code:
168
172
  return create_response_model(
169
173
  list(range(len(self.question_options))), self.permissive
@@ -92,7 +92,11 @@ class ResponseValidatorABC(ABC):
92
92
  generated_tokens: Optional[str]
93
93
 
94
94
  def validate(
95
- self, raw_edsl_answer_dict: RawEdslAnswerDict, fix=False, verbose=False
95
+ self,
96
+ raw_edsl_answer_dict: RawEdslAnswerDict,
97
+ fix=False,
98
+ verbose=False,
99
+ replacement_dict: dict = None,
96
100
  ) -> EdslAnswerDict:
97
101
  """This is the main validation function.
98
102