edsl 0.1.36.dev5__py3-none-any.whl → 0.1.37__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__init__.py +1 -0
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +92 -41
- edsl/agents/AgentList.py +15 -2
- edsl/agents/InvigilatorBase.py +15 -25
- edsl/agents/PromptConstructor.py +149 -108
- edsl/agents/descriptors.py +17 -4
- edsl/conjure/AgentConstructionMixin.py +11 -3
- edsl/conversation/Conversation.py +66 -14
- edsl/conversation/chips.py +95 -0
- edsl/coop/coop.py +148 -39
- edsl/data/Cache.py +1 -1
- edsl/data/RemoteCacheSync.py +25 -12
- edsl/exceptions/BaseException.py +21 -0
- edsl/exceptions/__init__.py +7 -3
- edsl/exceptions/agents.py +17 -19
- edsl/exceptions/results.py +11 -8
- edsl/exceptions/scenarios.py +22 -0
- edsl/exceptions/surveys.py +13 -10
- edsl/inference_services/AwsBedrock.py +7 -2
- edsl/inference_services/InferenceServicesCollection.py +42 -13
- edsl/inference_services/models_available_cache.py +25 -1
- edsl/jobs/Jobs.py +306 -71
- edsl/jobs/interviews/Interview.py +24 -14
- edsl/jobs/interviews/InterviewExceptionCollection.py +1 -1
- edsl/jobs/interviews/InterviewExceptionEntry.py +17 -13
- edsl/jobs/interviews/ReportErrors.py +2 -2
- edsl/jobs/runners/JobsRunnerAsyncio.py +10 -9
- edsl/jobs/tasks/TaskHistory.py +1 -0
- edsl/language_models/KeyLookup.py +30 -0
- edsl/language_models/LanguageModel.py +47 -59
- edsl/language_models/__init__.py +1 -0
- edsl/prompts/Prompt.py +11 -12
- edsl/questions/QuestionBase.py +53 -13
- edsl/questions/QuestionBasePromptsMixin.py +1 -33
- edsl/questions/QuestionFreeText.py +1 -0
- edsl/questions/QuestionFunctional.py +2 -2
- edsl/questions/descriptors.py +23 -28
- edsl/results/DatasetExportMixin.py +25 -1
- edsl/results/Result.py +27 -10
- edsl/results/Results.py +34 -121
- edsl/results/ResultsDBMixin.py +1 -1
- edsl/results/Selector.py +18 -1
- edsl/scenarios/FileStore.py +20 -5
- edsl/scenarios/Scenario.py +52 -13
- edsl/scenarios/ScenarioHtmlMixin.py +7 -2
- edsl/scenarios/ScenarioList.py +12 -1
- edsl/scenarios/__init__.py +2 -0
- edsl/surveys/Rule.py +10 -4
- edsl/surveys/Survey.py +100 -77
- edsl/utilities/utilities.py +18 -0
- {edsl-0.1.36.dev5.dist-info → edsl-0.1.37.dist-info}/METADATA +1 -1
- {edsl-0.1.36.dev5.dist-info → edsl-0.1.37.dist-info}/RECORD +55 -51
- {edsl-0.1.36.dev5.dist-info → edsl-0.1.37.dist-info}/LICENSE +0 -0
- {edsl-0.1.36.dev5.dist-info → edsl-0.1.37.dist-info}/WHEEL +0 -0
@@ -67,7 +67,11 @@ class InterviewExceptionEntry:
|
|
67
67
|
m = LanguageModel.example(test_model=True)
|
68
68
|
q = QuestionFreeText.example(exception_to_throw=ValueError)
|
69
69
|
results = q.by(m).run(
|
70
|
-
skip_retry=True,
|
70
|
+
skip_retry=True,
|
71
|
+
print_exceptions=False,
|
72
|
+
raise_validation_errors=True,
|
73
|
+
disable_remote_cache=True,
|
74
|
+
disable_remote_inference=True,
|
71
75
|
)
|
72
76
|
return results.task_history.exceptions[0]["how_are_you"][0]
|
73
77
|
|
@@ -132,18 +136,25 @@ class InterviewExceptionEntry:
|
|
132
136
|
)
|
133
137
|
console.print(tb)
|
134
138
|
return html_output.getvalue()
|
135
|
-
|
139
|
+
|
136
140
|
@staticmethod
|
137
141
|
def serialize_exception(exception: Exception) -> dict:
|
138
142
|
return {
|
139
143
|
"type": type(exception).__name__,
|
140
144
|
"message": str(exception),
|
141
|
-
"traceback": "".join(
|
145
|
+
"traceback": "".join(
|
146
|
+
traceback.format_exception(
|
147
|
+
type(exception), exception, exception.__traceback__
|
148
|
+
)
|
149
|
+
),
|
142
150
|
}
|
143
|
-
|
151
|
+
|
144
152
|
@staticmethod
|
145
153
|
def deserialize_exception(data: dict) -> Exception:
|
146
|
-
|
154
|
+
try:
|
155
|
+
exception_class = globals()[data["type"]]
|
156
|
+
except KeyError:
|
157
|
+
exception_class = Exception
|
147
158
|
return exception_class(data["message"])
|
148
159
|
|
149
160
|
def to_dict(self) -> dict:
|
@@ -158,7 +169,7 @@ class InterviewExceptionEntry:
|
|
158
169
|
"traceback": self.traceback,
|
159
170
|
"invigilator": self.invigilator.to_dict(),
|
160
171
|
}
|
161
|
-
|
172
|
+
|
162
173
|
@classmethod
|
163
174
|
def from_dict(cls, data: dict) -> "InterviewExceptionEntry":
|
164
175
|
"""Create an InterviewExceptionEntry from a dictionary."""
|
@@ -168,13 +179,6 @@ class InterviewExceptionEntry:
|
|
168
179
|
invigilator = InvigilatorAI.from_dict(data["invigilator"])
|
169
180
|
return cls(exception=exception, invigilator=invigilator)
|
170
181
|
|
171
|
-
def push(self):
|
172
|
-
from edsl import Coop
|
173
|
-
|
174
|
-
coop = Coop()
|
175
|
-
results = coop.error_create(self.to_dict())
|
176
|
-
return results
|
177
|
-
|
178
182
|
|
179
183
|
if __name__ == "__main__":
|
180
184
|
import doctest
|
@@ -36,8 +36,8 @@ class ReportErrors:
|
|
36
36
|
print("No input received within the timeout period.")
|
37
37
|
|
38
38
|
def upload(self):
|
39
|
-
|
40
|
-
|
39
|
+
# The previous implementation was removed because it relied on the old Coop ErrorModel
|
40
|
+
pass
|
41
41
|
|
42
42
|
|
43
43
|
def main():
|
@@ -19,6 +19,7 @@ from edsl.results.Results import Results
|
|
19
19
|
from edsl.language_models.LanguageModel import LanguageModel
|
20
20
|
from edsl.data.Cache import Cache
|
21
21
|
|
22
|
+
|
22
23
|
class StatusTracker(UserList):
|
23
24
|
def __init__(self, total_tasks: int):
|
24
25
|
self.total_tasks = total_tasks
|
@@ -164,20 +165,20 @@ class JobsRunnerAsyncio:
|
|
164
165
|
|
165
166
|
prompt_dictionary = {}
|
166
167
|
for answer_key_name in answer_key_names:
|
167
|
-
prompt_dictionary[
|
168
|
-
|
169
|
-
|
170
|
-
prompt_dictionary[
|
171
|
-
|
172
|
-
|
168
|
+
prompt_dictionary[
|
169
|
+
answer_key_name + "_user_prompt"
|
170
|
+
] = question_name_to_prompts[answer_key_name]["user_prompt"]
|
171
|
+
prompt_dictionary[
|
172
|
+
answer_key_name + "_system_prompt"
|
173
|
+
] = question_name_to_prompts[answer_key_name]["system_prompt"]
|
173
174
|
|
174
175
|
raw_model_results_dictionary = {}
|
175
176
|
cache_used_dictionary = {}
|
176
177
|
for result in valid_results:
|
177
178
|
question_name = result.question_name
|
178
|
-
raw_model_results_dictionary[
|
179
|
-
|
180
|
-
|
179
|
+
raw_model_results_dictionary[
|
180
|
+
question_name + "_raw_model_response"
|
181
|
+
] = result.raw_model_response
|
181
182
|
raw_model_results_dictionary[question_name + "_cost"] = result.cost
|
182
183
|
one_use_buys = (
|
183
184
|
"NA"
|
edsl/jobs/tasks/TaskHistory.py
CHANGED
@@ -0,0 +1,30 @@
|
|
1
|
+
import os
|
2
|
+
from collections import UserDict
|
3
|
+
|
4
|
+
from edsl.enums import service_to_api_keyname
|
5
|
+
from edsl.exceptions import MissingAPIKeyError
|
6
|
+
|
7
|
+
|
8
|
+
class KeyLookup(UserDict):
|
9
|
+
@classmethod
|
10
|
+
def from_os_environ(cls):
|
11
|
+
"""Create an instance of KeyLookupAPI with keys from os.environ"""
|
12
|
+
return cls({key: value for key, value in os.environ.items()})
|
13
|
+
|
14
|
+
def get_api_token(self, service: str, remote: bool = False):
|
15
|
+
key_name = service_to_api_keyname.get(service, "NOT FOUND")
|
16
|
+
|
17
|
+
if service == "bedrock":
|
18
|
+
api_token = [self.get(key_name[0]), self.get(key_name[1])]
|
19
|
+
missing_token = any(token is None for token in api_token)
|
20
|
+
else:
|
21
|
+
api_token = self.get(key_name)
|
22
|
+
missing_token = api_token is None
|
23
|
+
|
24
|
+
if missing_token and service != "test" and not remote:
|
25
|
+
raise MissingAPIKeyError(
|
26
|
+
f"""The key for service: `{service}` is not set.
|
27
|
+
Need a key with name {key_name} in your .env file."""
|
28
|
+
)
|
29
|
+
|
30
|
+
return api_token
|
@@ -17,9 +17,7 @@ import warnings
|
|
17
17
|
from functools import wraps
|
18
18
|
import asyncio
|
19
19
|
import json
|
20
|
-
import time
|
21
20
|
import os
|
22
|
-
import hashlib
|
23
21
|
from typing import (
|
24
22
|
Coroutine,
|
25
23
|
Any,
|
@@ -30,6 +28,7 @@ from typing import (
|
|
30
28
|
get_type_hints,
|
31
29
|
TypedDict,
|
32
30
|
Optional,
|
31
|
+
TYPE_CHECKING,
|
33
32
|
)
|
34
33
|
from abc import ABC, abstractmethod
|
35
34
|
|
@@ -49,34 +48,16 @@ from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
|
|
49
48
|
from edsl.language_models.repair import repair
|
50
49
|
from edsl.enums import InferenceServiceType
|
51
50
|
from edsl.Base import RichPrintingMixin, PersistenceMixin
|
52
|
-
from edsl.enums import service_to_api_keyname
|
53
|
-
from edsl.exceptions import MissingAPIKeyError
|
54
51
|
from edsl.language_models.RegisterLanguageModelsMeta import RegisterLanguageModelsMeta
|
55
52
|
from edsl.exceptions.language_models import LanguageModelBadResponseError
|
56
53
|
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
def convert_answer(response_part):
|
61
|
-
import json
|
62
|
-
|
63
|
-
response_part = response_part.strip()
|
64
|
-
|
65
|
-
if response_part == "None":
|
66
|
-
return None
|
67
|
-
|
68
|
-
repaired = repair_json(response_part)
|
69
|
-
if repaired == '""':
|
70
|
-
# it was a literal string
|
71
|
-
return response_part
|
54
|
+
from edsl.language_models.KeyLookup import KeyLookup
|
72
55
|
|
73
|
-
|
74
|
-
return json.loads(repaired)
|
75
|
-
except json.JSONDecodeError as j:
|
76
|
-
# last resort
|
77
|
-
return response_part
|
56
|
+
TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
|
78
57
|
|
79
58
|
|
59
|
+
# you might be tempated to move this to be a static method of LanguageModel, but this doesn't work
|
60
|
+
# for reasons I don't understand. So leave it here.
|
80
61
|
def extract_item_from_raw_response(data, key_sequence):
|
81
62
|
if isinstance(data, str):
|
82
63
|
try:
|
@@ -167,7 +148,12 @@ class LanguageModel(
|
|
167
148
|
_safety_factor = 0.8
|
168
149
|
|
169
150
|
def __init__(
|
170
|
-
self,
|
151
|
+
self,
|
152
|
+
tpm: float = None,
|
153
|
+
rpm: float = None,
|
154
|
+
omit_system_prompt_if_empty_string: bool = True,
|
155
|
+
key_lookup: Optional[KeyLookup] = None,
|
156
|
+
**kwargs,
|
171
157
|
):
|
172
158
|
"""Initialize the LanguageModel."""
|
173
159
|
self.model = getattr(self, "_model_", None)
|
@@ -200,29 +186,26 @@ class LanguageModel(
|
|
200
186
|
# Skip the API key check. Sometimes this is useful for testing.
|
201
187
|
self._api_token = None
|
202
188
|
|
189
|
+
if key_lookup is not None:
|
190
|
+
self.key_lookup = key_lookup
|
191
|
+
else:
|
192
|
+
self.key_lookup = KeyLookup.from_os_environ()
|
193
|
+
|
203
194
|
def ask_question(self, question):
|
204
195
|
user_prompt = question.get_instructions().render(question.data).text
|
205
196
|
system_prompt = "You are a helpful agent pretending to be a human."
|
206
197
|
return self.execute_model_call(user_prompt, system_prompt)
|
207
198
|
|
199
|
+
def set_key_lookup(self, key_lookup: KeyLookup):
|
200
|
+
del self._api_token
|
201
|
+
self.key_lookup = key_lookup
|
202
|
+
|
208
203
|
@property
|
209
204
|
def api_token(self) -> str:
|
210
205
|
if not hasattr(self, "_api_token"):
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
# Check if any of the tokens are None
|
215
|
-
missing_token = any(token is None for token in self._api_token)
|
216
|
-
else:
|
217
|
-
self._api_token = os.getenv(key_name)
|
218
|
-
missing_token = self._api_token is None
|
219
|
-
if missing_token and self._inference_service_ != "test" and not self.remote:
|
220
|
-
print("raising error")
|
221
|
-
raise MissingAPIKeyError(
|
222
|
-
f"""The key for service: `{self._inference_service_}` is not set.
|
223
|
-
Need a key with name {key_name} in your .env file."""
|
224
|
-
)
|
225
|
-
|
206
|
+
self._api_token = self.key_lookup.get_api_token(
|
207
|
+
self._inference_service_, self.remote
|
208
|
+
)
|
226
209
|
return self._api_token
|
227
210
|
|
228
211
|
def __getitem__(self, key):
|
@@ -291,21 +274,6 @@ class LanguageModel(
|
|
291
274
|
if tpm is not None:
|
292
275
|
self._tpm = tpm
|
293
276
|
return None
|
294
|
-
# self._set_rate_limits(rpm=rpm, tpm=tpm)
|
295
|
-
|
296
|
-
# def _set_rate_limits(self, rpm=None, tpm=None) -> None:
|
297
|
-
# """Set the rate limits for the model.
|
298
|
-
|
299
|
-
# If the model does not have rate limits, use the default rate limits."""
|
300
|
-
# if rpm is not None and tpm is not None:
|
301
|
-
# self.__rate_limits = {"rpm": rpm, "tpm": tpm}
|
302
|
-
# return
|
303
|
-
|
304
|
-
# if self.__rate_limits is None:
|
305
|
-
# if hasattr(self, "get_rate_limits"):
|
306
|
-
# self.__rate_limits = self.get_rate_limits()
|
307
|
-
# else:
|
308
|
-
# self.__rate_limits = self.__default_rate_limits
|
309
277
|
|
310
278
|
@property
|
311
279
|
def RPM(self):
|
@@ -416,6 +384,26 @@ class LanguageModel(
|
|
416
384
|
)
|
417
385
|
return extract_item_from_raw_response(raw_response, cls.usage_sequence)
|
418
386
|
|
387
|
+
@staticmethod
|
388
|
+
def convert_answer(response_part):
|
389
|
+
import json
|
390
|
+
|
391
|
+
response_part = response_part.strip()
|
392
|
+
|
393
|
+
if response_part == "None":
|
394
|
+
return None
|
395
|
+
|
396
|
+
repaired = repair_json(response_part)
|
397
|
+
if repaired == '""':
|
398
|
+
# it was a literal string
|
399
|
+
return response_part
|
400
|
+
|
401
|
+
try:
|
402
|
+
return json.loads(repaired)
|
403
|
+
except json.JSONDecodeError as j:
|
404
|
+
# last resort
|
405
|
+
return response_part
|
406
|
+
|
419
407
|
@classmethod
|
420
408
|
def parse_response(cls, raw_response: dict[str, Any]) -> EDSLOutput:
|
421
409
|
"""Parses the API response and returns the response text."""
|
@@ -425,13 +413,13 @@ class LanguageModel(
|
|
425
413
|
if last_newline == -1:
|
426
414
|
# There is no comment
|
427
415
|
edsl_dict = {
|
428
|
-
"answer": convert_answer(generated_token_string),
|
416
|
+
"answer": cls.convert_answer(generated_token_string),
|
429
417
|
"generated_tokens": generated_token_string,
|
430
418
|
"comment": None,
|
431
419
|
}
|
432
420
|
else:
|
433
421
|
edsl_dict = {
|
434
|
-
"answer": convert_answer(generated_token_string[:last_newline]),
|
422
|
+
"answer": cls.convert_answer(generated_token_string[:last_newline]),
|
435
423
|
"comment": generated_token_string[last_newline + 1 :].strip(),
|
436
424
|
"generated_tokens": generated_token_string,
|
437
425
|
}
|
@@ -492,7 +480,7 @@ class LanguageModel(
|
|
492
480
|
params = {
|
493
481
|
"user_prompt": user_prompt,
|
494
482
|
"system_prompt": system_prompt,
|
495
|
-
"files_list": files_list
|
483
|
+
"files_list": files_list,
|
496
484
|
# **({"encoded_image": encoded_image} if encoded_image else {}),
|
497
485
|
}
|
498
486
|
# response = await f(**params)
|
@@ -699,7 +687,7 @@ class LanguageModel(
|
|
699
687
|
True
|
700
688
|
>>> from edsl import QuestionFreeText
|
701
689
|
>>> q = QuestionFreeText(question_text = "What is your name?", question_name = 'example')
|
702
|
-
>>> q.by(m).run(cache = False).select('example').first()
|
690
|
+
>>> q.by(m).run(cache = False, disable_remote_cache = True, disable_remote_inference = True).select('example').first()
|
703
691
|
'WOWZA!'
|
704
692
|
"""
|
705
693
|
from edsl import Model
|
edsl/language_models/__init__.py
CHANGED
edsl/prompts/Prompt.py
CHANGED
@@ -17,14 +17,6 @@ class PreserveUndefined(Undefined):
|
|
17
17
|
|
18
18
|
|
19
19
|
from edsl.exceptions.prompts import TemplateRenderError
|
20
|
-
|
21
|
-
# from edsl.prompts.prompt_config import (
|
22
|
-
# C2A,
|
23
|
-
# names_to_component_types,
|
24
|
-
# ComponentTypes,
|
25
|
-
# NEGATIVE_INFINITY,
|
26
|
-
# )
|
27
|
-
# from edsl.prompts.registry import RegisterPromptsMeta
|
28
20
|
from edsl.Base import PersistenceMixin, RichPrintingMixin
|
29
21
|
|
30
22
|
MAX_NESTING = 100
|
@@ -60,6 +52,9 @@ class Prompt(PersistenceMixin, RichPrintingMixin):
|
|
60
52
|
text = self.default_instructions
|
61
53
|
else:
|
62
54
|
text = ""
|
55
|
+
if isinstance(text, Prompt):
|
56
|
+
# make it idempotent w/ a prompt
|
57
|
+
text = text.text
|
63
58
|
self._text = text
|
64
59
|
|
65
60
|
@classmethod
|
@@ -245,10 +240,14 @@ class Prompt(PersistenceMixin, RichPrintingMixin):
|
|
245
240
|
>>> p.render({"person": "Mr. {{last_name}}"})
|
246
241
|
Prompt(text=\"""Hello, Mr. {{ last_name }}\""")
|
247
242
|
"""
|
248
|
-
|
249
|
-
self.
|
250
|
-
|
251
|
-
|
243
|
+
try:
|
244
|
+
new_text = self._render(
|
245
|
+
self.text, primary_replacement, **additional_replacements
|
246
|
+
)
|
247
|
+
return self.__class__(text=new_text)
|
248
|
+
except Exception as e:
|
249
|
+
print(f"Error rendering prompt: {e}")
|
250
|
+
return self
|
252
251
|
|
253
252
|
@staticmethod
|
254
253
|
def _render(
|
edsl/questions/QuestionBase.py
CHANGED
@@ -150,14 +150,21 @@ class QuestionBase(
|
|
150
150
|
"_include_comment",
|
151
151
|
"_fake_data_factory",
|
152
152
|
"_use_code",
|
153
|
-
"_answering_instructions",
|
154
|
-
"_question_presentation",
|
155
153
|
"_model_instructions",
|
156
154
|
]
|
155
|
+
only_if_not_na_list = ["_answering_instructions", "_question_presentation"]
|
156
|
+
|
157
|
+
def ok(key, value):
|
158
|
+
if not key.startswith("_"):
|
159
|
+
return False
|
160
|
+
if key in exclude_list:
|
161
|
+
return False
|
162
|
+
if key in only_if_not_na_list and value is None:
|
163
|
+
return False
|
164
|
+
return True
|
165
|
+
|
157
166
|
candidate_data = {
|
158
|
-
k.replace("_", "", 1): v
|
159
|
-
for k, v in self.__dict__.items()
|
160
|
-
if k.startswith("_") and k not in exclude_list
|
167
|
+
k.replace("_", "", 1): v for k, v in self.__dict__.items() if ok(k, v)
|
161
168
|
}
|
162
169
|
|
163
170
|
if "func" in candidate_data:
|
@@ -176,7 +183,9 @@ class QuestionBase(
|
|
176
183
|
"""
|
177
184
|
candidate_data = self.data.copy()
|
178
185
|
candidate_data["question_type"] = self.question_type
|
179
|
-
return
|
186
|
+
return {
|
187
|
+
key: value for key, value in candidate_data.items() if value is not None
|
188
|
+
}
|
180
189
|
|
181
190
|
@add_edsl_version
|
182
191
|
def to_dict(self) -> dict[str, Any]:
|
@@ -239,6 +248,8 @@ class QuestionBase(
|
|
239
248
|
show_answer: bool = True,
|
240
249
|
model: Optional["LanguageModel"] = None,
|
241
250
|
cache=False,
|
251
|
+
disable_remote_cache: bool = False,
|
252
|
+
disable_remote_inference: bool = False,
|
242
253
|
**kwargs,
|
243
254
|
):
|
244
255
|
"""Run an example of the question.
|
@@ -247,7 +258,7 @@ class QuestionBase(
|
|
247
258
|
>>> m = Q._get_test_model(canned_response = "Yo, what's up?")
|
248
259
|
>>> m.execute_model_call("", "")
|
249
260
|
{'message': [{'text': "Yo, what's up?"}], 'usage': {'prompt_tokens': 1, 'completion_tokens': 1}}
|
250
|
-
>>> Q.run_example(show_answer = True, model = m)
|
261
|
+
>>> Q.run_example(show_answer = True, model = m, disable_remote_cache = True, disable_remote_inference = True)
|
251
262
|
┏━━━━━━━━━━━━━━━━┓
|
252
263
|
┃ answer ┃
|
253
264
|
┃ .how_are_you ┃
|
@@ -259,25 +270,48 @@ class QuestionBase(
|
|
259
270
|
from edsl import Model
|
260
271
|
|
261
272
|
model = Model()
|
262
|
-
results =
|
273
|
+
results = (
|
274
|
+
cls.example(**kwargs)
|
275
|
+
.by(model)
|
276
|
+
.run(
|
277
|
+
cache=cache,
|
278
|
+
disable_remote_cache=disable_remote_cache,
|
279
|
+
disable_remote_inference=disable_remote_inference,
|
280
|
+
)
|
281
|
+
)
|
263
282
|
if show_answer:
|
264
283
|
results.select("answer.*").print()
|
265
284
|
else:
|
266
285
|
return results
|
267
286
|
|
268
|
-
def __call__(
|
287
|
+
def __call__(
|
288
|
+
self,
|
289
|
+
just_answer=True,
|
290
|
+
model=None,
|
291
|
+
agent=None,
|
292
|
+
disable_remote_cache: bool = False,
|
293
|
+
disable_remote_inference: bool = False,
|
294
|
+
**kwargs,
|
295
|
+
):
|
269
296
|
"""Call the question.
|
270
297
|
|
271
298
|
|
272
299
|
>>> from edsl import QuestionFreeText as Q
|
273
300
|
>>> m = Q._get_test_model(canned_response = "Yo, what's up?")
|
274
301
|
>>> q = Q(question_name = "color", question_text = "What is your favorite color?")
|
275
|
-
>>> q(model = m)
|
302
|
+
>>> q(model = m, disable_remote_cache = True, disable_remote_inference = True)
|
276
303
|
"Yo, what's up?"
|
277
304
|
|
278
305
|
"""
|
279
306
|
survey = self.to_survey()
|
280
|
-
results = survey(
|
307
|
+
results = survey(
|
308
|
+
model=model,
|
309
|
+
agent=agent,
|
310
|
+
**kwargs,
|
311
|
+
cache=False,
|
312
|
+
disable_remote_cache=disable_remote_cache,
|
313
|
+
disable_remote_inference=disable_remote_inference,
|
314
|
+
)
|
281
315
|
if just_answer:
|
282
316
|
return results.select(f"answer.{self.question_name}").first()
|
283
317
|
else:
|
@@ -295,6 +329,7 @@ class QuestionBase(
|
|
295
329
|
just_answer: bool = True,
|
296
330
|
model: Optional["Model"] = None,
|
297
331
|
agent: Optional["Agent"] = None,
|
332
|
+
disable_remote_inference: bool = False,
|
298
333
|
**kwargs,
|
299
334
|
) -> Union[Any, "Results"]:
|
300
335
|
"""Call the question asynchronously.
|
@@ -303,12 +338,17 @@ class QuestionBase(
|
|
303
338
|
>>> from edsl import QuestionFreeText as Q
|
304
339
|
>>> m = Q._get_test_model(canned_response = "Blue")
|
305
340
|
>>> q = Q(question_name = "color", question_text = "What is your favorite color?")
|
306
|
-
>>> async def test_run_async(): result = await q.run_async(model=m); print(result)
|
341
|
+
>>> async def test_run_async(): result = await q.run_async(model=m, disable_remote_inference = True); print(result)
|
307
342
|
>>> asyncio.run(test_run_async())
|
308
343
|
Blue
|
309
344
|
"""
|
310
345
|
survey = self.to_survey()
|
311
|
-
results = await survey.run_async(
|
346
|
+
results = await survey.run_async(
|
347
|
+
model=model,
|
348
|
+
agent=agent,
|
349
|
+
disable_remote_inference=disable_remote_inference,
|
350
|
+
**kwargs,
|
351
|
+
)
|
312
352
|
if just_answer:
|
313
353
|
return results.select(f"answer.{self.question_name}").first()
|
314
354
|
else:
|
@@ -30,38 +30,6 @@ template_manager = TemplateManager()
|
|
30
30
|
|
31
31
|
|
32
32
|
class QuestionBasePromptsMixin:
|
33
|
-
# @classmethod
|
34
|
-
# @lru_cache(maxsize=1)
|
35
|
-
# def _read_template(cls, template_name):
|
36
|
-
# with resources.open_text(
|
37
|
-
# f"edsl.questions.templates.{cls.question_type}", template_name
|
38
|
-
# ) as file:
|
39
|
-
# return file.read()
|
40
|
-
|
41
|
-
# @classmethod
|
42
|
-
# def applicable_prompts(
|
43
|
-
# cls, model: Optional[str] = None
|
44
|
-
# ) -> list[type["PromptBase"]]:
|
45
|
-
# """Get the prompts that are applicable to the question type.
|
46
|
-
|
47
|
-
# :param model: The language model to use.
|
48
|
-
|
49
|
-
# >>> from edsl.questions import QuestionFreeText
|
50
|
-
# >>> QuestionFreeText.applicable_prompts()
|
51
|
-
# [<class 'edsl.prompts.library.question_freetext.FreeText'>]
|
52
|
-
|
53
|
-
# :param model: The language model to use. If None, assumes does not matter.
|
54
|
-
|
55
|
-
# """
|
56
|
-
# from edsl.prompts.registry import get_classes as prompt_lookup
|
57
|
-
|
58
|
-
# applicable_prompts = prompt_lookup(
|
59
|
-
# component_type="question_instructions",
|
60
|
-
# question_type=cls.question_type,
|
61
|
-
# model=model,
|
62
|
-
# )
|
63
|
-
# return applicable_prompts
|
64
|
-
|
65
33
|
@property
|
66
34
|
def model_instructions(self) -> dict:
|
67
35
|
"""Get the model-specific instructions for the question."""
|
@@ -231,7 +199,7 @@ class QuestionBasePromptsMixin:
|
|
231
199
|
@property
|
232
200
|
def new_default_instructions(self) -> "Prompt":
|
233
201
|
"This is set up as a property because there are mutable question values that determine how it is rendered."
|
234
|
-
return self.question_presentation + self.answering_instructions
|
202
|
+
return Prompt(self.question_presentation) + Prompt(self.answering_instructions)
|
235
203
|
|
236
204
|
@property
|
237
205
|
def parameters(self) -> set[str]:
|
@@ -19,7 +19,7 @@ class QuestionFunctional(QuestionBase):
|
|
19
19
|
>>> question.activate()
|
20
20
|
>>> scenario = Scenario({"numbers": [1, 2, 3, 4, 5]})
|
21
21
|
>>> agent = Agent(traits={"multiplier": 10})
|
22
|
-
>>> results = question.by(scenario).by(agent).run()
|
22
|
+
>>> results = question.by(scenario).by(agent).run(disable_remote_cache = True, disable_remote_inference = True)
|
23
23
|
>>> results.select("answer.*").to_list()[0] == 150
|
24
24
|
True
|
25
25
|
|
@@ -27,7 +27,7 @@ class QuestionFunctional(QuestionBase):
|
|
27
27
|
|
28
28
|
>>> from edsl.questions.QuestionBase import QuestionBase
|
29
29
|
>>> new_question = QuestionBase.from_dict(question.to_dict())
|
30
|
-
>>> results = new_question.by(scenario).by(agent).run()
|
30
|
+
>>> results = new_question.by(scenario).by(agent).run(disable_remote_cache = True, disable_remote_inference = True)
|
31
31
|
>>> results.select("answer.*").to_list()[0] == 150
|
32
32
|
True
|
33
33
|
|