edsl 0.1.40.dev2__py3-none-any.whl → 0.1.42__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__init__.py +1 -0
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +1 -1
- edsl/agents/Invigilator.py +6 -4
- edsl/agents/InvigilatorBase.py +2 -1
- edsl/agents/QuestionTemplateReplacementsBuilder.py +7 -2
- edsl/coop/coop.py +37 -2
- edsl/data/Cache.py +7 -0
- edsl/data/RemoteCacheSync.py +16 -16
- edsl/enums.py +3 -0
- edsl/exceptions/jobs.py +1 -9
- edsl/exceptions/language_models.py +8 -4
- edsl/exceptions/questions.py +8 -11
- edsl/inference_services/DeepSeekService.py +18 -0
- edsl/inference_services/registry.py +2 -0
- edsl/jobs/AnswerQuestionFunctionConstructor.py +1 -1
- edsl/jobs/Jobs.py +42 -34
- edsl/jobs/JobsPrompts.py +11 -1
- edsl/jobs/JobsRemoteInferenceHandler.py +1 -0
- edsl/jobs/JobsRemoteInferenceLogger.py +1 -1
- edsl/jobs/interviews/Interview.py +2 -6
- edsl/jobs/interviews/InterviewExceptionEntry.py +14 -4
- edsl/jobs/loggers/HTMLTableJobLogger.py +6 -1
- edsl/jobs/results_exceptions_handler.py +2 -7
- edsl/jobs/runners/JobsRunnerAsyncio.py +18 -6
- edsl/jobs/runners/JobsRunnerStatus.py +2 -1
- edsl/jobs/tasks/TaskHistory.py +49 -17
- edsl/language_models/LanguageModel.py +7 -4
- edsl/language_models/ModelList.py +1 -1
- edsl/language_models/key_management/KeyLookupBuilder.py +7 -3
- edsl/language_models/model.py +49 -0
- edsl/questions/QuestionBudget.py +2 -2
- edsl/questions/QuestionDict.py +343 -0
- edsl/questions/QuestionExtract.py +1 -1
- edsl/questions/__init__.py +1 -0
- edsl/questions/answer_validator_mixin.py +29 -0
- edsl/questions/derived/QuestionLinearScale.py +1 -1
- edsl/questions/descriptors.py +49 -5
- edsl/questions/question_registry.py +1 -1
- edsl/questions/templates/dict/__init__.py +0 -0
- edsl/questions/templates/dict/answering_instructions.jinja +21 -0
- edsl/questions/templates/dict/question_presentation.jinja +1 -0
- edsl/results/Result.py +25 -3
- edsl/results/Results.py +17 -5
- edsl/scenarios/FileStore.py +32 -0
- edsl/scenarios/PdfExtractor.py +3 -6
- edsl/scenarios/Scenario.py +2 -1
- edsl/scenarios/handlers/csv.py +11 -0
- edsl/surveys/Survey.py +5 -1
- edsl/templates/error_reporting/base.html +2 -4
- edsl/templates/error_reporting/exceptions_table.html +35 -0
- edsl/templates/error_reporting/interview_details.html +67 -53
- edsl/templates/error_reporting/interviews.html +4 -17
- edsl/templates/error_reporting/overview.html +31 -5
- edsl/templates/error_reporting/performance_plot.html +1 -1
- {edsl-0.1.40.dev2.dist-info → edsl-0.1.42.dist-info}/METADATA +1 -1
- {edsl-0.1.40.dev2.dist-info → edsl-0.1.42.dist-info}/RECORD +59 -53
- {edsl-0.1.40.dev2.dist-info → edsl-0.1.42.dist-info}/LICENSE +0 -0
- {edsl-0.1.40.dev2.dist-info → edsl-0.1.42.dist-info}/WHEEL +0 -0
edsl/__init__.py
CHANGED
@@ -21,6 +21,7 @@ from edsl.questions import QuestionFunctional
|
|
21
21
|
from edsl.questions import QuestionLikertFive
|
22
22
|
from edsl.questions import QuestionList
|
23
23
|
from edsl.questions import QuestionMatrix
|
24
|
+
from edsl.questions import QuestionDict
|
24
25
|
from edsl.questions import QuestionLinearScale
|
25
26
|
from edsl.questions import QuestionNumerical
|
26
27
|
from edsl.questions import QuestionYesNo
|
edsl/__version__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.1.
|
1
|
+
__version__ = "0.1.42"
|
edsl/agents/Agent.py
CHANGED
@@ -906,7 +906,7 @@ class Agent(Base):
|
|
906
906
|
{'traits': {'age': 10, 'hair': 'brown', 'height': 5.5}, 'instruction': 'Have fun.', 'edsl_version': '...', 'edsl_class_name': 'Agent'}
|
907
907
|
"""
|
908
908
|
d = {}
|
909
|
-
d["traits"] = copy.deepcopy(self.
|
909
|
+
d["traits"] = copy.deepcopy(dict(self._traits))
|
910
910
|
if self.name:
|
911
911
|
d["name"] = self.name
|
912
912
|
if self.set_instructions:
|
edsl/agents/Invigilator.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
"""Module for creating Invigilators, which are objects to administer a question to an Agent."""
|
2
2
|
|
3
|
-
from typing import Dict, Any, Optional, TYPE_CHECKING
|
3
|
+
from typing import Dict, Any, Optional, TYPE_CHECKING, Literal
|
4
4
|
|
5
5
|
from edsl.utilities.decorators import sync_wrapper
|
6
6
|
from edsl.exceptions.questions import QuestionAnswerValidationError
|
@@ -12,6 +12,7 @@ if TYPE_CHECKING:
|
|
12
12
|
from edsl.scenarios.Scenario import Scenario
|
13
13
|
from edsl.surveys.Survey import Survey
|
14
14
|
|
15
|
+
PromptType = Literal["user_prompt", "system_prompt", "encoded_image", "files_list"]
|
15
16
|
|
16
17
|
NA = "Not Applicable"
|
17
18
|
|
@@ -19,7 +20,7 @@ NA = "Not Applicable"
|
|
19
20
|
class InvigilatorAI(InvigilatorBase):
|
20
21
|
"""An invigilator that uses an AI model to answer questions."""
|
21
22
|
|
22
|
-
def get_prompts(self) -> Dict[
|
23
|
+
def get_prompts(self) -> Dict[PromptType, "Prompt"]:
|
23
24
|
"""Return the prompts used."""
|
24
25
|
return self.prompt_constructor.get_prompts()
|
25
26
|
|
@@ -48,13 +49,14 @@ class InvigilatorAI(InvigilatorBase):
|
|
48
49
|
"""Store the response in the invigilator, in case it is needed later because of validation failure."""
|
49
50
|
self.raw_model_response = agent_response_dict.model_outputs.response
|
50
51
|
self.generated_tokens = agent_response_dict.edsl_dict.generated_tokens
|
52
|
+
self.cache_key = agent_response_dict.model_outputs.cache_key
|
51
53
|
|
52
|
-
async def async_answer_question(self) ->
|
54
|
+
async def async_answer_question(self) -> EDSLResultObjectInput:
|
53
55
|
"""Answer a question using the AI model.
|
54
56
|
|
55
57
|
>>> i = InvigilatorAI.example()
|
56
58
|
"""
|
57
|
-
agent_response_dict = await self.async_get_agent_response()
|
59
|
+
agent_response_dict: AgentResponseDict = await self.async_get_agent_response()
|
58
60
|
self.store_response(agent_response_dict)
|
59
61
|
return self._extract_edsl_result_entry_and_validate(agent_response_dict)
|
60
62
|
|
edsl/agents/InvigilatorBase.py
CHANGED
@@ -135,6 +135,7 @@ class InvigilatorBase(ABC):
|
|
135
135
|
d["additional_prompt_data"] = data["additional_prompt_data"]
|
136
136
|
|
137
137
|
d = cls(**d)
|
138
|
+
return d
|
138
139
|
|
139
140
|
def __repr__(self) -> str:
|
140
141
|
"""Return a string representation of the Invigilator.
|
@@ -143,7 +144,7 @@ class InvigilatorBase(ABC):
|
|
143
144
|
'InvigilatorExample(...)'
|
144
145
|
|
145
146
|
"""
|
146
|
-
return f"{self.__class__.__name__}(agent={repr(self.agent)}, question={repr(self.question)},
|
147
|
+
return f"{self.__class__.__name__}(agent={repr(self.agent)}, question={repr(self.question)}, scenario={repr(self.scenario)}, model={repr(self.model)}, memory_plan={repr(self.memory_plan)}, current_answers={repr(self.current_answers)}, iteration={repr(self.iteration)}, additional_prompt_data={repr(self.additional_prompt_data)}, cache={repr(self.cache)})"
|
147
148
|
|
148
149
|
def get_failed_task_result(self, failure_reason: str) -> EDSLResultObjectInput:
|
149
150
|
"""Return an AgentResponseDict used in case the question-asking fails.
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from jinja2 import Environment, meta
|
1
|
+
from jinja2 import Environment, meta, TemplateSyntaxError
|
2
2
|
from typing import Any, Set, TYPE_CHECKING
|
3
3
|
|
4
4
|
if TYPE_CHECKING:
|
@@ -29,7 +29,12 @@ class QuestionTemplateReplacementsBuilder:
|
|
29
29
|
Set[str]: A set of variable names found in the template
|
30
30
|
"""
|
31
31
|
env = Environment()
|
32
|
-
|
32
|
+
try:
|
33
|
+
ast = env.parse(template_str)
|
34
|
+
except TemplateSyntaxError:
|
35
|
+
print(f"Error parsing template: {template_str}")
|
36
|
+
raise
|
37
|
+
|
33
38
|
return meta.find_undeclared_variables(ast)
|
34
39
|
|
35
40
|
@staticmethod
|
edsl/coop/coop.py
CHANGED
@@ -111,13 +111,13 @@ class Coop(CoopFunctionsMixin):
|
|
111
111
|
url = f"{self.api_url}/{uri}"
|
112
112
|
method = method.upper()
|
113
113
|
if payload is None:
|
114
|
-
timeout =
|
114
|
+
timeout = 40
|
115
115
|
elif (
|
116
116
|
method.upper() == "POST"
|
117
117
|
and "json_string" in payload
|
118
118
|
and payload.get("json_string") is not None
|
119
119
|
):
|
120
|
-
timeout = max(
|
120
|
+
timeout = max(40, (len(payload.get("json_string", "")) // (1024 * 1024)))
|
121
121
|
try:
|
122
122
|
if method in ["GET", "DELETE"]:
|
123
123
|
response = requests.request(
|
@@ -533,6 +533,7 @@ class Coop(CoopFunctionsMixin):
|
|
533
533
|
uri="api/v0/remote-cache/many",
|
534
534
|
method="POST",
|
535
535
|
payload=payload,
|
536
|
+
timeout=40,
|
536
537
|
)
|
537
538
|
self._resolve_server_response(response)
|
538
539
|
response_json = response.json()
|
@@ -563,6 +564,7 @@ class Coop(CoopFunctionsMixin):
|
|
563
564
|
uri="api/v0/remote-cache/get-many",
|
564
565
|
method="POST",
|
565
566
|
payload={"keys": exclude_keys},
|
567
|
+
timeout=40,
|
566
568
|
)
|
567
569
|
self._resolve_server_response(response)
|
568
570
|
return [
|
@@ -581,6 +583,7 @@ class Coop(CoopFunctionsMixin):
|
|
581
583
|
uri="api/v0/remote-cache/get-diff",
|
582
584
|
method="POST",
|
583
585
|
payload={"keys": client_cacheentry_keys},
|
586
|
+
timeout=40,
|
584
587
|
)
|
585
588
|
self._resolve_server_response(response)
|
586
589
|
response_json = response.json()
|
@@ -891,6 +894,38 @@ class Coop(CoopFunctionsMixin):
|
|
891
894
|
data = response.json()
|
892
895
|
return ServiceToModelsMapping(data)
|
893
896
|
|
897
|
+
def fetch_working_models(self) -> list[dict]:
|
898
|
+
"""
|
899
|
+
Fetch a list of working models from Coop.
|
900
|
+
|
901
|
+
Example output:
|
902
|
+
|
903
|
+
[
|
904
|
+
{
|
905
|
+
"service": "openai",
|
906
|
+
"model": "gpt-4o",
|
907
|
+
"works_with_text": True,
|
908
|
+
"works_with_images": True,
|
909
|
+
"usd_per_1M_input_tokens": 2.5,
|
910
|
+
"usd_per_1M_output_tokens": 10.0,
|
911
|
+
}
|
912
|
+
]
|
913
|
+
"""
|
914
|
+
response = self._send_server_request(uri="api/v0/working-models", method="GET")
|
915
|
+
self._resolve_server_response(response)
|
916
|
+
data = response.json()
|
917
|
+
return [
|
918
|
+
{
|
919
|
+
"service": record.get("service"),
|
920
|
+
"model": record.get("model"),
|
921
|
+
"works_with_text": record.get("works_with_text"),
|
922
|
+
"works_with_images": record.get("works_with_images"),
|
923
|
+
"usd_per_1M_input_tokens": record.get("input_price_per_1M_tokens"),
|
924
|
+
"usd_per_1M_output_tokens": record.get("output_price_per_1M_tokens"),
|
925
|
+
}
|
926
|
+
for record in data
|
927
|
+
]
|
928
|
+
|
894
929
|
def fetch_rate_limit_config_vars(self) -> dict:
|
895
930
|
"""
|
896
931
|
Fetch a dict of rate limit config vars from Coop.
|
edsl/data/Cache.py
CHANGED
@@ -535,6 +535,13 @@ class Cache(Base):
|
|
535
535
|
"""
|
536
536
|
return html
|
537
537
|
|
538
|
+
def subset(self, keys: list[str]) -> Cache:
|
539
|
+
"""
|
540
|
+
Return a subset of the Cache with the specified keys.
|
541
|
+
"""
|
542
|
+
new_data = {k: v for k, v in self.data.items() if k in keys}
|
543
|
+
return Cache(data=new_data)
|
544
|
+
|
538
545
|
def view(self) -> None:
|
539
546
|
"""View the Cache in a new browser tab."""
|
540
547
|
import tempfile
|
edsl/data/RemoteCacheSync.py
CHANGED
@@ -112,18 +112,18 @@ class RemoteCacheSync(AbstractContextManager):
|
|
112
112
|
missing_count = len(diff.client_missing_entries)
|
113
113
|
|
114
114
|
if missing_count == 0:
|
115
|
-
|
115
|
+
# self._output("No new entries to add to local cache.")
|
116
116
|
return
|
117
117
|
|
118
|
-
self._output(
|
119
|
-
|
120
|
-
|
121
|
-
)
|
118
|
+
# self._output(
|
119
|
+
# f"Updating local cache with {missing_count:,} new "
|
120
|
+
# f"{'entry' if missing_count == 1 else 'entries'} from remote..."
|
121
|
+
# )
|
122
122
|
|
123
123
|
self.cache.add_from_dict(
|
124
124
|
{entry.key: entry for entry in diff.client_missing_entries}
|
125
125
|
)
|
126
|
-
self._output("Local cache updated!")
|
126
|
+
# self._output("Local cache updated!")
|
127
127
|
|
128
128
|
def _get_entries_to_upload(self, diff: CacheDifference) -> CacheEntriesList:
|
129
129
|
"""Determines which entries need to be uploaded to remote cache."""
|
@@ -154,23 +154,23 @@ class RemoteCacheSync(AbstractContextManager):
|
|
154
154
|
upload_count = len(entries_to_upload)
|
155
155
|
|
156
156
|
if upload_count > 0:
|
157
|
-
self._output(
|
158
|
-
|
159
|
-
|
160
|
-
)
|
157
|
+
# self._output(
|
158
|
+
# f"Updating remote cache with {upload_count:,} new "
|
159
|
+
# f"{'entry' if upload_count == 1 else 'entries'}..."
|
160
|
+
# )
|
161
161
|
|
162
162
|
self.coop.remote_cache_create_many(
|
163
163
|
entries_to_upload,
|
164
164
|
visibility="private",
|
165
165
|
description=self.remote_cache_description,
|
166
166
|
)
|
167
|
-
self._output("Remote cache updated!")
|
168
|
-
else:
|
169
|
-
self._output("No new entries to add to remote cache.")
|
167
|
+
# self._output("Remote cache updated!")
|
168
|
+
# else:
|
169
|
+
# self._output("No new entries to add to remote cache.")
|
170
170
|
|
171
|
-
self._output(
|
172
|
-
f"There are {len(self.cache.keys()):,} entries in the local cache."
|
173
|
-
)
|
171
|
+
# self._output(
|
172
|
+
# f"There are {len(self.cache.keys()):,} entries in the local cache."
|
173
|
+
# )
|
174
174
|
|
175
175
|
|
176
176
|
if __name__ == "__main__":
|
edsl/enums.py
CHANGED
@@ -66,6 +66,7 @@ class InferenceServiceType(EnumWithChecks):
|
|
66
66
|
MISTRAL = "mistral"
|
67
67
|
TOGETHER = "together"
|
68
68
|
PERPLEXITY = "perplexity"
|
69
|
+
DEEPSEEK = "deepseek"
|
69
70
|
|
70
71
|
|
71
72
|
# unavoidable violation of the DRY principle but it is necessary
|
@@ -84,6 +85,7 @@ InferenceServiceLiteral = Literal[
|
|
84
85
|
"mistral",
|
85
86
|
"together",
|
86
87
|
"perplexity",
|
88
|
+
"deepseek",
|
87
89
|
]
|
88
90
|
|
89
91
|
available_models_urls = {
|
@@ -107,6 +109,7 @@ service_to_api_keyname = {
|
|
107
109
|
InferenceServiceType.MISTRAL.value: "MISTRAL_API_KEY",
|
108
110
|
InferenceServiceType.TOGETHER.value: "TOGETHER_API_KEY",
|
109
111
|
InferenceServiceType.PERPLEXITY.value: "PERPLEXITY_API_KEY",
|
112
|
+
InferenceServiceType.DEEPSEEK.value: "DEEPSEEK_API_KEY",
|
110
113
|
}
|
111
114
|
|
112
115
|
|
edsl/exceptions/jobs.py
CHANGED
@@ -10,15 +10,7 @@ class JobsRunError(JobsErrors):
|
|
10
10
|
|
11
11
|
|
12
12
|
class MissingRemoteInferenceError(JobsErrors):
|
13
|
-
|
14
|
-
message = dedent(
|
15
|
-
"""\\
|
16
|
-
You are trying to run the job remotely, but you have not set the EXPECTED_PARROT_INFERENCE_URL environment variable.
|
17
|
-
This remote running service is not quite ready yet!
|
18
|
-
But please see https://docs.expectedparrot.com/en/latest/coop.html for what we are working on.
|
19
|
-
"""
|
20
|
-
)
|
21
|
-
super().__init__(message)
|
13
|
+
pass
|
22
14
|
|
23
15
|
|
24
16
|
class InterviewError(Exception):
|
@@ -34,11 +34,15 @@ class LanguageModelNotFound(LanguageModelExceptions):
|
|
34
34
|
msg = dedent(
|
35
35
|
f"""\
|
36
36
|
Model {model_name} not found.
|
37
|
-
To create an instance,
|
38
|
-
|
37
|
+
To create an instance of this model, pass the model name to a `Model` object.
|
38
|
+
You can optionally pass additional parameters to the model, e.g.:
|
39
|
+
>>> m = Model('gpt-4-1106-preview', temperature=0.5)
|
39
40
|
|
40
|
-
To
|
41
|
-
To
|
41
|
+
To use the default model, simply run your job without specifying a model.
|
42
|
+
To check the default model, run the following code:
|
43
|
+
>>> Model()
|
44
|
+
|
45
|
+
To see information about all available models, run the following code:
|
42
46
|
>>> Model.available()
|
43
47
|
|
44
48
|
See https://docs.expectedparrot.com/en/latest/language_models.html#available-models for more details.
|
edsl/exceptions/questions.py
CHANGED
@@ -16,7 +16,8 @@ class QuestionErrors(Exception):
|
|
16
16
|
class QuestionAnswerValidationError(QuestionErrors):
|
17
17
|
documentation = "https://docs.expectedparrot.com/en/latest/exceptions.html"
|
18
18
|
|
19
|
-
explanation = """
|
19
|
+
explanation = """
|
20
|
+
This can occur when the answer coming from the Language Model does not conform to the expectations for the question type.
|
20
21
|
For example, if the question is a multiple choice question, the answer should be drawn from the list of options provided.
|
21
22
|
"""
|
22
23
|
|
@@ -52,28 +53,24 @@ class QuestionAnswerValidationError(QuestionErrors):
|
|
52
53
|
|
53
54
|
def to_html_dict(self):
|
54
55
|
return {
|
55
|
-
"
|
56
|
-
"
|
57
|
-
"
|
58
|
-
"What model returned",
|
56
|
+
"Exception type": ("p", "/p", self.__class__.__name__),
|
57
|
+
"Explanation": ("p", "/p", self.explanation),
|
58
|
+
"EDSL response": (
|
59
59
|
"pre",
|
60
60
|
"/pre",
|
61
61
|
json.dumps(self.data, indent=2),
|
62
62
|
),
|
63
|
-
"
|
64
|
-
"Pydantic model for answers",
|
63
|
+
"Validating model": (
|
65
64
|
"pre",
|
66
65
|
"/pre",
|
67
66
|
json.dumps(self.model.model_json_schema(), indent=2),
|
68
67
|
),
|
69
|
-
"
|
70
|
-
"Error message Pydantic returned",
|
68
|
+
"Error message": (
|
71
69
|
"p",
|
72
70
|
"/p",
|
73
71
|
self.message,
|
74
72
|
),
|
75
|
-
"
|
76
|
-
"URL to EDSL docs",
|
73
|
+
"Documentation": (
|
77
74
|
f"a href='{self.documentation}'",
|
78
75
|
"/a",
|
79
76
|
self.documentation,
|
@@ -0,0 +1,18 @@
|
|
1
|
+
import aiohttp
|
2
|
+
import json
|
3
|
+
import requests
|
4
|
+
from typing import Any, List
|
5
|
+
|
6
|
+
# from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
7
|
+
from edsl.language_models import LanguageModel
|
8
|
+
|
9
|
+
from edsl.inference_services.OpenAIService import OpenAIService
|
10
|
+
|
11
|
+
|
12
|
+
class DeepSeekService(OpenAIService):
|
13
|
+
"""DeepInfra service class."""
|
14
|
+
|
15
|
+
_inference_service_ = "deepseek"
|
16
|
+
_env_key_name_ = "DEEPSEEK_API_KEY"
|
17
|
+
_base_url_ = "https://api.deepseek.com"
|
18
|
+
_models_list_cache: List[str] = []
|
@@ -13,6 +13,7 @@ from edsl.inference_services.OllamaService import OllamaService
|
|
13
13
|
from edsl.inference_services.TestService import TestService
|
14
14
|
from edsl.inference_services.TogetherAIService import TogetherAIService
|
15
15
|
from edsl.inference_services.PerplexityService import PerplexityService
|
16
|
+
from edsl.inference_services.DeepSeekService import DeepSeekService
|
16
17
|
|
17
18
|
try:
|
18
19
|
from edsl.inference_services.MistralAIService import MistralAIService
|
@@ -33,6 +34,7 @@ services = [
|
|
33
34
|
TestService,
|
34
35
|
TogetherAIService,
|
35
36
|
PerplexityService,
|
37
|
+
DeepSeekService,
|
36
38
|
]
|
37
39
|
|
38
40
|
if mistral_available:
|
edsl/jobs/Jobs.py
CHANGED
@@ -499,7 +499,6 @@ class Jobs(Base):
|
|
499
499
|
jc.check_api_keys()
|
500
500
|
|
501
501
|
async def _execute_with_remote_cache(self, run_job_async: bool) -> Results:
|
502
|
-
|
503
502
|
use_remote_cache = self.use_remote_cache()
|
504
503
|
|
505
504
|
from edsl.coop.coop import Coop
|
@@ -508,43 +507,48 @@ class Jobs(Base):
|
|
508
507
|
|
509
508
|
assert isinstance(self.run_config.environment.cache, Cache)
|
510
509
|
|
511
|
-
with RemoteCacheSync(
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
):
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
510
|
+
# with RemoteCacheSync(
|
511
|
+
# coop=Coop(),
|
512
|
+
# cache=self.run_config.environment.cache,
|
513
|
+
# output_func=self._output,
|
514
|
+
# remote_cache=use_remote_cache,
|
515
|
+
# remote_cache_description=self.run_config.parameters.remote_cache_description,
|
516
|
+
# ):
|
517
|
+
runner = JobsRunnerAsyncio(self, environment=self.run_config.environment)
|
518
|
+
if run_job_async:
|
519
|
+
results = await runner.run_async(self.run_config.parameters)
|
520
|
+
else:
|
521
|
+
results = runner.run(self.run_config.parameters)
|
523
522
|
return results
|
524
523
|
|
525
|
-
def _setup_and_check(self) -> Tuple[RunConfig, Optional[Results]]:
|
524
|
+
# def _setup_and_check(self) -> Tuple[RunConfig, Optional[Results]]:
|
525
|
+
# self._prepare_to_run()
|
526
|
+
# self._check_if_remote_keys_ok()
|
526
527
|
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
return results
|
528
|
+
# # first try to run the job remotely
|
529
|
+
# results = self._remote_results()
|
530
|
+
# #breakpoint()
|
531
|
+
# if results is not None:
|
532
|
+
# return results
|
533
533
|
|
534
|
-
|
535
|
-
|
534
|
+
# self._check_if_local_keys_ok()
|
535
|
+
# return None
|
536
536
|
|
537
537
|
@property
|
538
538
|
def num_interviews(self):
|
539
539
|
if self.run_config.parameters.n is None:
|
540
540
|
return len(self)
|
541
541
|
else:
|
542
|
-
len(self) * self.run_config.parameters.n
|
542
|
+
return len(self) * self.run_config.parameters.n
|
543
543
|
|
544
|
-
def _run(self, config: RunConfig):
|
544
|
+
def _run(self, config: RunConfig) -> Union[None, "Results"]:
|
545
545
|
"Shared code for run and run_async"
|
546
546
|
if config.environment.cache is not None:
|
547
547
|
self.run_config.environment.cache = config.environment.cache
|
548
|
+
if config.environment.jobs_runner_status is not None:
|
549
|
+
self.run_config.environment.jobs_runner_status = (
|
550
|
+
config.environment.jobs_runner_status
|
551
|
+
)
|
548
552
|
|
549
553
|
if config.environment.bucket_collection is not None:
|
550
554
|
self.run_config.environment.bucket_collection = (
|
@@ -579,7 +583,7 @@ class Jobs(Base):
|
|
579
583
|
# first try to run the job remotely
|
580
584
|
if results := self._remote_results():
|
581
585
|
return results
|
582
|
-
|
586
|
+
|
583
587
|
self._check_if_local_keys_ok()
|
584
588
|
|
585
589
|
if config.environment.bucket_collection is None:
|
@@ -587,6 +591,8 @@ class Jobs(Base):
|
|
587
591
|
self.create_bucket_collection()
|
588
592
|
)
|
589
593
|
|
594
|
+
return None
|
595
|
+
|
590
596
|
@with_config
|
591
597
|
def run(self, *, config: RunConfig) -> "Results":
|
592
598
|
"""
|
@@ -606,7 +612,10 @@ class Jobs(Base):
|
|
606
612
|
:param bucket_collection: A BucketCollection object to track API calls
|
607
613
|
:param key_lookup: A KeyLookup object to manage API keys
|
608
614
|
"""
|
609
|
-
self._run(config)
|
615
|
+
potentially_completed_results = self._run(config)
|
616
|
+
|
617
|
+
if potentially_completed_results is not None:
|
618
|
+
return potentially_completed_results
|
610
619
|
|
611
620
|
return asyncio.run(self._execute_with_remote_cache(run_job_async=False))
|
612
621
|
|
@@ -646,20 +655,19 @@ class Jobs(Base):
|
|
646
655
|
}
|
647
656
|
|
648
657
|
def __len__(self) -> int:
|
649
|
-
"""Return the
|
650
|
-
|
658
|
+
"""Return the number of interviews that will be conducted for one iteration of this job.
|
659
|
+
An interview is the result of one survey, taken by one agent, with one model, with one scenario.
|
651
660
|
|
652
661
|
>>> from edsl.jobs import Jobs
|
653
662
|
>>> len(Jobs.example())
|
654
|
-
|
663
|
+
4
|
655
664
|
"""
|
656
|
-
|
665
|
+
number_of_interviews = (
|
657
666
|
len(self.agents or [1])
|
658
667
|
* len(self.scenarios or [1])
|
659
668
|
* len(self.models or [1])
|
660
|
-
* len(self.survey)
|
661
669
|
)
|
662
|
-
return
|
670
|
+
return number_of_interviews
|
663
671
|
|
664
672
|
def to_dict(self, add_edsl_version=True):
|
665
673
|
d = {
|
@@ -810,9 +818,9 @@ def main():
|
|
810
818
|
from edsl.data.Cache import Cache
|
811
819
|
|
812
820
|
job = Jobs.example()
|
813
|
-
len(job) ==
|
821
|
+
len(job) == 4
|
814
822
|
results = job.run(cache=Cache())
|
815
|
-
len(results) ==
|
823
|
+
len(results) == 4
|
816
824
|
results
|
817
825
|
|
818
826
|
|
edsl/jobs/JobsPrompts.py
CHANGED
@@ -12,6 +12,7 @@ if TYPE_CHECKING:
|
|
12
12
|
# from edsl.surveys.Survey import Survey
|
13
13
|
|
14
14
|
from edsl.jobs.FetchInvigilator import FetchInvigilator
|
15
|
+
from edsl.data.CacheEntry import CacheEntry
|
15
16
|
|
16
17
|
|
17
18
|
class JobsPrompts:
|
@@ -47,7 +48,7 @@ class JobsPrompts:
|
|
47
48
|
agent_indices = []
|
48
49
|
models = []
|
49
50
|
costs = []
|
50
|
-
|
51
|
+
cache_keys = []
|
51
52
|
for interview_index, interview in enumerate(interviews):
|
52
53
|
invigilators = [
|
53
54
|
FetchInvigilator(interview)(question)
|
@@ -76,6 +77,14 @@ class JobsPrompts:
|
|
76
77
|
)
|
77
78
|
costs.append(prompt_cost["cost_usd"])
|
78
79
|
|
80
|
+
cache_key = CacheEntry.gen_key(
|
81
|
+
model=invigilator.model.model,
|
82
|
+
parameters=invigilator.model.parameters,
|
83
|
+
system_prompt=system_prompt,
|
84
|
+
user_prompt=user_prompt,
|
85
|
+
iteration=0, # TODO how to handle when there are multiple iterations?
|
86
|
+
)
|
87
|
+
cache_keys.append(cache_key)
|
79
88
|
d = Dataset(
|
80
89
|
[
|
81
90
|
{"user_prompt": user_prompts},
|
@@ -86,6 +95,7 @@ class JobsPrompts:
|
|
86
95
|
{"agent_index": agent_indices},
|
87
96
|
{"model": models},
|
88
97
|
{"estimated_cost": costs},
|
98
|
+
{"cache_key": cache_keys},
|
89
99
|
]
|
90
100
|
)
|
91
101
|
return d
|
@@ -219,6 +219,7 @@ class JobsRemoteInferenceHandler:
|
|
219
219
|
job_info.logger.add_info("results_uuid", results_uuid)
|
220
220
|
results = object_fetcher(results_uuid, expected_object_type="results")
|
221
221
|
results_url = remote_job_data.get("results_url")
|
222
|
+
job_info.logger.add_info("results_url", results_url)
|
222
223
|
job_info.logger.update(
|
223
224
|
f"Job completed and Results stored on Coop: {results_url}",
|
224
225
|
status=JobsStatus.COMPLETED,
|
@@ -32,7 +32,7 @@ class JobsInfo:
|
|
32
32
|
pretty_names = {
|
33
33
|
"job_uuid": "Job UUID",
|
34
34
|
"progress_bar_url": "Progress Bar URL",
|
35
|
-
"error_report_url": "
|
35
|
+
"error_report_url": "Exceptions Report URL",
|
36
36
|
"results_uuid": "Results UUID",
|
37
37
|
"results_url": "Results URL",
|
38
38
|
}
|
@@ -153,7 +153,7 @@ class Interview:
|
|
153
153
|
|
154
154
|
>>> i = Interview.example()
|
155
155
|
>>> hash(i)
|
156
|
-
|
156
|
+
767745459362662063
|
157
157
|
"""
|
158
158
|
d = {
|
159
159
|
"agent": self.agent.to_dict(add_edsl_version=add_edsl_version),
|
@@ -213,10 +213,6 @@ class Interview:
|
|
213
213
|
async def async_conduct_interview(
|
214
214
|
self,
|
215
215
|
run_config: Optional["RunConfig"] = None,
|
216
|
-
# model_buckets: Optional[ModelBuckets] = None,
|
217
|
-
# stop_on_exception: bool = False,
|
218
|
-
# raise_validation_errors: bool = True,
|
219
|
-
# key_lookup: Optional[KeyLookup] = None,
|
220
216
|
) -> tuple["Answers", List[dict[str, Any]]]:
|
221
217
|
"""
|
222
218
|
Conduct an Interview asynchronously.
|
@@ -313,7 +309,7 @@ class Interview:
|
|
313
309
|
|
314
310
|
def handle_task(task, invigilator):
|
315
311
|
try:
|
316
|
-
result = task.result()
|
312
|
+
result: Answers = task.result()
|
317
313
|
except asyncio.CancelledError as e: # task was cancelled
|
318
314
|
result = invigilator.get_failed_task_result(
|
319
315
|
failure_reason="Task was cancelled."
|