edsl 0.1.33.dev2__py3-none-any.whl → 0.1.33.dev3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. edsl/Base.py +9 -3
  2. edsl/__init__.py +1 -0
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +6 -6
  5. edsl/agents/Invigilator.py +6 -3
  6. edsl/agents/InvigilatorBase.py +8 -27
  7. edsl/agents/{PromptConstructionMixin.py → PromptConstructor.py} +101 -29
  8. edsl/config.py +26 -34
  9. edsl/coop/coop.py +11 -2
  10. edsl/data_transfer_models.py +27 -73
  11. edsl/enums.py +2 -0
  12. edsl/inference_services/GoogleService.py +1 -1
  13. edsl/inference_services/InferenceServiceABC.py +44 -13
  14. edsl/inference_services/OpenAIService.py +7 -4
  15. edsl/inference_services/TestService.py +24 -15
  16. edsl/inference_services/TogetherAIService.py +170 -0
  17. edsl/inference_services/registry.py +2 -0
  18. edsl/jobs/Jobs.py +18 -8
  19. edsl/jobs/buckets/BucketCollection.py +24 -15
  20. edsl/jobs/buckets/TokenBucket.py +64 -10
  21. edsl/jobs/interviews/Interview.py +115 -47
  22. edsl/jobs/interviews/{interview_exception_tracking.py → InterviewExceptionCollection.py} +16 -0
  23. edsl/jobs/interviews/InterviewExceptionEntry.py +2 -0
  24. edsl/jobs/runners/JobsRunnerAsyncio.py +86 -161
  25. edsl/jobs/runners/JobsRunnerStatus.py +331 -0
  26. edsl/jobs/tasks/TaskHistory.py +17 -0
  27. edsl/language_models/LanguageModel.py +26 -31
  28. edsl/language_models/registry.py +13 -9
  29. edsl/questions/QuestionBase.py +64 -16
  30. edsl/questions/QuestionBudget.py +93 -41
  31. edsl/questions/QuestionFreeText.py +6 -0
  32. edsl/questions/QuestionMultipleChoice.py +11 -26
  33. edsl/questions/QuestionNumerical.py +5 -4
  34. edsl/questions/Quick.py +41 -0
  35. edsl/questions/ResponseValidatorABC.py +6 -5
  36. edsl/questions/derived/QuestionLinearScale.py +4 -1
  37. edsl/questions/derived/QuestionTopK.py +4 -1
  38. edsl/questions/derived/QuestionYesNo.py +8 -2
  39. edsl/questions/templates/budget/__init__.py +0 -0
  40. edsl/questions/templates/budget/answering_instructions.jinja +7 -0
  41. edsl/questions/templates/budget/question_presentation.jinja +7 -0
  42. edsl/questions/templates/extract/__init__.py +0 -0
  43. edsl/questions/templates/rank/__init__.py +0 -0
  44. edsl/results/DatasetExportMixin.py +5 -1
  45. edsl/results/Result.py +1 -1
  46. edsl/results/Results.py +4 -1
  47. edsl/scenarios/FileStore.py +71 -10
  48. edsl/scenarios/Scenario.py +86 -21
  49. edsl/scenarios/ScenarioImageMixin.py +2 -2
  50. edsl/scenarios/ScenarioList.py +13 -0
  51. edsl/scenarios/ScenarioListPdfMixin.py +150 -4
  52. edsl/study/Study.py +32 -0
  53. edsl/surveys/Rule.py +10 -1
  54. edsl/surveys/RuleCollection.py +19 -3
  55. edsl/surveys/Survey.py +7 -0
  56. edsl/templates/error_reporting/interview_details.html +6 -1
  57. edsl/utilities/utilities.py +9 -1
  58. {edsl-0.1.33.dev2.dist-info → edsl-0.1.33.dev3.dist-info}/METADATA +2 -1
  59. {edsl-0.1.33.dev2.dist-info → edsl-0.1.33.dev3.dist-info}/RECORD +61 -55
  60. edsl/jobs/interviews/retry_management.py +0 -39
  61. edsl/jobs/runners/JobsRunnerStatusMixin.py +0 -333
  62. {edsl-0.1.33.dev2.dist-info → edsl-0.1.33.dev3.dist-info}/LICENSE +0 -0
  63. {edsl-0.1.33.dev2.dist-info → edsl-0.1.33.dev3.dist-info}/WHEEL +0 -0
