edsl 0.1.33__py3-none-any.whl → 0.1.33.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. edsl/Base.py +3 -9
  2. edsl/__init__.py +0 -1
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +6 -6
  5. edsl/agents/Invigilator.py +3 -6
  6. edsl/agents/InvigilatorBase.py +27 -8
  7. edsl/agents/{PromptConstructor.py → PromptConstructionMixin.py} +29 -101
  8. edsl/config.py +34 -26
  9. edsl/coop/coop.py +2 -11
  10. edsl/data_transfer_models.py +73 -26
  11. edsl/enums.py +0 -2
  12. edsl/inference_services/GoogleService.py +1 -1
  13. edsl/inference_services/InferenceServiceABC.py +13 -44
  14. edsl/inference_services/OpenAIService.py +4 -7
  15. edsl/inference_services/TestService.py +15 -24
  16. edsl/inference_services/registry.py +0 -2
  17. edsl/jobs/Jobs.py +8 -18
  18. edsl/jobs/buckets/BucketCollection.py +15 -24
  19. edsl/jobs/buckets/TokenBucket.py +10 -64
  20. edsl/jobs/interviews/Interview.py +47 -115
  21. edsl/jobs/interviews/InterviewExceptionEntry.py +0 -2
  22. edsl/jobs/interviews/{InterviewExceptionCollection.py → interview_exception_tracking.py} +0 -16
  23. edsl/jobs/interviews/retry_management.py +39 -0
  24. edsl/jobs/runners/JobsRunnerAsyncio.py +170 -95
  25. edsl/jobs/runners/JobsRunnerStatusMixin.py +333 -0
  26. edsl/jobs/tasks/TaskHistory.py +0 -17
  27. edsl/language_models/LanguageModel.py +31 -26
  28. edsl/language_models/registry.py +9 -13
  29. edsl/questions/QuestionBase.py +14 -63
  30. edsl/questions/QuestionBudget.py +41 -93
  31. edsl/questions/QuestionFreeText.py +0 -6
  32. edsl/questions/QuestionMultipleChoice.py +23 -8
  33. edsl/questions/QuestionNumerical.py +4 -5
  34. edsl/questions/ResponseValidatorABC.py +5 -6
  35. edsl/questions/derived/QuestionLinearScale.py +1 -4
  36. edsl/questions/derived/QuestionTopK.py +1 -4
  37. edsl/questions/derived/QuestionYesNo.py +2 -8
  38. edsl/results/DatasetExportMixin.py +1 -5
  39. edsl/results/Result.py +1 -1
  40. edsl/results/Results.py +1 -4
  41. edsl/scenarios/FileStore.py +10 -71
  42. edsl/scenarios/Scenario.py +21 -86
  43. edsl/scenarios/ScenarioImageMixin.py +2 -2
  44. edsl/scenarios/ScenarioList.py +0 -13
  45. edsl/scenarios/ScenarioListPdfMixin.py +4 -150
  46. edsl/study/Study.py +0 -32
  47. edsl/surveys/Rule.py +1 -10
  48. edsl/surveys/RuleCollection.py +3 -19
  49. edsl/surveys/Survey.py +0 -7
  50. edsl/templates/error_reporting/interview_details.html +1 -6
  51. edsl/utilities/utilities.py +1 -9
  52. {edsl-0.1.33.dist-info → edsl-0.1.33.dev2.dist-info}/METADATA +1 -2
  53. {edsl-0.1.33.dist-info → edsl-0.1.33.dev2.dist-info}/RECORD +55 -61
  54. edsl/inference_services/TogetherAIService.py +0 -170
  55. edsl/jobs/runners/JobsRunnerStatus.py +0 -331
  56. edsl/questions/Quick.py +0 -41
  57. edsl/questions/templates/budget/__init__.py +0 -0
  58. edsl/questions/templates/budget/answering_instructions.jinja +0 -7
  59. edsl/questions/templates/budget/question_presentation.jinja +0 -7
  60. edsl/questions/templates/extract/__init__.py +0 -0
  61. edsl/questions/templates/rank/__init__.py +0 -0
  62. {edsl-0.1.33.dist-info → edsl-0.1.33.dev2.dist-info}/LICENSE +0 -0
  63. {edsl-0.1.33.dist-info → edsl-0.1.33.dev2.dist-info}/WHEEL +0 -0
