edsl 0.1.44__py3-none-any.whl → 0.1.46__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +7 -3
- edsl/__version__.py +1 -1
- edsl/agents/InvigilatorBase.py +3 -1
- edsl/agents/PromptConstructor.py +66 -91
- edsl/agents/QuestionInstructionPromptBuilder.py +160 -79
- edsl/agents/QuestionTemplateReplacementsBuilder.py +80 -17
- edsl/agents/question_option_processor.py +15 -6
- edsl/coop/CoopFunctionsMixin.py +3 -4
- edsl/coop/coop.py +171 -96
- edsl/data/RemoteCacheSync.py +10 -9
- edsl/enums.py +3 -3
- edsl/inference_services/AnthropicService.py +11 -9
- edsl/inference_services/AvailableModelFetcher.py +2 -0
- edsl/inference_services/AwsBedrock.py +1 -2
- edsl/inference_services/AzureAI.py +12 -9
- edsl/inference_services/GoogleService.py +9 -4
- edsl/inference_services/InferenceServicesCollection.py +2 -2
- edsl/inference_services/MistralAIService.py +1 -2
- edsl/inference_services/OpenAIService.py +9 -4
- edsl/inference_services/PerplexityService.py +2 -1
- edsl/inference_services/{GrokService.py → XAIService.py} +2 -2
- edsl/inference_services/registry.py +2 -2
- edsl/jobs/AnswerQuestionFunctionConstructor.py +12 -1
- edsl/jobs/Jobs.py +24 -17
- edsl/jobs/JobsChecks.py +10 -13
- edsl/jobs/JobsPrompts.py +49 -26
- edsl/jobs/JobsRemoteInferenceHandler.py +4 -5
- edsl/jobs/async_interview_runner.py +3 -1
- edsl/jobs/check_survey_scenario_compatibility.py +5 -5
- edsl/jobs/data_structures.py +3 -0
- edsl/jobs/interviews/Interview.py +6 -3
- edsl/jobs/interviews/InterviewExceptionEntry.py +12 -0
- edsl/jobs/tasks/TaskHistory.py +1 -1
- edsl/language_models/LanguageModel.py +6 -3
- edsl/language_models/PriceManager.py +45 -5
- edsl/language_models/model.py +47 -26
- edsl/questions/QuestionBase.py +21 -0
- edsl/questions/QuestionBasePromptsMixin.py +103 -0
- edsl/questions/QuestionFreeText.py +22 -5
- edsl/questions/descriptors.py +4 -0
- edsl/questions/question_base_gen_mixin.py +96 -29
- edsl/results/Dataset.py +65 -0
- edsl/results/DatasetExportMixin.py +320 -32
- edsl/results/Result.py +27 -0
- edsl/results/Results.py +22 -2
- edsl/results/ResultsGGMixin.py +7 -3
- edsl/scenarios/DocumentChunker.py +2 -0
- edsl/scenarios/FileStore.py +10 -0
- edsl/scenarios/PdfExtractor.py +21 -1
- edsl/scenarios/Scenario.py +25 -9
- edsl/scenarios/ScenarioList.py +226 -24
- edsl/scenarios/handlers/__init__.py +1 -0
- edsl/scenarios/handlers/docx.py +5 -1
- edsl/scenarios/handlers/jpeg.py +39 -0
- edsl/surveys/Survey.py +5 -4
- edsl/surveys/SurveyFlowVisualization.py +91 -43
- edsl/templates/error_reporting/exceptions_table.html +7 -8
- edsl/templates/error_reporting/interview_details.html +1 -1
- edsl/templates/error_reporting/interviews.html +0 -1
- edsl/templates/error_reporting/overview.html +2 -7
- edsl/templates/error_reporting/performance_plot.html +1 -1
- edsl/templates/error_reporting/report.css +1 -1
- edsl/utilities/PrettyList.py +14 -0
- edsl-0.1.46.dist-info/METADATA +246 -0
- {edsl-0.1.44.dist-info → edsl-0.1.46.dist-info}/RECORD +67 -66
- edsl-0.1.44.dist-info/METADATA +0 -110
- {edsl-0.1.44.dist-info → edsl-0.1.46.dist-info}/LICENSE +0 -0
- {edsl-0.1.44.dist-info → edsl-0.1.46.dist-info}/WHEEL +0 -0
@@ -24,7 +24,7 @@ from edsl.jobs.JobsRemoteInferenceLogger import JobLogger
|
|
24
24
|
class RemoteJobConstants:
|
25
25
|
"""Constants for remote job handling."""
|
26
26
|
|
27
|
-
REMOTE_JOB_POLL_INTERVAL =
|
27
|
+
REMOTE_JOB_POLL_INTERVAL = 4
|
28
28
|
REMOTE_JOB_VERBOSE = False
|
29
29
|
DISCORD_URL = "https://discord.com/invite/mxAYkjfy9m"
|
30
30
|
|
@@ -88,8 +88,8 @@ class JobsRemoteInferenceHandler:
|
|
88
88
|
iterations: int = 1,
|
89
89
|
remote_inference_description: Optional[str] = None,
|
90
90
|
remote_inference_results_visibility: Optional[VisibilityType] = "unlisted",
|
91
|
+
fresh: Optional[bool] = False,
|
91
92
|
) -> RemoteJobInfo:
|
92
|
-
|
93
93
|
from edsl.config import CONFIG
|
94
94
|
from edsl.coop.coop import Coop
|
95
95
|
|
@@ -106,6 +106,7 @@ class JobsRemoteInferenceHandler:
|
|
106
106
|
status="queued",
|
107
107
|
iterations=iterations,
|
108
108
|
initial_results_visibility=remote_inference_results_visibility,
|
109
|
+
fresh=fresh,
|
109
110
|
)
|
110
111
|
logger.update(
|
111
112
|
"Your survey is running at the Expected Parrot server...",
|
@@ -277,9 +278,7 @@ class JobsRemoteInferenceHandler:
|
|
277
278
|
job_in_queue = True
|
278
279
|
while job_in_queue:
|
279
280
|
result = self._attempt_fetch_job(
|
280
|
-
job_info,
|
281
|
-
remote_job_data_fetcher,
|
282
|
-
object_fetcher
|
281
|
+
job_info, remote_job_data_fetcher, object_fetcher
|
283
282
|
)
|
284
283
|
if result != "continue":
|
285
284
|
return result
|
@@ -7,6 +7,8 @@ from edsl.data_transfer_models import EDSLResultObjectInput
|
|
7
7
|
|
8
8
|
from edsl.results.Result import Result
|
9
9
|
from edsl.jobs.interviews.Interview import Interview
|
10
|
+
from edsl.config import Config
|
11
|
+
config = Config()
|
10
12
|
|
11
13
|
if TYPE_CHECKING:
|
12
14
|
from edsl.jobs.Jobs import Jobs
|
@@ -23,7 +25,7 @@ from edsl.jobs.data_structures import RunConfig
|
|
23
25
|
|
24
26
|
|
25
27
|
class AsyncInterviewRunner:
|
26
|
-
MAX_CONCURRENT =
|
28
|
+
MAX_CONCURRENT = int(config.EDSL_MAX_CONCURRENT_TASKS)
|
27
29
|
|
28
30
|
def __init__(self, jobs: "Jobs", run_config: RunConfig):
|
29
31
|
self.jobs = jobs
|
@@ -72,11 +72,11 @@ class CheckSurveyScenarioCompatibility:
|
|
72
72
|
if warn:
|
73
73
|
warnings.warn(message)
|
74
74
|
|
75
|
-
if self.scenarios.has_jinja_braces:
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
75
|
+
# if self.scenarios.has_jinja_braces:
|
76
|
+
# warnings.warn(
|
77
|
+
# "The scenarios have Jinja braces ({{ and }}). Converting to '<<' and '>>'. If you want a different conversion, use the convert_jinja_braces method first to modify the scenario."
|
78
|
+
# )
|
79
|
+
# self.scenarios = self.scenarios._convert_jinja_braces()
|
80
80
|
|
81
81
|
|
82
82
|
if __name__ == "__main__":
|
edsl/jobs/data_structures.py
CHANGED
@@ -36,6 +36,9 @@ class RunParameters(Base):
|
|
36
36
|
disable_remote_cache: bool = False
|
37
37
|
disable_remote_inference: bool = False
|
38
38
|
job_uuid: Optional[str] = None
|
39
|
+
fresh: Optional[
|
40
|
+
bool
|
41
|
+
] = False # if True, will not use cache and will save new results to cache
|
39
42
|
|
40
43
|
def to_dict(self, add_edsl_version=False) -> dict:
|
41
44
|
d = asdict(self)
|
@@ -238,9 +238,6 @@ class Interview:
|
|
238
238
|
>>> run_config = RunConfig(parameters = RunParameters(), environment = RunEnvironment())
|
239
239
|
>>> run_config.parameters.stop_on_exception = True
|
240
240
|
>>> result, _ = asyncio.run(i.async_conduct_interview(run_config))
|
241
|
-
Traceback (most recent call last):
|
242
|
-
...
|
243
|
-
asyncio.exceptions.CancelledError
|
244
241
|
"""
|
245
242
|
from edsl.jobs.Jobs import RunConfig, RunParameters, RunEnvironment
|
246
243
|
|
@@ -262,6 +259,8 @@ class Interview:
|
|
262
259
|
if model_buckets is None or hasattr(self.agent, "answer_question_directly"):
|
263
260
|
model_buckets = ModelBuckets.infinity_bucket()
|
264
261
|
|
262
|
+
self.skip_flags = {q.question_name: False for q in self.survey.questions}
|
263
|
+
|
265
264
|
# was "self.tasks" - is that necessary?
|
266
265
|
self.tasks = self.task_manager.build_question_tasks(
|
267
266
|
answer_func=AnswerQuestionFunctionConstructor(
|
@@ -310,6 +309,10 @@ class Interview:
|
|
310
309
|
def handle_task(task, invigilator):
|
311
310
|
try:
|
312
311
|
result: Answers = task.result()
|
312
|
+
if result == "skipped":
|
313
|
+
result = invigilator.get_failed_task_result(
|
314
|
+
failure_reason="Task was skipped."
|
315
|
+
)
|
313
316
|
except asyncio.CancelledError as e: # task was cancelled
|
314
317
|
result = invigilator.get_failed_task_result(
|
315
318
|
failure_reason="Task was cancelled."
|
@@ -166,6 +166,9 @@ class InterviewExceptionEntry:
|
|
166
166
|
>>> entry = InterviewExceptionEntry.example()
|
167
167
|
>>> _ = entry.to_dict()
|
168
168
|
"""
|
169
|
+
import json
|
170
|
+
from edsl.exceptions.questions import QuestionAnswerValidationError
|
171
|
+
|
169
172
|
invigilator = (
|
170
173
|
self.invigilator.to_dict() if self.invigilator is not None else None
|
171
174
|
)
|
@@ -174,7 +177,16 @@ class InterviewExceptionEntry:
|
|
174
177
|
"time": self.time,
|
175
178
|
"traceback": self.traceback,
|
176
179
|
"invigilator": invigilator,
|
180
|
+
"additional_data": {},
|
177
181
|
}
|
182
|
+
|
183
|
+
if isinstance(self.exception, QuestionAnswerValidationError):
|
184
|
+
d["additional_data"]["edsl_response"] = json.dumps(self.exception.data)
|
185
|
+
d["additional_data"]["validating_model"] = json.dumps(
|
186
|
+
self.exception.model.model_json_schema()
|
187
|
+
)
|
188
|
+
d["additional_data"]["error_message"] = str(self.exception.message)
|
189
|
+
|
178
190
|
return d
|
179
191
|
|
180
192
|
@classmethod
|
edsl/jobs/tasks/TaskHistory.py
CHANGED
@@ -419,7 +419,7 @@ class TaskHistory(RepresentationMixin):
|
|
419
419
|
filename: Optional[str] = None,
|
420
420
|
return_link=False,
|
421
421
|
css=None,
|
422
|
-
cta="
|
422
|
+
cta="<br><span style='font-size: 18px; font-weight: medium-bold; text-decoration: underline;'>Click to open the report in a new tab</span><br><br>",
|
423
423
|
open_in_browser=False,
|
424
424
|
):
|
425
425
|
"""Return an HTML report."""
|
@@ -379,8 +379,10 @@ class LanguageModel(
|
|
379
379
|
cached_response, cache_key = cache.fetch(**cache_call_params)
|
380
380
|
|
381
381
|
if cache_used := cached_response is not None:
|
382
|
+
# print("cache used")
|
382
383
|
response = json.loads(cached_response)
|
383
384
|
else:
|
385
|
+
# print("cache not used")
|
384
386
|
f = (
|
385
387
|
self.remote_async_execute_model_call
|
386
388
|
if hasattr(self, "remote") and self.remote
|
@@ -394,14 +396,16 @@ class LanguageModel(
|
|
394
396
|
from edsl.config import CONFIG
|
395
397
|
|
396
398
|
TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
|
397
|
-
|
398
399
|
response = await asyncio.wait_for(f(**params), timeout=TIMEOUT)
|
399
400
|
new_cache_key = cache.store(
|
400
401
|
**cache_call_params, response=response
|
401
402
|
) # store the response in the cache
|
402
403
|
assert new_cache_key == cache_key # should be the same
|
403
404
|
|
405
|
+
#breakpoint()
|
406
|
+
|
404
407
|
cost = self.cost(response)
|
408
|
+
#breakpoint()
|
405
409
|
return ModelResponse(
|
406
410
|
response=response,
|
407
411
|
cache_used=cache_used,
|
@@ -466,6 +470,7 @@ class LanguageModel(
|
|
466
470
|
model_outputs=model_outputs,
|
467
471
|
edsl_dict=edsl_dict,
|
468
472
|
)
|
473
|
+
#breakpoint()
|
469
474
|
return agent_response_dict
|
470
475
|
|
471
476
|
get_response = sync_wrapper(async_get_response)
|
@@ -518,8 +523,6 @@ class LanguageModel(
|
|
518
523
|
"""
|
519
524
|
from edsl.language_models.model import get_model_class
|
520
525
|
|
521
|
-
# breakpoint()
|
522
|
-
|
523
526
|
model_class = get_model_class(
|
524
527
|
data["model"], service_name=data.get("inference_service", None)
|
525
528
|
)
|
@@ -30,19 +30,22 @@ class PriceManager:
|
|
30
30
|
except Exception as e:
|
31
31
|
print(f"Error fetching prices: {str(e)}")
|
32
32
|
|
33
|
-
def get_price(self, inference_service: str, model: str) ->
|
33
|
+
def get_price(self, inference_service: str, model: str) -> Dict:
|
34
34
|
"""
|
35
35
|
Get the price information for a specific service and model combination.
|
36
|
+
If no specific price is found, returns a fallback price.
|
36
37
|
|
37
38
|
Args:
|
38
39
|
inference_service (str): The name of the inference service
|
39
40
|
model (str): The model identifier
|
40
41
|
|
41
42
|
Returns:
|
42
|
-
|
43
|
+
Dict: Price information (either actual or fallback prices)
|
43
44
|
"""
|
44
45
|
key = (inference_service, model)
|
45
|
-
return self._price_lookup.get(key)
|
46
|
+
return self._price_lookup.get(key) or self._get_fallback_price(
|
47
|
+
inference_service
|
48
|
+
)
|
46
49
|
|
47
50
|
def get_all_prices(self) -> Dict[Tuple[str, str], Dict]:
|
48
51
|
"""
|
@@ -53,6 +56,45 @@ class PriceManager:
|
|
53
56
|
"""
|
54
57
|
return self._price_lookup.copy()
|
55
58
|
|
59
|
+
def _get_fallback_price(self, inference_service: str) -> Dict:
|
60
|
+
"""
|
61
|
+
Get fallback prices for a service.
|
62
|
+
- First fallback: The highest input and output prices for that service from the price lookup.
|
63
|
+
- Second fallback: $1.00 per million tokens (for both input and output).
|
64
|
+
|
65
|
+
Args:
|
66
|
+
inference_service (str): The inference service name
|
67
|
+
|
68
|
+
Returns:
|
69
|
+
Dict: Price information
|
70
|
+
"""
|
71
|
+
service_prices = [
|
72
|
+
prices
|
73
|
+
for (service, _), prices in self._price_lookup.items()
|
74
|
+
if service == inference_service
|
75
|
+
]
|
76
|
+
|
77
|
+
input_tokens_per_usd = [
|
78
|
+
float(p["input"]["one_usd_buys"]) for p in service_prices if "input" in p
|
79
|
+
]
|
80
|
+
if input_tokens_per_usd:
|
81
|
+
min_input_tokens = min(input_tokens_per_usd)
|
82
|
+
else:
|
83
|
+
min_input_tokens = 1_000_000
|
84
|
+
|
85
|
+
output_tokens_per_usd = [
|
86
|
+
float(p["output"]["one_usd_buys"]) for p in service_prices if "output" in p
|
87
|
+
]
|
88
|
+
if output_tokens_per_usd:
|
89
|
+
min_output_tokens = min(output_tokens_per_usd)
|
90
|
+
else:
|
91
|
+
min_output_tokens = 1_000_000
|
92
|
+
|
93
|
+
return {
|
94
|
+
"input": {"one_usd_buys": min_input_tokens},
|
95
|
+
"output": {"one_usd_buys": min_output_tokens},
|
96
|
+
}
|
97
|
+
|
56
98
|
def calculate_cost(
|
57
99
|
self,
|
58
100
|
inference_service: str,
|
@@ -75,8 +117,6 @@ class PriceManager:
|
|
75
117
|
Union[float, str]: Total cost if calculation successful, error message string if not
|
76
118
|
"""
|
77
119
|
relevant_prices = self.get_price(inference_service, model)
|
78
|
-
if relevant_prices is None:
|
79
|
-
return f"Could not find price for model {model} in the price lookup."
|
80
120
|
|
81
121
|
# Extract token counts
|
82
122
|
try:
|
edsl/language_models/model.py
CHANGED
@@ -17,7 +17,11 @@ if TYPE_CHECKING:
|
|
17
17
|
from edsl.results.Dataset import Dataset
|
18
18
|
|
19
19
|
|
20
|
-
def get_model_class(
|
20
|
+
def get_model_class(
|
21
|
+
model_name,
|
22
|
+
registry: Optional[InferenceServicesCollection] = None,
|
23
|
+
service_name: Optional[InferenceServiceLiteral] = None,
|
24
|
+
):
|
21
25
|
from edsl.inference_services.registry import default
|
22
26
|
|
23
27
|
registry = registry or default
|
@@ -40,6 +44,9 @@ class Meta(type):
|
|
40
44
|
To get the default model, you can leave out the model name.
|
41
45
|
To see the available models, you can do:
|
42
46
|
>>> Model.available()
|
47
|
+
|
48
|
+
Or to see the models for a specific service, you can do:
|
49
|
+
>>> Model.available(service='openai')
|
43
50
|
"""
|
44
51
|
)
|
45
52
|
|
@@ -97,7 +104,10 @@ class Model(metaclass=Meta):
|
|
97
104
|
*args,
|
98
105
|
**kwargs,
|
99
106
|
):
|
100
|
-
"Instantiate a new language model.
|
107
|
+
"""Instantiate a new language model.
|
108
|
+
>>> Model()
|
109
|
+
Model(...)
|
110
|
+
"""
|
101
111
|
# Map index to the respective subclass
|
102
112
|
if model_name is None:
|
103
113
|
model_name = cls.default_model
|
@@ -127,28 +137,25 @@ class Model(metaclass=Meta):
|
|
127
137
|
>>> Model.service_classes()
|
128
138
|
[...]
|
129
139
|
"""
|
130
|
-
return [r for r in cls.services(
|
140
|
+
return [r for r in cls.services()]
|
131
141
|
|
132
142
|
@classmethod
|
133
143
|
def services(cls, name_only: bool = False) -> List[str]:
|
134
|
-
"""Returns a list of services
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
[r._inference_service_ for r in cls.get_registry().services],
|
142
|
-
columns=["Service Name"],
|
143
|
-
)
|
144
|
-
else:
|
145
|
-
return PrettyList(
|
144
|
+
"""Returns a list of services excluding 'test', sorted alphabetically.
|
145
|
+
|
146
|
+
>>> Model.services()
|
147
|
+
[...]
|
148
|
+
"""
|
149
|
+
return PrettyList(
|
150
|
+
sorted(
|
146
151
|
[
|
147
|
-
|
152
|
+
[r._inference_service_]
|
148
153
|
for r in cls.get_registry().services
|
149
|
-
|
150
|
-
|
151
|
-
)
|
154
|
+
if r._inference_service_.lower() != "test"
|
155
|
+
]
|
156
|
+
),
|
157
|
+
columns=["Service Name"],
|
158
|
+
)
|
152
159
|
|
153
160
|
@classmethod
|
154
161
|
def services_with_local_keys(cls) -> set:
|
@@ -198,7 +205,15 @@ class Model(metaclass=Meta):
|
|
198
205
|
search_term: str = None,
|
199
206
|
name_only: bool = False,
|
200
207
|
service: Optional[str] = None,
|
208
|
+
force_refresh: bool = False,
|
201
209
|
):
|
210
|
+
"""Get available models
|
211
|
+
|
212
|
+
>>> Model.available()
|
213
|
+
[...]
|
214
|
+
>>> Model.available(service='openai')
|
215
|
+
[...]
|
216
|
+
"""
|
202
217
|
# if search_term is None and service is None:
|
203
218
|
# print("Getting available models...")
|
204
219
|
# print("You have local keys for the following services:")
|
@@ -209,13 +224,16 @@ class Model(metaclass=Meta):
|
|
209
224
|
# return None
|
210
225
|
|
211
226
|
if service is not None:
|
212
|
-
|
227
|
+
known_services = [x[0] for x in cls.services(name_only=True)]
|
228
|
+
if service not in known_services:
|
213
229
|
raise ValueError(
|
214
230
|
f"Service {service} not found in available services.",
|
215
|
-
f"Available services are: {
|
231
|
+
f"Available services are: {known_services}",
|
216
232
|
)
|
217
233
|
|
218
|
-
full_list = cls.get_registry().available(
|
234
|
+
full_list = cls.get_registry().available(
|
235
|
+
service=service, force_refresh=force_refresh
|
236
|
+
)
|
219
237
|
|
220
238
|
if search_term is None:
|
221
239
|
if name_only:
|
@@ -319,6 +337,9 @@ class Model(metaclass=Meta):
|
|
319
337
|
"""
|
320
338
|
Returns an example Model instance.
|
321
339
|
|
340
|
+
>>> Model.example()
|
341
|
+
Model(...)
|
342
|
+
|
322
343
|
:param randomize: If True, the temperature is set to a random decimal between 0 and 1.
|
323
344
|
"""
|
324
345
|
temperature = 0.5 if not randomize else round(random(), 2)
|
@@ -331,7 +352,7 @@ if __name__ == "__main__":
|
|
331
352
|
|
332
353
|
doctest.testmod(optionflags=doctest.ELLIPSIS)
|
333
354
|
|
334
|
-
available = Model.available()
|
335
|
-
m = Model("gpt-4-1106-preview")
|
336
|
-
results = m.execute_model_call("Hello world")
|
337
|
-
print(results)
|
355
|
+
# available = Model.available()
|
356
|
+
# m = Model("gpt-4-1106-preview")
|
357
|
+
# results = m.execute_model_call("Hello world")
|
358
|
+
# print(results)
|
edsl/questions/QuestionBase.py
CHANGED
@@ -85,6 +85,9 @@ class QuestionBase(
|
|
85
85
|
>>> Q.example()._simulate_answer()
|
86
86
|
{'answer': '...', 'generated_tokens': ...}
|
87
87
|
"""
|
88
|
+
if self.question_type == "free_text":
|
89
|
+
return {"answer": "Hello, how are you?", 'generated_tokens': "Hello, how are you?"}
|
90
|
+
|
88
91
|
simulated_answer = self.fake_data_factory.build().dict()
|
89
92
|
if human_readable and hasattr(self, "question_options") and self.use_code:
|
90
93
|
simulated_answer["answer"] = [
|
@@ -432,6 +435,24 @@ class QuestionBase(
|
|
432
435
|
|
433
436
|
return Survey([self])
|
434
437
|
|
438
|
+
def humanize(
|
439
|
+
self,
|
440
|
+
project_name: str = "Project",
|
441
|
+
survey_description: Optional[str] = None,
|
442
|
+
survey_alias: Optional[str] = None,
|
443
|
+
survey_visibility: Optional["VisibilityType"] = "unlisted",
|
444
|
+
) -> dict:
|
445
|
+
"""
|
446
|
+
Turn a single question into a survey and send the survey to Coop.
|
447
|
+
|
448
|
+
Then, create a project on Coop so you can share the survey with human respondents.
|
449
|
+
"""
|
450
|
+
s = self.to_survey()
|
451
|
+
project_details = s.humanize(
|
452
|
+
project_name, survey_description, survey_alias, survey_visibility
|
453
|
+
)
|
454
|
+
return project_details
|
455
|
+
|
435
456
|
def by(self, *args) -> "Jobs":
|
436
457
|
"""Turn a single question into a survey and then a Job."""
|
437
458
|
from edsl.surveys.Survey import Survey
|
@@ -187,6 +187,73 @@ class QuestionBasePromptsMixin:
|
|
187
187
|
from edsl.prompts import Prompt
|
188
188
|
|
189
189
|
return Prompt(self.question_presentation) + Prompt(self.answering_instructions)
|
190
|
+
|
191
|
+
|
192
|
+
def detailed_parameters_by_key(self) -> dict[str, set[tuple[str, ...]]]:
|
193
|
+
"""
|
194
|
+
Return a dictionary of parameters by key.
|
195
|
+
|
196
|
+
>>> from edsl import QuestionMultipleChoice
|
197
|
+
>>> QuestionMultipleChoice.example().detailed_parameters_by_key()
|
198
|
+
{'question_name': set(), 'question_text': set()}
|
199
|
+
|
200
|
+
>>> from edsl import QuestionFreeText
|
201
|
+
>>> q = QuestionFreeText(question_name = "example", question_text = "What is your name, {{ nickname }}, based on {{ q0.answer }}?")
|
202
|
+
>>> r = q.detailed_parameters_by_key()
|
203
|
+
>>> r == {'question_name': set(), 'question_text': {('q0', 'answer'), ('nickname',)}}
|
204
|
+
True
|
205
|
+
"""
|
206
|
+
params_by_key = {}
|
207
|
+
for key, value in self.data.items():
|
208
|
+
if isinstance(value, str):
|
209
|
+
params_by_key[key] = self.extract_parameters(value)
|
210
|
+
return params_by_key
|
211
|
+
|
212
|
+
@staticmethod
|
213
|
+
def extract_parameters(txt: str) -> set[tuple[str, ...]]:
|
214
|
+
"""Return all parameters of the question as tuples representing their full paths.
|
215
|
+
|
216
|
+
:param txt: The text to extract parameters from.
|
217
|
+
:return: A set of tuples representing the parameters.
|
218
|
+
|
219
|
+
>>> from edsl.questions import QuestionMultipleChoice
|
220
|
+
>>> d = QuestionMultipleChoice.example().extract_parameters("What is your name, {{ nickname }}, based on {{ q0.answer }}?")
|
221
|
+
>>> d =={('nickname',), ('q0', 'answer')}
|
222
|
+
True
|
223
|
+
"""
|
224
|
+
from jinja2 import Environment, nodes
|
225
|
+
|
226
|
+
env = Environment()
|
227
|
+
#txt = self._all_text()
|
228
|
+
ast = env.parse(txt)
|
229
|
+
|
230
|
+
variables = set()
|
231
|
+
processed_nodes = set() # Keep track of nodes we've processed
|
232
|
+
|
233
|
+
def visit_node(node, path=()):
|
234
|
+
if id(node) in processed_nodes:
|
235
|
+
return
|
236
|
+
processed_nodes.add(id(node))
|
237
|
+
|
238
|
+
if isinstance(node, nodes.Name):
|
239
|
+
# Only add the name if we're not in the middle of building a longer path
|
240
|
+
if not path:
|
241
|
+
variables.add((node.name,))
|
242
|
+
else:
|
243
|
+
variables.add((node.name,) + path)
|
244
|
+
elif isinstance(node, nodes.Getattr):
|
245
|
+
# Build path from bottom up
|
246
|
+
new_path = (node.attr,) + path
|
247
|
+
visit_node(node.node, new_path)
|
248
|
+
|
249
|
+
for node in ast.find_all((nodes.Name, nodes.Getattr)):
|
250
|
+
visit_node(node)
|
251
|
+
|
252
|
+
return variables
|
253
|
+
|
254
|
+
@property
|
255
|
+
def detailed_parameters(self):
|
256
|
+
return [".".join(p) for p in self.extract_parameters(self._all_text())]
|
190
257
|
|
191
258
|
@property
|
192
259
|
def parameters(self) -> set[str]:
|
@@ -219,3 +286,39 @@ class QuestionBasePromptsMixin:
|
|
219
286
|
return self.new_default_instructions
|
220
287
|
else:
|
221
288
|
return self.applicable_prompts(model)[0]()
|
289
|
+
|
290
|
+
@staticmethod
|
291
|
+
def sequence_in_dict(d: dict, path: tuple[str, ...]) -> tuple[bool, any]:
|
292
|
+
"""Check if a sequence of nested keys exists in a dictionary and return the value.
|
293
|
+
|
294
|
+
Args:
|
295
|
+
d: The dictionary to check
|
296
|
+
path: Tuple of keys representing the nested path
|
297
|
+
|
298
|
+
Returns:
|
299
|
+
tuple[bool, any]: (True, value) if the path exists, (False, None) otherwise
|
300
|
+
|
301
|
+
Example:
|
302
|
+
>>> sequence_in_dict = QuestionBasePromptsMixin.sequence_in_dict
|
303
|
+
>>> d = {'a': {'b': {'c': 1}}}
|
304
|
+
>>> sequence_in_dict(d, ('a', 'b', 'c'))
|
305
|
+
(True, 1)
|
306
|
+
>>> sequence_in_dict(d, ('a', 'b', 'd'))
|
307
|
+
(False, None)
|
308
|
+
>>> sequence_in_dict(d, ('x',))
|
309
|
+
(False, None)
|
310
|
+
"""
|
311
|
+
try:
|
312
|
+
current = d
|
313
|
+
for key in path:
|
314
|
+
current = current.get(key)
|
315
|
+
if current is None:
|
316
|
+
return (False, None)
|
317
|
+
return (True, current)
|
318
|
+
except (AttributeError, TypeError):
|
319
|
+
return (False, None)
|
320
|
+
|
321
|
+
|
322
|
+
if __name__ == "__main__":
|
323
|
+
import doctest
|
324
|
+
doctest.testmod()
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
2
2
|
from typing import Any, Optional
|
3
3
|
from uuid import uuid4
|
4
4
|
|
5
|
-
from pydantic import field_validator
|
5
|
+
from pydantic import field_validator, model_validator
|
6
6
|
|
7
7
|
from edsl.questions.QuestionBase import QuestionBase
|
8
8
|
from edsl.questions.response_validator_abc import ResponseValidatorABC
|
@@ -24,6 +24,17 @@ class FreeTextResponse(BaseModel):
|
|
24
24
|
answer: str
|
25
25
|
generated_tokens: Optional[str] = None
|
26
26
|
|
27
|
+
@model_validator(mode='after')
|
28
|
+
def validate_tokens_match_answer(self):
|
29
|
+
if self.generated_tokens is not None: # If generated_tokens exists
|
30
|
+
# Ensure exact string equality
|
31
|
+
if self.answer.strip() != self.generated_tokens.strip(): # They MUST match exactly
|
32
|
+
raise ValueError(
|
33
|
+
f"answer '{self.answer}' must exactly match generated_tokens '{self.generated_tokens}'. "
|
34
|
+
f"Type of answer: {type(self.answer)}, Type of tokens: {type(self.generated_tokens)}"
|
35
|
+
)
|
36
|
+
return self
|
37
|
+
|
27
38
|
|
28
39
|
class FreeTextResponseValidator(ResponseValidatorABC):
|
29
40
|
required_params = []
|
@@ -37,10 +48,16 @@ class FreeTextResponseValidator(ResponseValidatorABC):
|
|
37
48
|
]
|
38
49
|
|
39
50
|
def fix(self, response, verbose=False):
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
51
|
+
if response.get("generated_tokens") != response.get("answer"):
|
52
|
+
return {
|
53
|
+
"answer": str(response.get("generated_tokens")),
|
54
|
+
"generated_tokens": str(response.get("generated_tokens")),
|
55
|
+
}
|
56
|
+
else:
|
57
|
+
return {
|
58
|
+
"answer": str(response.get("generated_tokens")),
|
59
|
+
"generated_tokens": str(response.get("generated_tokens")),
|
60
|
+
}
|
44
61
|
|
45
62
|
|
46
63
|
class QuestionFreeText(QuestionBase):
|
edsl/questions/descriptors.py
CHANGED
@@ -2,6 +2,7 @@
|
|
2
2
|
|
3
3
|
from abc import ABC, abstractmethod
|
4
4
|
import re
|
5
|
+
import textwrap
|
5
6
|
from typing import Any, Callable, List, Optional
|
6
7
|
from edsl.exceptions.questions import (
|
7
8
|
QuestionCreationValidationError,
|
@@ -404,6 +405,9 @@ class QuestionTextDescriptor(BaseDescriptor):
|
|
404
405
|
raise Exception("Question is too short!")
|
405
406
|
if not isinstance(value, str):
|
406
407
|
raise Exception("Question must be a string!")
|
408
|
+
|
409
|
+
#value = textwrap.dedent(value).strip()
|
410
|
+
|
407
411
|
if contains_single_braced_substring(value):
|
408
412
|
import warnings
|
409
413
|
|