edsl 0.1.44__py3-none-any.whl → 0.1.45__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. edsl/__version__.py +1 -1
  2. edsl/agents/InvigilatorBase.py +3 -1
  3. edsl/agents/PromptConstructor.py +62 -34
  4. edsl/agents/QuestionInstructionPromptBuilder.py +111 -68
  5. edsl/agents/QuestionTemplateReplacementsBuilder.py +69 -16
  6. edsl/agents/question_option_processor.py +15 -6
  7. edsl/coop/CoopFunctionsMixin.py +3 -4
  8. edsl/coop/coop.py +23 -9
  9. edsl/enums.py +3 -3
  10. edsl/inference_services/AnthropicService.py +11 -9
  11. edsl/inference_services/AvailableModelFetcher.py +2 -0
  12. edsl/inference_services/AwsBedrock.py +1 -2
  13. edsl/inference_services/AzureAI.py +12 -9
  14. edsl/inference_services/GoogleService.py +9 -4
  15. edsl/inference_services/InferenceServicesCollection.py +2 -2
  16. edsl/inference_services/MistralAIService.py +1 -2
  17. edsl/inference_services/OpenAIService.py +9 -4
  18. edsl/inference_services/PerplexityService.py +2 -1
  19. edsl/inference_services/{GrokService.py → XAIService.py} +2 -2
  20. edsl/inference_services/registry.py +2 -2
  21. edsl/jobs/Jobs.py +9 -0
  22. edsl/jobs/JobsChecks.py +10 -13
  23. edsl/jobs/async_interview_runner.py +3 -1
  24. edsl/jobs/check_survey_scenario_compatibility.py +5 -5
  25. edsl/jobs/interviews/InterviewExceptionEntry.py +12 -0
  26. edsl/jobs/tasks/TaskHistory.py +1 -1
  27. edsl/language_models/LanguageModel.py +0 -3
  28. edsl/language_models/PriceManager.py +45 -5
  29. edsl/language_models/model.py +47 -26
  30. edsl/questions/QuestionBase.py +21 -0
  31. edsl/questions/QuestionBasePromptsMixin.py +103 -0
  32. edsl/questions/QuestionFreeText.py +22 -5
  33. edsl/questions/descriptors.py +4 -0
  34. edsl/questions/question_base_gen_mixin.py +94 -29
  35. edsl/results/Dataset.py +65 -0
  36. edsl/results/DatasetExportMixin.py +299 -32
  37. edsl/results/Result.py +27 -0
  38. edsl/results/Results.py +22 -2
  39. edsl/results/ResultsGGMixin.py +7 -3
  40. edsl/scenarios/DocumentChunker.py +2 -0
  41. edsl/scenarios/FileStore.py +10 -0
  42. edsl/scenarios/PdfExtractor.py +21 -1
  43. edsl/scenarios/Scenario.py +25 -9
  44. edsl/scenarios/ScenarioList.py +73 -3
  45. edsl/scenarios/handlers/__init__.py +1 -0
  46. edsl/scenarios/handlers/docx.py +5 -1
  47. edsl/scenarios/handlers/jpeg.py +39 -0
  48. edsl/surveys/Survey.py +5 -4
  49. edsl/surveys/SurveyFlowVisualization.py +91 -43
  50. edsl/templates/error_reporting/exceptions_table.html +7 -8
  51. edsl/templates/error_reporting/interview_details.html +1 -1
  52. edsl/templates/error_reporting/interviews.html +0 -1
  53. edsl/templates/error_reporting/overview.html +2 -7
  54. edsl/templates/error_reporting/performance_plot.html +1 -1
  55. edsl/templates/error_reporting/report.css +1 -1
  56. edsl/utilities/PrettyList.py +14 -0
  57. edsl-0.1.45.dist-info/METADATA +246 -0
  58. {edsl-0.1.44.dist-info → edsl-0.1.45.dist-info}/RECORD +60 -59
  59. edsl-0.1.44.dist-info/METADATA +0 -110
  60. {edsl-0.1.44.dist-info → edsl-0.1.45.dist-info}/LICENSE +0 -0
  61. {edsl-0.1.44.dist-info → edsl-0.1.45.dist-info}/WHEEL +0 -0