@@ -0,0 +1,331 @@
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ from dataclasses import dataclass, asdict
5
+
6
+ from typing import List, DefaultDict, Optional, Type, Literal
7
+ from collections import UserDict, defaultdict
8
+
9
+ from rich.text import Text
10
+ from rich.box import SIMPLE
11
+ from rich.table import Table
12
+ from rich.live import Live
13
+ from rich.panel import Panel
14
+ from rich.progress import Progress, TextColumn, BarColumn, TaskProgressColumn
15
+ from rich.layout import Layout
16
+ from rich.console import Group
17
+ from rich import box
18
+
19
+ from edsl.jobs.interviews.InterviewStatusDictionary import InterviewStatusDictionary
20
+ from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
21
+ from edsl.jobs.tokens.TokenUsage import TokenUsage
22
+ from edsl.enums import get_token_pricing
23
+ from edsl.jobs.tasks.task_status_enum import TaskStatus
24
+
25
+ InterviewTokenUsageMapping = DefaultDict[str, InterviewTokenUsage]
26
+
27
+ from edsl.jobs.interviews.InterviewStatistic import InterviewStatistic
28
+ from edsl.jobs.interviews.InterviewStatisticsCollection import (
29
+ InterviewStatisticsCollection,
30
+ )
31
+ from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
32
+
33
+
34
+ @dataclass
35
+ class ModelInfo:
36
+ model_name: str
37
+ TPM_limit_k: float
38
+ RPM_limit_k: float
39
+ num_tasks_waiting: int
40
+ token_usage_info: dict
41
+
42
+
43
+ @dataclass
44
+ class ModelTokenUsageStats:
45
+ token_usage_type: str
46
+ details: List[dict]
47
+ cost: str
48
+
49
+
50
+ class Stats:
51
+ def elapsed_time(self):
52
+ InterviewStatistic("elapsed_time", value=elapsed_time, digits=1, units="sec.")
53
+
54
+
55
+ class JobsRunnerStatus:
56
+ def __init__(
57
+ self, jobs_runner: "JobsRunnerAsyncio", n: int, refresh_rate: float = 0.25
58
+ ):
59
+ self.jobs_runner = jobs_runner
60
+ self.start_time = time.time()
61
+ self.completed_interviews = []
62
+ self.refresh_rate = refresh_rate
63
+ self.statistics = [
64
+ "elapsed_time",
65
+ "total_interviews_requested",
66
+ "completed_interviews",
67
+ # "percent_complete",
68
+ "average_time_per_interview",
69
+ # "task_remaining",
70
+ "estimated_time_remaining",
71
+ "exceptions",
72
+ "unfixed_exceptions",
73
+ "throughput",
74
+ ]
75
+ self.num_total_interviews = n * len(self.jobs_runner.interviews)
76
+
77
+ self.distinct_models = list(
78
+ set(i.model.model for i in self.jobs_runner.interviews)
79
+ )
80
+
81
+ self.completed_interview_by_model = defaultdict(list)
82
+
83
+ def add_completed_interview(self, result):
84
+ self.completed_interviews.append(result.interview_hash)
85
+
86
+ relevant_model = result.model.model
87
+ self.completed_interview_by_model[relevant_model].append(result.interview_hash)
88
+
89
+ def _compute_statistic(self, stat_name: str):
90
+ completed_tasks = self.completed_interviews
91
+ elapsed_time = time.time() - self.start_time
92
+ interviews = self.jobs_runner.total_interviews
93
+
94
+ stat_definitions = {
95
+ "elapsed_time": lambda: InterviewStatistic(
96
+ "elapsed_time", value=elapsed_time, digits=1, units="sec."
97
+ ),
98
+ "total_interviews_requested": lambda: InterviewStatistic(
99
+ "total_interviews_requested", value=len(interviews), units=""
100
+ ),
101
+ "completed_interviews": lambda: InterviewStatistic(
102
+ "completed_interviews", value=len(completed_tasks), units=""
103
+ ),
104
+ "percent_complete": lambda: InterviewStatistic(
105
+ "percent_complete",
106
+ value=(
107
+ len(completed_tasks) / len(interviews) * 100
108
+ if len(interviews) > 0
109
+ else 0
110
+ ),
111
+ digits=1,
112
+ units="%",
113
+ ),
114
+ "average_time_per_interview": lambda: InterviewStatistic(
115
+ "average_time_per_interview",
116
+ value=elapsed_time / len(completed_tasks) if completed_tasks else 0,
117
+ digits=2,
118
+ units="sec.",
119
+ ),
120
+ "task_remaining": lambda: InterviewStatistic(
121
+ "task_remaining", value=len(interviews) - len(completed_tasks), units=""
122
+ ),
123
+ "estimated_time_remaining": lambda: InterviewStatistic(
124
+ "estimated_time_remaining",
125
+ value=(
126
+ (len(interviews) - len(completed_tasks))
127
+ * (elapsed_time / len(completed_tasks))
128
+ if len(completed_tasks) > 0
129
+ else 0
130
+ ),
131
+ digits=1,
132
+ units="sec.",
133
+ ),
134
+ "exceptions": lambda: InterviewStatistic(
135
+ "exceptions",
136
+ value=sum(len(i.exceptions) for i in interviews),
137
+ units="",
138
+ ),
139
+ "unfixed_exceptions": lambda: InterviewStatistic(
140
+ "unfixed_exceptions",
141
+ value=sum(i.exceptions.num_unfixed() for i in interviews),
142
+ units="",
143
+ ),
144
+ "throughput": lambda: InterviewStatistic(
145
+ "throughput",
146
+ value=len(completed_tasks) / elapsed_time if elapsed_time > 0 else 0,
147
+ digits=2,
148
+ units="interviews/sec.",
149
+ ),
150
+ }
151
+ return stat_definitions[stat_name]()
152
+
153
+ def create_progress_bar(self):
154
+ return Progress(
155
+ TextColumn("[progress.description]{task.description}"),
156
+ BarColumn(),
157
+ TaskProgressColumn(),
158
+ TextColumn("{task.completed}/{task.total}"),
159
+ )
160
+
161
+ def generate_model_queues_table(self):
162
+ table = Table(show_header=False, box=box.SIMPLE)
163
+ table.add_column("Info", style="cyan")
164
+ table.add_column("Value", style="magenta")
165
+ # table.add_row("Bucket collection", str(self.jobs_runner.bucket_collection))
166
+ for model, bucket in self.jobs_runner.bucket_collection.items():
167
+ table.add_row(Text(model.model, style="bold blue"), "")
168
+ bucket_types = ["requests_bucket", "tokens_bucket"]
169
+ for bucket_type in bucket_types:
170
+ table.add_row(Text(" " + bucket_type, style="green"), "")
171
+ # table.add_row(
172
+ # f" Current level (capacity = {round(getattr(bucket, bucket_type).capacity, 3)})",
173
+ # str(round(getattr(bucket, bucket_type).tokens, 3)),
174
+ # )
175
+ num_requests = getattr(bucket, bucket_type).num_requests
176
+ num_released = getattr(bucket, bucket_type).num_released
177
+ tokens_returned = getattr(bucket, bucket_type).tokens_returned
178
+ # table.add_row(
179
+ # f" Requested",
180
+ # str(num_requests),
181
+ # )
182
+ # table.add_row(
183
+ # f" Completed",
184
+ # str(num_released),
185
+ # )
186
+ table.add_row(
187
+ " Completed vs. Requested", f"{num_released} vs. {num_requests}"
188
+ )
189
+ table.add_row(
190
+ " Added tokens (from cache)",
191
+ str(tokens_returned),
192
+ )
193
+ if bucket_type == "tokens_bucket":
194
+ rate_name = "TPM"
195
+ else:
196
+ rate_name = "RPM"
197
+ target_rate = round(getattr(bucket, bucket_type).target_rate, 1)
198
+ table.add_row(
199
+ f" Empirical {rate_name} (target = {target_rate})",
200
+ str(round(getattr(bucket, bucket_type).get_throughput(), 0)),
201
+ )
202
+
203
+ return table
204
+
205
+ def generate_layout(self):
206
+ progress = self.create_progress_bar()
207
+ task_ids = []
208
+ for model in self.distinct_models:
209
+ task_id = progress.add_task(
210
+ f"[cyan]{model}...",
211
+ total=int(self.num_total_interviews / len(self.distinct_models)),
212
+ )
213
+ task_ids.append((model, task_id))
214
+
215
+ progress_height = min(5, 2 + len(self.distinct_models))
216
+ layout = Layout()
217
+
218
+ # Create the top row with only the progress panel
219
+ layout.split_column(
220
+ Layout(
221
+ Panel(
222
+ progress,
223
+ title="Interview Progress",
224
+ border_style="cyan",
225
+ box=box.ROUNDED,
226
+ ),
227
+ name="progress",
228
+ size=progress_height, # Adjusted size
229
+ ),
230
+ Layout(name="bottom_row"), # Adjusted size
231
+ )
232
+
233
+ # Split the bottom row into two columns for metrics and model queues
234
+ layout["bottom_row"].split_row(
235
+ Layout(
236
+ Panel(
237
+ self.generate_metrics_table(),
238
+ title="Metrics",
239
+ border_style="magenta",
240
+ box=box.ROUNDED,
241
+ ),
242
+ name="metrics",
243
+ ),
244
+ Layout(
245
+ Panel(
246
+ self.generate_model_queues_table(),
247
+ title="Model Queues",
248
+ border_style="yellow",
249
+ box=box.ROUNDED,
250
+ ),
251
+ name="model_queues",
252
+ ),
253
+ )
254
+
255
+ return layout, progress, task_ids
256
+
257
+ def generate_metrics_table(self):
258
+ table = Table(show_header=True, header_style="bold magenta", box=box.SIMPLE)
259
+ table.add_column("Metric", style="cyan", no_wrap=True)
260
+ table.add_column("Value", justify="right")
261
+
262
+ for stat_name in self.statistics:
263
+ pretty_name, value = list(self._compute_statistic(stat_name).items())[0]
264
+ # breakpoint()
265
+ table.add_row(pretty_name, value)
266
+ return table
267
+
268
+ def update_progress(self):
269
+ layout, progress, task_ids = self.generate_layout()
270
+
271
+ with Live(
272
+ layout, refresh_per_second=int(1 / self.refresh_rate), transient=True
273
+ ) as live:
274
+ while len(self.completed_interviews) < len(
275
+ self.jobs_runner.total_interviews
276
+ ):
277
+ completed_tasks = len(self.completed_interviews)
278
+ total_tasks = len(self.jobs_runner.total_interviews)
279
+
280
+ for model, task_id in task_ids:
281
+ completed_tasks = len(self.completed_interview_by_model[model])
282
+ progress.update(
283
+ task_id,
284
+ completed=completed_tasks,
285
+ description=f"[cyan]Conducting interviews for {model}...",
286
+ )
287
+
288
+ layout["metrics"].update(
289
+ Panel(
290
+ self.generate_metrics_table(),
291
+ title="Metrics",
292
+ border_style="magenta",
293
+ box=box.ROUNDED,
294
+ )
295
+ )
296
+ layout["model_queues"].update(
297
+ Panel(
298
+ self.generate_model_queues_table(),
299
+ title="Final Model Queues",
300
+ border_style="yellow",
301
+ box=box.ROUNDED,
302
+ )
303
+ )
304
+
305
+ time.sleep(self.refresh_rate)
306
+
307
+ # Final update
308
+ for model, task_id in task_ids:
309
+ completed_tasks = len(self.completed_interview_by_model[model])
310
+ progress.update(
311
+ task_id,
312
+ completed=completed_tasks,
313
+ description=f"[cyan]Conducting interviews for {model}...",
314
+ )
315
+
316
+ layout["metrics"].update(
317
+ Panel(
318
+ self.generate_metrics_table(),
319
+ title="Final Metrics",
320
+ border_style="magenta",
321
+ box=box.ROUNDED,
322
+ )
323
+ )
324
+ live.update(layout)
325
+ time.sleep(1) # Show final state for 1 second
326
+
327
+
328
+ if __name__ == "__main__":
329
+ import doctest
330
+
331
+ doctest.testmod(optionflags=doctest.ELLIPSIS)
@@ -50,6 +50,18 @@ class TaskHistory:
50
50
  """
51
51
  return [i.exceptions for k, i in self._interviews.items() if i.exceptions != {}]
52
52
 
53
+ @property
54
+ def unfixed_exceptions(self):
55
+ """
56
+ >>> len(TaskHistory.example().unfixed_exceptions)
57
+ 4
58
+ """
59
+ return [
60
+ i.exceptions
61
+ for k, i in self._interviews.items()
62
+ if i.exceptions.num_unfixed() > 0
63
+ ]
64
+
53
65
  @property
54
66
  def indices(self):
55
67
  return [k for k, i in self._interviews.items() if i.exceptions != {}]
@@ -78,6 +90,11 @@ class TaskHistory:
78
90
  """