@@ -0,0 +1,333 @@
1
+ from __future__ import annotations
2
+ from typing import List, DefaultDict
3
+ import asyncio
4
+ from typing import Type
5
+ from collections import defaultdict
6
+
7
+ from typing import Literal, List, Type, DefaultDict
8
+ from collections import UserDict, defaultdict
9
+
10
+ from edsl.jobs.interviews.InterviewStatusDictionary import InterviewStatusDictionary
11
+ from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
12
+ from edsl.jobs.tokens.TokenUsage import TokenUsage
13
+ from edsl.enums import get_token_pricing
14
+ from edsl.jobs.tasks.task_status_enum import TaskStatus
15
+
16
+ InterviewTokenUsageMapping = DefaultDict[str, InterviewTokenUsage]
17
+
18
+ from edsl.jobs.interviews.InterviewStatistic import InterviewStatistic
19
+ from edsl.jobs.interviews.InterviewStatisticsCollection import (
20
+ InterviewStatisticsCollection,
21
+ )
22
+ from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
23
+
24
+
25
+ # return {"cache_status": token_usage_type, "details": details, "cost": f"${token_usage.cost(prices):.5f}"}
26
+
27
+ from dataclasses import dataclass, asdict
28
+
29
+ from rich.text import Text
30
+ from rich.box import SIMPLE
31
+ from rich.table import Table
32
+
33
+
34
+ @dataclass
35
+ class ModelInfo:
36
+ model_name: str
37
+ TPM_limit_k: float
38
+ RPM_limit_k: float
39
+ num_tasks_waiting: int
40
+ token_usage_info: dict
41
+
42
+
43
+ @dataclass
44
+ class ModelTokenUsageStats:
45
+ token_usage_type: str
46
+ details: List[dict]
47
+ cost: str
48
+
49
+
50
+ class Stats:
51
+ def elapsed_time(self):
52
+ InterviewStatistic("elapsed_time", value=elapsed_time, digits=1, units="sec.")
53
+
54
+
55
+ class JobsRunnerStatusMixin:
56
+ # @staticmethod
57
+ # def status_dict(interviews: List[Type["Interview"]]) -> List[Type[InterviewStatusDictionary]]:
58
+ # """
59
+ # >>> from edsl.jobs.interviews.Interview import Interview
60
+ # >>> interviews = [Interview.example()]
61
+ # >>> JobsRunnerStatusMixin().status_dict(interviews)
62
+ # [InterviewStatusDictionary({<TaskStatus.NOT_STARTED: 1>: 0, <TaskStatus.WAITING_FOR_DEPENDENCIES: 2>: 0, <TaskStatus.CANCELLED: 3>: 0, <TaskStatus.PARENT_FAILED: 4>: 0, <TaskStatus.WAITING_FOR_REQUEST_CAPACITY: 5>: 0, <TaskStatus.WAITING_FOR_TOKEN_CAPACITY: 6>: 0, <TaskStatus.API_CALL_IN_PROGRESS: 7>: 0, <TaskStatus.SUCCESS: 8>: 0, <TaskStatus.FAILED: 9>: 0, 'number_from_cache': 0})]
63
+ # """
64
+ # return [interview.interview_status for interview in interviews]
65
+
66
+ def _compute_statistic(stat_name: str, completed_tasks, elapsed_time, interviews):
67
+ stat_definitions = {
68
+ "elapsed_time": lambda: InterviewStatistic(
69
+ "elapsed_time", value=elapsed_time, digits=1, units="sec."
70
+ ),
71
+ "total_interviews_requested": lambda: InterviewStatistic(
72
+ "total_interviews_requested", value=len(interviews), units=""
73
+ ),
74
+ "completed_interviews": lambda: InterviewStatistic(
75
+ "completed_interviews", value=len(completed_tasks), units=""
76
+ ),
77
+ "percent_complete": lambda: InterviewStatistic(
78
+ "percent_complete",
79
+ value=(
80
+ len(completed_tasks) / len(interviews) * 100
81
+ if len(interviews) > 0
82
+ else "NA"
83
+ ),
84
+ digits=0,
85
+ units="%",
86
+ ),
87
+ "average_time_per_interview": lambda: InterviewStatistic(
88
+ "average_time_per_interview",
89
+ value=elapsed_time / len(completed_tasks) if completed_tasks else "NA",
90
+ digits=1,
91
+ units="sec.",
92
+ ),
93
+ "task_remaining": lambda: InterviewStatistic(
94
+ "task_remaining", value=len(interviews) - len(completed_tasks), units=""
95
+ ),
96
+ "estimated_time_remaining": lambda: InterviewStatistic(
97
+ "estimated_time_remaining",
98
+ value=(
99
+ (len(interviews) - len(completed_tasks))
100
+ * (elapsed_time / len(completed_tasks))
101
+ if len(completed_tasks) > 0
102
+ else "NA"
103
+ ),
104
+ digits=1,
105
+ units="sec.",
106
+ ),
107
+ }
108
+ if stat_name not in stat_definitions:
109
+ raise ValueError(
110
+ f"Invalid stat_name: {stat_name}. The valid stat_names are: {list(stat_definitions.keys())}"
111
+ )
112
+ return stat_definitions[stat_name]()
113
+
114
+ @staticmethod
115
+ def _job_level_info(
116
+ completed_tasks: List[Type[asyncio.Task]],
117
+ elapsed_time: float,
118
+ interviews: List[Type["Interview"]],
119
+ ) -> InterviewStatisticsCollection:
120
+ interview_statistics = InterviewStatisticsCollection()
121
+
122
+ default_statistics = [
123
+ "elapsed_time",
124
+ "total_interviews_requested",
125
+ "completed_interviews",
126
+ "percent_complete",
127
+ "average_time_per_interview",
128
+ "task_remaining",
129
+ "estimated_time_remaining",
130
+ ]
131
+ for stat_name in default_statistics:
132
+ interview_statistics.add_stat(
133
+ JobsRunnerStatusMixin._compute_statistic(
134
+ stat_name, completed_tasks, elapsed_time, interviews
135
+ )
136
+ )
137
+
138
+ return interview_statistics
139
+
140
+ @staticmethod
141
+ def _get_model_queues_info(interviews):
142
+ models_to_tokens = defaultdict(InterviewTokenUsage)
143
+ model_to_status = defaultdict(InterviewStatusDictionary)
144
+ waiting_dict = defaultdict(int)
145
+
146
+ for interview in interviews:
147
+ models_to_tokens[interview.model] += interview.token_usage
148
+ model_to_status[interview.model] += interview.interview_status
149
+ waiting_dict[interview.model] += interview.interview_status.waiting
150
+
151
+ for model, num_waiting in waiting_dict.items():
152
+ yield JobsRunnerStatusMixin._get_model_info(
153
+ model, num_waiting, models_to_tokens
154
+ )
155
+
156
+ @staticmethod
157
+ def generate_status_summary(
158
+ completed_tasks: List[Type[asyncio.Task]],
159
+ elapsed_time: float,
160
+ interviews: List[Type["Interview"]],
161
+ include_model_queues=False,
162
+ ) -> InterviewStatisticsCollection:
163
+ """Generate a summary of the status of the job runner.
164
+
165
+ :param completed_tasks: list of completed tasks
166
+ :param elapsed_time: time elapsed since the start of the job
167
+ :param interviews: list of interviews to be conducted
168
+
169
+ >>> from edsl.jobs.interviews.Interview import Interview
170
+ >>> interviews = [Interview.example()]
171
+ >>> completed_tasks = []
172
+ >>> elapsed_time = 0
173
+ >>> JobsRunnerStatusMixin().generate_status_summary(completed_tasks, elapsed_time, interviews)
174
+ {'Elapsed time': '0.0 sec.', 'Total interviews requested': '1 ', 'Completed interviews': '0 ', 'Percent complete': '0 %', 'Average time per interview': 'NA', 'Task remaining': '1 ', 'Estimated time remaining': 'NA'}
175
+ """
176
+
177
+ interview_status_summary: InterviewStatisticsCollection = (
178
+ JobsRunnerStatusMixin._job_level_info(
179
+ completed_tasks=completed_tasks,
180
+ elapsed_time=elapsed_time,
181
+ interviews=interviews,
182
+ )
183
+ )
184
+ if include_model_queues:
185
+ interview_status_summary.model_queues = list(
186
+ JobsRunnerStatusMixin._get_model_queues_info(interviews)
187
+ )
188
+ else:
189
+ interview_status_summary.model_queues = None
190
+
191
+ return interview_status_summary
192
+
193
+ @staticmethod
194
+ def _get_model_info(
195
+ model: str,
196
+ num_waiting: int,
197
+ models_to_tokens: InterviewTokenUsageMapping,
198
+ ) -> dict:
199
+ """Get the status of a model.
200
+
201
+ :param model: the model name
202
+ :param num_waiting: the number of tasks waiting for capacity
203
+ :param models_to_tokens: a mapping of models to token usage
204
+
205
+ >>> from edsl.jobs.interviews.Interview import Interview
206
+ >>> interviews = [Interview.example()]
207
+ >>> models_to_tokens = defaultdict(InterviewTokenUsage)
208
+ >>> model = interviews[0].model
209
+ >>> num_waiting = 0
210
+ >>> JobsRunnerStatusMixin()._get_model_info(model, num_waiting, models_to_tokens)
211
+ ModelInfo(model_name='...', TPM_limit_k=..., RPM_limit_k=..., num_tasks_waiting=0, token_usage_info=[ModelTokenUsageStats(token_usage_type='new_token_usage', details=[{'type': 'prompt_tokens', 'tokens': 0}, {'type': 'completion_tokens', 'tokens': 0}], cost='$0.00000'), ModelTokenUsageStats(token_usage_type='cached_token_usage', details=[{'type': 'prompt_tokens', 'tokens': 0}, {'type': 'completion_tokens', 'tokens': 0}], cost='$0.00000')])
212
+ """
213
+
214
+ ## TODO: This should probably be a coop method
215
+ prices = get_token_pricing(model.model)
216
+
217
+ token_usage_info = []
218
+ for token_usage_type in ["new_token_usage", "cached_token_usage"]:
219
+ token_usage_info.append(
220
+ JobsRunnerStatusMixin._get_token_usage_info(
221
+ token_usage_type, models_to_tokens, model, prices
222
+ )
223
+ )
224
+
225
+ return ModelInfo(
226
+ **{
227
+ "model_name": model.model,
228
+ "TPM_limit_k": model.TPM / 1000,
229
+ "RPM_limit_k": model.RPM / 1000,
230
+ "num_tasks_waiting": num_waiting,
231
+ "token_usage_info": token_usage_info,
232
+ }
233
+ )
234
+
235
+ @staticmethod
236
+ def _get_token_usage_info(
237
+ token_usage_type: Literal["new_token_usage", "cached_token_usage"],
238
+ models_to_tokens: InterviewTokenUsageMapping,
239
+ model: str,
240
+ prices: "TokenPricing",
241
+ ) -> ModelTokenUsageStats:
242
+ """Get the token usage info for a model.
243
+
244
+ >>> from edsl.jobs.interviews.Interview import Interview
245
+ >>> interviews = [Interview.example()]
246
+ >>> models_to_tokens = defaultdict(InterviewTokenUsage)
247
+ >>> model = interviews[0].model
248
+ >>> prices = get_token_pricing(model.model)
249
+ >>> cache_status = "new_token_usage"
250
+ >>> JobsRunnerStatusMixin()._get_token_usage_info(cache_status, models_to_tokens, model, prices)
251
+ ModelTokenUsageStats(token_usage_type='new_token_usage', details=[{'type': 'prompt_tokens', 'tokens': 0}, {'type': 'completion_tokens', 'tokens': 0}], cost='$0.00000')
252
+
253
+ """
254
+ all_token_usage: InterviewTokenUsage = models_to_tokens[model]
255
+ token_usage: TokenUsage = getattr(all_token_usage, token_usage_type)
256
+
257
+ details = [
258
+ {"type": token_type, "tokens": getattr(token_usage, token_type)}
259
+ for token_type in ["prompt_tokens", "completion_tokens"]
260
+ ]
261
+
262
+ return ModelTokenUsageStats(
263
+ token_usage_type=token_usage_type,
264
+ details=details,
265
+ cost=f"${token_usage.cost(prices):.5f}",
266
+ )
267
+
268
+ @staticmethod
269
+ def _add_statistics_to_table(table, status_summary):
270
+ table.add_column("Statistic", style="dim", no_wrap=True, width=50)
271
+ table.add_column("Value", width=10)
272
+
273
+ for key, value in status_summary.items():
274
+ if key != "model_queues":
275
+ table.add_row(key, value)
276
+
277
+ @staticmethod
278
+ def display_status_table(status_summary: InterviewStatisticsCollection) -> "Table":
279
+ table = Table(
280
+ title="Job Status",
281
+ show_header=True,
282
+ header_style="bold magenta",
283
+ box=SIMPLE,
284
+ )
285
+
286
+ ### Job-level statistics
287
+ JobsRunnerStatusMixin._add_statistics_to_table(table, status_summary)
288
+
289
+ ## Model-level statistics
290
+ spacing = " "
291
+
292
+ if status_summary.model_queues is not None:
293
+ table.add_row(Text("Model Queues", style="bold red"), "")
294
+ for model_info in status_summary.model_queues:
295
+ model_name = model_info.model_name
296
+ tpm = f"TPM (k)={model_info.TPM_limit_k}"
297
+ rpm = f"RPM (k)= {model_info.RPM_limit_k}"
298
+ pretty_model_name = model_name + ";" + tpm + ";" + rpm
299
+ table.add_row(Text(pretty_model_name, style="blue"), "")
300
+ table.add_row(
301
+ "Number question tasks waiting for capacity",
302
+ str(model_info.num_tasks_waiting),
303
+ )
304
+ # Token usage and cost info
305
+ for token_usage_info in model_info.token_usage_info:
306
+ token_usage_type = token_usage_info.token_usage_type
307
+ table.add_row(
308
+ Text(
309
+ spacing + token_usage_type.replace("_", " "), style="bold"
310
+ ),
311
+ "",
312
+ )
313
+ for detail in token_usage_info.details:
314
+ token_type = detail["type"]
315
+ tokens = detail["tokens"]
316
+ table.add_row(spacing + f"{token_type}", f"{tokens:,}")
317
+ # table.add_row(spacing + "cost", cache_info["cost"])
318
+
319
+ return table
320
+
321
+ def status_table(self, completed_tasks: List[asyncio.Task], elapsed_time: float):
322
+ summary_data = JobsRunnerStatusMixin.generate_status_summary(
323
+ completed_tasks=completed_tasks,
324
+ elapsed_time=elapsed_time,
325
+ interviews=self.total_interviews,
326
+ )
327
+ return self.display_status_table(summary_data)
328
+
329
+
330
+ if __name__ == "__main__":
331
+ import doctest
332
+
333
+ doctest.testmod(optionflags=doctest.ELLIPSIS)
@@ -50,18 +50,6 @@ class TaskHistory:
50
50
  """
51
51
  return [i.exceptions for k, i in self._interviews.items() if i.exceptions != {}]
52
52
 
53
- @property
54
- def unfixed_exceptions(self):
55
- """
56
- >>> len(TaskHistory.example().unfixed_exceptions)
57
- 4
58
- """
59
- return [
60
- i.exceptions
61
- for k, i in self._interviews.items()
62
- if i.exceptions.num_unfixed() > 0
63
- ]
64
-
65
53
  @property
66
54
  def indices(self):
67
55
  return [k for k, i in self._interviews.items() if i.exceptions != {}]
@@ -90,11 +78,6 @@ class TaskHistory:
90
78
  """