edsl/coop/coop.py CHANGED
@@ -190,7 +190,7 @@ class Coop(CoopFunctionsMixin):
190
190
  server_version_str=server_edsl_version,
191
191
  ):
192
192
  print(
193
- "Please upgrade your EDSL version to access our latest features. To upgrade, open your terminal and run `pip install --upgrade edsl`"
193
+ "Please upgrade your EDSL version to access our latest features. Open your terminal and run `pip install --upgrade edsl`"
194
194
  )
195
195
 
196
196
  if response.status_code >= 400:
@@ -212,7 +212,7 @@ class Coop(CoopFunctionsMixin):
212
212
  print("Your Expected Parrot API key is invalid.")
213
213
  self._display_login_url(
214
214
  edsl_auth_token=edsl_auth_token,
215
- link_description="\n🔗 Use the link below to log in to Expected Parrot so we can automatically update your API key.",
215
+ link_description="\n🔗 Use the link below to log in to your account and automatically update your API key.",
216
216
  )
217
217
  api_key = self._poll_for_api_key(edsl_auth_token)
218
218
 
@@ -870,7 +870,7 @@ class Coop(CoopFunctionsMixin):
870
870
  def create_project(
871
871
  self,
872
872
  survey: Survey,
873
- project_name: str,
873
+ project_name: str = "Project",
874
874
  survey_description: Optional[str] = None,
875
875
  survey_alias: Optional[str] = None,
876
876
  survey_visibility: Optional[VisibilityType] = "unlisted",
@@ -895,7 +895,8 @@ class Coop(CoopFunctionsMixin):
895
895
  return {
896
896
  "name": response_json.get("project_name"),
897
897
  "uuid": response_json.get("uuid"),
898
- "url": f"{self.url}/home/projects/{response_json.get('uuid')}",
898
+ "admin_url": f"{self.url}/home/projects/{response_json.get('uuid')}",
899
+ "respondent_url": f"{self.url}/respond/{response_json.get('uuid')}",
899
900
  }
900
901
 
901
902
  ################
@@ -1027,15 +1028,28 @@ class Coop(CoopFunctionsMixin):
1027
1028
  - We need this function because URL detection with print() does not work alongside animations in VSCode.
1028
1029
  """
1029
1030
  from rich import print as rich_print
1031
+ from rich.console import Console
1032
+
1033
+ console = Console()
1030
1034
 
1031
1035
  url = f"{CONFIG.EXPECTED_PARROT_URL}/login?edsl_auth_token={edsl_auth_token}"
1032
1036
 
1033
- if link_description:
1034
- rich_print(
1035
- f"{link_description}\n [#38bdf8][link={url}]{url}[/link][/#38bdf8]"
1036
- )
1037
+ if console.is_terminal:
1038
+ # Running in a standard terminal, show the full URL
1039
+ if link_description:
1040
+ rich_print("{link_description}\n[#38bdf8][link={url}]{url}[/link][/#38bdf8]")
1041
+ else:
1042
+ rich_print(f"[#38bdf8][link={url}]{url}[/link][/#38bdf8]")
1037
1043
  else:
1038
- rich_print(f" [#38bdf8][link={url}]{url}[/link][/#38bdf8]")
1044
+ # Running in an interactive environment (e.g., Jupyter Notebook), hide the URL
1045
+ if link_description:
1046
+ rich_print(f"{link_description}\n[#38bdf8][link={url}][underline]Log in and automatically store key[/underline][/link][/#38bdf8]")
1047
+ else:
1048
+ rich_print(f"[#38bdf8][link={url}][underline]Log in and automatically store key[/underline][/link][/#38bdf8]")
1049
+
1050
+
1051
+
1052
+
1039
1053
 
1040
1054
  def _get_api_key(self, edsl_auth_token: str):
1041
1055
  """
edsl/enums.py CHANGED
@@ -67,7 +67,7 @@ class InferenceServiceType(EnumWithChecks):
67
67
  TOGETHER = "together"
68
68
  PERPLEXITY = "perplexity"
69
69
  DEEPSEEK = "deepseek"
70
- GROK = "grok"
70
+ XAI = "xai"
71
71
 
72
72
 
73
73
  # unavoidable violation of the DRY principle but it is necessary
@@ -87,7 +87,7 @@ InferenceServiceLiteral = Literal[
87
87
  "together",
88
88
  "perplexity",
89
89
  "deepseek",
90
- "grok",
90
+ "xai",
91
91
  ]
92
92
 
93
93
  available_models_urls = {
@@ -111,7 +111,7 @@ service_to_api_keyname = {
111
111
  InferenceServiceType.TOGETHER.value: "TOGETHER_API_KEY",
112
112
  InferenceServiceType.PERPLEXITY.value: "PERPLEXITY_API_KEY",
113
113
  InferenceServiceType.DEEPSEEK.value: "DEEPSEEK_API_KEY",
114
- InferenceServiceType.GROK.value: "XAI_API_KEY",
114
+ InferenceServiceType.XAI.value: "XAI_API_KEY",
115
115
  }
116
116
 
117
117
 
@@ -17,11 +17,10 @@ class AnthropicService(InferenceServiceABC):
17
17
  output_token_name = "output_tokens"
18
18
  model_exclude_list = []
19
19
 
20
- available_models_url = 'https://docs.anthropic.com/en/docs/about-claude/models'
20
+ available_models_url = "https://docs.anthropic.com/en/docs/about-claude/models"
21
21
 
22
22
  @classmethod
23
23
  def get_model_list(cls, api_key: str = None):
24
-
25
24
  import requests
26
25
 
27
26
  if api_key is None:
@@ -94,13 +93,16 @@ class AnthropicService(InferenceServiceABC):
94
93
  # breakpoint()
95
94
  client = AsyncAnthropic(api_key=self.api_token)
96
95
 
97
- response = await client.messages.create(
98
- model=model_name,
99
- max_tokens=self.max_tokens,
100
- temperature=self.temperature,
101
- system=system_prompt, # note that the Anthropic API uses "system" parameter rather than put it in the message
102
- messages=messages,
103
- )
96
+ try:
97
+ response = await client.messages.create(
98
+ model=model_name,
99
+ max_tokens=self.max_tokens,
100
+ temperature=self.temperature,
101
+ system=system_prompt, # note that the Anthropic API uses "system" parameter rather than put it in the message
102
+ messages=messages,
103
+ )
104
+ except Exception as e:
105
+ return {"message": str(e)}
104
106
  return response.model_dump()
105
107
 
106
108
  LLM.__name__ = model_class_name
@@ -69,6 +69,8 @@ class AvailableModelFetcher:
69
69
 
70
70
  Returns a list of [model, service_name, index] entries.
71
71
  """
72
+ if service == "azure":
73
+ force_refresh = True # Azure models are listed inside the .env AZURE_ENDPOINT_URL_AND_KEY variable
72
74
 
73
75
  if service: # they passed a specific service
74
76
  matching_models, _ = self.get_available_models_by_service(
@@ -110,8 +110,7 @@ class AwsBedrockService(InferenceServiceABC):
110
110
  )
111
111
  return response
112
112
  except (ClientError, Exception) as e:
113
- print(e)
114
- return {"error": str(e)}
113
+ return {"message": str(e)}
115
114
 
116
115
  LLM.__name__ = model_class_name
117
116
 
@@ -179,15 +179,18 @@ class AzureAIService(InferenceServiceABC):
179
179
  api_version=api_version,
180
180
  api_key=api_key,
181
181
  )
182
- response = await client.chat.completions.create(
183
- model=model_name,
184
- messages=[
185
- {
186
- "role": "user",
187
- "content": user_prompt, # Your question can go here
188
- },
189
- ],
190
- )
182
+ try:
183
+ response = await client.chat.completions.create(
184
+ model=model_name,
185
+ messages=[
186
+ {
187
+ "role": "user",
188
+ "content": user_prompt, # Your question can go here
189
+ },
190
+ ],
191
+ )
192
+ except Exception as e:
193
+ return {"message": str(e)}
191
194
  return response.model_dump()
192
195
 
193
196
  # @staticmethod
@@ -39,7 +39,9 @@ class GoogleService(InferenceServiceABC):
39
39
 
40
40
  model_exclude_list = []
41
41
 
42
- available_models_url = 'https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models'
42
+ available_models_url = (
43
+ "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models"
44
+ )
43
45
 
44
46
  @classmethod
45
47
  def get_model_list(cls):
@@ -132,9 +134,12 @@ class GoogleService(InferenceServiceABC):
132
134
  )
133
135
  combined_prompt.append(gen_ai_file)
134
136
 
135
- response = await self.generative_model.generate_content_async(
136
- combined_prompt, generation_config=generation_config
137
- )
137
+ try:
138
+ response = await self.generative_model.generate_content_async(
139
+ combined_prompt, generation_config=generation_config
140
+ )
141
+ except Exception as e:
142
+ return {"message": str(e)}
138
143
  return response.to_dict()
139
144
 
140
145
  LLM.__name__ = model_name
@@ -104,8 +104,9 @@ class InferenceServicesCollection:
104
104
  def available(
105
105
  self,
106
106
  service: Optional[str] = None,
107
+ force_refresh: bool = False,
107
108
  ) -> List[Tuple[str, str, int]]:
108
- return self.availability_fetcher.available(service)
109
+ return self.availability_fetcher.available(service, force_refresh=force_refresh)
109
110
 
110
111
  def reset_cache(self) -> None:
111
112
  self.availability_fetcher.reset_cache()
@@ -120,7 +121,6 @@ class InferenceServicesCollection:
120
121
  def create_model_factory(
121
122
  self, model_name: str, service_name: Optional[InferenceServiceLiteral] = None
122
123
  ) -> "LanguageModel":
123
-
124
124
  if service_name is None: # we try to find the right service
125
125
  service = self.resolver.resolve_model(model_name, service_name)
126
126
  else: # if they passed a service, we'll use that
@@ -111,8 +111,7 @@ class MistralAIService(InferenceServiceABC):
111
111
  ],
112
112
  )
113
113
  except Exception as e:
114
- raise LanguageModelBadResponseError(f"Error with Mistral API: {e}")
115
-
114
+ return {"message": str(e)}
116
115
  return res.model_dump()
117
116
 
118
117
  LLM.__name__ = model_class_name
@@ -207,8 +207,10 @@ class OpenAIService(InferenceServiceABC):
207
207
  {"role": "user", "content": content},
208
208
  ]
209
209
  if (
210
- system_prompt == "" and self.omit_system_prompt_if_empty
211
- ) or "o1" in self.model:
210
+ (system_prompt == "" and self.omit_system_prompt_if_empty)
211
+ or "o1" in self.model
212
+ or "o3" in self.model
213
+ ):
212
214
  messages = messages[1:]
213
215
 
214
216
  params = {
@@ -222,14 +224,17 @@ class OpenAIService(InferenceServiceABC):
222
224
  "logprobs": self.logprobs,
223
225
  "top_logprobs": self.top_logprobs if self.logprobs else None,
224
226
  }
225
- if "o1" in self.model:
227
+ if "o1" in self.model or "o3" in self.model:
226
228
  params.pop("max_tokens")
227
229
  params["max_completion_tokens"] = self.max_tokens
228
230
  params["temperature"] = 1
229
231
  try:
230
232
  response = await client.chat.completions.create(**params)
231
233
  except Exception as e:
232
- print(e)
234
+ #breakpoint()
235
+ #print(e)
236
+ #raise e
237
+ return {'message': str(e)}
233
238
  return response.model_dump()
234
239
 
235
240
  LLM.__name__ = "LanguageModel"
@@ -152,7 +152,8 @@ class PerplexityService(OpenAIService):
152
152
  try:
153
153
  response = await client.chat.completions.create(**params)
154
154
  except Exception as e:
155
- print(e, flush=True)
155
+ return {"message": str(e)}
156
+
156
157
  return response.model_dump()
157
158
 
158
159
  LLM.__name__ = "LanguageModel"
@@ -2,10 +2,10 @@ from typing import Any, List
2
2
  from edsl.inference_services.OpenAIService import OpenAIService
3
3
 
4
4
 
5
- class GrokService(OpenAIService):
5
+ class XAIService(OpenAIService):
6
6
  """Openai service class."""
7
7
 
8
- _inference_service_ = "grok"
8
+ _inference_service_ = "xai"
9
9
  _env_key_name_ = "XAI_API_KEY"
10
10
  _base_url_ = "https://api.x.ai/v1"
11
11
  _models_list_cache: List[str] = []
@@ -14,7 +14,7 @@ from edsl.inference_services.TestService import TestService
14
14
  from edsl.inference_services.TogetherAIService import TogetherAIService
15
15
  from edsl.inference_services.PerplexityService import PerplexityService
16
16
  from edsl.inference_services.DeepSeekService import DeepSeekService
17
- from edsl.inference_services.GrokService import GrokService
17
+ from edsl.inference_services.XAIService import XAIService
18
18
 
19
19
  try:
20
20
  from edsl.inference_services.MistralAIService import MistralAIService
@@ -36,7 +36,7 @@ services = [
36
36
  TogetherAIService,
37
37
  PerplexityService,
38
38
  DeepSeekService,
39
- GrokService,
39
+ XAIService,
40
40
  ]
41
41
 
42
42
  if mistral_available:
edsl/jobs/Jobs.py CHANGED
@@ -364,6 +364,15 @@ class Jobs(Base):
364
364
  self, cache=self.run_config.environment.cache
365
365
  ).create_interviews()
366
366
 
367
+ def show_flow(self, filename: Optional[str] = None) -> None:
368
+ """Show the flow of the survey."""
369
+ from edsl.surveys.SurveyFlowVisualization import SurveyFlowVisualization
370
+ if self.scenarios:
371
+ scenario = self.scenarios[0]
372
+ else:
373
+ scenario = None
374
+ SurveyFlowVisualization(self.survey, scenario=scenario, agent=None).show_flow(filename=filename)
375
+
367
376
  def interviews(self) -> list[Interview]:
368
377
  """
369
378
  Return a list of :class:`edsl.jobs.interviews.Interview` objects.
edsl/jobs/JobsChecks.py CHANGED
@@ -24,7 +24,7 @@ class JobsChecks:
24
24
 
25
25
  def get_missing_api_keys(self) -> set:
26
26
  """
27
- Returns a list of the api keys that a user needs to run this job, but does not currently have in their .env file.
27
+ Returns a list of the API keys that a user needs to run this job, but does not currently have in their .env file.
28
28
  """
29
29
  missing_api_keys = set()
30
30
 
@@ -134,22 +134,20 @@ class JobsChecks:
134
134
 
135
135
  edsl_auth_token = secrets.token_urlsafe(16)
136
136
 
137
- print("API keys are required to run surveys with language models. The following keys are needed to run this survey: ")
137
+ print("\nThe following keys are needed to run this survey: \n")
138
138
  for api_key in missing_api_keys:
139
- print(f" 🔑 {api_key}")
139
+ print(f"🔑 {api_key}")
140
140
  print(
141
- "\nYou can provide your own keys or use an Expected Parrot key to access all available models."
141
+ """
142
+ \nYou can provide your own keys for language models or use an Expected Parrot key to access all available models.
143
+ \nClick the link below to create an account and run your survey with your Expected Parrot key:
144
+ """
142
145
  )
143
- print("Please see the documentation page to learn about options for managing keys: https://docs.expectedparrot.com/en/latest/api_keys.html")
144
-
146
+
145
147
  coop = Coop()
146
148
  coop._display_login_url(
147
149
  edsl_auth_token=edsl_auth_token,
148
- link_description="\n➡️ Click the link below to create an account and get an Expected Parrot key:\n",
149
- )
150
-
151
- print(
152
- "\nOnce you log in, your key will be stored on your computer and your survey will start running at the Expected Parrot server."
150
+ # link_description="",
153
151
  )
154
152
 
155
153
  api_key = coop._poll_for_api_key(edsl_auth_token)
@@ -159,8 +157,7 @@ class JobsChecks:
159
157
  return
160
158
 
161
159
  path_to_env = write_api_key_to_env(api_key)
162
- print("\n✨ Your key has been stored at the following path: ")
163
- print(f" {path_to_env}")
160
+ print(f"\n✨ Your Expected Parrot key has been stored at the following path: {path_to_env}\n")
164
161
 
165
162
  # Retrieve API key so we can continue running the job
166
163
  load_dotenv()
@@ -7,6 +7,8 @@ from edsl.data_transfer_models import EDSLResultObjectInput
7
7
 
8
8
  from edsl.results.Result import Result
9
9
  from edsl.jobs.interviews.Interview import Interview
10
+ from edsl.config import Config
11
+ config = Config()
10
12
 
11
13
  if TYPE_CHECKING:
12
14
  from edsl.jobs.Jobs import Jobs
@@ -23,7 +25,7 @@ from edsl.jobs.data_structures import RunConfig
23
25
 
24
26
 
25
27
  class AsyncInterviewRunner:
26
- MAX_CONCURRENT = 5
28
+ MAX_CONCURRENT = int(config.EDSL_MAX_CONCURRENT_TASKS)
27
29
 
28
30
  def __init__(self, jobs: "Jobs", run_config: RunConfig):
29
31
  self.jobs = jobs
@@ -72,11 +72,11 @@ class CheckSurveyScenarioCompatibility:
72
72
  if warn:
73
73
  warnings.warn(message)
74
74
 
75
- if self.scenarios.has_jinja_braces:
76
- warnings.warn(
77
- "The scenarios have Jinja braces ({{ and }}). Converting to '<<' and '>>'. If you want a different conversion, use the convert_jinja_braces method first to modify the scenario."
78
- )
79
- self.scenarios = self.scenarios._convert_jinja_braces()
75
+ # if self.scenarios.has_jinja_braces:
76
+ # warnings.warn(
77
+ # "The scenarios have Jinja braces ({{ and }}). Converting to '<<' and '>>'. If you want a different conversion, use the convert_jinja_braces method first to modify the scenario."
78
+ # )
79
+ # self.scenarios = self.scenarios._convert_jinja_braces()
80
80
 
81
81
 
82
82
  if __name__ == "__main__":
@@ -166,6 +166,9 @@ class InterviewExceptionEntry:
166
166
  >>> entry = InterviewExceptionEntry.example()
167
167
  >>> _ = entry.to_dict()
168
168
  """
169
+ import json
170
+ from edsl.exceptions.questions import QuestionAnswerValidationError
171
+
169
172
  invigilator = (
170
173
  self.invigilator.to_dict() if self.invigilator is not None else None
171
174
  )
@@ -174,7 +177,16 @@ class InterviewExceptionEntry:
174
177
  "time": self.time,
175
178
  "traceback": self.traceback,
176
179
  "invigilator": invigilator,
180
+ "additional_data": {},
177
181
  }
182
+
183
+ if isinstance(self.exception, QuestionAnswerValidationError):
184
+ d["additional_data"]["edsl_response"] = json.dumps(self.exception.data)
185
+ d["additional_data"]["validating_model"] = json.dumps(
186
+ self.exception.model.model_json_schema()
187
+ )
188
+ d["additional_data"]["error_message"] = str(self.exception.message)
189
+
178
190
  return d
179
191
 
180
192
  @classmethod
@@ -419,7 +419,7 @@ class TaskHistory(RepresentationMixin):
419
419
  filename: Optional[str] = None,
420
420
  return_link=False,
421
421
  css=None,
422
- cta="\nClick to open the report in a new tab\n",
422
+ cta="<br><span style='font-size: 18px; font-weight: medium-bold; text-decoration: underline;'>Click to open the report in a new tab</span><br><br>",
423
423
  open_in_browser=False,
424
424
  ):
425
425
  """Return an HTML report."""
@@ -394,7 +394,6 @@ class LanguageModel(
394
394
  from edsl.config import CONFIG
395
395
 
396
396
  TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
397
-
398
397
  response = await asyncio.wait_for(f(**params), timeout=TIMEOUT)
399
398
  new_cache_key = cache.store(
400
399
  **cache_call_params, response=response
@@ -518,8 +517,6 @@ class LanguageModel(
518
517
  """
519
518
  from edsl.language_models.model import get_model_class
520
519
 
521
- # breakpoint()
522
-
523
520
  model_class = get_model_class(
524
521
  data["model"], service_name=data.get("inference_service", None)
525
522
  )
@@ -30,19 +30,22 @@ class PriceManager:
30
30
  except Exception as e:
31
31
  print(f"Error fetching prices: {str(e)}")
32
32
 
33
- def get_price(self, inference_service: str, model: str) -> Optional[Dict]:
33
+ def get_price(self, inference_service: str, model: str) -> Dict:
34
34
  """
35
35
  Get the price information for a specific service and model combination.
36
+ If no specific price is found, returns a fallback price.
36
37
 
37
38
  Args:
38
39
  inference_service (str): The name of the inference service
39
40
  model (str): The model identifier
40
41
 
41
42
  Returns:
42
- Optional[Dict]: Price information if found, None otherwise
43
+ Dict: Price information (either actual or fallback prices)
43
44
  """
44
45
  key = (inference_service, model)
45
- return self._price_lookup.get(key)
46
+ return self._price_lookup.get(key) or self._get_fallback_price(
47
+ inference_service
48
+ )
46
49
 
47
50
  def get_all_prices(self) -> Dict[Tuple[str, str], Dict]:
48
51
  """
@@ -53,6 +56,45 @@ class PriceManager:
53
56
  """
54
57
  return self._price_lookup.copy()
55
58
 
59
+ def _get_fallback_price(self, inference_service: str) -> Dict:
60
+ """
61
+ Get fallback prices for a service.
62
+ - First fallback: The highest input and output prices for that service from the price lookup.
63
+ - Second fallback: $1.00 per million tokens (for both input and output).
64
+
65
+ Args:
66
+ inference_service (str): The inference service name
67
+
68
+ Returns:
69
+ Dict: Price information
70
+ """
71
+ service_prices = [
72
+ prices
73
+ for (service, _), prices in self._price_lookup.items()
74
+ if service == inference_service
75
+ ]
76
+
77
+ input_tokens_per_usd = [
78
+ float(p["input"]["one_usd_buys"]) for p in service_prices if "input" in p
79
+ ]
80
+ if input_tokens_per_usd:
81
+ min_input_tokens = min(input_tokens_per_usd)
82
+ else:
83
+ min_input_tokens = 1_000_000
84
+
85
+ output_tokens_per_usd = [
86
+ float(p["output"]["one_usd_buys"]) for p in service_prices if "output" in p
87
+ ]
88
+ if output_tokens_per_usd:
89
+ min_output_tokens = min(output_tokens_per_usd)
90
+ else:
91
+ min_output_tokens = 1_000_000
92
+
93
+ return {
94
+ "input": {"one_usd_buys": min_input_tokens},
95
+ "output": {"one_usd_buys": min_output_tokens},
96
+ }
97
+
56
98
  def calculate_cost(
57
99
  self,
58
100
  inference_service: str,
@@ -75,8 +117,6 @@ class PriceManager:
75
117
  Union[float, str]: Total cost if calculation successful, error message string if not
76
118
  """
77
119
  relevant_prices = self.get_price(inference_service, model)
78
- if relevant_prices is None:
79
- return f"Could not find price for model {model} in the price lookup."
80
120
 
81
121
  # Extract token counts
82
122
  try: