edsl 0.1.44__py3-none-any.whl → 0.1.45__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__version__.py +1 -1
- edsl/agents/InvigilatorBase.py +3 -1
- edsl/agents/PromptConstructor.py +62 -34
- edsl/agents/QuestionInstructionPromptBuilder.py +111 -68
- edsl/agents/QuestionTemplateReplacementsBuilder.py +69 -16
- edsl/agents/question_option_processor.py +15 -6
- edsl/coop/CoopFunctionsMixin.py +3 -4
- edsl/coop/coop.py +23 -9
- edsl/enums.py +3 -3
- edsl/inference_services/AnthropicService.py +11 -9
- edsl/inference_services/AvailableModelFetcher.py +2 -0
- edsl/inference_services/AwsBedrock.py +1 -2
- edsl/inference_services/AzureAI.py +12 -9
- edsl/inference_services/GoogleService.py +9 -4
- edsl/inference_services/InferenceServicesCollection.py +2 -2
- edsl/inference_services/MistralAIService.py +1 -2
- edsl/inference_services/OpenAIService.py +9 -4
- edsl/inference_services/PerplexityService.py +2 -1
- edsl/inference_services/{GrokService.py → XAIService.py} +2 -2
- edsl/inference_services/registry.py +2 -2
- edsl/jobs/Jobs.py +9 -0
- edsl/jobs/JobsChecks.py +10 -13
- edsl/jobs/async_interview_runner.py +3 -1
- edsl/jobs/check_survey_scenario_compatibility.py +5 -5
- edsl/jobs/interviews/InterviewExceptionEntry.py +12 -0
- edsl/jobs/tasks/TaskHistory.py +1 -1
- edsl/language_models/LanguageModel.py +0 -3
- edsl/language_models/PriceManager.py +45 -5
- edsl/language_models/model.py +47 -26
- edsl/questions/QuestionBase.py +21 -0
- edsl/questions/QuestionBasePromptsMixin.py +103 -0
- edsl/questions/QuestionFreeText.py +22 -5
- edsl/questions/descriptors.py +4 -0
- edsl/questions/question_base_gen_mixin.py +94 -29
- edsl/results/Dataset.py +65 -0
- edsl/results/DatasetExportMixin.py +299 -32
- edsl/results/Result.py +27 -0
- edsl/results/Results.py +22 -2
- edsl/results/ResultsGGMixin.py +7 -3
- edsl/scenarios/DocumentChunker.py +2 -0
- edsl/scenarios/FileStore.py +10 -0
- edsl/scenarios/PdfExtractor.py +21 -1
- edsl/scenarios/Scenario.py +25 -9
- edsl/scenarios/ScenarioList.py +73 -3
- edsl/scenarios/handlers/__init__.py +1 -0
- edsl/scenarios/handlers/docx.py +5 -1
- edsl/scenarios/handlers/jpeg.py +39 -0
- edsl/surveys/Survey.py +5 -4
- edsl/surveys/SurveyFlowVisualization.py +91 -43
- edsl/templates/error_reporting/exceptions_table.html +7 -8
- edsl/templates/error_reporting/interview_details.html +1 -1
- edsl/templates/error_reporting/interviews.html +0 -1
- edsl/templates/error_reporting/overview.html +2 -7
- edsl/templates/error_reporting/performance_plot.html +1 -1
- edsl/templates/error_reporting/report.css +1 -1
- edsl/utilities/PrettyList.py +14 -0
- edsl-0.1.45.dist-info/METADATA +246 -0
- {edsl-0.1.44.dist-info → edsl-0.1.45.dist-info}/RECORD +60 -59
- edsl-0.1.44.dist-info/METADATA +0 -110
- {edsl-0.1.44.dist-info → edsl-0.1.45.dist-info}/LICENSE +0 -0
- {edsl-0.1.44.dist-info → edsl-0.1.45.dist-info}/WHEEL +0 -0
edsl/coop/coop.py
CHANGED
@@ -190,7 +190,7 @@ class Coop(CoopFunctionsMixin):
|
|
190
190
|
server_version_str=server_edsl_version,
|
191
191
|
):
|
192
192
|
print(
|
193
|
-
"Please upgrade your EDSL version to access our latest features.
|
193
|
+
"Please upgrade your EDSL version to access our latest features. Open your terminal and run `pip install --upgrade edsl`"
|
194
194
|
)
|
195
195
|
|
196
196
|
if response.status_code >= 400:
|
@@ -212,7 +212,7 @@ class Coop(CoopFunctionsMixin):
|
|
212
212
|
print("Your Expected Parrot API key is invalid.")
|
213
213
|
self._display_login_url(
|
214
214
|
edsl_auth_token=edsl_auth_token,
|
215
|
-
link_description="\n🔗 Use the link below to log in to
|
215
|
+
link_description="\n🔗 Use the link below to log in to your account and automatically update your API key.",
|
216
216
|
)
|
217
217
|
api_key = self._poll_for_api_key(edsl_auth_token)
|
218
218
|
|
@@ -870,7 +870,7 @@ class Coop(CoopFunctionsMixin):
|
|
870
870
|
def create_project(
|
871
871
|
self,
|
872
872
|
survey: Survey,
|
873
|
-
project_name: str,
|
873
|
+
project_name: str = "Project",
|
874
874
|
survey_description: Optional[str] = None,
|
875
875
|
survey_alias: Optional[str] = None,
|
876
876
|
survey_visibility: Optional[VisibilityType] = "unlisted",
|
@@ -895,7 +895,8 @@ class Coop(CoopFunctionsMixin):
|
|
895
895
|
return {
|
896
896
|
"name": response_json.get("project_name"),
|
897
897
|
"uuid": response_json.get("uuid"),
|
898
|
-
"
|
898
|
+
"admin_url": f"{self.url}/home/projects/{response_json.get('uuid')}",
|
899
|
+
"respondent_url": f"{self.url}/respond/{response_json.get('uuid')}",
|
899
900
|
}
|
900
901
|
|
901
902
|
################
|
@@ -1027,15 +1028,28 @@ class Coop(CoopFunctionsMixin):
|
|
1027
1028
|
- We need this function because URL detection with print() does not work alongside animations in VSCode.
|
1028
1029
|
"""
|
1029
1030
|
from rich import print as rich_print
|
1031
|
+
from rich.console import Console
|
1032
|
+
|
1033
|
+
console = Console()
|
1030
1034
|
|
1031
1035
|
url = f"{CONFIG.EXPECTED_PARROT_URL}/login?edsl_auth_token={edsl_auth_token}"
|
1032
1036
|
|
1033
|
-
if
|
1034
|
-
|
1035
|
-
|
1036
|
-
|
1037
|
+
if console.is_terminal:
|
1038
|
+
# Running in a standard terminal, show the full URL
|
1039
|
+
if link_description:
|
1040
|
+
rich_print("{link_description}\n[#38bdf8][link={url}]{url}[/link][/#38bdf8]")
|
1041
|
+
else:
|
1042
|
+
rich_print(f"[#38bdf8][link={url}]{url}[/link][/#38bdf8]")
|
1037
1043
|
else:
|
1038
|
-
|
1044
|
+
# Running in an interactive environment (e.g., Jupyter Notebook), hide the URL
|
1045
|
+
if link_description:
|
1046
|
+
rich_print(f"{link_description}\n[#38bdf8][link={url}][underline]Log in and automatically store key[/underline][/link][/#38bdf8]")
|
1047
|
+
else:
|
1048
|
+
rich_print(f"[#38bdf8][link={url}][underline]Log in and automatically store key[/underline][/link][/#38bdf8]")
|
1049
|
+
|
1050
|
+
|
1051
|
+
|
1052
|
+
|
1039
1053
|
|
1040
1054
|
def _get_api_key(self, edsl_auth_token: str):
|
1041
1055
|
"""
|
edsl/enums.py
CHANGED
@@ -67,7 +67,7 @@ class InferenceServiceType(EnumWithChecks):
|
|
67
67
|
TOGETHER = "together"
|
68
68
|
PERPLEXITY = "perplexity"
|
69
69
|
DEEPSEEK = "deepseek"
|
70
|
-
|
70
|
+
XAI = "xai"
|
71
71
|
|
72
72
|
|
73
73
|
# unavoidable violation of the DRY principle but it is necessary
|
@@ -87,7 +87,7 @@ InferenceServiceLiteral = Literal[
|
|
87
87
|
"together",
|
88
88
|
"perplexity",
|
89
89
|
"deepseek",
|
90
|
-
"
|
90
|
+
"xai",
|
91
91
|
]
|
92
92
|
|
93
93
|
available_models_urls = {
|
@@ -111,7 +111,7 @@ service_to_api_keyname = {
|
|
111
111
|
InferenceServiceType.TOGETHER.value: "TOGETHER_API_KEY",
|
112
112
|
InferenceServiceType.PERPLEXITY.value: "PERPLEXITY_API_KEY",
|
113
113
|
InferenceServiceType.DEEPSEEK.value: "DEEPSEEK_API_KEY",
|
114
|
-
InferenceServiceType.
|
114
|
+
InferenceServiceType.XAI.value: "XAI_API_KEY",
|
115
115
|
}
|
116
116
|
|
117
117
|
|
@@ -17,11 +17,10 @@ class AnthropicService(InferenceServiceABC):
|
|
17
17
|
output_token_name = "output_tokens"
|
18
18
|
model_exclude_list = []
|
19
19
|
|
20
|
-
available_models_url =
|
20
|
+
available_models_url = "https://docs.anthropic.com/en/docs/about-claude/models"
|
21
21
|
|
22
22
|
@classmethod
|
23
23
|
def get_model_list(cls, api_key: str = None):
|
24
|
-
|
25
24
|
import requests
|
26
25
|
|
27
26
|
if api_key is None:
|
@@ -94,13 +93,16 @@ class AnthropicService(InferenceServiceABC):
|
|
94
93
|
# breakpoint()
|
95
94
|
client = AsyncAnthropic(api_key=self.api_token)
|
96
95
|
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
96
|
+
try:
|
97
|
+
response = await client.messages.create(
|
98
|
+
model=model_name,
|
99
|
+
max_tokens=self.max_tokens,
|
100
|
+
temperature=self.temperature,
|
101
|
+
system=system_prompt, # note that the Anthropic API uses "system" parameter rather than put it in the message
|
102
|
+
messages=messages,
|
103
|
+
)
|
104
|
+
except Exception as e:
|
105
|
+
return {"message": str(e)}
|
104
106
|
return response.model_dump()
|
105
107
|
|
106
108
|
LLM.__name__ = model_class_name
|
@@ -69,6 +69,8 @@ class AvailableModelFetcher:
|
|
69
69
|
|
70
70
|
Returns a list of [model, service_name, index] entries.
|
71
71
|
"""
|
72
|
+
if service == "azure":
|
73
|
+
force_refresh = True # Azure models are listed inside the .env AZURE_ENDPOINT_URL_AND_KEY variable
|
72
74
|
|
73
75
|
if service: # they passed a specific service
|
74
76
|
matching_models, _ = self.get_available_models_by_service(
|
@@ -179,15 +179,18 @@ class AzureAIService(InferenceServiceABC):
|
|
179
179
|
api_version=api_version,
|
180
180
|
api_key=api_key,
|
181
181
|
)
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
182
|
+
try:
|
183
|
+
response = await client.chat.completions.create(
|
184
|
+
model=model_name,
|
185
|
+
messages=[
|
186
|
+
{
|
187
|
+
"role": "user",
|
188
|
+
"content": user_prompt, # Your question can go here
|
189
|
+
},
|
190
|
+
],
|
191
|
+
)
|
192
|
+
except Exception as e:
|
193
|
+
return {"message": str(e)}
|
191
194
|
return response.model_dump()
|
192
195
|
|
193
196
|
# @staticmethod
|
@@ -39,7 +39,9 @@ class GoogleService(InferenceServiceABC):
|
|
39
39
|
|
40
40
|
model_exclude_list = []
|
41
41
|
|
42
|
-
available_models_url =
|
42
|
+
available_models_url = (
|
43
|
+
"https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models"
|
44
|
+
)
|
43
45
|
|
44
46
|
@classmethod
|
45
47
|
def get_model_list(cls):
|
@@ -132,9 +134,12 @@ class GoogleService(InferenceServiceABC):
|
|
132
134
|
)
|
133
135
|
combined_prompt.append(gen_ai_file)
|
134
136
|
|
135
|
-
|
136
|
-
|
137
|
-
|
137
|
+
try:
|
138
|
+
response = await self.generative_model.generate_content_async(
|
139
|
+
combined_prompt, generation_config=generation_config
|
140
|
+
)
|
141
|
+
except Exception as e:
|
142
|
+
return {"message": str(e)}
|
138
143
|
return response.to_dict()
|
139
144
|
|
140
145
|
LLM.__name__ = model_name
|
@@ -104,8 +104,9 @@ class InferenceServicesCollection:
|
|
104
104
|
def available(
|
105
105
|
self,
|
106
106
|
service: Optional[str] = None,
|
107
|
+
force_refresh: bool = False,
|
107
108
|
) -> List[Tuple[str, str, int]]:
|
108
|
-
return self.availability_fetcher.available(service)
|
109
|
+
return self.availability_fetcher.available(service, force_refresh=force_refresh)
|
109
110
|
|
110
111
|
def reset_cache(self) -> None:
|
111
112
|
self.availability_fetcher.reset_cache()
|
@@ -120,7 +121,6 @@ class InferenceServicesCollection:
|
|
120
121
|
def create_model_factory(
|
121
122
|
self, model_name: str, service_name: Optional[InferenceServiceLiteral] = None
|
122
123
|
) -> "LanguageModel":
|
123
|
-
|
124
124
|
if service_name is None: # we try to find the right service
|
125
125
|
service = self.resolver.resolve_model(model_name, service_name)
|
126
126
|
else: # if they passed a service, we'll use that
|
@@ -111,8 +111,7 @@ class MistralAIService(InferenceServiceABC):
|
|
111
111
|
],
|
112
112
|
)
|
113
113
|
except Exception as e:
|
114
|
-
|
115
|
-
|
114
|
+
return {"message": str(e)}
|
116
115
|
return res.model_dump()
|
117
116
|
|
118
117
|
LLM.__name__ = model_class_name
|
@@ -207,8 +207,10 @@ class OpenAIService(InferenceServiceABC):
|
|
207
207
|
{"role": "user", "content": content},
|
208
208
|
]
|
209
209
|
if (
|
210
|
-
system_prompt == "" and self.omit_system_prompt_if_empty
|
211
|
-
|
210
|
+
(system_prompt == "" and self.omit_system_prompt_if_empty)
|
211
|
+
or "o1" in self.model
|
212
|
+
or "o3" in self.model
|
213
|
+
):
|
212
214
|
messages = messages[1:]
|
213
215
|
|
214
216
|
params = {
|
@@ -222,14 +224,17 @@ class OpenAIService(InferenceServiceABC):
|
|
222
224
|
"logprobs": self.logprobs,
|
223
225
|
"top_logprobs": self.top_logprobs if self.logprobs else None,
|
224
226
|
}
|
225
|
-
if "o1" in self.model:
|
227
|
+
if "o1" in self.model or "o3" in self.model:
|
226
228
|
params.pop("max_tokens")
|
227
229
|
params["max_completion_tokens"] = self.max_tokens
|
228
230
|
params["temperature"] = 1
|
229
231
|
try:
|
230
232
|
response = await client.chat.completions.create(**params)
|
231
233
|
except Exception as e:
|
232
|
-
|
234
|
+
#breakpoint()
|
235
|
+
#print(e)
|
236
|
+
#raise e
|
237
|
+
return {'message': str(e)}
|
233
238
|
return response.model_dump()
|
234
239
|
|
235
240
|
LLM.__name__ = "LanguageModel"
|
@@ -152,7 +152,8 @@ class PerplexityService(OpenAIService):
|
|
152
152
|
try:
|
153
153
|
response = await client.chat.completions.create(**params)
|
154
154
|
except Exception as e:
|
155
|
-
|
155
|
+
return {"message": str(e)}
|
156
|
+
|
156
157
|
return response.model_dump()
|
157
158
|
|
158
159
|
LLM.__name__ = "LanguageModel"
|
@@ -2,10 +2,10 @@ from typing import Any, List
|
|
2
2
|
from edsl.inference_services.OpenAIService import OpenAIService
|
3
3
|
|
4
4
|
|
5
|
-
class
|
5
|
+
class XAIService(OpenAIService):
|
6
6
|
"""Openai service class."""
|
7
7
|
|
8
|
-
_inference_service_ = "
|
8
|
+
_inference_service_ = "xai"
|
9
9
|
_env_key_name_ = "XAI_API_KEY"
|
10
10
|
_base_url_ = "https://api.x.ai/v1"
|
11
11
|
_models_list_cache: List[str] = []
|
@@ -14,7 +14,7 @@ from edsl.inference_services.TestService import TestService
|
|
14
14
|
from edsl.inference_services.TogetherAIService import TogetherAIService
|
15
15
|
from edsl.inference_services.PerplexityService import PerplexityService
|
16
16
|
from edsl.inference_services.DeepSeekService import DeepSeekService
|
17
|
-
from edsl.inference_services.
|
17
|
+
from edsl.inference_services.XAIService import XAIService
|
18
18
|
|
19
19
|
try:
|
20
20
|
from edsl.inference_services.MistralAIService import MistralAIService
|
@@ -36,7 +36,7 @@ services = [
|
|
36
36
|
TogetherAIService,
|
37
37
|
PerplexityService,
|
38
38
|
DeepSeekService,
|
39
|
-
|
39
|
+
XAIService,
|
40
40
|
]
|
41
41
|
|
42
42
|
if mistral_available:
|
edsl/jobs/Jobs.py
CHANGED
@@ -364,6 +364,15 @@ class Jobs(Base):
|
|
364
364
|
self, cache=self.run_config.environment.cache
|
365
365
|
).create_interviews()
|
366
366
|
|
367
|
+
def show_flow(self, filename: Optional[str] = None) -> None:
|
368
|
+
"""Show the flow of the survey."""
|
369
|
+
from edsl.surveys.SurveyFlowVisualization import SurveyFlowVisualization
|
370
|
+
if self.scenarios:
|
371
|
+
scenario = self.scenarios[0]
|
372
|
+
else:
|
373
|
+
scenario = None
|
374
|
+
SurveyFlowVisualization(self.survey, scenario=scenario, agent=None).show_flow(filename=filename)
|
375
|
+
|
367
376
|
def interviews(self) -> list[Interview]:
|
368
377
|
"""
|
369
378
|
Return a list of :class:`edsl.jobs.interviews.Interview` objects.
|
edsl/jobs/JobsChecks.py
CHANGED
@@ -24,7 +24,7 @@ class JobsChecks:
|
|
24
24
|
|
25
25
|
def get_missing_api_keys(self) -> set:
|
26
26
|
"""
|
27
|
-
Returns a list of the
|
27
|
+
Returns a list of the API keys that a user needs to run this job, but does not currently have in their .env file.
|
28
28
|
"""
|
29
29
|
missing_api_keys = set()
|
30
30
|
|
@@ -134,22 +134,20 @@ class JobsChecks:
|
|
134
134
|
|
135
135
|
edsl_auth_token = secrets.token_urlsafe(16)
|
136
136
|
|
137
|
-
print("
|
137
|
+
print("\nThe following keys are needed to run this survey: \n")
|
138
138
|
for api_key in missing_api_keys:
|
139
|
-
print(f"
|
139
|
+
print(f"🔑 {api_key}")
|
140
140
|
print(
|
141
|
-
"
|
141
|
+
"""
|
142
|
+
\nYou can provide your own keys for language models or use an Expected Parrot key to access all available models.
|
143
|
+
\nClick the link below to create an account and run your survey with your Expected Parrot key:
|
144
|
+
"""
|
142
145
|
)
|
143
|
-
|
144
|
-
|
146
|
+
|
145
147
|
coop = Coop()
|
146
148
|
coop._display_login_url(
|
147
149
|
edsl_auth_token=edsl_auth_token,
|
148
|
-
link_description="
|
149
|
-
)
|
150
|
-
|
151
|
-
print(
|
152
|
-
"\nOnce you log in, your key will be stored on your computer and your survey will start running at the Expected Parrot server."
|
150
|
+
# link_description="",
|
153
151
|
)
|
154
152
|
|
155
153
|
api_key = coop._poll_for_api_key(edsl_auth_token)
|
@@ -159,8 +157,7 @@ class JobsChecks:
|
|
159
157
|
return
|
160
158
|
|
161
159
|
path_to_env = write_api_key_to_env(api_key)
|
162
|
-
print("\n✨ Your key has been stored at the following path: ")
|
163
|
-
print(f" {path_to_env}")
|
160
|
+
print(f"\n✨ Your Expected Parrot key has been stored at the following path: {path_to_env}\n")
|
164
161
|
|
165
162
|
# Retrieve API key so we can continue running the job
|
166
163
|
load_dotenv()
|
@@ -7,6 +7,8 @@ from edsl.data_transfer_models import EDSLResultObjectInput
|
|
7
7
|
|
8
8
|
from edsl.results.Result import Result
|
9
9
|
from edsl.jobs.interviews.Interview import Interview
|
10
|
+
from edsl.config import Config
|
11
|
+
config = Config()
|
10
12
|
|
11
13
|
if TYPE_CHECKING:
|
12
14
|
from edsl.jobs.Jobs import Jobs
|
@@ -23,7 +25,7 @@ from edsl.jobs.data_structures import RunConfig
|
|
23
25
|
|
24
26
|
|
25
27
|
class AsyncInterviewRunner:
|
26
|
-
MAX_CONCURRENT =
|
28
|
+
MAX_CONCURRENT = int(config.EDSL_MAX_CONCURRENT_TASKS)
|
27
29
|
|
28
30
|
def __init__(self, jobs: "Jobs", run_config: RunConfig):
|
29
31
|
self.jobs = jobs
|
@@ -72,11 +72,11 @@ class CheckSurveyScenarioCompatibility:
|
|
72
72
|
if warn:
|
73
73
|
warnings.warn(message)
|
74
74
|
|
75
|
-
if self.scenarios.has_jinja_braces:
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
75
|
+
# if self.scenarios.has_jinja_braces:
|
76
|
+
# warnings.warn(
|
77
|
+
# "The scenarios have Jinja braces ({{ and }}). Converting to '<<' and '>>'. If you want a different conversion, use the convert_jinja_braces method first to modify the scenario."
|
78
|
+
# )
|
79
|
+
# self.scenarios = self.scenarios._convert_jinja_braces()
|
80
80
|
|
81
81
|
|
82
82
|
if __name__ == "__main__":
|
@@ -166,6 +166,9 @@ class InterviewExceptionEntry:
|
|
166
166
|
>>> entry = InterviewExceptionEntry.example()
|
167
167
|
>>> _ = entry.to_dict()
|
168
168
|
"""
|
169
|
+
import json
|
170
|
+
from edsl.exceptions.questions import QuestionAnswerValidationError
|
171
|
+
|
169
172
|
invigilator = (
|
170
173
|
self.invigilator.to_dict() if self.invigilator is not None else None
|
171
174
|
)
|
@@ -174,7 +177,16 @@ class InterviewExceptionEntry:
|
|
174
177
|
"time": self.time,
|
175
178
|
"traceback": self.traceback,
|
176
179
|
"invigilator": invigilator,
|
180
|
+
"additional_data": {},
|
177
181
|
}
|
182
|
+
|
183
|
+
if isinstance(self.exception, QuestionAnswerValidationError):
|
184
|
+
d["additional_data"]["edsl_response"] = json.dumps(self.exception.data)
|
185
|
+
d["additional_data"]["validating_model"] = json.dumps(
|
186
|
+
self.exception.model.model_json_schema()
|
187
|
+
)
|
188
|
+
d["additional_data"]["error_message"] = str(self.exception.message)
|
189
|
+
|
178
190
|
return d
|
179
191
|
|
180
192
|
@classmethod
|
edsl/jobs/tasks/TaskHistory.py
CHANGED
@@ -419,7 +419,7 @@ class TaskHistory(RepresentationMixin):
|
|
419
419
|
filename: Optional[str] = None,
|
420
420
|
return_link=False,
|
421
421
|
css=None,
|
422
|
-
cta="
|
422
|
+
cta="<br><span style='font-size: 18px; font-weight: medium-bold; text-decoration: underline;'>Click to open the report in a new tab</span><br><br>",
|
423
423
|
open_in_browser=False,
|
424
424
|
):
|
425
425
|
"""Return an HTML report."""
|
@@ -394,7 +394,6 @@ class LanguageModel(
|
|
394
394
|
from edsl.config import CONFIG
|
395
395
|
|
396
396
|
TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
|
397
|
-
|
398
397
|
response = await asyncio.wait_for(f(**params), timeout=TIMEOUT)
|
399
398
|
new_cache_key = cache.store(
|
400
399
|
**cache_call_params, response=response
|
@@ -518,8 +517,6 @@ class LanguageModel(
|
|
518
517
|
"""
|
519
518
|
from edsl.language_models.model import get_model_class
|
520
519
|
|
521
|
-
# breakpoint()
|
522
|
-
|
523
520
|
model_class = get_model_class(
|
524
521
|
data["model"], service_name=data.get("inference_service", None)
|
525
522
|
)
|
@@ -30,19 +30,22 @@ class PriceManager:
|
|
30
30
|
except Exception as e:
|
31
31
|
print(f"Error fetching prices: {str(e)}")
|
32
32
|
|
33
|
-
def get_price(self, inference_service: str, model: str) ->
|
33
|
+
def get_price(self, inference_service: str, model: str) -> Dict:
|
34
34
|
"""
|
35
35
|
Get the price information for a specific service and model combination.
|
36
|
+
If no specific price is found, returns a fallback price.
|
36
37
|
|
37
38
|
Args:
|
38
39
|
inference_service (str): The name of the inference service
|
39
40
|
model (str): The model identifier
|
40
41
|
|
41
42
|
Returns:
|
42
|
-
|
43
|
+
Dict: Price information (either actual or fallback prices)
|
43
44
|
"""
|
44
45
|
key = (inference_service, model)
|
45
|
-
return self._price_lookup.get(key)
|
46
|
+
return self._price_lookup.get(key) or self._get_fallback_price(
|
47
|
+
inference_service
|
48
|
+
)
|
46
49
|
|
47
50
|
def get_all_prices(self) -> Dict[Tuple[str, str], Dict]:
|
48
51
|
"""
|
@@ -53,6 +56,45 @@ class PriceManager:
|
|
53
56
|
"""
|
54
57
|
return self._price_lookup.copy()
|
55
58
|
|
59
|
+
def _get_fallback_price(self, inference_service: str) -> Dict:
|
60
|
+
"""
|
61
|
+
Get fallback prices for a service.
|
62
|
+
- First fallback: The highest input and output prices for that service from the price lookup.
|
63
|
+
- Second fallback: $1.00 per million tokens (for both input and output).
|
64
|
+
|
65
|
+
Args:
|
66
|
+
inference_service (str): The inference service name
|
67
|
+
|
68
|
+
Returns:
|
69
|
+
Dict: Price information
|
70
|
+
"""
|
71
|
+
service_prices = [
|
72
|
+
prices
|
73
|
+
for (service, _), prices in self._price_lookup.items()
|
74
|
+
if service == inference_service
|
75
|
+
]
|
76
|
+
|
77
|
+
input_tokens_per_usd = [
|
78
|
+
float(p["input"]["one_usd_buys"]) for p in service_prices if "input" in p
|
79
|
+
]
|
80
|
+
if input_tokens_per_usd:
|
81
|
+
min_input_tokens = min(input_tokens_per_usd)
|
82
|
+
else:
|
83
|
+
min_input_tokens = 1_000_000
|
84
|
+
|
85
|
+
output_tokens_per_usd = [
|
86
|
+
float(p["output"]["one_usd_buys"]) for p in service_prices if "output" in p
|
87
|
+
]
|
88
|
+
if output_tokens_per_usd:
|
89
|
+
min_output_tokens = min(output_tokens_per_usd)
|
90
|
+
else:
|
91
|
+
min_output_tokens = 1_000_000
|
92
|
+
|
93
|
+
return {
|
94
|
+
"input": {"one_usd_buys": min_input_tokens},
|
95
|
+
"output": {"one_usd_buys": min_output_tokens},
|
96
|
+
}
|
97
|
+
|
56
98
|
def calculate_cost(
|
57
99
|
self,
|
58
100
|
inference_service: str,
|
@@ -75,8 +117,6 @@ class PriceManager:
|
|
75
117
|
Union[float, str]: Total cost if calculation successful, error message string if not
|
76
118
|
"""
|
77
119
|
relevant_prices = self.get_price(inference_service, model)
|
78
|
-
if relevant_prices is None:
|
79
|
-
return f"Could not find price for model {model} in the price lookup."
|
80
120
|
|
81
121
|
# Extract token counts
|
82
122
|
try:
|