91
79
  return len(self.exceptions) > 0
92
80
 
93
- @property
94
- def has_unfixed_exceptions(self) -> bool:
95
- """Return True if there are any exceptions."""
96
- return len(self.unfixed_exceptions) > 0
97
-
98
81
  def _repr_html_(self):
99
82
  """Return an HTML representation of the TaskHistory."""
100
83
  from edsl.utilities.utilities import data_to_html
@@ -164,20 +164,20 @@ class LanguageModel(
164
164
  None # This should be something like ["choices", 0, "message", "content"]
165
165
  )
166
166
  __rate_limits = None
167
+ __default_rate_limits = {
168
+ "rpm": 10_000,
169
+ "tpm": 2_000_000,
170
+ } # TODO: Use the OpenAI Teir 1 rate limits
167
171
  _safety_factor = 0.8
168
172
 
169
- def __init__(
170
- self, tpm=None, rpm=None, omit_system_prompt_if_empty_string=True, **kwargs
171
- ):
173
+ def __init__(self, tpm=None, rpm=None, **kwargs):
172
174
  """Initialize the LanguageModel."""
173
175
  self.model = getattr(self, "_model_", None)
174
176
  default_parameters = getattr(self, "_parameters_", None)
175
177
  parameters = self._overide_default_parameters(kwargs, default_parameters)