79
91
  return len(self.exceptions) > 0
80
92
 
93
+ @property
94
+ def has_unfixed_exceptions(self) -> bool:
95
+ """Return True if there are any exceptions."""
96
+ return len(self.unfixed_exceptions) > 0
97
+
81
98
  def _repr_html_(self):
82
99
  """Return an HTML representation of the TaskHistory."""
83
100
  from edsl.utilities.utilities import data_to_html
@@ -164,20 +164,20 @@ class LanguageModel(
164
164
  None # This should be something like ["choices", 0, "message", "content"]
165
165
  )
166
166
  __rate_limits = None
167
- __default_rate_limits = {
168
- "rpm": 10_000,
169
- "tpm": 2_000_000,
170
- } # TODO: Use the OpenAI Teir 1 rate limits
171
167
  _safety_factor = 0.8
172
168
 
173
- def __init__(self, tpm=None, rpm=None, **kwargs):
169
+ def __init__(
170
+ self, tpm=None, rpm=None, omit_system_prompt_if_empty_string=True, **kwargs
171
+ ):
174
172
  """Initialize the LanguageModel."""
175
173
  self.model = getattr(self, "_model_", None)
176
174
  default_parameters = getattr(self, "_parameters_", None)
177
175
  parameters = self._overide_default_parameters(kwargs, default_parameters)
