edsl 0.1.41__py3-none-any.whl → 0.1.43__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__version__.py +1 -1
- edsl/agents/Invigilator.py +4 -3
- edsl/agents/InvigilatorBase.py +2 -1
- edsl/agents/PromptConstructor.py +92 -21
- edsl/agents/QuestionInstructionPromptBuilder.py +68 -9
- edsl/agents/QuestionTemplateReplacementsBuilder.py +7 -2
- edsl/agents/prompt_helpers.py +2 -2
- edsl/coop/coop.py +97 -19
- edsl/enums.py +3 -1
- edsl/exceptions/coop.py +4 -0
- edsl/exceptions/jobs.py +1 -9
- edsl/exceptions/language_models.py +8 -4
- edsl/exceptions/questions.py +8 -11
- edsl/inference_services/AvailableModelFetcher.py +4 -1
- edsl/inference_services/DeepSeekService.py +18 -0
- edsl/inference_services/registry.py +2 -0
- edsl/jobs/Jobs.py +60 -34
- edsl/jobs/JobsPrompts.py +64 -3
- edsl/jobs/JobsRemoteInferenceHandler.py +42 -25
- edsl/jobs/JobsRemoteInferenceLogger.py +1 -1
- edsl/jobs/buckets/BucketCollection.py +30 -0
- edsl/jobs/data_structures.py +1 -0
- edsl/jobs/interviews/Interview.py +1 -1
- edsl/jobs/loggers/HTMLTableJobLogger.py +6 -1
- edsl/jobs/results_exceptions_handler.py +2 -7
- edsl/jobs/tasks/TaskHistory.py +49 -17
- edsl/language_models/LanguageModel.py +7 -4
- edsl/language_models/ModelList.py +1 -1
- edsl/language_models/key_management/KeyLookupBuilder.py +47 -20
- edsl/language_models/key_management/models.py +10 -4
- edsl/language_models/model.py +49 -0
- edsl/prompts/Prompt.py +124 -61
- edsl/questions/descriptors.py +37 -23
- edsl/questions/question_base_gen_mixin.py +1 -0
- edsl/results/DatasetExportMixin.py +35 -6
- edsl/results/Result.py +9 -3
- edsl/results/Results.py +180 -2
- edsl/results/ResultsGGMixin.py +117 -60
- edsl/scenarios/PdfExtractor.py +3 -6
- edsl/scenarios/Scenario.py +35 -1
- edsl/scenarios/ScenarioList.py +22 -3
- edsl/scenarios/ScenarioListPdfMixin.py +9 -3
- edsl/surveys/Survey.py +1 -1
- edsl/templates/error_reporting/base.html +2 -4
- edsl/templates/error_reporting/exceptions_table.html +35 -0
- edsl/templates/error_reporting/interview_details.html +67 -53
- edsl/templates/error_reporting/interviews.html +4 -17
- edsl/templates/error_reporting/overview.html +31 -5
- edsl/templates/error_reporting/performance_plot.html +1 -1
- {edsl-0.1.41.dist-info → edsl-0.1.43.dist-info}/METADATA +2 -3
- {edsl-0.1.41.dist-info → edsl-0.1.43.dist-info}/RECORD +53 -51
- {edsl-0.1.41.dist-info → edsl-0.1.43.dist-info}/LICENSE +0 -0
- {edsl-0.1.41.dist-info → edsl-0.1.43.dist-info}/WHEEL +0 -0
@@ -244,7 +244,7 @@ class LanguageModel(
|
|
244
244
|
|
245
245
|
>>> m = LanguageModel.example()
|
246
246
|
>>> hash(m)
|
247
|
-
|
247
|
+
325654563661254408
|
248
248
|
"""
|
249
249
|
from edsl.utilities.utilities import dict_hash
|
250
250
|
|
@@ -495,11 +495,12 @@ class LanguageModel(
|
|
495
495
|
|
496
496
|
>>> m = LanguageModel.example()
|
497
497
|
>>> m.to_dict()
|
498
|
-
{'model': '...', 'parameters': {'temperature': ..., 'max_tokens': ..., 'top_p': ..., 'frequency_penalty': ..., 'presence_penalty': ..., 'logprobs': False, 'top_logprobs': ...}, 'edsl_version': '...', 'edsl_class_name': 'LanguageModel'}
|
498
|
+
{'model': '...', 'parameters': {'temperature': ..., 'max_tokens': ..., 'top_p': ..., 'frequency_penalty': ..., 'presence_penalty': ..., 'logprobs': False, 'top_logprobs': ...}, 'inference_service': 'openai', 'edsl_version': '...', 'edsl_class_name': 'LanguageModel'}
|
499
499
|
"""
|
500
500
|
d = {
|
501
501
|
"model": self.model,
|
502
502
|
"parameters": self.parameters,
|
503
|
+
"inference_service": self._inference_service_,
|
503
504
|
}
|
504
505
|
if add_edsl_version:
|
505
506
|
from edsl import __version__
|
@@ -511,7 +512,10 @@ class LanguageModel(
|
|
511
512
|
@classmethod
|
512
513
|
@remove_edsl_version
|
513
514
|
def from_dict(cls, data: dict) -> Type[LanguageModel]:
|
514
|
-
"""Convert dictionary to a LanguageModel child instance.
|
515
|
+
"""Convert dictionary to a LanguageModel child instance.
|
516
|
+
|
517
|
+
NB: This method does not use the stores inference_service but rather just fetches a model class based on the name.
|
518
|
+
"""
|
515
519
|
from edsl.language_models.model import get_model_class
|
516
520
|
|
517
521
|
model_class = get_model_class(data["model"])
|
@@ -558,7 +562,6 @@ class LanguageModel(
|
|
558
562
|
>>> m = LanguageModel.example(test_model = True, canned_response = "WOWZA!", throw_exception = True)
|
559
563
|
>>> r = q.by(m).run(cache = False, disable_remote_cache = True, disable_remote_inference = True, print_exceptions = True)
|
560
564
|
Exception report saved to ...
|
561
|
-
Also see: ...
|
562
565
|
"""
|
563
566
|
from edsl.language_models.model import Model
|
564
567
|
|
@@ -61,7 +61,14 @@ class KeyLookupBuilder:
|
|
61
61
|
DEFAULT_RPM = int(CONFIG.get("EDSL_SERVICE_RPM_BASELINE"))
|
62
62
|
DEFAULT_TPM = int(CONFIG.get("EDSL_SERVICE_TPM_BASELINE"))
|
63
63
|
|
64
|
-
def __init__(
|
64
|
+
def __init__(
|
65
|
+
self,
|
66
|
+
fetch_order: Optional[tuple[str]] = None,
|
67
|
+
coop: Optional["Coop"] = None,
|
68
|
+
):
|
69
|
+
from edsl.coop import Coop
|
70
|
+
|
71
|
+
# Fetch order goes from lowest priority to highest priority
|
65
72
|
if fetch_order is None:
|
66
73
|
self.fetch_order = ("config", "env")
|
67
74
|
else:
|
@@ -70,6 +77,11 @@ class KeyLookupBuilder:
|
|
70
77
|
if not isinstance(self.fetch_order, tuple):
|
71
78
|
raise ValueError("fetch_order must be a tuple")
|
72
79
|
|
80
|
+
if coop is None:
|
81
|
+
self.coop = Coop()
|
82
|
+
else:
|
83
|
+
self.coop = coop
|
84
|
+
|
73
85
|
self.limit_data = {}
|
74
86
|
self.key_data = {}
|
75
87
|
self.id_data = {}
|
@@ -131,7 +143,8 @@ class KeyLookupBuilder:
|
|
131
143
|
service=service,
|
132
144
|
rpm=self.DEFAULT_RPM,
|
133
145
|
tpm=self.DEFAULT_TPM,
|
134
|
-
|
146
|
+
rpm_source="default",
|
147
|
+
tpm_source="default",
|
135
148
|
)
|
136
149
|
|
137
150
|
if limit_entry.rpm is None:
|
@@ -145,7 +158,8 @@ class KeyLookupBuilder:
|
|
145
158
|
tpm=int(limit_entry.tpm),
|
146
159
|
api_id=api_id,
|
147
160
|
token_source=api_key_entry.source,
|
148
|
-
|
161
|
+
rpm_source=limit_entry.rpm_source,
|
162
|
+
tpm_source=limit_entry.tpm_source,
|
149
163
|
id_source=id_source,
|
150
164
|
)
|
151
165
|
|
@@ -156,10 +170,7 @@ class KeyLookupBuilder:
|
|
156
170
|
return dict(list(os.environ.items()))
|
157
171
|
|
158
172
|
def _coop_key_value_pairs(self):
|
159
|
-
|
160
|
-
|
161
|
-
c = Coop()
|
162
|
-
return dict(list(c.fetch_rate_limit_config_vars().items()))
|
173
|
+
return dict(list(self.coop.fetch_rate_limit_config_vars().items()))
|
163
174
|
|
164
175
|
def _config_key_value_pairs(self):
|
165
176
|
from edsl.config import CONFIG
|
@@ -169,7 +180,7 @@ class KeyLookupBuilder:
|
|
169
180
|
@staticmethod
|
170
181
|
def extract_service(key: str) -> str:
|
171
182
|
"""Extract the service and limit type from the key"""
|
172
|
-
limit_type, service_raw = key.replace("EDSL_SERVICE_", "").split("_")
|
183
|
+
limit_type, service_raw = key.replace("EDSL_SERVICE_", "").split("_", 1)
|
173
184
|
return service_raw.lower(), limit_type.lower()
|
174
185
|
|
175
186
|
def get_key_value_pairs(self) -> dict:
|
@@ -187,17 +198,17 @@ class KeyLookupBuilder:
|
|
187
198
|
d[k] = (v, source)
|
188
199
|
return d
|
189
200
|
|
190
|
-
def _entry_type(self, key
|
201
|
+
def _entry_type(self, key: str) -> str:
|
191
202
|
"""Determine the type of entry from a key.
|
192
203
|
|
193
204
|
>>> builder = KeyLookupBuilder()
|
194
|
-
>>> builder._entry_type("EDSL_SERVICE_RPM_OPENAI"
|
205
|
+
>>> builder._entry_type("EDSL_SERVICE_RPM_OPENAI")
|
195
206
|
'limit'
|
196
|
-
>>> builder._entry_type("OPENAI_API_KEY"
|
207
|
+
>>> builder._entry_type("OPENAI_API_KEY")
|
197
208
|
'api_key'
|
198
|
-
>>> builder._entry_type("AWS_ACCESS_KEY_ID"
|
209
|
+
>>> builder._entry_type("AWS_ACCESS_KEY_ID")
|
199
210
|
'api_id'
|
200
|
-
>>> builder._entry_type("UNKNOWN_KEY"
|
211
|
+
>>> builder._entry_type("UNKNOWN_KEY")
|
201
212
|
'unknown'
|
202
213
|
"""
|
203
214
|
if key.startswith("EDSL_SERVICE_"):
|
@@ -243,11 +254,13 @@ class KeyLookupBuilder:
|
|
243
254
|
service, limit_type = self.extract_service(key)
|
244
255
|
if service in self.limit_data:
|
245
256
|
setattr(self.limit_data[service], limit_type.lower(), value)
|
257
|
+
setattr(self.limit_data[service], f"{limit_type}_source", source)
|
246
258
|
else:
|
247
259
|
new_limit_entry = LimitEntry(
|
248
|
-
service=service, rpm=None, tpm=None,
|
260
|
+
service=service, rpm=None, tpm=None, rpm_source=None, tpm_source=None
|
249
261
|
)
|
250
262
|
setattr(new_limit_entry, limit_type.lower(), value)
|
263
|
+
setattr(new_limit_entry, f"{limit_type}_source", source)
|
251
264
|
self.limit_data[service] = new_limit_entry
|
252
265
|
|
253
266
|
def _add_api_key(self, key: str, value: str, source: str) -> None:
|
@@ -265,13 +278,27 @@ class KeyLookupBuilder:
|
|
265
278
|
else:
|
266
279
|
self.key_data[service].append(new_entry)
|
267
280
|
|
268
|
-
def
|
269
|
-
"""
|
270
|
-
|
281
|
+
def update_from_dict(self, d: dict) -> None:
|
282
|
+
"""
|
283
|
+
Update data from a dictionary of key-value pairs.
|
284
|
+
Each key is a key name, and each value is a tuple of (value, source).
|
285
|
+
|
286
|
+
>>> builder = KeyLookupBuilder()
|
287
|
+
>>> builder.update_from_dict({"OPENAI_API_KEY": ("sk-1234", "custodial_keys")})
|
288
|
+
>>> 'sk-1234' == builder.key_data["openai"][-1].value
|
289
|
+
True
|
290
|
+
>>> 'custodial_keys' == builder.key_data["openai"][-1].source
|
291
|
+
True
|
292
|
+
"""
|
293
|
+
for key, value_pair in d.items():
|
271
294
|
value, source = value_pair
|
272
|
-
if
|
295
|
+
if self._entry_type(key) == "limit":
|
273
296
|
self._add_limit(key, value, source)
|
274
|
-
elif
|
297
|
+
elif self._entry_type(key) == "api_key":
|
275
298
|
self._add_api_key(key, value, source)
|
276
|
-
elif
|
299
|
+
elif self._entry_type(key) == "api_id":
|
277
300
|
self._add_id(key, value, source)
|
301
|
+
|
302
|
+
def process_key_value_pairs(self) -> None:
|
303
|
+
"""Process all key-value pairs from the configured sources."""
|
304
|
+
self.update_from_dict(self.get_key_value_pairs())
|
@@ -40,18 +40,23 @@ class LimitEntry:
|
|
40
40
|
60
|
41
41
|
>>> limit.tpm
|
42
42
|
100000
|
43
|
-
>>> limit.
|
43
|
+
>>> limit.rpm_source
|
44
44
|
'config'
|
45
|
+
>>> limit.tpm_source
|
46
|
+
'env'
|
45
47
|
"""
|
46
48
|
|
47
49
|
service: str
|
48
50
|
rpm: int
|
49
51
|
tpm: int
|
50
|
-
|
52
|
+
rpm_source: Optional[str] = None
|
53
|
+
tpm_source: Optional[str] = None
|
51
54
|
|
52
55
|
@classmethod
|
53
56
|
def example(cls):
|
54
|
-
return LimitEntry(
|
57
|
+
return LimitEntry(
|
58
|
+
service="openai", rpm=60, tpm=100000, rpm_source="config", tpm_source="env"
|
59
|
+
)
|
55
60
|
|
56
61
|
|
57
62
|
@dataclass
|
@@ -108,7 +113,8 @@ class LanguageModelInput:
|
|
108
113
|
tpm: int
|
109
114
|
api_id: Optional[str] = None
|
110
115
|
token_source: Optional[str] = None
|
111
|
-
|
116
|
+
rpm_source: Optional[str] = None
|
117
|
+
tpm_source: Optional[str] = None
|
112
118
|
id_source: Optional[str] = None
|
113
119
|
|
114
120
|
def to_dict(self):
|
edsl/language_models/model.py
CHANGED
@@ -233,6 +233,55 @@ class Model(metaclass=Meta):
|
|
233
233
|
print("OK!")
|
234
234
|
print("\n")
|
235
235
|
|
236
|
+
@classmethod
|
237
|
+
def check_working_models(
|
238
|
+
cls,
|
239
|
+
service: Optional[str] = None,
|
240
|
+
works_with_text: Optional[bool] = None,
|
241
|
+
works_with_images: Optional[bool] = None,
|
242
|
+
) -> list[dict]:
|
243
|
+
from edsl.coop import Coop
|
244
|
+
|
245
|
+
c = Coop()
|
246
|
+
working_models = c.fetch_working_models()
|
247
|
+
|
248
|
+
if service is not None:
|
249
|
+
working_models = [m for m in working_models if m["service"] == service]
|
250
|
+
if works_with_text is not None:
|
251
|
+
working_models = [
|
252
|
+
m for m in working_models if m["works_with_text"] == works_with_text
|
253
|
+
]
|
254
|
+
if works_with_images is not None:
|
255
|
+
working_models = [
|
256
|
+
m for m in working_models if m["works_with_images"] == works_with_images
|
257
|
+
]
|
258
|
+
|
259
|
+
if len(working_models) == 0:
|
260
|
+
return []
|
261
|
+
|
262
|
+
else:
|
263
|
+
return PrettyList(
|
264
|
+
[
|
265
|
+
[
|
266
|
+
m["service"],
|
267
|
+
m["model"],
|
268
|
+
m["works_with_text"],
|
269
|
+
m["works_with_images"],
|
270
|
+
m["usd_per_1M_input_tokens"],
|
271
|
+
m["usd_per_1M_output_tokens"],
|
272
|
+
]
|
273
|
+
for m in working_models
|
274
|
+
],
|
275
|
+
columns=[
|
276
|
+
"Service",
|
277
|
+
"Model",
|
278
|
+
"Works with text",
|
279
|
+
"Works with images",
|
280
|
+
"Price per 1M input tokens (USD)",
|
281
|
+
"Price per 1M output tokens (USD)",
|
282
|
+
],
|
283
|
+
)
|
284
|
+
|
236
285
|
@classmethod
|
237
286
|
def example(cls, randomize: bool = False) -> "Model":
|
238
287
|
"""
|
edsl/prompts/Prompt.py
CHANGED
@@ -10,6 +10,48 @@ from edsl.Base import PersistenceMixin, RepresentationMixin
|
|
10
10
|
|
11
11
|
MAX_NESTING = 100
|
12
12
|
|
13
|
+
from jinja2 import Environment, meta, TemplateSyntaxError, Undefined
|
14
|
+
from functools import lru_cache
|
15
|
+
|
16
|
+
class PreserveUndefined(Undefined):
|
17
|
+
def __str__(self):
|
18
|
+
return "{{ " + str(self._undefined_name) + " }}"
|
19
|
+
|
20
|
+
# Create environment once at module level
|
21
|
+
_env = Environment(undefined=PreserveUndefined)
|
22
|
+
|
23
|
+
@lru_cache(maxsize=1024)
|
24
|
+
def _compile_template(text: str):
|
25
|
+
return _env.from_string(text)
|
26
|
+
|
27
|
+
@lru_cache(maxsize=1024)
|
28
|
+
def _find_template_variables(template: str) -> list[str]:
|
29
|
+
"""Find and return the template variables."""
|
30
|
+
ast = _env.parse(template)
|
31
|
+
return list(meta.find_undeclared_variables(ast))
|
32
|
+
|
33
|
+
def _make_hashable(value):
|
34
|
+
"""Convert unhashable types to hashable ones."""
|
35
|
+
if isinstance(value, list):
|
36
|
+
return tuple(_make_hashable(item) for item in value)
|
37
|
+
if isinstance(value, dict):
|
38
|
+
return frozenset((k, _make_hashable(v)) for k, v in value.items())
|
39
|
+
return value
|
40
|
+
|
41
|
+
@lru_cache(maxsize=1024)
|
42
|
+
def _cached_render(text: str, frozen_replacements: frozenset) -> str:
|
43
|
+
"""Cached version of template rendering with frozen replacements."""
|
44
|
+
# Print cache info on every call
|
45
|
+
cache_info = _cached_render.cache_info()
|
46
|
+
print(f"\t\t\t\t\t Cache status - hits: {cache_info.hits}, misses: {cache_info.misses}, current size: {cache_info.currsize}")
|
47
|
+
|
48
|
+
# Convert back to dict with original types for rendering
|
49
|
+
replacements = {k: v for k, v in frozen_replacements}
|
50
|
+
|
51
|
+
template = _compile_template(text)
|
52
|
+
result = template.render(replacements)
|
53
|
+
|
54
|
+
return result
|
13
55
|
|
14
56
|
class Prompt(PersistenceMixin, RepresentationMixin):
|
15
57
|
"""Class for creating a prompt to be used in a survey."""
|
@@ -145,33 +187,8 @@ class Prompt(PersistenceMixin, RepresentationMixin):
|
|
145
187
|
return f'Prompt(text="""{self.text}""")'
|
146
188
|
|
147
189
|
def template_variables(self) -> list[str]:
|
148
|
-
"""Return the
|
149
|
-
|
150
|
-
Example:
|
151
|
-
|
152
|
-
>>> p = Prompt("Hello, {{person}}")
|
153
|
-
>>> p.template_variables()
|
154
|
-
['person']
|
155
|
-
|
156
|
-
"""
|
157
|
-
return self._template_variables(self.text)
|
158
|
-
|
159
|
-
@staticmethod
|
160
|
-
def _template_variables(template: str) -> list[str]:
|
161
|
-
"""Find and return the template variables.
|
162
|
-
|
163
|
-
:param template: The template to find the variables in.
|
164
|
-
|
165
|
-
"""
|
166
|
-
from jinja2 import Environment, meta, Undefined
|
167
|
-
|
168
|
-
class PreserveUndefined(Undefined):
|
169
|
-
def __str__(self):
|
170
|
-
return "{{ " + str(self._undefined_name) + " }}"
|
171
|
-
|
172
|
-
env = Environment(undefined=PreserveUndefined)
|
173
|
-
ast = env.parse(template)
|
174
|
-
return list(meta.find_undeclared_variables(ast))
|
190
|
+
"""Return the variables in the template."""
|
191
|
+
return _find_template_variables(self.text)
|
175
192
|
|
176
193
|
def undefined_template_variables(self, replacement_dict: dict):
|
177
194
|
"""Return the variables in the template that are not in the replacement_dict.
|
@@ -239,45 +256,39 @@ class Prompt(PersistenceMixin, RepresentationMixin):
|
|
239
256
|
return self
|
240
257
|
|
241
258
|
@staticmethod
|
242
|
-
def _render(
|
243
|
-
text
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
Allows for nested variable resolution up to a specified maximum nesting depth.
|
252
|
-
|
253
|
-
Example:
|
254
|
-
|
255
|
-
>>> codebook = {"age": "Age"}
|
256
|
-
>>> p = Prompt("You are an agent named {{ name }}. {{ codebook['age']}}: {{ age }}")
|
257
|
-
>>> p.render({"name": "John", "age": 44}, codebook=codebook)
|
258
|
-
Prompt(text=\"""You are an agent named John. Age: 44\""")
|
259
|
-
"""
|
260
|
-
from jinja2 import Environment, meta, TemplateSyntaxError, Undefined
|
261
|
-
|
262
|
-
class PreserveUndefined(Undefined):
|
263
|
-
def __str__(self):
|
264
|
-
return "{{ " + str(self._undefined_name) + " }}"
|
265
|
-
|
266
|
-
env = Environment(undefined=PreserveUndefined)
|
259
|
+
def _render(text: str, primary_replacement, **additional_replacements) -> "PromptBase":
|
260
|
+
"""Render the template text with variables replaced."""
|
261
|
+
import time
|
262
|
+
|
263
|
+
# if there are no replacements, return the text
|
264
|
+
if not primary_replacement and not additional_replacements:
|
265
|
+
return text
|
266
|
+
|
267
267
|
try:
|
268
|
+
variables = _find_template_variables(text)
|
269
|
+
|
270
|
+
if not variables: # if there are no variables, return the text
|
271
|
+
return text
|
272
|
+
|
273
|
+
# Combine all replacements
|
274
|
+
all_replacements = {**primary_replacement, **additional_replacements}
|
275
|
+
|
268
276
|
previous_text = None
|
277
|
+
current_text = text
|
278
|
+
iteration = 0
|
279
|
+
|
269
280
|
for _ in range(MAX_NESTING):
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
)
|
274
|
-
|
275
|
-
|
281
|
+
iteration += 1
|
282
|
+
|
283
|
+
template = _compile_template(current_text)
|
284
|
+
rendered_text = template.render(all_replacements)
|
285
|
+
|
286
|
+
if rendered_text == current_text:
|
276
287
|
return rendered_text
|
277
|
-
|
278
|
-
|
288
|
+
|
289
|
+
previous_text = current_text
|
290
|
+
current_text = rendered_text
|
279
291
|
|
280
|
-
# If the loop exits without returning, it indicates too much nesting
|
281
292
|
raise TemplateRenderError(
|
282
293
|
"Too much nesting - you created an infinite loop here, pal"
|
283
294
|
)
|
@@ -331,6 +342,58 @@ class Prompt(PersistenceMixin, RepresentationMixin):
|
|
331
342
|
"""Return an example of the prompt."""
|
332
343
|
return cls(cls.default_instructions)
|
333
344
|
|
345
|
+
def get_prompts(self) -> Dict[str, Any]:
|
346
|
+
"""Get the prompts for the question."""
|
347
|
+
start = time.time()
|
348
|
+
|
349
|
+
# Build all the components
|
350
|
+
instr_start = time.time()
|
351
|
+
agent_instructions = self.agent_instructions_prompt
|
352
|
+
instr_end = time.time()
|
353
|
+
logger.debug(f"Time taken for agent instructions: {instr_end - instr_start:.4f}s")
|
354
|
+
|
355
|
+
persona_start = time.time()
|
356
|
+
agent_persona = self.agent_persona_prompt
|
357
|
+
persona_end = time.time()
|
358
|
+
logger.debug(f"Time taken for agent persona: {persona_end - persona_start:.4f}s")
|
359
|
+
|
360
|
+
q_instr_start = time.time()
|
361
|
+
question_instructions = self.question_instructions_prompt
|
362
|
+
q_instr_end = time.time()
|
363
|
+
logger.debug(f"Time taken for question instructions: {q_instr_end - q_instr_start:.4f}s")
|
364
|
+
|
365
|
+
memory_start = time.time()
|
366
|
+
prior_question_memory = self.prior_question_memory_prompt
|
367
|
+
memory_end = time.time()
|
368
|
+
logger.debug(f"Time taken for prior question memory: {memory_end - memory_start:.4f}s")
|
369
|
+
|
370
|
+
# Get components dict
|
371
|
+
components = {
|
372
|
+
"agent_instructions": agent_instructions.text,
|
373
|
+
"agent_persona": agent_persona.text,
|
374
|
+
"question_instructions": question_instructions.text,
|
375
|
+
"prior_question_memory": prior_question_memory.text,
|
376
|
+
}
|
377
|
+
|
378
|
+
# Use PromptPlan's get_prompts method
|
379
|
+
plan_start = time.time()
|
380
|
+
prompts = self.prompt_plan.get_prompts(**components)
|
381
|
+
plan_end = time.time()
|
382
|
+
logger.debug(f"Time taken for prompt processing: {plan_end - plan_start:.4f}s")
|
383
|
+
|
384
|
+
# Handle file keys if present
|
385
|
+
if hasattr(self, 'question_file_keys') and self.question_file_keys:
|
386
|
+
files_start = time.time()
|
387
|
+
files_list = []
|
388
|
+
for key in self.question_file_keys:
|
389
|
+
files_list.append(self.scenario[key])
|
390
|
+
prompts["files_list"] = files_list
|
391
|
+
files_end = time.time()
|
392
|
+
logger.debug(f"Time taken for file key processing: {files_end - files_start:.4f}s")
|
393
|
+
|
394
|
+
end = time.time()
|
395
|
+
logger.debug(f"Total time in get_prompts: {end - start:.4f}s")
|
396
|
+
return prompts
|
334
397
|
|
335
398
|
if __name__ == "__main__":
|
336
399
|
print("Running doctests...")
|
edsl/questions/descriptors.py
CHANGED
@@ -249,7 +249,28 @@ class QuestionNameDescriptor(BaseDescriptor):
|
|
249
249
|
|
250
250
|
|
251
251
|
class QuestionOptionsDescriptor(BaseDescriptor):
|
252
|
-
"""Validate that `question_options` is a list, does not exceed the min/max lengths, and has unique items.
|
252
|
+
"""Validate that `question_options` is a list, does not exceed the min/max lengths, and has unique items.
|
253
|
+
|
254
|
+
>>> import warnings
|
255
|
+
>>> q_class = QuestionOptionsDescriptor.example()
|
256
|
+
>>> with warnings.catch_warnings(record=True) as w:
|
257
|
+
... _ = q_class(["a ", "b", "c"]) # Has trailing space
|
258
|
+
... assert len(w) == 1
|
259
|
+
... assert "trailing whitespace" in str(w[0].message)
|
260
|
+
|
261
|
+
>>> _ = q_class(["a", "b", "c", "d", "d"])
|
262
|
+
Traceback (most recent call last):
|
263
|
+
...
|
264
|
+
edsl.exceptions.questions.QuestionCreationValidationError: Question options must be unique (got ['a', 'b', 'c', 'd', 'd']).
|
265
|
+
|
266
|
+
We allow dynamic question options, which are strings of the form '{{ question_options }}'.
|
267
|
+
|
268
|
+
>>> _ = q_class("{{dynamic_options}}")
|
269
|
+
>>> _ = q_class("dynamic_options")
|
270
|
+
Traceback (most recent call last):
|
271
|
+
...
|
272
|
+
edsl.exceptions.questions.QuestionCreationValidationError: ...
|
273
|
+
"""
|
253
274
|
|
254
275
|
@classmethod
|
255
276
|
def example(cls):
|
@@ -273,23 +294,7 @@ class QuestionOptionsDescriptor(BaseDescriptor):
|
|
273
294
|
self.q_budget = q_budget
|
274
295
|
|
275
296
|
def validate(self, value: Any, instance) -> None:
|
276
|
-
"""Validate the question options.
|
277
|
-
|
278
|
-
>>> q_class = QuestionOptionsDescriptor.example()
|
279
|
-
>>> _ = q_class(["a", "b", "c"])
|
280
|
-
>>> _ = q_class(["a", "b", "c", "d", "d"])
|
281
|
-
Traceback (most recent call last):
|
282
|
-
...
|
283
|
-
edsl.exceptions.questions.QuestionCreationValidationError: Question options must be unique (got ['a', 'b', 'c', 'd', 'd']).
|
284
|
-
|
285
|
-
We allow dynamic question options, which are strings of the form '{{ question_options }}'.
|
286
|
-
|
287
|
-
>>> _ = q_class("{{dynamic_options}}")
|
288
|
-
>>> _ = q_class("dynamic_options")
|
289
|
-
Traceback (most recent call last):
|
290
|
-
...
|
291
|
-
edsl.exceptions.questions.QuestionCreationValidationError: ...
|
292
|
-
"""
|
297
|
+
"""Validate the question options."""
|
293
298
|
if isinstance(value, str):
|
294
299
|
# Check if the string is a dynamic question option
|
295
300
|
if "{{" in value and "}}" in value:
|
@@ -302,10 +307,10 @@ class QuestionOptionsDescriptor(BaseDescriptor):
|
|
302
307
|
raise QuestionCreationValidationError(
|
303
308
|
f"Question options must be a list (got {value})."
|
304
309
|
)
|
305
|
-
if len(value) > Settings.MAX_NUM_OPTIONS:
|
306
|
-
|
307
|
-
|
308
|
-
|
310
|
+
# if len(value) > Settings.MAX_NUM_OPTIONS:
|
311
|
+
# raise QuestionCreationValidationError(
|
312
|
+
# f"Too many question options (got {value})."
|
313
|
+
# )
|
309
314
|
if len(value) < Settings.MIN_NUM_OPTIONS:
|
310
315
|
raise QuestionCreationValidationError(
|
311
316
|
f"Too few question options (got {value})."
|
@@ -343,6 +348,15 @@ class QuestionOptionsDescriptor(BaseDescriptor):
|
|
343
348
|
f"All question options must be at least 1 character long but less than {Settings.MAX_OPTION_LENGTH} characters long (got {value})."
|
344
349
|
)
|
345
350
|
|
351
|
+
# Check for trailing whitespace in string options
|
352
|
+
if any(isinstance(x, str) and (x != x.strip()) for x in value):
|
353
|
+
import warnings
|
354
|
+
|
355
|
+
warnings.warn(
|
356
|
+
"Some question options contain trailing whitespace. This may cause unexpected behavior.",
|
357
|
+
UserWarning,
|
358
|
+
)
|
359
|
+
|
346
360
|
if hasattr(instance, "min_selections") and instance.min_selections != None:
|
347
361
|
if instance.min_selections > len(value):
|
348
362
|
raise QuestionCreationValidationError(
|
@@ -408,7 +422,7 @@ class QuestionTextDescriptor(BaseDescriptor):
|
|
408
422
|
# Automatically replace single braces with double braces
|
409
423
|
# This is here because if the user is using an f-string, the double brace will get converted to a single brace.
|
410
424
|
# This undoes that.
|
411
|
-
value = re.sub(r"\{([^\{\}]+)\}", r"{{\1}}", value)
|
425
|
+
# value = re.sub(r"\{([^\{\}]+)\}", r"{{\1}}", value)
|
412
426
|
return value
|
413
427
|
|
414
428
|
# iterate through all doubles braces and check if they are valid python identifiers
|