edsl 0.1.50__py3-none-any.whl → 0.1.52__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__init__.py +45 -34
- edsl/__version__.py +1 -1
- edsl/base/base_exception.py +2 -2
- edsl/buckets/bucket_collection.py +1 -1
- edsl/buckets/exceptions.py +32 -0
- edsl/buckets/token_bucket_api.py +26 -10
- edsl/caching/cache.py +5 -2
- edsl/caching/remote_cache_sync.py +5 -5
- edsl/caching/sql_dict.py +12 -11
- edsl/config/__init__.py +1 -1
- edsl/config/config_class.py +4 -2
- edsl/conversation/Conversation.py +9 -5
- edsl/conversation/car_buying.py +1 -3
- edsl/conversation/mug_negotiation.py +2 -6
- edsl/coop/__init__.py +11 -8
- edsl/coop/coop.py +15 -13
- edsl/coop/coop_functions.py +1 -1
- edsl/coop/ep_key_handling.py +1 -1
- edsl/coop/price_fetcher.py +2 -2
- edsl/coop/utils.py +2 -2
- edsl/dataset/dataset.py +144 -63
- edsl/dataset/dataset_operations_mixin.py +14 -6
- edsl/dataset/dataset_tree.py +3 -3
- edsl/dataset/display/table_renderers.py +6 -3
- edsl/dataset/file_exports.py +4 -4
- edsl/dataset/r/ggplot.py +3 -3
- edsl/inference_services/available_model_fetcher.py +2 -2
- edsl/inference_services/data_structures.py +5 -5
- edsl/inference_services/inference_service_abc.py +1 -1
- edsl/inference_services/inference_services_collection.py +1 -1
- edsl/inference_services/service_availability.py +3 -3
- edsl/inference_services/services/azure_ai.py +3 -3
- edsl/inference_services/services/google_service.py +1 -1
- edsl/inference_services/services/test_service.py +1 -1
- edsl/instructions/change_instruction.py +5 -4
- edsl/instructions/instruction.py +1 -0
- edsl/instructions/instruction_collection.py +5 -4
- edsl/instructions/instruction_handler.py +10 -8
- edsl/interviews/answering_function.py +20 -21
- edsl/interviews/exception_tracking.py +3 -2
- edsl/interviews/interview.py +1 -1
- edsl/interviews/interview_status_dictionary.py +1 -1
- edsl/interviews/interview_task_manager.py +7 -4
- edsl/interviews/request_token_estimator.py +3 -2
- edsl/interviews/statistics.py +2 -2
- edsl/invigilators/invigilators.py +34 -6
- edsl/jobs/__init__.py +39 -2
- edsl/jobs/async_interview_runner.py +1 -1
- edsl/jobs/check_survey_scenario_compatibility.py +5 -5
- edsl/jobs/data_structures.py +2 -2
- edsl/jobs/html_table_job_logger.py +494 -257
- edsl/jobs/jobs.py +2 -2
- edsl/jobs/jobs_checks.py +5 -5
- edsl/jobs/jobs_component_constructor.py +2 -2
- edsl/jobs/jobs_pricing_estimation.py +1 -1
- edsl/jobs/jobs_runner_asyncio.py +2 -2
- edsl/jobs/jobs_status_enums.py +1 -0
- edsl/jobs/remote_inference.py +47 -13
- edsl/jobs/results_exceptions_handler.py +2 -2
- edsl/language_models/language_model.py +151 -145
- edsl/notebooks/__init__.py +24 -1
- edsl/notebooks/exceptions.py +82 -0
- edsl/notebooks/notebook.py +7 -3
- edsl/notebooks/notebook_to_latex.py +1 -1
- edsl/prompts/__init__.py +23 -2
- edsl/prompts/prompt.py +1 -1
- edsl/questions/__init__.py +4 -4
- edsl/questions/answer_validator_mixin.py +0 -5
- edsl/questions/compose_questions.py +2 -2
- edsl/questions/descriptors.py +1 -1
- edsl/questions/question_base.py +32 -3
- edsl/questions/question_base_prompts_mixin.py +4 -4
- edsl/questions/question_budget.py +503 -102
- edsl/questions/question_check_box.py +658 -156
- edsl/questions/question_dict.py +176 -2
- edsl/questions/question_extract.py +401 -61
- edsl/questions/question_free_text.py +77 -9
- edsl/questions/question_functional.py +118 -9
- edsl/questions/{derived/question_likert_five.py → question_likert_five.py} +2 -2
- edsl/questions/{derived/question_linear_scale.py → question_linear_scale.py} +3 -4
- edsl/questions/question_list.py +246 -26
- edsl/questions/question_matrix.py +586 -73
- edsl/questions/question_multiple_choice.py +213 -47
- edsl/questions/question_numerical.py +360 -29
- edsl/questions/question_rank.py +401 -124
- edsl/questions/question_registry.py +3 -3
- edsl/questions/{derived/question_top_k.py → question_top_k.py} +3 -3
- edsl/questions/{derived/question_yes_no.py → question_yes_no.py} +3 -4
- edsl/questions/register_questions_meta.py +2 -1
- edsl/questions/response_validator_abc.py +6 -2
- edsl/questions/response_validator_factory.py +10 -12
- edsl/results/report.py +1 -1
- edsl/results/result.py +7 -4
- edsl/results/results.py +500 -271
- edsl/results/results_selector.py +2 -2
- edsl/scenarios/construct_download_link.py +3 -3
- edsl/scenarios/scenario.py +1 -2
- edsl/scenarios/scenario_list.py +41 -23
- edsl/surveys/survey_css.py +3 -3
- edsl/surveys/survey_simulator.py +2 -1
- edsl/tasks/__init__.py +22 -2
- edsl/tasks/exceptions.py +72 -0
- edsl/tasks/task_history.py +48 -11
- edsl/templates/error_reporting/base.html +37 -4
- edsl/templates/error_reporting/exceptions_table.html +105 -33
- edsl/templates/error_reporting/interview_details.html +130 -126
- edsl/templates/error_reporting/overview.html +21 -25
- edsl/templates/error_reporting/report.css +215 -46
- edsl/templates/error_reporting/report.js +122 -20
- edsl/tokens/__init__.py +27 -1
- edsl/tokens/exceptions.py +37 -0
- edsl/tokens/interview_token_usage.py +3 -2
- edsl/tokens/token_usage.py +4 -3
- {edsl-0.1.50.dist-info → edsl-0.1.52.dist-info}/METADATA +1 -1
- {edsl-0.1.50.dist-info → edsl-0.1.52.dist-info}/RECORD +118 -116
- edsl/questions/derived/__init__.py +0 -0
- {edsl-0.1.50.dist-info → edsl-0.1.52.dist-info}/LICENSE +0 -0
- {edsl-0.1.50.dist-info → edsl-0.1.52.dist-info}/WHEEL +0 -0
- {edsl-0.1.50.dist-info → edsl-0.1.52.dist-info}/entry_points.txt +0 -0
@@ -55,7 +55,7 @@ class AvailableModelFetcher:
|
|
55
55
|
|
56
56
|
:param service: Optional[InferenceServiceABC] - If specified, only fetch models for this service.
|
57
57
|
|
58
|
-
>>> from
|
58
|
+
>>> from .services.open_ai_service import OpenAIService
|
59
59
|
>>> af = AvailableModelFetcher([OpenAIService()], {})
|
60
60
|
>>> af.available(service="openai")
|
61
61
|
[LanguageModelInfo(model_name='...', service_name='openai'), ...]
|
@@ -155,7 +155,7 @@ class AvailableModelFetcher:
|
|
155
155
|
"""The service name is the _inference_service_ attribute of the service."""
|
156
156
|
if service_name in self._service_map:
|
157
157
|
return self._service_map[service_name]
|
158
|
-
from
|
158
|
+
from .exceptions import InferenceServiceValueError
|
159
159
|
raise InferenceServiceValueError(f"Service {service_name} not found")
|
160
160
|
|
161
161
|
def _get_all_models(self, force_refresh=False) -> List[LanguageModelInfo]:
|
@@ -43,7 +43,7 @@ class LanguageModelInfo:
|
|
43
43
|
elif key == 1:
|
44
44
|
return self.service_name
|
45
45
|
else:
|
46
|
-
from
|
46
|
+
from .exceptions import InferenceServiceIndexError
|
47
47
|
raise InferenceServiceIndexError("Index out of range")
|
48
48
|
|
49
49
|
@classmethod
|
@@ -70,7 +70,7 @@ class AvailableModels(UserList):
|
|
70
70
|
return self.to_dataset().print()
|
71
71
|
|
72
72
|
def to_dataset(self):
|
73
|
-
from
|
73
|
+
from ..scenarios.scenario_list import ScenarioList
|
74
74
|
|
75
75
|
models, services = zip(
|
76
76
|
*[(model.model_name, model.service_name) for model in self]
|
@@ -106,14 +106,14 @@ class AvailableModels(UserList):
|
|
106
106
|
]
|
107
107
|
)
|
108
108
|
if len(avm) == 0:
|
109
|
-
from
|
109
|
+
from .exceptions import InferenceServiceValueError
|
110
110
|
raise InferenceServiceValueError(
|
111
111
|
"No models found matching the search pattern: " + pattern
|
112
112
|
)
|
113
113
|
else:
|
114
114
|
return avm
|
115
115
|
except re.error as e:
|
116
|
-
from
|
116
|
+
from .exceptions import InferenceServiceValueError
|
117
117
|
raise InferenceServiceValueError(f"Invalid regular expression pattern: {e}")
|
118
118
|
|
119
119
|
|
@@ -128,7 +128,7 @@ class ServiceToModelsMapping(UserDict):
|
|
128
128
|
def _validate_service_names(self):
|
129
129
|
for service in self.service_names:
|
130
130
|
if service not in InferenceServiceLiteral:
|
131
|
-
from
|
131
|
+
from .exceptions import InferenceServiceValueError
|
132
132
|
raise InferenceServiceValueError(f"Invalid service name: {service}")
|
133
133
|
|
134
134
|
def model_to_services(self) -> dict:
|
@@ -26,7 +26,7 @@ class InferenceServiceABC(ABC):
|
|
26
26
|
]
|
27
27
|
for attr in must_have_attributes:
|
28
28
|
if not hasattr(cls, attr):
|
29
|
-
from
|
29
|
+
from .exceptions import InferenceServiceNotImplementedError
|
30
30
|
raise InferenceServiceNotImplementedError(
|
31
31
|
f"Class {cls.__name__} must have a '{attr}' attribute."
|
32
32
|
)
|
@@ -1,7 +1,7 @@
|
|
1
1
|
from collections import defaultdict
|
2
2
|
from typing import Optional, Protocol, Dict, List, Tuple, TYPE_CHECKING
|
3
3
|
|
4
|
-
from
|
4
|
+
from ..enums import InferenceServiceLiteral
|
5
5
|
from .inference_service_abc import InferenceServiceABC
|
6
6
|
from .available_model_fetcher import AvailableModelFetcher
|
7
7
|
from .exceptions import InferenceServiceError
|
@@ -42,7 +42,7 @@ class ServiceAvailability:
|
|
42
42
|
@classmethod
|
43
43
|
def models_from_coop(cls) -> AvailableModels:
|
44
44
|
if not cls._coop_model_list:
|
45
|
-
from
|
45
|
+
from ..coop.coop import Coop
|
46
46
|
|
47
47
|
c = Coop()
|
48
48
|
coop_model_list = c.fetch_models()
|
@@ -74,7 +74,7 @@ class ServiceAvailability:
|
|
74
74
|
continue
|
75
75
|
|
76
76
|
# If we get here, all sources failed
|
77
|
-
from
|
77
|
+
from .exceptions import InferenceServiceRuntimeError
|
78
78
|
raise InferenceServiceRuntimeError(
|
79
79
|
f"All sources failed to fetch models. Last error: {last_error}"
|
80
80
|
)
|
@@ -93,7 +93,7 @@ class ServiceAvailability:
|
|
93
93
|
@staticmethod
|
94
94
|
def _fetch_from_cache(service: "InferenceServiceABC") -> ModelNamesList:
|
95
95
|
"""Fetch models from local cache."""
|
96
|
-
from
|
96
|
+
from .models_available_cache import models_available
|
97
97
|
|
98
98
|
return models_available.get(service._inference_service_, [])
|
99
99
|
|
@@ -46,7 +46,7 @@ class AzureAIService(InferenceServiceABC):
|
|
46
46
|
out = []
|
47
47
|
azure_endpoints = os.getenv("AZURE_ENDPOINT_URL_AND_KEY", None)
|
48
48
|
if not azure_endpoints:
|
49
|
-
from
|
49
|
+
from ..exceptions import InferenceServiceEnvironmentError
|
50
50
|
raise InferenceServiceEnvironmentError("AZURE_ENDPOINT_URL_AND_KEY is not defined")
|
51
51
|
azure_endpoints = azure_endpoints.split(",")
|
52
52
|
for data in azure_endpoints:
|
@@ -135,7 +135,7 @@ class AzureAIService(InferenceServiceABC):
|
|
135
135
|
api_key = None
|
136
136
|
|
137
137
|
if not api_key:
|
138
|
-
from
|
138
|
+
from ..exceptions import InferenceServiceEnvironmentError
|
139
139
|
raise InferenceServiceEnvironmentError(
|
140
140
|
f"AZURE_ENDPOINT_URL_AND_KEY doesn't have the endpoint:key pair for your model: {model_name}"
|
141
141
|
)
|
@@ -146,7 +146,7 @@ class AzureAIService(InferenceServiceABC):
|
|
146
146
|
endpoint = None
|
147
147
|
|
148
148
|
if not endpoint:
|
149
|
-
from
|
149
|
+
from ..exceptions import InferenceServiceEnvironmentError
|
150
150
|
raise InferenceServiceEnvironmentError(
|
151
151
|
f"AZURE_ENDPOINT_URL_AND_KEY doesn't have the endpoint:key pair for your model: {model_name}"
|
152
152
|
)
|
@@ -5,7 +5,7 @@ import google.generativeai as genai
|
|
5
5
|
from google.generativeai.types import GenerationConfig
|
6
6
|
from google.api_core.exceptions import InvalidArgument
|
7
7
|
|
8
|
-
# from
|
8
|
+
# from ...exceptions.general import MissingAPIKeyError
|
9
9
|
from ..inference_service_abc import InferenceServiceABC
|
10
10
|
from ...language_models import LanguageModel
|
11
11
|
|
@@ -74,7 +74,7 @@ class TestService(InferenceServiceABC):
|
|
74
74
|
p = 1
|
75
75
|
|
76
76
|
if random.random() < p:
|
77
|
-
from
|
77
|
+
from ..exceptions import InferenceServiceError
|
78
78
|
raise InferenceServiceError("This is a test error")
|
79
79
|
|
80
80
|
if hasattr(self, "func"):
|
@@ -1,5 +1,5 @@
|
|
1
1
|
from typing import List, Optional
|
2
|
-
from
|
2
|
+
from ..utilities.remove_edsl_version import remove_edsl_version
|
3
3
|
|
4
4
|
|
5
5
|
class ChangeInstruction:
|
@@ -9,11 +9,12 @@ class ChangeInstruction:
|
|
9
9
|
drop: Optional[List[str]] = None,
|
10
10
|
):
|
11
11
|
if keep is None and drop is None:
|
12
|
-
from
|
12
|
+
from .exceptions import InstructionValueError
|
13
13
|
raise InstructionValueError("Keep and drop cannot both be None")
|
14
14
|
|
15
15
|
self.keep = keep or []
|
16
16
|
self.drop = drop or []
|
17
|
+
self.pseudo_index = 0.0
|
17
18
|
|
18
19
|
def include_instruction(self, instruction_name) -> bool:
|
19
20
|
return (instruction_name in self.keep) or (instruction_name not in self.drop)
|
@@ -30,7 +31,7 @@ class ChangeInstruction:
|
|
30
31
|
"drop": self.drop,
|
31
32
|
}
|
32
33
|
if add_edsl_version:
|
33
|
-
from
|
34
|
+
from .. import __version__
|
34
35
|
|
35
36
|
d["edsl_version"] = __version__
|
36
37
|
d["edsl_class_name"] = "ChangeInstruction"
|
@@ -39,7 +40,7 @@ class ChangeInstruction:
|
|
39
40
|
|
40
41
|
def __hash__(self) -> int:
|
41
42
|
"""Return a hash of the question."""
|
42
|
-
from
|
43
|
+
from ..utilities.utilities import dict_hash
|
43
44
|
|
44
45
|
return dict_hash(self.to_dict(add_edsl_version=False))
|
45
46
|
|
edsl/instructions/instruction.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import TYPE_CHECKING, Dict, List, Generator, Union
|
1
|
+
from typing import TYPE_CHECKING, Dict, List, Generator, Union, Tuple
|
2
2
|
from collections import UserDict
|
3
3
|
|
4
4
|
from .instruction import Instruction
|
@@ -40,13 +40,14 @@ class InstructionCollection(UserDict):
|
|
40
40
|
|
41
41
|
def _entries_before(
|
42
42
|
self, question_name
|
43
|
-
) ->
|
43
|
+
) -> Tuple[List[Instruction], List[ChangeInstruction]]:
|
44
44
|
if question_name not in self.question_names:
|
45
|
-
from
|
45
|
+
from .exceptions import InstructionCollectionError
|
46
46
|
raise InstructionCollectionError(
|
47
47
|
f"Question name not found in the list of questions: got '{question_name}'; list is {self.question_names}"
|
48
48
|
)
|
49
|
-
instructions
|
49
|
+
instructions: List[Instruction] = []
|
50
|
+
changes: List[ChangeInstruction] = []
|
50
51
|
|
51
52
|
index = self.question_index(question_name)
|
52
53
|
for instruction in self.instruction_names_to_instruction.values():
|
@@ -1,11 +1,13 @@
|
|
1
1
|
from dataclasses import dataclass
|
2
2
|
|
3
3
|
|
4
|
+
from typing import Dict, List, Any
|
5
|
+
|
4
6
|
@dataclass
|
5
7
|
class SeparatedComponents:
|
6
|
-
true_questions:
|
7
|
-
instruction_names_to_instructions:
|
8
|
-
pseudo_indices:
|
8
|
+
true_questions: List[Any]
|
9
|
+
instruction_names_to_instructions: Dict[str, Any]
|
10
|
+
pseudo_indices: Dict[str, float]
|
9
11
|
|
10
12
|
|
11
13
|
class InstructionHandler:
|
@@ -13,7 +15,7 @@ class InstructionHandler:
|
|
13
15
|
self.survey = survey
|
14
16
|
|
15
17
|
@staticmethod
|
16
|
-
def separate_questions_and_instructions(questions_and_instructions: list) ->
|
18
|
+
def separate_questions_and_instructions(questions_and_instructions: list) -> SeparatedComponents:
|
17
19
|
"""
|
18
20
|
The 'pseudo_indices' attribute is a dictionary that maps question names to pseudo-indices
|
19
21
|
that are used to order questions and instructions in the survey.
|
@@ -48,9 +50,9 @@ class InstructionHandler:
|
|
48
50
|
"""
|
49
51
|
from .instruction import Instruction
|
50
52
|
from .change_instruction import ChangeInstruction
|
51
|
-
from
|
53
|
+
from ..questions import QuestionBase
|
52
54
|
|
53
|
-
true_questions = []
|
55
|
+
true_questions: List[QuestionBase] = []
|
54
56
|
instruction_names_to_instructions = {}
|
55
57
|
|
56
58
|
num_change_instructions = 0
|
@@ -63,7 +65,7 @@ class InstructionHandler:
|
|
63
65
|
num_change_instructions += 1
|
64
66
|
for prior_instruction in entry.keep + entry.drop:
|
65
67
|
if prior_instruction not in instruction_names_to_instructions:
|
66
|
-
from
|
68
|
+
from .exceptions import InstructionValueError
|
67
69
|
raise InstructionValueError(
|
68
70
|
f"ChangeInstruction {entry.name} references instruction {prior_instruction} which does not exist."
|
69
71
|
)
|
@@ -77,7 +79,7 @@ class InstructionHandler:
|
|
77
79
|
instructions_run_length = 0
|
78
80
|
true_questions.append(entry)
|
79
81
|
else:
|
80
|
-
from
|
82
|
+
from .exceptions import InstructionValueError
|
81
83
|
raise InstructionValueError(
|
82
84
|
f"Entry {repr(entry)} is not a QuestionBase or an Instruction."
|
83
85
|
)
|
@@ -26,7 +26,6 @@ class RetryConfig:
|
|
26
26
|
|
27
27
|
|
28
28
|
class SkipHandler:
|
29
|
-
|
30
29
|
def __init__(self, interview: "Interview"):
|
31
30
|
self.interview = interview
|
32
31
|
self.question_index = self.interview.to_index
|
@@ -47,7 +46,7 @@ class SkipHandler:
|
|
47
46
|
|
48
47
|
def _current_info_env(self) -> dict[str, Any]:
|
49
48
|
"""
|
50
|
-
- The current answers are "generated_tokens" and "comment"
|
49
|
+
- The current answers are "generated_tokens" and "comment"
|
51
50
|
- The scenario should have "scenario." added to the keys
|
52
51
|
- The agent traits should have "agent." added to the keys
|
53
52
|
"""
|
@@ -65,10 +64,14 @@ class SkipHandler:
|
|
65
64
|
processed_answers[f"{key}.answer"] = value
|
66
65
|
|
67
66
|
# Process scenario dictionary
|
68
|
-
processed_scenario = {
|
67
|
+
processed_scenario = {
|
68
|
+
f"scenario.{k}": v for k, v in self.interview.scenario.items()
|
69
|
+
}
|
69
70
|
|
70
71
|
# Process agent traits
|
71
|
-
processed_agent = {
|
72
|
+
processed_agent = {
|
73
|
+
f"agent.{k}": v for k, v in self.interview.agent["traits"].items()
|
74
|
+
}
|
72
75
|
|
73
76
|
return processed_answers | processed_scenario | processed_agent
|
74
77
|
|
@@ -85,21 +88,22 @@ class SkipHandler:
|
|
85
88
|
# )
|
86
89
|
|
87
90
|
# Get the index of the next question, which could also be the end of the survey
|
88
|
-
next_question: Union[
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
91
|
+
next_question: Union[
|
92
|
+
int, EndOfSurvey
|
93
|
+
] = self.interview.survey.rule_collection.next_question(
|
94
|
+
q_now=current_question_index,
|
95
|
+
answers=answers,
|
93
96
|
)
|
94
97
|
|
95
|
-
|
96
98
|
def cancel_between(start, end):
|
97
99
|
"""Cancel the tasks for questions between the start and end indices."""
|
98
100
|
for i in range(start, end):
|
99
|
-
#print(f"Cancelling task {i}")
|
100
|
-
#self.interview.tasks[i].cancel()
|
101
|
-
#self.interview.tasks[i].set_result("skipped")
|
102
|
-
self.interview.skip_flags[
|
101
|
+
# print(f"Cancelling task {i}")
|
102
|
+
# self.interview.tasks[i].cancel()
|
103
|
+
# self.interview.tasks[i].set_result("skipped")
|
104
|
+
self.interview.skip_flags[
|
105
|
+
self.interview.survey.questions[i].question_name
|
106
|
+
] = True
|
103
107
|
|
104
108
|
if (next_question_index := next_question.next_q) == EndOfSurvey:
|
105
109
|
cancel_between(
|
@@ -111,8 +115,6 @@ class SkipHandler:
|
|
111
115
|
cancel_between(current_question_index + 1, next_question_index)
|
112
116
|
|
113
117
|
|
114
|
-
|
115
|
-
|
116
118
|
class AnswerQuestionFunctionConstructor:
|
117
119
|
"""Constructs a function that answers a question and records the answer."""
|
118
120
|
|
@@ -137,7 +139,6 @@ class AnswerQuestionFunctionConstructor:
|
|
137
139
|
):
|
138
140
|
"""Handle an exception that occurred while answering a question."""
|
139
141
|
|
140
|
-
|
141
142
|
answers = copy.copy(
|
142
143
|
self.interview.answers
|
143
144
|
) # copy to freeze the answers here for logging
|
@@ -171,7 +172,6 @@ class AnswerQuestionFunctionConstructor:
|
|
171
172
|
question: "QuestionBase",
|
172
173
|
task=None,
|
173
174
|
) -> "EDSLResultObjectInput":
|
174
|
-
|
175
175
|
from tenacity import (
|
176
176
|
RetryError,
|
177
177
|
retry,
|
@@ -196,7 +196,6 @@ class AnswerQuestionFunctionConstructor:
|
|
196
196
|
return invigilator.get_failed_task_result(
|
197
197
|
failure_reason="Question skipped."
|
198
198
|
)
|
199
|
-
|
200
199
|
if self.skip_handler.should_skip(question):
|
201
200
|
return invigilator.get_failed_task_result(
|
202
201
|
failure_reason="Question skipped."
|
@@ -240,7 +239,6 @@ class AnswerQuestionFunctionConstructor:
|
|
240
239
|
raise LanguageModelNoResponseError(
|
241
240
|
f"Language model did not return a response for question '{question.question_name}.'"
|
242
241
|
)
|
243
|
-
|
244
242
|
if (
|
245
243
|
question.question_name in self.interview.exceptions
|
246
244
|
and had_language_model_no_response_error
|
@@ -250,7 +248,8 @@ class AnswerQuestionFunctionConstructor:
|
|
250
248
|
return response
|
251
249
|
|
252
250
|
try:
|
253
|
-
|
251
|
+
out = await attempt_answer()
|
252
|
+
return out
|
254
253
|
except RetryError as retry_error:
|
255
254
|
original_error = retry_error.last_attempt.exception()
|
256
255
|
self._handle_exception(
|
@@ -81,6 +81,7 @@ class InterviewExceptionEntry:
|
|
81
81
|
raise_validation_errors=True,
|
82
82
|
disable_remote_cache=True,
|
83
83
|
disable_remote_inference=True,
|
84
|
+
cache=False,
|
84
85
|
)
|
85
86
|
return results.task_history.exceptions[0]["how_are_you"][0]
|
86
87
|
|
@@ -97,8 +98,8 @@ class InterviewExceptionEntry:
|
|
97
98
|
lines.append(f"q = {repr(self.invigilator.question)}")
|
98
99
|
lines.append(f"scenario = {repr(self.invigilator.scenario)}")
|
99
100
|
lines.append(f"agent = {repr(self.invigilator.agent)}")
|
100
|
-
lines.append(f"
|
101
|
-
lines.append("results = q.by(
|
101
|
+
lines.append(f"model = {repr(self.invigilator.model)}")
|
102
|
+
lines.append("results = q.by(model).by(agent).by(scenario).run()")
|
102
103
|
code_str = "\n".join(lines)
|
103
104
|
|
104
105
|
if run:
|
edsl/interviews/interview.py
CHANGED
@@ -365,7 +365,7 @@ class Interview:
|
|
365
365
|
bool: True if the interviews are equivalent, False otherwise
|
366
366
|
|
367
367
|
Examples:
|
368
|
-
>>> from
|
368
|
+
>>> from . import Interview
|
369
369
|
>>> i = Interview.example()
|
370
370
|
>>> d = i.to_dict()
|
371
371
|
>>> i2 = Interview.from_dict(d)
|
@@ -27,7 +27,7 @@ class InterviewStatusDictionary(UserDict):
|
|
27
27
|
) -> "InterviewStatusDictionary":
|
28
28
|
"""Adds two InterviewStatusDictionaries together."""
|
29
29
|
if not isinstance(other, InterviewStatusDictionary):
|
30
|
-
from
|
30
|
+
from .exceptions import InterviewStatusError
|
31
31
|
raise InterviewStatusError(f"Can't add {type(other)} to InterviewStatusDictionary")
|
32
32
|
new_dict = {}
|
33
33
|
for key in self.keys():
|
@@ -24,12 +24,13 @@ class InterviewTaskManager:
|
|
24
24
|
for index, question_name in enumerate(self.survey.question_names)
|
25
25
|
}
|
26
26
|
self._task_status_log_dict = InterviewStatusLog()
|
27
|
+
self.survey_dag = None
|
27
28
|
|
28
29
|
def build_question_tasks(
|
29
30
|
self, answer_func, token_estimator, model_buckets
|
30
31
|
) -> list[asyncio.Task]:
|
31
32
|
"""Create tasks for all questions with proper dependencies."""
|
32
|
-
tasks = []
|
33
|
+
tasks: list[asyncio.Task] = []
|
33
34
|
for question in self.survey.questions:
|
34
35
|
dependencies = self._get_task_dependencies(tasks, question)
|
35
36
|
task = self._create_single_task(
|
@@ -40,14 +41,15 @@ class InterviewTaskManager:
|
|
40
41
|
model_buckets=model_buckets,
|
41
42
|
)
|
42
43
|
tasks.append(task)
|
43
|
-
return
|
44
|
+
return tasks
|
44
45
|
|
45
46
|
def _get_task_dependencies(
|
46
47
|
self, existing_tasks: list[asyncio.Task], question: "QuestionBase"
|
47
48
|
) -> list[asyncio.Task]:
|
48
49
|
"""Get tasks that must be completed before the given question."""
|
49
|
-
|
50
|
-
|
50
|
+
if self.survey_dag is None:
|
51
|
+
self.survey_dag = self.survey.dag(textify=True)
|
52
|
+
parents = self.survey_dag.get(question.question_name, [])
|
51
53
|
return [existing_tasks[self.to_index[parent_name]] for parent_name in parents]
|
52
54
|
|
53
55
|
def _create_single_task(
|
@@ -100,4 +102,5 @@ class InterviewTaskManager:
|
|
100
102
|
|
101
103
|
if __name__ == "__main__":
|
102
104
|
import doctest
|
105
|
+
|
103
106
|
doctest.testmod()
|
@@ -26,9 +26,10 @@ class RequestTokenEstimator:
|
|
26
26
|
if isinstance(file, FileStore):
|
27
27
|
file_tokens += file.size * 0.25
|
28
28
|
else:
|
29
|
-
from
|
29
|
+
from .exceptions import InterviewTokenError
|
30
30
|
raise InterviewTokenError(f"Prompt is of type {type(prompt)}")
|
31
|
-
|
31
|
+
result: float = len(combined_text) / 4.0 + file_tokens
|
32
|
+
return result
|
32
33
|
|
33
34
|
|
34
35
|
|
edsl/interviews/statistics.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
from collections import UserDict
|
2
|
-
from typing import DefaultDict
|
2
|
+
from typing import DefaultDict, Union, Optional
|
3
3
|
|
4
4
|
from ..tokens import InterviewTokenUsage
|
5
5
|
|
@@ -43,7 +43,7 @@ class InterviewStatistic(UserDict):
|
|
43
43
|
value: float,
|
44
44
|
digits: int = 0,
|
45
45
|
units: str = "",
|
46
|
-
pretty_name: str = None,
|
46
|
+
pretty_name: Optional[str] = None,
|
47
47
|
):
|
48
48
|
"""Create a new InterviewStatistic object."""
|
49
49
|
self.name = name
|
@@ -6,7 +6,7 @@ from typing import Literal
|
|
6
6
|
|
7
7
|
from ..utilities.decorators import sync_wrapper
|
8
8
|
from ..questions.exceptions import QuestionAnswerValidationError
|
9
|
-
from ..data_transfer_models import AgentResponseDict, EDSLResultObjectInput
|
9
|
+
from ..base.data_transfer_models import AgentResponseDict, EDSLResultObjectInput
|
10
10
|
from ..utilities.decorators import jupyter_nb_handler
|
11
11
|
|
12
12
|
from .prompt_constructor import PromptConstructor
|
@@ -24,11 +24,11 @@ if TYPE_CHECKING:
|
|
24
24
|
from ..key_management import KeyLookup
|
25
25
|
|
26
26
|
|
27
|
-
|
28
27
|
PromptType = Literal["user_prompt", "system_prompt", "encoded_image", "files_list"]
|
29
28
|
|
30
29
|
NA = "Not Applicable"
|
31
30
|
|
31
|
+
|
32
32
|
class InvigilatorBase(ABC):
|
33
33
|
"""An invigiator (someone who administers an exam) is a class that is responsible for administering a question to an agent.
|
34
34
|
|
@@ -261,13 +261,14 @@ class InvigilatorBase(ABC):
|
|
261
261
|
current_answers=current_answers,
|
262
262
|
)
|
263
263
|
|
264
|
+
|
264
265
|
class InvigilatorAI(InvigilatorBase):
|
265
266
|
"""An invigilator that uses an AI model to answer questions."""
|
266
267
|
|
267
268
|
def get_prompts(self) -> Dict[PromptType, "Prompt"]:
|
268
269
|
"""Return the prompts used."""
|
269
270
|
return self.prompt_constructor.get_prompts()
|
270
|
-
|
271
|
+
|
271
272
|
def get_captured_variables(self) -> dict:
|
272
273
|
"""Get the captured variables."""
|
273
274
|
return self.prompt_constructor.get_captured_variables()
|
@@ -280,7 +281,8 @@ class InvigilatorAI(InvigilatorBase):
|
|
280
281
|
}
|
281
282
|
if "encoded_image" in prompts:
|
282
283
|
params["encoded_image"] = prompts["encoded_image"]
|
283
|
-
from
|
284
|
+
from .exceptions import InvigilatorNotImplementedError
|
285
|
+
|
284
286
|
raise InvigilatorNotImplementedError("encoded_image not implemented")
|
285
287
|
|
286
288
|
if "files_list" in prompts:
|
@@ -307,7 +309,8 @@ class InvigilatorAI(InvigilatorBase):
|
|
307
309
|
"""
|
308
310
|
agent_response_dict: AgentResponseDict = await self.async_get_agent_response()
|
309
311
|
self.store_response(agent_response_dict)
|
310
|
-
|
312
|
+
out = self._extract_edsl_result_entry_and_validate(agent_response_dict)
|
313
|
+
return out
|
311
314
|
|
312
315
|
def _remove_from_cache(self, cache_key) -> None:
|
313
316
|
"""Remove an entry from the cache."""
|
@@ -389,6 +392,30 @@ class InvigilatorAI(InvigilatorBase):
|
|
389
392
|
edsl_dict = agent_response_dict.edsl_dict._asdict()
|
390
393
|
exception_occurred = None
|
391
394
|
validated = False
|
395
|
+
|
396
|
+
if agent_response_dict.model_outputs.cache_used:
|
397
|
+
data = {
|
398
|
+
"answer": agent_response_dict.edsl_dict.answer
|
399
|
+
if type(agent_response_dict.edsl_dict.answer) is str
|
400
|
+
else "",
|
401
|
+
"comment": agent_response_dict.edsl_dict.comment
|
402
|
+
if agent_response_dict.edsl_dict.comment
|
403
|
+
else "",
|
404
|
+
"generated_tokens": agent_response_dict.edsl_dict.generated_tokens,
|
405
|
+
"question_name": self.question.question_name,
|
406
|
+
"prompts": self.get_prompts(),
|
407
|
+
"cached_response": agent_response_dict.model_outputs.cached_response,
|
408
|
+
"raw_model_response": agent_response_dict.model_outputs.response,
|
409
|
+
"cache_used": agent_response_dict.model_outputs.cache_used,
|
410
|
+
"cache_key": agent_response_dict.model_outputs.cache_key,
|
411
|
+
"validated": True,
|
412
|
+
"exception_occurred": exception_occurred,
|
413
|
+
"cost": agent_response_dict.model_outputs.cost,
|
414
|
+
}
|
415
|
+
|
416
|
+
result = EDSLResultObjectInput(**data)
|
417
|
+
return result
|
418
|
+
|
392
419
|
try:
|
393
420
|
# if the question has jinja parameters, it is easier to make a new question with the parameters
|
394
421
|
if self.question.parameters:
|
@@ -405,7 +432,7 @@ class InvigilatorAI(InvigilatorBase):
|
|
405
432
|
self.question.question_options = new_question_options
|
406
433
|
|
407
434
|
question_with_validators = self.question.render(
|
408
|
-
self.scenario | prior_answers_dict | {
|
435
|
+
self.scenario | prior_answers_dict | {"agent": self.agent.traits}
|
409
436
|
)
|
410
437
|
question_with_validators.use_code = self.question.use_code
|
411
438
|
else:
|
@@ -426,6 +453,7 @@ class InvigilatorAI(InvigilatorBase):
|
|
426
453
|
exception_occurred = non_validation_error
|
427
454
|
finally:
|
428
455
|
# even if validation failes, we still return the result
|
456
|
+
|
429
457
|
data = {
|
430
458
|
"answer": answer,
|
431
459
|
"comment": comment,
|
edsl/jobs/__init__.py
CHANGED
@@ -1,7 +1,44 @@
|
|
1
|
+
"""
|
2
|
+
The jobs module provides tools for running and managing EDSL jobs.
|
3
|
+
|
4
|
+
It includes classes for job configuration, execution, pricing estimation,
|
5
|
+
and management of concurrent language model API calls.
|
6
|
+
"""
|
7
|
+
|
1
8
|
from .jobs import Jobs
|
2
9
|
from .jobs import RunConfig, RunParameters, RunEnvironment # noqa: F401
|
3
10
|
from .remote_inference import JobsRemoteInferenceHandler # noqa: F401
|
4
11
|
from .jobs_runner_status import JobsRunnerStatusBase # noqa: F401
|
12
|
+
from .exceptions import (
|
13
|
+
JobsErrors,
|
14
|
+
JobsRunError,
|
15
|
+
MissingRemoteInferenceError,
|
16
|
+
InterviewError,
|
17
|
+
InterviewErrorPriorTaskCanceled,
|
18
|
+
InterviewTimeoutError,
|
19
|
+
JobsValueError,
|
20
|
+
JobsCompatibilityError,
|
21
|
+
JobsImplementationError,
|
22
|
+
RemoteInferenceError,
|
23
|
+
JobsTypeError
|
24
|
+
)
|
5
25
|
|
6
|
-
|
7
|
-
|
26
|
+
__all__ = [
|
27
|
+
"Jobs",
|
28
|
+
"JobsErrors",
|
29
|
+
"JobsRunError",
|
30
|
+
"MissingRemoteInferenceError",
|
31
|
+
"InterviewError",
|
32
|
+
"InterviewErrorPriorTaskCanceled",
|
33
|
+
"InterviewTimeoutError",
|
34
|
+
"JobsValueError",
|
35
|
+
"JobsCompatibilityError",
|
36
|
+
"JobsImplementationError",
|
37
|
+
"RemoteInferenceError",
|
38
|
+
"JobsTypeError",
|
39
|
+
"JobsRemoteInferenceHandler",
|
40
|
+
"JobsRunnerStatusBase",
|
41
|
+
"RunConfig",
|
42
|
+
"RunParameters",
|
43
|
+
"RunEnvironment"
|
44
|
+
]
|