176
178
  self.parameters = parameters
177
179
  self.remote = False
178
- self.omit_system_prompt_if_empty = omit_system_prompt_if_empty_string
179
180
 
180
- # self._rpm / _tpm comes from the class
181
181
  if rpm is not None:
182
182
  self._rpm = rpm
183
183
 
@@ -286,40 +286,35 @@ class LanguageModel(
286
286
  >>> m.RPM
287
287
  100
288
288
  """
289
- if rpm is not None:
290
- self._rpm = rpm
291
- if tpm is not None:
292
- self._tpm = tpm
293
- return None
294
- # self._set_rate_limits(rpm=rpm, tpm=tpm)
289
+ self._set_rate_limits(rpm=rpm, tpm=tpm)
295
290
 
296
- # def _set_rate_limits(self, rpm=None, tpm=None) -> None:
297
- # """Set the rate limits for the model.
291
+ def _set_rate_limits(self, rpm=None, tpm=None) -> None:
292
+ """Set the rate limits for the model.
298
293
 
299
- # If the model does not have rate limits, use the default rate limits."""
300
- # if rpm is not None and tpm is not None:
301
- # self.__rate_limits = {"rpm": rpm, "tpm": tpm}
302
- # return
294
+ If the model does not have rate limits, use the default rate limits."""
295
+ if rpm is not None and tpm is not None:
296
+ self.__rate_limits = {"rpm": rpm, "tpm": tpm}
297
+ return
303
298
 
304
- # if self.__rate_limits is None:
305
- # if hasattr(self, "get_rate_limits"):
306
- # self.__rate_limits = self.get_rate_limits()
307
- # else:
308
- # self.__rate_limits = self.__default_rate_limits
299
+ if self.__rate_limits is None:
300
+ if hasattr(self, "get_rate_limits"):
301
+ self.__rate_limits = self.get_rate_limits()
302
+ else:
303
+ self.__rate_limits = self.__default_rate_limits
309
304
 
310
305
  @property
311
306
  def RPM(self):
312
307
  """Model's requests-per-minute limit."""
313
308
  # self._set_rate_limits()
314
309
  # return self._safety_factor * self.__rate_limits["rpm"]
315
- return self._rpm
310
+ return self.rpm
316
311
 
317
312
  @property
318
313
  def TPM(self):
319
314
  """Model's tokens-per-minute limit."""
320
315
  # self._set_rate_limits()
321
316
  # return self._safety_factor * self.__rate_limits["tpm"]
322
- return self._tpm
317
+ return self.tpm
323
318
 
324
319
  @property
325
320
  def rpm(self):
@@ -337,6 +332,17 @@ class LanguageModel(
337
332
  def tpm(self, value):
338
333
  self._tpm = value
339
334
 
335
+ @property
336
+ def TPM(self):
337
+ """Model's tokens-per-minute limit.
338
+
339
+ >>> m = LanguageModel.example()
340
+ >>> m.TPM > 0
341
+ True
342
+ """
343
+ self._set_rate_limits()
344
+ return self._safety_factor * self.__rate_limits["tpm"]
345
+
340
346
  @staticmethod
341
347
  def _overide_default_parameters(passed_parameter_dict, default_parameter_dict):
342
348
  """Return a dictionary of parameters, with passed parameters taking precedence over defaults.
@@ -465,13 +471,12 @@ class LanguageModel(
465
471
  if encoded_image:
466
472
  # the image has is appended to the user_prompt for hash-lookup purposes
467
473
  image_hash = hashlib.md5(encoded_image.encode()).hexdigest()
468
- user_prompt += f" {image_hash}"
469
474
 
470
475
  cache_call_params = {
471
476
  "model": str(self.model),
472
477
  "parameters": self.parameters,
473
478
  "system_prompt": system_prompt,
474
- "user_prompt": user_prompt,
479
+ "user_prompt": user_prompt + "" if not encoded_image else f" {image_hash}",
475
480
  "iteration": iteration,
476
481
  }
477
482
  cached_response, cache_key = cache.fetch(**cache_call_params)
@@ -2,10 +2,10 @@ import textwrap
2
2
  from random import random
3
3
  from edsl.config import CONFIG
4
4
 
5
- # if "EDSL_DEFAULT_MODEL" not in CONFIG:
6
- # default_model = "test"
7
- # else:
8
- # default_model = CONFIG.get("EDSL_DEFAULT_MODEL")
5
+ if "EDSL_DEFAULT_MODEL" not in CONFIG:
6
+ default_model = "test"
7
+ else:
8
+ default_model = CONFIG.get("EDSL_DEFAULT_MODEL")
9
9
 
10
10
 
11
11
  def get_model_class(model_name, registry=None):
@@ -33,24 +33,20 @@ class Meta(type):
33
33
 
34
34
 
35
35
  class Model(metaclass=Meta):
36
- default_model = CONFIG.get("EDSL_DEFAULT_MODEL")
36
+ default_model = default_model
37
37
 
38
- def __new__(
39
- cls, model_name=None, registry=None, service_name=None, *args, **kwargs
40
- ):
38
+ def __new__(cls, model_name=None, registry=None, *args, **kwargs):
41
39
  # Map index to the respective subclass
42
40
  if model_name is None:
43
- model_name = (
44
- cls.default_model
45
- ) # when model_name is None, use the default model, set in the config file
41
+ model_name = cls.default_model
46
42
  from edsl.inference_services.registry import default
47
43
 
48
44
  registry = registry or default
49
45
 
50
- if isinstance(model_name, int): # can refer to a model by index
46
+ if isinstance(model_name, int):
51
47
  model_name = cls.available(name_only=True)[model_name]
52
48
 
53
- factory = registry.create_model_factory(model_name, service_name=service_name)
49
+ factory = registry.create_model_factory(model_name)
54
50
  return factory(*args, **kwargs)
55
51
 
56
52
  @classmethod
@@ -103,8 +103,13 @@ class QuestionBase(
103
103
  """Validate the answer.
104
104
  >>> from edsl.exceptions import QuestionAnswerValidationError
105
105
  >>> from edsl import QuestionFreeText as Q
106
- >>> Q.example()._validate_answer({'answer': 'Hello', 'generated_tokens': 'Hello'})
107
- {'answer': 'Hello', 'generated_tokens': 'Hello'}
106
+ >>> Q.example()._validate_answer({'answer': 'Hello'})
107
+ {'answer': 'Hello', 'generated_tokens': None}
108
+ >>> Q.example()._validate_answer({'shmanswer': 1})
109
+ Traceback (most recent call last):
110
+ ...
111
+ edsl.exceptions.questions.QuestionAnswerValidationError:...
112
+ ...
108
113
  """
109
114
 
110
115
  return self.response_validator.validate(answer)
@@ -466,7 +471,6 @@ class QuestionBase(
466
471
  self,
467
472
  scenario: Optional[dict] = None,
468
473
  agent: Optional[dict] = {},
469
- answers: Optional[dict] = None,
470
474
  include_question_name: bool = False,
471
475
  height: Optional[int] = None,
472
476
  width: Optional[int] = None,
@@ -478,17 +482,6 @@ class QuestionBase(
478
482
  if scenario is None:
479
483
  scenario = {}
480
484
 
481
- prior_answers_dict = {}
482
-
483
- if isinstance(answers, dict):
484
- for key, value in answers.items():
485
- if not key.endswith("_comment") and not key.endswith(
486
- "_generated_tokens"
487
- ):
488
- prior_answers_dict[key] = {"answer": value}
489
-
490
- # breakpoint()
491
-
492
485
  base_template = """
493
486
  <div id="{{ question_name }}" class="survey_question" data-type="{{ question_type }}">
494
487
  {% if include_question_name %}
@@ -508,40 +501,13 @@ class QuestionBase(
508
501
 
509
502
  base_template = Template(base_template)
510
503
 
511
- context = {
512
- "scenario": scenario,
513
- "agent": agent,
514
- } | prior_answers_dict
515
-
516
- # Render the question text
517
- try:
518
- question_text = Template(self.question_text).render(context)
519
- except Exception as e:
520
- print(
521
- f"Error rendering question: question_text = {self.question_text}, error = {e}"
522
- )
523
- question_text = self.question_text
524
-
525
- try:
526
- question_content = Template(question_content).render(context)
527
- except Exception as e:
528
- print(
529
- f"Error rendering question: question_content = {question_content}, error = {e}"
530
- )
531
- question_content = question_content
532
-
533
- try:
534
- params = {
535
- "question_name": self.question_name,
536
- "question_text": question_text,
537
- "question_type": self.question_type,
538
- "question_content": question_content,
539
- "include_question_name": include_question_name,
540
- }
541
- except Exception as e:
542
- raise ValueError(
543
- f"Error rendering question: params = {params}, error = {e}"
544
- )
504
+ params = {
505
+ "question_name": self.question_name,
506
+ "question_text": Template(self.question_text).render(scenario, agent=agent),
507
+ "question_type": self.question_type,
508
+ "question_content": Template(question_content).render(scenario),
509
+ "include_question_name": include_question_name,
510
+ }
545
511
  rendered_html = base_template.render(**params)
546
512
 
547
513
  if iframe:
@@ -560,21 +526,6 @@ class QuestionBase(
560
526
 
561
527
  return rendered_html
562
528
 
563
- @classmethod
564
- def example_model(cls):
565
- from edsl import Model
566
-
567
- q = cls.example()
568
- m = Model("test", canned_response=cls._simulate_answer(q)["answer"])
569
-
570
- return m
571
-
572
- @classmethod
573
- def example_results(cls):
574
- m = cls.example_model()
575
- q = cls.example()
576
- return q.by(m).run(cache=False)
577
-
578
529
  def rich_print(self):
579
530
  """Print the question in a rich format."""
580
531
  from rich.table import Table