178
176
  self.parameters = parameters
179
177
  self.remote = False
178
+ self.omit_system_prompt_if_empty = omit_system_prompt_if_empty_string
180
179
 
180
+ # self._rpm / _tpm comes from the class
181
181
  if rpm is not None:
182
182
  self._rpm = rpm
183
183
 
@@ -286,35 +286,40 @@ class LanguageModel(
286
286
  >>> m.RPM
287
287
  100
288
288
  """
289
- self._set_rate_limits(rpm=rpm, tpm=tpm)
289
+ if rpm is not None:
290
+ self._rpm = rpm
291
+ if tpm is not None:
292
+ self._tpm = tpm
293
+ return None
294
+ # self._set_rate_limits(rpm=rpm, tpm=tpm)
290
295
 
291
- def _set_rate_limits(self, rpm=None, tpm=None) -> None:
292
- """Set the rate limits for the model.
296
+ # def _set_rate_limits(self, rpm=None, tpm=None) -> None:
297
+ # """Set the rate limits for the model.
293
298
 
294
- If the model does not have rate limits, use the default rate limits."""
295
- if rpm is not None and tpm is not None:
296
- self.__rate_limits = {"rpm": rpm, "tpm": tpm}
297
- return
299
+ # If the model does not have rate limits, use the default rate limits."""
300
+ # if rpm is not None and tpm is not None:
301
+ # self.__rate_limits = {"rpm": rpm, "tpm": tpm}
302
+ # return
298
303
 
299
- if self.__rate_limits is None:
300
- if hasattr(self, "get_rate_limits"):
301
- self.__rate_limits = self.get_rate_limits()
302
- else:
303
- self.__rate_limits = self.__default_rate_limits
304
+ # if self.__rate_limits is None:
305
+ # if hasattr(self, "get_rate_limits"):
306
+ # self.__rate_limits = self.get_rate_limits()
307
+ # else:
308
+ # self.__rate_limits = self.__default_rate_limits
304
309
 
305
310
  @property
306
311
  def RPM(self):
307
312
  """Model's requests-per-minute limit."""
308
313
  # self._set_rate_limits()
309
314
  # return self._safety_factor * self.__rate_limits["rpm"]
310
- return self.rpm
315
+ return self._rpm
311
316
 
312
317
  @property
313
318
  def TPM(self):
314
319
  """Model's tokens-per-minute limit."""
315
320
  # self._set_rate_limits()
316
321
  # return self._safety_factor * self.__rate_limits["tpm"]
317
- return self.tpm
322
+ return self._tpm
318
323
 
319
324
  @property
320
325
  def rpm(self):
@@ -332,17 +337,6 @@ class LanguageModel(
332
337
  def tpm(self, value):
333
338
  self._tpm = value
334
339
 
335
- @property
336
- def TPM(self):
337
- """Model's tokens-per-minute limit.
338
-
339
- >>> m = LanguageModel.example()
340
- >>> m.TPM > 0
341
- True
342
- """
343
- self._set_rate_limits()
344
- return self._safety_factor * self.__rate_limits["tpm"]
345
-
346
340
  @staticmethod
347
341
  def _overide_default_parameters(passed_parameter_dict, default_parameter_dict):
348
342
  """Return a dictionary of parameters, with passed parameters taking precedence over defaults.
@@ -471,12 +465,13 @@ class LanguageModel(
471
465
  if encoded_image:
472
466
  # the image has is appended to the user_prompt for hash-lookup purposes
473
467
  image_hash = hashlib.md5(encoded_image.encode()).hexdigest()
468
+ user_prompt += f" {image_hash}"
474
469
 
475
470
  cache_call_params = {
476
471
  "model": str(self.model),
477
472
  "parameters": self.parameters,
478
473
  "system_prompt": system_prompt,
479
- "user_prompt": user_prompt + "" if not encoded_image else f" {image_hash}",
474
+ "user_prompt": user_prompt,
480
475
  "iteration": iteration,
481
476
  }
482
477
  cached_response, cache_key = cache.fetch(**cache_call_params)
@@ -2,10 +2,10 @@ import textwrap
2
2
  from random import random
3
3
  from edsl.config import CONFIG
4
4
 
5
- if "EDSL_DEFAULT_MODEL" not in CONFIG:
6
- default_model = "test"
7
- else:
8
- default_model = CONFIG.get("EDSL_DEFAULT_MODEL")
5
+ # if "EDSL_DEFAULT_MODEL" not in CONFIG:
6
+ # default_model = "test"
7
+ # else:
8
+ # default_model = CONFIG.get("EDSL_DEFAULT_MODEL")
9
9
 
10
10
 
11
11
  def get_model_class(model_name, registry=None):
@@ -33,20 +33,24 @@ class Meta(type):
33
33
 
34
34
 
35
35
  class Model(metaclass=Meta):
36
- default_model = default_model
36
+ default_model = CONFIG.get("EDSL_DEFAULT_MODEL")
37
37
 
38
- def __new__(cls, model_name=None, registry=None, *args, **kwargs):
38
+ def __new__(
39
+ cls, model_name=None, registry=None, service_name=None, *args, **kwargs
40
+ ):
39
41
  # Map index to the respective subclass
40
42
  if model_name is None:
41
- model_name = cls.default_model
43
+ model_name = (
44
+ cls.default_model
45
+ ) # when model_name is None, use the default model, set in the config file
42
46
  from edsl.inference_services.registry import default
43
47
 
44
48
  registry = registry or default
45
49
 
46
- if isinstance(model_name, int):
50
+ if isinstance(model_name, int): # can refer to a model by index
47
51
  model_name = cls.available(name_only=True)[model_name]
48
52
 
49
- factory = registry.create_model_factory(model_name)
53
+ factory = registry.create_model_factory(model_name, service_name=service_name)
50
54
  return factory(*args, **kwargs)
51
55
 
52
56
  @classmethod
@@ -75,8 +75,7 @@ class QuestionBase(
75
75
  if not hasattr(self, "_fake_data_factory"):
76
76
  from polyfactory.factories.pydantic_factory import ModelFactory
77
77
 
78
- class FakeData(ModelFactory[self.response_model]):
79
- ...
78
+ class FakeData(ModelFactory[self.response_model]): ...
80
79
 
81
80
  self._fake_data_factory = FakeData
82
81
  return self._fake_data_factory
@@ -103,13 +102,8 @@ class QuestionBase(
103
102
  """Validate the answer.
104
103
  >>> from edsl.exceptions import QuestionAnswerValidationError
105
104
  >>> from edsl import QuestionFreeText as Q
106
- >>> Q.example()._validate_answer({'answer': 'Hello'})
107
- {'answer': 'Hello', 'generated_tokens': None}
108
- >>> Q.example()._validate_answer({'shmanswer': 1})
109
- Traceback (most recent call last):
110
- ...
111
- edsl.exceptions.questions.QuestionAnswerValidationError:...
112
- ...
105
+ >>> Q.example()._validate_answer({'answer': 'Hello', 'generated_tokens': 'Hello'})
106
+ {'answer': 'Hello', 'generated_tokens': 'Hello'}
113
107
  """
114
108
 
115
109
  return self.response_validator.validate(answer)
@@ -471,6 +465,7 @@ class QuestionBase(
471
465
  self,
472
466
  scenario: Optional[dict] = None,
473
467
  agent: Optional[dict] = {},
468
+ answers: Optional[dict] = None,
474
469
  include_question_name: bool = False,
475
470
  height: Optional[int] = None,
476
471
  width: Optional[int] = None,
@@ -482,6 +477,17 @@ class QuestionBase(
482
477
  if scenario is None:
483
478
  scenario = {}
484
479
 
480
+ prior_answers_dict = {}
481
+
482
+ if isinstance(answers, dict):
483
+ for key, value in answers.items():
484
+ if not key.endswith("_comment") and not key.endswith(
485
+ "_generated_tokens"
486
+ ):
487
+ prior_answers_dict[key] = {"answer": value}
488
+
489
+ # breakpoint()
490
+
485
491
  base_template = """
486
492
  <div id="{{ question_name }}" class="survey_question" data-type="{{ question_type }}">
487
493
  {% if include_question_name %}
@@ -501,13 +507,40 @@ class QuestionBase(
501
507
 
502
508
  base_template = Template(base_template)
503
509
 
504
- params = {
505
- "question_name": self.question_name,
506
- "question_text": Template(self.question_text).render(scenario, agent=agent),
507
- "question_type": self.question_type,
508
- "question_content": Template(question_content).render(scenario),
509
- "include_question_name": include_question_name,
510
- }
510
+ context = {
511
+ "scenario": scenario,
512
+ "agent": agent,
513
+ } | prior_answers_dict
514
+
515
+ # Render the question text
516
+ try:
517
+ question_text = Template(self.question_text).render(context)
518
+ except Exception as e:
519
+ print(
520
+ f"Error rendering question: question_text = {self.question_text}, error = {e}"
521
+ )
522
+ question_text = self.question_text
523
+
524
+ try:
525
+ question_content = Template(question_content).render(context)
526
+ except Exception as e:
527
+ print(
528
+ f"Error rendering question: question_content = {question_content}, error = {e}"
529
+ )
530
+ question_content = question_content
531
+
532
+ try:
533
+ params = {
534
+ "question_name": self.question_name,
535
+ "question_text": question_text,
536
+ "question_type": self.question_type,
537
+ "question_content": question_content,
538
+ "include_question_name": include_question_name,
539
+ }
540
+ except Exception as e:
541
+ raise ValueError(
542
+ f"Error rendering question: params = {params}, error = {e}"
543
+ )
511
544
  rendered_html = base_template.render(**params)
512
545
 
513
546
  if iframe:
@@ -526,6 +559,21 @@ class QuestionBase(
526
559
 
527
560
  return rendered_html
528
561
 
562
+ @classmethod
563
+ def example_model(cls):
564
+ from edsl import Model
565
+
566
+ q = cls.example()
567
+ m = Model("test", canned_response=cls._simulate_answer(q)["answer"])
568
+
569
+ return m
570
+
571
+ @classmethod
572
+ def example_results(cls):
573
+ m = cls.example_model()
574
+ q = cls.example()
575
+ return q.by(m).run(cache=False)
576
+
529
577
  def rich_print(self):
530
578
  """Print the question in a rich format."""
531
579
  from rich.table import Table