edsl 0.1.43__py3-none-any.whl → 0.1.45__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +15 -6
- edsl/__version__.py +1 -1
- edsl/agents/InvigilatorBase.py +3 -1
- edsl/agents/PromptConstructor.py +62 -34
- edsl/agents/QuestionInstructionPromptBuilder.py +111 -68
- edsl/agents/QuestionTemplateReplacementsBuilder.py +69 -16
- edsl/agents/question_option_processor.py +15 -6
- edsl/coop/CoopFunctionsMixin.py +3 -4
- edsl/coop/coop.py +56 -10
- edsl/enums.py +4 -1
- edsl/inference_services/AnthropicService.py +12 -8
- edsl/inference_services/AvailableModelFetcher.py +2 -0
- edsl/inference_services/AwsBedrock.py +1 -2
- edsl/inference_services/AzureAI.py +12 -9
- edsl/inference_services/GoogleService.py +10 -3
- edsl/inference_services/InferenceServiceABC.py +1 -0
- edsl/inference_services/InferenceServicesCollection.py +2 -2
- edsl/inference_services/MistralAIService.py +1 -2
- edsl/inference_services/OpenAIService.py +10 -4
- edsl/inference_services/PerplexityService.py +2 -1
- edsl/inference_services/TestService.py +1 -0
- edsl/inference_services/XAIService.py +11 -0
- edsl/inference_services/registry.py +2 -0
- edsl/jobs/Jobs.py +9 -0
- edsl/jobs/JobsChecks.py +11 -14
- edsl/jobs/JobsPrompts.py +3 -3
- edsl/jobs/async_interview_runner.py +3 -1
- edsl/jobs/check_survey_scenario_compatibility.py +5 -5
- edsl/jobs/interviews/InterviewExceptionEntry.py +12 -0
- edsl/jobs/tasks/TaskHistory.py +1 -1
- edsl/language_models/LanguageModel.py +3 -3
- edsl/language_models/PriceManager.py +45 -5
- edsl/language_models/model.py +89 -36
- edsl/questions/QuestionBase.py +21 -0
- edsl/questions/QuestionBasePromptsMixin.py +103 -0
- edsl/questions/QuestionFreeText.py +22 -5
- edsl/questions/descriptors.py +4 -0
- edsl/questions/question_base_gen_mixin.py +94 -29
- edsl/results/Dataset.py +65 -0
- edsl/results/DatasetExportMixin.py +299 -32
- edsl/results/Result.py +27 -0
- edsl/results/Results.py +24 -3
- edsl/results/ResultsGGMixin.py +7 -3
- edsl/scenarios/DocumentChunker.py +2 -0
- edsl/scenarios/FileStore.py +29 -8
- edsl/scenarios/PdfExtractor.py +21 -1
- edsl/scenarios/Scenario.py +25 -9
- edsl/scenarios/ScenarioList.py +73 -3
- edsl/scenarios/handlers/__init__.py +1 -0
- edsl/scenarios/handlers/docx.py +5 -1
- edsl/scenarios/handlers/jpeg.py +39 -0
- edsl/surveys/Survey.py +28 -6
- edsl/surveys/SurveyFlowVisualization.py +91 -43
- edsl/templates/error_reporting/exceptions_table.html +7 -8
- edsl/templates/error_reporting/interview_details.html +1 -1
- edsl/templates/error_reporting/interviews.html +0 -1
- edsl/templates/error_reporting/overview.html +2 -7
- edsl/templates/error_reporting/performance_plot.html +1 -1
- edsl/templates/error_reporting/report.css +1 -1
- edsl/utilities/PrettyList.py +14 -0
- edsl-0.1.45.dist-info/METADATA +246 -0
- {edsl-0.1.43.dist-info → edsl-0.1.45.dist-info}/RECORD +64 -62
- edsl-0.1.43.dist-info/METADATA +0 -110
- {edsl-0.1.43.dist-info → edsl-0.1.45.dist-info}/LICENSE +0 -0
- {edsl-0.1.43.dist-info → edsl-0.1.45.dist-info}/WHEEL +0 -0
edsl/coop/coop.py
CHANGED
@@ -4,10 +4,8 @@ import requests
|
|
4
4
|
|
5
5
|
from typing import Any, Optional, Union, Literal, TypedDict
|
6
6
|
from uuid import UUID
|
7
|
-
from collections import UserDict, defaultdict
|
8
7
|
|
9
8
|
import edsl
|
10
|
-
from pathlib import Path
|
11
9
|
|
12
10
|
from edsl.config import CONFIG
|
13
11
|
from edsl.data.CacheEntry import CacheEntry
|
@@ -192,7 +190,7 @@ class Coop(CoopFunctionsMixin):
|
|
192
190
|
server_version_str=server_edsl_version,
|
193
191
|
):
|
194
192
|
print(
|
195
|
-
"Please upgrade your EDSL version to access our latest features.
|
193
|
+
"Please upgrade your EDSL version to access our latest features. Open your terminal and run `pip install --upgrade edsl`"
|
196
194
|
)
|
197
195
|
|
198
196
|
if response.status_code >= 400:
|
@@ -214,7 +212,7 @@ class Coop(CoopFunctionsMixin):
|
|
214
212
|
print("Your Expected Parrot API key is invalid.")
|
215
213
|
self._display_login_url(
|
216
214
|
edsl_auth_token=edsl_auth_token,
|
217
|
-
link_description="\n🔗 Use the link below to log in to
|
215
|
+
link_description="\n🔗 Use the link below to log in to your account and automatically update your API key.",
|
218
216
|
)
|
219
217
|
api_key = self._poll_for_api_key(edsl_auth_token)
|
220
218
|
|
@@ -340,7 +338,7 @@ class Coop(CoopFunctionsMixin):
|
|
340
338
|
|
341
339
|
try:
|
342
340
|
response = self._send_server_request(
|
343
|
-
uri="api/v0/edsl-settings", method="GET", timeout=
|
341
|
+
uri="api/v0/edsl-settings", method="GET", timeout=20
|
344
342
|
)
|
345
343
|
self._resolve_server_response(response, check_api_key=False)
|
346
344
|
return response.json()
|
@@ -866,6 +864,41 @@ class Coop(CoopFunctionsMixin):
|
|
866
864
|
"usd": response_json.get("cost_in_usd"),
|
867
865
|
}
|
868
866
|
|
867
|
+
################
|
868
|
+
# PROJECTS
|
869
|
+
################
|
870
|
+
def create_project(
|
871
|
+
self,
|
872
|
+
survey: Survey,
|
873
|
+
project_name: str = "Project",
|
874
|
+
survey_description: Optional[str] = None,
|
875
|
+
survey_alias: Optional[str] = None,
|
876
|
+
survey_visibility: Optional[VisibilityType] = "unlisted",
|
877
|
+
):
|
878
|
+
"""
|
879
|
+
Create a survey object on Coop, then create a project from the survey.
|
880
|
+
"""
|
881
|
+
survey_details = self.create(
|
882
|
+
object=survey,
|
883
|
+
description=survey_description,
|
884
|
+
alias=survey_alias,
|
885
|
+
visibility=survey_visibility,
|
886
|
+
)
|
887
|
+
survey_uuid = survey_details.get("uuid")
|
888
|
+
response = self._send_server_request(
|
889
|
+
uri=f"api/v0/projects/create-from-survey",
|
890
|
+
method="POST",
|
891
|
+
payload={"project_name": project_name, "survey_uuid": str(survey_uuid)},
|
892
|
+
)
|
893
|
+
self._resolve_server_response(response)
|
894
|
+
response_json = response.json()
|
895
|
+
return {
|
896
|
+
"name": response_json.get("project_name"),
|
897
|
+
"uuid": response_json.get("uuid"),
|
898
|
+
"admin_url": f"{self.url}/home/projects/{response_json.get('uuid')}",
|
899
|
+
"respondent_url": f"{self.url}/respond/{response_json.get('uuid')}",
|
900
|
+
}
|
901
|
+
|
869
902
|
################
|
870
903
|
# DUNDER METHODS
|
871
904
|
################
|
@@ -995,15 +1028,28 @@ class Coop(CoopFunctionsMixin):
|
|
995
1028
|
- We need this function because URL detection with print() does not work alongside animations in VSCode.
|
996
1029
|
"""
|
997
1030
|
from rich import print as rich_print
|
1031
|
+
from rich.console import Console
|
1032
|
+
|
1033
|
+
console = Console()
|
998
1034
|
|
999
1035
|
url = f"{CONFIG.EXPECTED_PARROT_URL}/login?edsl_auth_token={edsl_auth_token}"
|
1000
1036
|
|
1001
|
-
if
|
1002
|
-
|
1003
|
-
|
1004
|
-
|
1037
|
+
if console.is_terminal:
|
1038
|
+
# Running in a standard terminal, show the full URL
|
1039
|
+
if link_description:
|
1040
|
+
rich_print("{link_description}\n[#38bdf8][link={url}]{url}[/link][/#38bdf8]")
|
1041
|
+
else:
|
1042
|
+
rich_print(f"[#38bdf8][link={url}]{url}[/link][/#38bdf8]")
|
1005
1043
|
else:
|
1006
|
-
|
1044
|
+
# Running in an interactive environment (e.g., Jupyter Notebook), hide the URL
|
1045
|
+
if link_description:
|
1046
|
+
rich_print(f"{link_description}\n[#38bdf8][link={url}][underline]Log in and automatically store key[/underline][/link][/#38bdf8]")
|
1047
|
+
else:
|
1048
|
+
rich_print(f"[#38bdf8][link={url}][underline]Log in and automatically store key[/underline][/link][/#38bdf8]")
|
1049
|
+
|
1050
|
+
|
1051
|
+
|
1052
|
+
|
1007
1053
|
|
1008
1054
|
def _get_api_key(self, edsl_auth_token: str):
|
1009
1055
|
"""
|
edsl/enums.py
CHANGED
@@ -67,6 +67,7 @@ class InferenceServiceType(EnumWithChecks):
|
|
67
67
|
TOGETHER = "together"
|
68
68
|
PERPLEXITY = "perplexity"
|
69
69
|
DEEPSEEK = "deepseek"
|
70
|
+
XAI = "xai"
|
70
71
|
|
71
72
|
|
72
73
|
# unavoidable violation of the DRY principle but it is necessary
|
@@ -86,6 +87,7 @@ InferenceServiceLiteral = Literal[
|
|
86
87
|
"together",
|
87
88
|
"perplexity",
|
88
89
|
"deepseek",
|
90
|
+
"xai",
|
89
91
|
]
|
90
92
|
|
91
93
|
available_models_urls = {
|
@@ -108,7 +110,8 @@ service_to_api_keyname = {
|
|
108
110
|
InferenceServiceType.MISTRAL.value: "MISTRAL_API_KEY",
|
109
111
|
InferenceServiceType.TOGETHER.value: "TOGETHER_API_KEY",
|
110
112
|
InferenceServiceType.PERPLEXITY.value: "PERPLEXITY_API_KEY",
|
111
|
-
InferenceServiceType.DEEPSEEK.value: "DEEPSEEK_API_KEY"
|
113
|
+
InferenceServiceType.DEEPSEEK.value: "DEEPSEEK_API_KEY",
|
114
|
+
InferenceServiceType.XAI.value: "XAI_API_KEY",
|
112
115
|
}
|
113
116
|
|
114
117
|
|
@@ -17,9 +17,10 @@ class AnthropicService(InferenceServiceABC):
|
|
17
17
|
output_token_name = "output_tokens"
|
18
18
|
model_exclude_list = []
|
19
19
|
|
20
|
+
available_models_url = "https://docs.anthropic.com/en/docs/about-claude/models"
|
21
|
+
|
20
22
|
@classmethod
|
21
23
|
def get_model_list(cls, api_key: str = None):
|
22
|
-
|
23
24
|
import requests
|
24
25
|
|
25
26
|
if api_key is None:
|
@@ -92,13 +93,16 @@ class AnthropicService(InferenceServiceABC):
|
|
92
93
|
# breakpoint()
|
93
94
|
client = AsyncAnthropic(api_key=self.api_token)
|
94
95
|
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
96
|
+
try:
|
97
|
+
response = await client.messages.create(
|
98
|
+
model=model_name,
|
99
|
+
max_tokens=self.max_tokens,
|
100
|
+
temperature=self.temperature,
|
101
|
+
system=system_prompt, # note that the Anthropic API uses "system" parameter rather than put it in the message
|
102
|
+
messages=messages,
|
103
|
+
)
|
104
|
+
except Exception as e:
|
105
|
+
return {"message": str(e)}
|
102
106
|
return response.model_dump()
|
103
107
|
|
104
108
|
LLM.__name__ = model_class_name
|
@@ -69,6 +69,8 @@ class AvailableModelFetcher:
|
|
69
69
|
|
70
70
|
Returns a list of [model, service_name, index] entries.
|
71
71
|
"""
|
72
|
+
if service == "azure":
|
73
|
+
force_refresh = True # Azure models are listed inside the .env AZURE_ENDPOINT_URL_AND_KEY variable
|
72
74
|
|
73
75
|
if service: # they passed a specific service
|
74
76
|
matching_models, _ = self.get_available_models_by_service(
|
@@ -179,15 +179,18 @@ class AzureAIService(InferenceServiceABC):
|
|
179
179
|
api_version=api_version,
|
180
180
|
api_key=api_key,
|
181
181
|
)
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
182
|
+
try:
|
183
|
+
response = await client.chat.completions.create(
|
184
|
+
model=model_name,
|
185
|
+
messages=[
|
186
|
+
{
|
187
|
+
"role": "user",
|
188
|
+
"content": user_prompt, # Your question can go here
|
189
|
+
},
|
190
|
+
],
|
191
|
+
)
|
192
|
+
except Exception as e:
|
193
|
+
return {"message": str(e)}
|
191
194
|
return response.model_dump()
|
192
195
|
|
193
196
|
# @staticmethod
|
@@ -39,6 +39,10 @@ class GoogleService(InferenceServiceABC):
|
|
39
39
|
|
40
40
|
model_exclude_list = []
|
41
41
|
|
42
|
+
available_models_url = (
|
43
|
+
"https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models"
|
44
|
+
)
|
45
|
+
|
42
46
|
@classmethod
|
43
47
|
def get_model_list(cls):
|
44
48
|
model_list = []
|
@@ -130,9 +134,12 @@ class GoogleService(InferenceServiceABC):
|
|
130
134
|
)
|
131
135
|
combined_prompt.append(gen_ai_file)
|
132
136
|
|
133
|
-
|
134
|
-
|
135
|
-
|
137
|
+
try:
|
138
|
+
response = await self.generative_model.generate_content_async(
|
139
|
+
combined_prompt, generation_config=generation_config
|
140
|
+
)
|
141
|
+
except Exception as e:
|
142
|
+
return {"message": str(e)}
|
136
143
|
return response.to_dict()
|
137
144
|
|
138
145
|
LLM.__name__ = model_name
|
@@ -104,8 +104,9 @@ class InferenceServicesCollection:
|
|
104
104
|
def available(
|
105
105
|
self,
|
106
106
|
service: Optional[str] = None,
|
107
|
+
force_refresh: bool = False,
|
107
108
|
) -> List[Tuple[str, str, int]]:
|
108
|
-
return self.availability_fetcher.available(service)
|
109
|
+
return self.availability_fetcher.available(service, force_refresh=force_refresh)
|
109
110
|
|
110
111
|
def reset_cache(self) -> None:
|
111
112
|
self.availability_fetcher.reset_cache()
|
@@ -120,7 +121,6 @@ class InferenceServicesCollection:
|
|
120
121
|
def create_model_factory(
|
121
122
|
self, model_name: str, service_name: Optional[InferenceServiceLiteral] = None
|
122
123
|
) -> "LanguageModel":
|
123
|
-
|
124
124
|
if service_name is None: # we try to find the right service
|
125
125
|
service = self.resolver.resolve_model(model_name, service_name)
|
126
126
|
else: # if they passed a service, we'll use that
|
@@ -111,8 +111,7 @@ class MistralAIService(InferenceServiceABC):
|
|
111
111
|
],
|
112
112
|
)
|
113
113
|
except Exception as e:
|
114
|
-
|
115
|
-
|
114
|
+
return {"message": str(e)}
|
116
115
|
return res.model_dump()
|
117
116
|
|
118
117
|
LLM.__name__ = model_class_name
|
@@ -84,6 +84,7 @@ class OpenAIService(InferenceServiceABC):
|
|
84
84
|
|
85
85
|
@classmethod
|
86
86
|
def get_model_list(cls, api_key=None):
|
87
|
+
# breakpoint()
|
87
88
|
if api_key is None:
|
88
89
|
api_key = os.getenv(cls._env_key_name_)
|
89
90
|
raw_list = cls.sync_client(api_key).models.list()
|
@@ -206,8 +207,10 @@ class OpenAIService(InferenceServiceABC):
|
|
206
207
|
{"role": "user", "content": content},
|
207
208
|
]
|
208
209
|
if (
|
209
|
-
system_prompt == "" and self.omit_system_prompt_if_empty
|
210
|
-
|
210
|
+
(system_prompt == "" and self.omit_system_prompt_if_empty)
|
211
|
+
or "o1" in self.model
|
212
|
+
or "o3" in self.model
|
213
|
+
):
|
211
214
|
messages = messages[1:]
|
212
215
|
|
213
216
|
params = {
|
@@ -221,14 +224,17 @@ class OpenAIService(InferenceServiceABC):
|
|
221
224
|
"logprobs": self.logprobs,
|
222
225
|
"top_logprobs": self.top_logprobs if self.logprobs else None,
|
223
226
|
}
|
224
|
-
if "o1" in self.model:
|
227
|
+
if "o1" in self.model or "o3" in self.model:
|
225
228
|
params.pop("max_tokens")
|
226
229
|
params["max_completion_tokens"] = self.max_tokens
|
227
230
|
params["temperature"] = 1
|
228
231
|
try:
|
229
232
|
response = await client.chat.completions.create(**params)
|
230
233
|
except Exception as e:
|
231
|
-
|
234
|
+
#breakpoint()
|
235
|
+
#print(e)
|
236
|
+
#raise e
|
237
|
+
return {'message': str(e)}
|
232
238
|
return response.model_dump()
|
233
239
|
|
234
240
|
LLM.__name__ = "LanguageModel"
|
@@ -152,7 +152,8 @@ class PerplexityService(OpenAIService):
|
|
152
152
|
try:
|
153
153
|
response = await client.chat.completions.create(**params)
|
154
154
|
except Exception as e:
|
155
|
-
|
155
|
+
return {"message": str(e)}
|
156
|
+
|
156
157
|
return response.model_dump()
|
157
158
|
|
158
159
|
LLM.__name__ = "LanguageModel"
|
@@ -0,0 +1,11 @@
|
|
1
|
+
from typing import Any, List
|
2
|
+
from edsl.inference_services.OpenAIService import OpenAIService
|
3
|
+
|
4
|
+
|
5
|
+
class XAIService(OpenAIService):
|
6
|
+
"""Openai service class."""
|
7
|
+
|
8
|
+
_inference_service_ = "xai"
|
9
|
+
_env_key_name_ = "XAI_API_KEY"
|
10
|
+
_base_url_ = "https://api.x.ai/v1"
|
11
|
+
_models_list_cache: List[str] = []
|
@@ -14,6 +14,7 @@ from edsl.inference_services.TestService import TestService
|
|
14
14
|
from edsl.inference_services.TogetherAIService import TogetherAIService
|
15
15
|
from edsl.inference_services.PerplexityService import PerplexityService
|
16
16
|
from edsl.inference_services.DeepSeekService import DeepSeekService
|
17
|
+
from edsl.inference_services.XAIService import XAIService
|
17
18
|
|
18
19
|
try:
|
19
20
|
from edsl.inference_services.MistralAIService import MistralAIService
|
@@ -35,6 +36,7 @@ services = [
|
|
35
36
|
TogetherAIService,
|
36
37
|
PerplexityService,
|
37
38
|
DeepSeekService,
|
39
|
+
XAIService,
|
38
40
|
]
|
39
41
|
|
40
42
|
if mistral_available:
|
edsl/jobs/Jobs.py
CHANGED
@@ -364,6 +364,15 @@ class Jobs(Base):
|
|
364
364
|
self, cache=self.run_config.environment.cache
|
365
365
|
).create_interviews()
|
366
366
|
|
367
|
+
def show_flow(self, filename: Optional[str] = None) -> None:
|
368
|
+
"""Show the flow of the survey."""
|
369
|
+
from edsl.surveys.SurveyFlowVisualization import SurveyFlowVisualization
|
370
|
+
if self.scenarios:
|
371
|
+
scenario = self.scenarios[0]
|
372
|
+
else:
|
373
|
+
scenario = None
|
374
|
+
SurveyFlowVisualization(self.survey, scenario=scenario, agent=None).show_flow(filename=filename)
|
375
|
+
|
367
376
|
def interviews(self) -> list[Interview]:
|
368
377
|
"""
|
369
378
|
Return a list of :class:`edsl.jobs.interviews.Interview` objects.
|
edsl/jobs/JobsChecks.py
CHANGED
@@ -24,14 +24,14 @@ class JobsChecks:
|
|
24
24
|
|
25
25
|
def get_missing_api_keys(self) -> set:
|
26
26
|
"""
|
27
|
-
Returns a list of the
|
27
|
+
Returns a list of the API keys that a user needs to run this job, but does not currently have in their .env file.
|
28
28
|
"""
|
29
29
|
missing_api_keys = set()
|
30
30
|
|
31
31
|
from edsl.language_models.model import Model
|
32
32
|
from edsl.enums import service_to_api_keyname
|
33
33
|
|
34
|
-
for model in self.jobs.models + [Model()]:
|
34
|
+
for model in self.jobs.models: # + [Model()]:
|
35
35
|
if not model.has_valid_api_key():
|
36
36
|
key_name = service_to_api_keyname.get(
|
37
37
|
model._inference_service_, "NOT FOUND"
|
@@ -134,22 +134,20 @@ class JobsChecks:
|
|
134
134
|
|
135
135
|
edsl_auth_token = secrets.token_urlsafe(16)
|
136
136
|
|
137
|
-
print("
|
137
|
+
print("\nThe following keys are needed to run this survey: \n")
|
138
138
|
for api_key in missing_api_keys:
|
139
|
-
print(f"
|
139
|
+
print(f"🔑 {api_key}")
|
140
140
|
print(
|
141
|
-
"
|
141
|
+
"""
|
142
|
+
\nYou can provide your own keys for language models or use an Expected Parrot key to access all available models.
|
143
|
+
\nClick the link below to create an account and run your survey with your Expected Parrot key:
|
144
|
+
"""
|
142
145
|
)
|
143
|
-
|
144
|
-
|
146
|
+
|
145
147
|
coop = Coop()
|
146
148
|
coop._display_login_url(
|
147
149
|
edsl_auth_token=edsl_auth_token,
|
148
|
-
link_description="
|
149
|
-
)
|
150
|
-
|
151
|
-
print(
|
152
|
-
"\nOnce you log in, we will automatically retrieve your Expected Parrot API key and continue your job remotely."
|
150
|
+
# link_description="",
|
153
151
|
)
|
154
152
|
|
155
153
|
api_key = coop._poll_for_api_key(edsl_auth_token)
|
@@ -159,8 +157,7 @@ class JobsChecks:
|
|
159
157
|
return
|
160
158
|
|
161
159
|
path_to_env = write_api_key_to_env(api_key)
|
162
|
-
print("\n✨
|
163
|
-
print(f" {path_to_env}")
|
160
|
+
print(f"\n✨ Your Expected Parrot key has been stored at the following path: {path_to_env}\n")
|
164
161
|
|
165
162
|
# Retrieve API key so we can continue running the job
|
166
163
|
load_dotenv()
|
edsl/jobs/JobsPrompts.py
CHANGED
@@ -200,10 +200,10 @@ class JobsPrompts:
|
|
200
200
|
import warnings
|
201
201
|
|
202
202
|
warnings.warn(
|
203
|
-
"Price data could not be retrieved. Using default estimates for input and output token prices. Input: $
|
203
|
+
"Price data could not be retrieved. Using default estimates for input and output token prices. Input: $1.00 / 1M tokens; Output: $1.00 / 1M tokens"
|
204
204
|
)
|
205
|
-
input_price_per_token = 0.
|
206
|
-
output_price_per_token = 0.
|
205
|
+
input_price_per_token = 0.000001 # $1.00 / 1M tokens
|
206
|
+
output_price_per_token = 0.000001 # $1.00 / 1M tokens
|
207
207
|
|
208
208
|
# Compute the number of characters (double if the question involves piping)
|
209
209
|
user_prompt_chars = len(str(user_prompt)) * get_piping_multiplier(
|
@@ -7,6 +7,8 @@ from edsl.data_transfer_models import EDSLResultObjectInput
|
|
7
7
|
|
8
8
|
from edsl.results.Result import Result
|
9
9
|
from edsl.jobs.interviews.Interview import Interview
|
10
|
+
from edsl.config import Config
|
11
|
+
config = Config()
|
10
12
|
|
11
13
|
if TYPE_CHECKING:
|
12
14
|
from edsl.jobs.Jobs import Jobs
|
@@ -23,7 +25,7 @@ from edsl.jobs.data_structures import RunConfig
|
|
23
25
|
|
24
26
|
|
25
27
|
class AsyncInterviewRunner:
|
26
|
-
MAX_CONCURRENT =
|
28
|
+
MAX_CONCURRENT = int(config.EDSL_MAX_CONCURRENT_TASKS)
|
27
29
|
|
28
30
|
def __init__(self, jobs: "Jobs", run_config: RunConfig):
|
29
31
|
self.jobs = jobs
|
@@ -72,11 +72,11 @@ class CheckSurveyScenarioCompatibility:
|
|
72
72
|
if warn:
|
73
73
|
warnings.warn(message)
|
74
74
|
|
75
|
-
if self.scenarios.has_jinja_braces:
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
75
|
+
# if self.scenarios.has_jinja_braces:
|
76
|
+
# warnings.warn(
|
77
|
+
# "The scenarios have Jinja braces ({{ and }}). Converting to '<<' and '>>'. If you want a different conversion, use the convert_jinja_braces method first to modify the scenario."
|
78
|
+
# )
|
79
|
+
# self.scenarios = self.scenarios._convert_jinja_braces()
|
80
80
|
|
81
81
|
|
82
82
|
if __name__ == "__main__":
|
@@ -166,6 +166,9 @@ class InterviewExceptionEntry:
|
|
166
166
|
>>> entry = InterviewExceptionEntry.example()
|
167
167
|
>>> _ = entry.to_dict()
|
168
168
|
"""
|
169
|
+
import json
|
170
|
+
from edsl.exceptions.questions import QuestionAnswerValidationError
|
171
|
+
|
169
172
|
invigilator = (
|
170
173
|
self.invigilator.to_dict() if self.invigilator is not None else None
|
171
174
|
)
|
@@ -174,7 +177,16 @@ class InterviewExceptionEntry:
|
|
174
177
|
"time": self.time,
|
175
178
|
"traceback": self.traceback,
|
176
179
|
"invigilator": invigilator,
|
180
|
+
"additional_data": {},
|
177
181
|
}
|
182
|
+
|
183
|
+
if isinstance(self.exception, QuestionAnswerValidationError):
|
184
|
+
d["additional_data"]["edsl_response"] = json.dumps(self.exception.data)
|
185
|
+
d["additional_data"]["validating_model"] = json.dumps(
|
186
|
+
self.exception.model.model_json_schema()
|
187
|
+
)
|
188
|
+
d["additional_data"]["error_message"] = str(self.exception.message)
|
189
|
+
|
178
190
|
return d
|
179
191
|
|
180
192
|
@classmethod
|
edsl/jobs/tasks/TaskHistory.py
CHANGED
@@ -419,7 +419,7 @@ class TaskHistory(RepresentationMixin):
|
|
419
419
|
filename: Optional[str] = None,
|
420
420
|
return_link=False,
|
421
421
|
css=None,
|
422
|
-
cta="
|
422
|
+
cta="<br><span style='font-size: 18px; font-weight: medium-bold; text-decoration: underline;'>Click to open the report in a new tab</span><br><br>",
|
423
423
|
open_in_browser=False,
|
424
424
|
):
|
425
425
|
"""Return an HTML report."""
|
@@ -394,7 +394,6 @@ class LanguageModel(
|
|
394
394
|
from edsl.config import CONFIG
|
395
395
|
|
396
396
|
TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
|
397
|
-
|
398
397
|
response = await asyncio.wait_for(f(**params), timeout=TIMEOUT)
|
399
398
|
new_cache_key = cache.store(
|
400
399
|
**cache_call_params, response=response
|
@@ -518,7 +517,9 @@ class LanguageModel(
|
|
518
517
|
"""
|
519
518
|
from edsl.language_models.model import get_model_class
|
520
519
|
|
521
|
-
model_class = get_model_class(
|
520
|
+
model_class = get_model_class(
|
521
|
+
data["model"], service_name=data.get("inference_service", None)
|
522
|
+
)
|
522
523
|
return model_class(**data)
|
523
524
|
|
524
525
|
def __repr__(self) -> str:
|
@@ -574,7 +575,6 @@ class LanguageModel(
|
|
574
575
|
return Model(skip_api_key_check=True)
|
575
576
|
|
576
577
|
def from_cache(self, cache: "Cache") -> LanguageModel:
|
577
|
-
|
578
578
|
from copy import deepcopy
|
579
579
|
from types import MethodType
|
580
580
|
from edsl import Cache
|
@@ -30,19 +30,22 @@ class PriceManager:
|
|
30
30
|
except Exception as e:
|
31
31
|
print(f"Error fetching prices: {str(e)}")
|
32
32
|
|
33
|
-
def get_price(self, inference_service: str, model: str) ->
|
33
|
+
def get_price(self, inference_service: str, model: str) -> Dict:
|
34
34
|
"""
|
35
35
|
Get the price information for a specific service and model combination.
|
36
|
+
If no specific price is found, returns a fallback price.
|
36
37
|
|
37
38
|
Args:
|
38
39
|
inference_service (str): The name of the inference service
|
39
40
|
model (str): The model identifier
|
40
41
|
|
41
42
|
Returns:
|
42
|
-
|
43
|
+
Dict: Price information (either actual or fallback prices)
|
43
44
|
"""
|
44
45
|
key = (inference_service, model)
|
45
|
-
return self._price_lookup.get(key)
|
46
|
+
return self._price_lookup.get(key) or self._get_fallback_price(
|
47
|
+
inference_service
|
48
|
+
)
|
46
49
|
|
47
50
|
def get_all_prices(self) -> Dict[Tuple[str, str], Dict]:
|
48
51
|
"""
|
@@ -53,6 +56,45 @@ class PriceManager:
|
|
53
56
|
"""
|
54
57
|
return self._price_lookup.copy()
|
55
58
|
|
59
|
+
def _get_fallback_price(self, inference_service: str) -> Dict:
|
60
|
+
"""
|
61
|
+
Get fallback prices for a service.
|
62
|
+
- First fallback: The highest input and output prices for that service from the price lookup.
|
63
|
+
- Second fallback: $1.00 per million tokens (for both input and output).
|
64
|
+
|
65
|
+
Args:
|
66
|
+
inference_service (str): The inference service name
|
67
|
+
|
68
|
+
Returns:
|
69
|
+
Dict: Price information
|
70
|
+
"""
|
71
|
+
service_prices = [
|
72
|
+
prices
|
73
|
+
for (service, _), prices in self._price_lookup.items()
|
74
|
+
if service == inference_service
|
75
|
+
]
|
76
|
+
|
77
|
+
input_tokens_per_usd = [
|
78
|
+
float(p["input"]["one_usd_buys"]) for p in service_prices if "input" in p
|
79
|
+
]
|
80
|
+
if input_tokens_per_usd:
|
81
|
+
min_input_tokens = min(input_tokens_per_usd)
|
82
|
+
else:
|
83
|
+
min_input_tokens = 1_000_000
|
84
|
+
|
85
|
+
output_tokens_per_usd = [
|
86
|
+
float(p["output"]["one_usd_buys"]) for p in service_prices if "output" in p
|
87
|
+
]
|
88
|
+
if output_tokens_per_usd:
|
89
|
+
min_output_tokens = min(output_tokens_per_usd)
|
90
|
+
else:
|
91
|
+
min_output_tokens = 1_000_000
|
92
|
+
|
93
|
+
return {
|
94
|
+
"input": {"one_usd_buys": min_input_tokens},
|
95
|
+
"output": {"one_usd_buys": min_output_tokens},
|
96
|
+
}
|
97
|
+
|
56
98
|
def calculate_cost(
|
57
99
|
self,
|
58
100
|
inference_service: str,
|
@@ -75,8 +117,6 @@ class PriceManager:
|
|
75
117
|
Union[float, str]: Total cost if calculation successful, error message string if not
|
76
118
|
"""
|
77
119
|
relevant_prices = self.get_price(inference_service, model)
|
78
|
-
if relevant_prices is None:
|
79
|
-
return f"Could not find price for model {model} in the price lookup."
|
80
120
|
|
81
121
|
# Extract token counts
|
82
122
|
try:
|