edsl 0.1.43__py3-none-any.whl → 0.1.45__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. edsl/Base.py +15 -6
  2. edsl/__version__.py +1 -1
  3. edsl/agents/InvigilatorBase.py +3 -1
  4. edsl/agents/PromptConstructor.py +62 -34
  5. edsl/agents/QuestionInstructionPromptBuilder.py +111 -68
  6. edsl/agents/QuestionTemplateReplacementsBuilder.py +69 -16
  7. edsl/agents/question_option_processor.py +15 -6
  8. edsl/coop/CoopFunctionsMixin.py +3 -4
  9. edsl/coop/coop.py +56 -10
  10. edsl/enums.py +4 -1
  11. edsl/inference_services/AnthropicService.py +12 -8
  12. edsl/inference_services/AvailableModelFetcher.py +2 -0
  13. edsl/inference_services/AwsBedrock.py +1 -2
  14. edsl/inference_services/AzureAI.py +12 -9
  15. edsl/inference_services/GoogleService.py +10 -3
  16. edsl/inference_services/InferenceServiceABC.py +1 -0
  17. edsl/inference_services/InferenceServicesCollection.py +2 -2
  18. edsl/inference_services/MistralAIService.py +1 -2
  19. edsl/inference_services/OpenAIService.py +10 -4
  20. edsl/inference_services/PerplexityService.py +2 -1
  21. edsl/inference_services/TestService.py +1 -0
  22. edsl/inference_services/XAIService.py +11 -0
  23. edsl/inference_services/registry.py +2 -0
  24. edsl/jobs/Jobs.py +9 -0
  25. edsl/jobs/JobsChecks.py +11 -14
  26. edsl/jobs/JobsPrompts.py +3 -3
  27. edsl/jobs/async_interview_runner.py +3 -1
  28. edsl/jobs/check_survey_scenario_compatibility.py +5 -5
  29. edsl/jobs/interviews/InterviewExceptionEntry.py +12 -0
  30. edsl/jobs/tasks/TaskHistory.py +1 -1
  31. edsl/language_models/LanguageModel.py +3 -3
  32. edsl/language_models/PriceManager.py +45 -5
  33. edsl/language_models/model.py +89 -36
  34. edsl/questions/QuestionBase.py +21 -0
  35. edsl/questions/QuestionBasePromptsMixin.py +103 -0
  36. edsl/questions/QuestionFreeText.py +22 -5
  37. edsl/questions/descriptors.py +4 -0
  38. edsl/questions/question_base_gen_mixin.py +94 -29
  39. edsl/results/Dataset.py +65 -0
  40. edsl/results/DatasetExportMixin.py +299 -32
  41. edsl/results/Result.py +27 -0
  42. edsl/results/Results.py +24 -3
  43. edsl/results/ResultsGGMixin.py +7 -3
  44. edsl/scenarios/DocumentChunker.py +2 -0
  45. edsl/scenarios/FileStore.py +29 -8
  46. edsl/scenarios/PdfExtractor.py +21 -1
  47. edsl/scenarios/Scenario.py +25 -9
  48. edsl/scenarios/ScenarioList.py +73 -3
  49. edsl/scenarios/handlers/__init__.py +1 -0
  50. edsl/scenarios/handlers/docx.py +5 -1
  51. edsl/scenarios/handlers/jpeg.py +39 -0
  52. edsl/surveys/Survey.py +28 -6
  53. edsl/surveys/SurveyFlowVisualization.py +91 -43
  54. edsl/templates/error_reporting/exceptions_table.html +7 -8
  55. edsl/templates/error_reporting/interview_details.html +1 -1
  56. edsl/templates/error_reporting/interviews.html +0 -1
  57. edsl/templates/error_reporting/overview.html +2 -7
  58. edsl/templates/error_reporting/performance_plot.html +1 -1
  59. edsl/templates/error_reporting/report.css +1 -1
  60. edsl/utilities/PrettyList.py +14 -0
  61. edsl-0.1.45.dist-info/METADATA +246 -0
  62. {edsl-0.1.43.dist-info → edsl-0.1.45.dist-info}/RECORD +64 -62
  63. edsl-0.1.43.dist-info/METADATA +0 -110
  64. {edsl-0.1.43.dist-info → edsl-0.1.45.dist-info}/LICENSE +0 -0
  65. {edsl-0.1.43.dist-info → edsl-0.1.45.dist-info}/WHEEL +0 -0
edsl/coop/coop.py CHANGED
@@ -4,10 +4,8 @@ import requests
4
4
 
5
5
  from typing import Any, Optional, Union, Literal, TypedDict
6
6
  from uuid import UUID
7
- from collections import UserDict, defaultdict
8
7
 
9
8
  import edsl
10
- from pathlib import Path
11
9
 
12
10
  from edsl.config import CONFIG
13
11
  from edsl.data.CacheEntry import CacheEntry
@@ -192,7 +190,7 @@ class Coop(CoopFunctionsMixin):
192
190
  server_version_str=server_edsl_version,
193
191
  ):
194
192
  print(
195
- "Please upgrade your EDSL version to access our latest features. To upgrade, open your terminal and run `pip install --upgrade edsl`"
193
+ "Please upgrade your EDSL version to access our latest features. Open your terminal and run `pip install --upgrade edsl`"
196
194
  )
197
195
 
198
196
  if response.status_code >= 400:
@@ -214,7 +212,7 @@ class Coop(CoopFunctionsMixin):
214
212
  print("Your Expected Parrot API key is invalid.")
215
213
  self._display_login_url(
216
214
  edsl_auth_token=edsl_auth_token,
217
- link_description="\n🔗 Use the link below to log in to Expected Parrot so we can automatically update your API key.",
215
+ link_description="\n🔗 Use the link below to log in to your account and automatically update your API key.",
218
216
  )
219
217
  api_key = self._poll_for_api_key(edsl_auth_token)
220
218
 
@@ -340,7 +338,7 @@ class Coop(CoopFunctionsMixin):
340
338
 
341
339
  try:
342
340
  response = self._send_server_request(
343
- uri="api/v0/edsl-settings", method="GET", timeout=5
341
+ uri="api/v0/edsl-settings", method="GET", timeout=20
344
342
  )
345
343
  self._resolve_server_response(response, check_api_key=False)
346
344
  return response.json()
@@ -866,6 +864,41 @@ class Coop(CoopFunctionsMixin):
866
864
  "usd": response_json.get("cost_in_usd"),
867
865
  }
868
866
 
867
+ ################
868
+ # PROJECTS
869
+ ################
870
+ def create_project(
871
+ self,
872
+ survey: Survey,
873
+ project_name: str = "Project",
874
+ survey_description: Optional[str] = None,
875
+ survey_alias: Optional[str] = None,
876
+ survey_visibility: Optional[VisibilityType] = "unlisted",
877
+ ):
878
+ """
879
+ Create a survey object on Coop, then create a project from the survey.
880
+ """
881
+ survey_details = self.create(
882
+ object=survey,
883
+ description=survey_description,
884
+ alias=survey_alias,
885
+ visibility=survey_visibility,
886
+ )
887
+ survey_uuid = survey_details.get("uuid")
888
+ response = self._send_server_request(
889
+ uri=f"api/v0/projects/create-from-survey",
890
+ method="POST",
891
+ payload={"project_name": project_name, "survey_uuid": str(survey_uuid)},
892
+ )
893
+ self._resolve_server_response(response)
894
+ response_json = response.json()
895
+ return {
896
+ "name": response_json.get("project_name"),
897
+ "uuid": response_json.get("uuid"),
898
+ "admin_url": f"{self.url}/home/projects/{response_json.get('uuid')}",
899
+ "respondent_url": f"{self.url}/respond/{response_json.get('uuid')}",
900
+ }
901
+
869
902
  ################
870
903
  # DUNDER METHODS
871
904
  ################
@@ -995,15 +1028,28 @@ class Coop(CoopFunctionsMixin):
995
1028
  - We need this function because URL detection with print() does not work alongside animations in VSCode.
996
1029
  """
997
1030
  from rich import print as rich_print
1031
+ from rich.console import Console
1032
+
1033
+ console = Console()
998
1034
 
999
1035
  url = f"{CONFIG.EXPECTED_PARROT_URL}/login?edsl_auth_token={edsl_auth_token}"
1000
1036
 
1001
- if link_description:
1002
- rich_print(
1003
- f"{link_description}\n [#38bdf8][link={url}]{url}[/link][/#38bdf8]"
1004
- )
1037
+ if console.is_terminal:
1038
+ # Running in a standard terminal, show the full URL
1039
+ if link_description:
1040
+ rich_print("{link_description}\n[#38bdf8][link={url}]{url}[/link][/#38bdf8]")
1041
+ else:
1042
+ rich_print(f"[#38bdf8][link={url}]{url}[/link][/#38bdf8]")
1005
1043
  else:
1006
- rich_print(f" [#38bdf8][link={url}]{url}[/link][/#38bdf8]")
1044
+ # Running in an interactive environment (e.g., Jupyter Notebook), hide the URL
1045
+ if link_description:
1046
+ rich_print(f"{link_description}\n[#38bdf8][link={url}][underline]Log in and automatically store key[/underline][/link][/#38bdf8]")
1047
+ else:
1048
+ rich_print(f"[#38bdf8][link={url}][underline]Log in and automatically store key[/underline][/link][/#38bdf8]")
1049
+
1050
+
1051
+
1052
+
1007
1053
 
1008
1054
  def _get_api_key(self, edsl_auth_token: str):
1009
1055
  """
edsl/enums.py CHANGED
@@ -67,6 +67,7 @@ class InferenceServiceType(EnumWithChecks):
67
67
  TOGETHER = "together"
68
68
  PERPLEXITY = "perplexity"
69
69
  DEEPSEEK = "deepseek"
70
+ XAI = "xai"
70
71
 
71
72
 
72
73
  # unavoidable violation of the DRY principle but it is necessary
@@ -86,6 +87,7 @@ InferenceServiceLiteral = Literal[
86
87
  "together",
87
88
  "perplexity",
88
89
  "deepseek",
90
+ "xai",
89
91
  ]
90
92
 
91
93
  available_models_urls = {
@@ -108,7 +110,8 @@ service_to_api_keyname = {
108
110
  InferenceServiceType.MISTRAL.value: "MISTRAL_API_KEY",
109
111
  InferenceServiceType.TOGETHER.value: "TOGETHER_API_KEY",
110
112
  InferenceServiceType.PERPLEXITY.value: "PERPLEXITY_API_KEY",
111
- InferenceServiceType.DEEPSEEK.value: "DEEPSEEK_API_KEY"
113
+ InferenceServiceType.DEEPSEEK.value: "DEEPSEEK_API_KEY",
114
+ InferenceServiceType.XAI.value: "XAI_API_KEY",
112
115
  }
113
116
 
114
117
 
@@ -17,9 +17,10 @@ class AnthropicService(InferenceServiceABC):
17
17
  output_token_name = "output_tokens"
18
18
  model_exclude_list = []
19
19
 
20
+ available_models_url = "https://docs.anthropic.com/en/docs/about-claude/models"
21
+
20
22
  @classmethod
21
23
  def get_model_list(cls, api_key: str = None):
22
-
23
24
  import requests
24
25
 
25
26
  if api_key is None:
@@ -92,13 +93,16 @@ class AnthropicService(InferenceServiceABC):
92
93
  # breakpoint()
93
94
  client = AsyncAnthropic(api_key=self.api_token)
94
95
 
95
- response = await client.messages.create(
96
- model=model_name,
97
- max_tokens=self.max_tokens,
98
- temperature=self.temperature,
99
- system=system_prompt, # note that the Anthropic API uses "system" parameter rather than put it in the message
100
- messages=messages,
101
- )
96
+ try:
97
+ response = await client.messages.create(
98
+ model=model_name,
99
+ max_tokens=self.max_tokens,
100
+ temperature=self.temperature,
101
+ system=system_prompt, # note that the Anthropic API uses "system" parameter rather than put it in the message
102
+ messages=messages,
103
+ )
104
+ except Exception as e:
105
+ return {"message": str(e)}
102
106
  return response.model_dump()
103
107
 
104
108
  LLM.__name__ = model_class_name
@@ -69,6 +69,8 @@ class AvailableModelFetcher:
69
69
 
70
70
  Returns a list of [model, service_name, index] entries.
71
71
  """
72
+ if service == "azure":
73
+ force_refresh = True # Azure models are listed inside the .env AZURE_ENDPOINT_URL_AND_KEY variable
72
74
 
73
75
  if service: # they passed a specific service
74
76
  matching_models, _ = self.get_available_models_by_service(
@@ -110,8 +110,7 @@ class AwsBedrockService(InferenceServiceABC):
110
110
  )
111
111
  return response
112
112
  except (ClientError, Exception) as e:
113
- print(e)
114
- return {"error": str(e)}
113
+ return {"message": str(e)}
115
114
 
116
115
  LLM.__name__ = model_class_name
117
116
 
@@ -179,15 +179,18 @@ class AzureAIService(InferenceServiceABC):
179
179
  api_version=api_version,
180
180
  api_key=api_key,
181
181
  )
182
- response = await client.chat.completions.create(
183
- model=model_name,
184
- messages=[
185
- {
186
- "role": "user",
187
- "content": user_prompt, # Your question can go here
188
- },
189
- ],
190
- )
182
+ try:
183
+ response = await client.chat.completions.create(
184
+ model=model_name,
185
+ messages=[
186
+ {
187
+ "role": "user",
188
+ "content": user_prompt, # Your question can go here
189
+ },
190
+ ],
191
+ )
192
+ except Exception as e:
193
+ return {"message": str(e)}
191
194
  return response.model_dump()
192
195
 
193
196
  # @staticmethod
@@ -39,6 +39,10 @@ class GoogleService(InferenceServiceABC):
39
39
 
40
40
  model_exclude_list = []
41
41
 
42
+ available_models_url = (
43
+ "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models"
44
+ )
45
+
42
46
  @classmethod
43
47
  def get_model_list(cls):
44
48
  model_list = []
@@ -130,9 +134,12 @@ class GoogleService(InferenceServiceABC):
130
134
  )
131
135
  combined_prompt.append(gen_ai_file)
132
136
 
133
- response = await self.generative_model.generate_content_async(
134
- combined_prompt, generation_config=generation_config
135
- )
137
+ try:
138
+ response = await self.generative_model.generate_content_async(
139
+ combined_prompt, generation_config=generation_config
140
+ )
141
+ except Exception as e:
142
+ return {"message": str(e)}
136
143
  return response.to_dict()
137
144
 
138
145
  LLM.__name__ = model_name
@@ -23,6 +23,7 @@ class InferenceServiceABC(ABC):
23
23
  "usage_sequence",
24
24
  "input_token_name",
25
25
  "output_token_name",
26
+ #"available_models_url",
26
27
  ]
27
28
  for attr in must_have_attributes:
28
29
  if not hasattr(cls, attr):
@@ -104,8 +104,9 @@ class InferenceServicesCollection:
104
104
  def available(
105
105
  self,
106
106
  service: Optional[str] = None,
107
+ force_refresh: bool = False,
107
108
  ) -> List[Tuple[str, str, int]]:
108
- return self.availability_fetcher.available(service)
109
+ return self.availability_fetcher.available(service, force_refresh=force_refresh)
109
110
 
110
111
  def reset_cache(self) -> None:
111
112
  self.availability_fetcher.reset_cache()
@@ -120,7 +121,6 @@ class InferenceServicesCollection:
120
121
  def create_model_factory(
121
122
  self, model_name: str, service_name: Optional[InferenceServiceLiteral] = None
122
123
  ) -> "LanguageModel":
123
-
124
124
  if service_name is None: # we try to find the right service
125
125
  service = self.resolver.resolve_model(model_name, service_name)
126
126
  else: # if they passed a service, we'll use that
@@ -111,8 +111,7 @@ class MistralAIService(InferenceServiceABC):
111
111
  ],
112
112
  )
113
113
  except Exception as e:
114
- raise LanguageModelBadResponseError(f"Error with Mistral API: {e}")
115
-
114
+ return {"message": str(e)}
116
115
  return res.model_dump()
117
116
 
118
117
  LLM.__name__ = model_class_name
@@ -84,6 +84,7 @@ class OpenAIService(InferenceServiceABC):
84
84
 
85
85
  @classmethod
86
86
  def get_model_list(cls, api_key=None):
87
+ # breakpoint()
87
88
  if api_key is None:
88
89
  api_key = os.getenv(cls._env_key_name_)
89
90
  raw_list = cls.sync_client(api_key).models.list()
@@ -206,8 +207,10 @@ class OpenAIService(InferenceServiceABC):
206
207
  {"role": "user", "content": content},
207
208
  ]
208
209
  if (
209
- system_prompt == "" and self.omit_system_prompt_if_empty
210
- ) or "o1" in self.model:
210
+ (system_prompt == "" and self.omit_system_prompt_if_empty)
211
+ or "o1" in self.model
212
+ or "o3" in self.model
213
+ ):
211
214
  messages = messages[1:]
212
215
 
213
216
  params = {
@@ -221,14 +224,17 @@ class OpenAIService(InferenceServiceABC):
221
224
  "logprobs": self.logprobs,
222
225
  "top_logprobs": self.top_logprobs if self.logprobs else None,
223
226
  }
224
- if "o1" in self.model:
227
+ if "o1" in self.model or "o3" in self.model:
225
228
  params.pop("max_tokens")
226
229
  params["max_completion_tokens"] = self.max_tokens
227
230
  params["temperature"] = 1
228
231
  try:
229
232
  response = await client.chat.completions.create(**params)
230
233
  except Exception as e:
231
- print(e)
234
+ #breakpoint()
235
+ #print(e)
236
+ #raise e
237
+ return {'message': str(e)}
232
238
  return response.model_dump()
233
239
 
234
240
  LLM.__name__ = "LanguageModel"
@@ -152,7 +152,8 @@ class PerplexityService(OpenAIService):
152
152
  try:
153
153
  response = await client.chat.completions.create(**params)
154
154
  except Exception as e:
155
- print(e, flush=True)
155
+ return {"message": str(e)}
156
+
156
157
  return response.model_dump()
157
158
 
158
159
  LLM.__name__ = "LanguageModel"
@@ -28,6 +28,7 @@ class TestService(InferenceServiceABC):
28
28
  model_exclude_list = []
29
29
  input_token_name = "prompt_tokens"
30
30
  output_token_name = "completion_tokens"
31
+ available_models_url = None
31
32
 
32
33
  @classmethod
33
34
  def available(cls) -> list[str]:
@@ -0,0 +1,11 @@
1
+ from typing import Any, List
2
+ from edsl.inference_services.OpenAIService import OpenAIService
3
+
4
+
5
+ class XAIService(OpenAIService):
6
+ """Openai service class."""
7
+
8
+ _inference_service_ = "xai"
9
+ _env_key_name_ = "XAI_API_KEY"
10
+ _base_url_ = "https://api.x.ai/v1"
11
+ _models_list_cache: List[str] = []
@@ -14,6 +14,7 @@ from edsl.inference_services.TestService import TestService
14
14
  from edsl.inference_services.TogetherAIService import TogetherAIService
15
15
  from edsl.inference_services.PerplexityService import PerplexityService
16
16
  from edsl.inference_services.DeepSeekService import DeepSeekService
17
+ from edsl.inference_services.XAIService import XAIService
17
18
 
18
19
  try:
19
20
  from edsl.inference_services.MistralAIService import MistralAIService
@@ -35,6 +36,7 @@ services = [
35
36
  TogetherAIService,
36
37
  PerplexityService,
37
38
  DeepSeekService,
39
+ XAIService,
38
40
  ]
39
41
 
40
42
  if mistral_available:
edsl/jobs/Jobs.py CHANGED
@@ -364,6 +364,15 @@ class Jobs(Base):
364
364
  self, cache=self.run_config.environment.cache
365
365
  ).create_interviews()
366
366
 
367
+ def show_flow(self, filename: Optional[str] = None) -> None:
368
+ """Show the flow of the survey."""
369
+ from edsl.surveys.SurveyFlowVisualization import SurveyFlowVisualization
370
+ if self.scenarios:
371
+ scenario = self.scenarios[0]
372
+ else:
373
+ scenario = None
374
+ SurveyFlowVisualization(self.survey, scenario=scenario, agent=None).show_flow(filename=filename)
375
+
367
376
  def interviews(self) -> list[Interview]:
368
377
  """
369
378
  Return a list of :class:`edsl.jobs.interviews.Interview` objects.
edsl/jobs/JobsChecks.py CHANGED
@@ -24,14 +24,14 @@ class JobsChecks:
24
24
 
25
25
  def get_missing_api_keys(self) -> set:
26
26
  """
27
- Returns a list of the api keys that a user needs to run this job, but does not currently have in their .env file.
27
+ Returns a list of the API keys that a user needs to run this job, but does not currently have in their .env file.
28
28
  """
29
29
  missing_api_keys = set()
30
30
 
31
31
  from edsl.language_models.model import Model
32
32
  from edsl.enums import service_to_api_keyname
33
33
 
34
- for model in self.jobs.models + [Model()]:
34
+ for model in self.jobs.models: # + [Model()]:
35
35
  if not model.has_valid_api_key():
36
36
  key_name = service_to_api_keyname.get(
37
37
  model._inference_service_, "NOT FOUND"
@@ -134,22 +134,20 @@ class JobsChecks:
134
134
 
135
135
  edsl_auth_token = secrets.token_urlsafe(16)
136
136
 
137
- print("You're missing some of the API keys needed to run this job:")
137
+ print("\nThe following keys are needed to run this survey: \n")
138
138
  for api_key in missing_api_keys:
139
- print(f" 🔑 {api_key}")
139
+ print(f"🔑 {api_key}")
140
140
  print(
141
- "\nYou can either add the missing keys to your .env file, or use remote inference."
141
+ """
142
+ \nYou can provide your own keys for language models or use an Expected Parrot key to access all available models.
143
+ \nClick the link below to create an account and run your survey with your Expected Parrot key:
144
+ """
142
145
  )
143
- print("Remote inference allows you to run jobs on our server.")
144
-
146
+
145
147
  coop = Coop()
146
148
  coop._display_login_url(
147
149
  edsl_auth_token=edsl_auth_token,
148
- link_description="\n🚀 To use remote inference, sign up at the following link:",
149
- )
150
-
151
- print(
152
- "\nOnce you log in, we will automatically retrieve your Expected Parrot API key and continue your job remotely."
150
+ # link_description="",
153
151
  )
154
152
 
155
153
  api_key = coop._poll_for_api_key(edsl_auth_token)
@@ -159,8 +157,7 @@ class JobsChecks:
159
157
  return
160
158
 
161
159
  path_to_env = write_api_key_to_env(api_key)
162
- print("\n✨ API key retrieved and written to .env file at the following path:")
163
- print(f" {path_to_env}")
160
+ print(f"\n✨ Your Expected Parrot key has been stored at the following path: {path_to_env}\n")
164
161
 
165
162
  # Retrieve API key so we can continue running the job
166
163
  load_dotenv()
edsl/jobs/JobsPrompts.py CHANGED
@@ -200,10 +200,10 @@ class JobsPrompts:
200
200
  import warnings
201
201
 
202
202
  warnings.warn(
203
- "Price data could not be retrieved. Using default estimates for input and output token prices. Input: $0.15 / 1M tokens; Output: $0.60 / 1M tokens"
203
+ "Price data could not be retrieved. Using default estimates for input and output token prices. Input: $1.00 / 1M tokens; Output: $1.00 / 1M tokens"
204
204
  )
205
- input_price_per_token = 0.00000015 # $0.15 / 1M tokens
206
- output_price_per_token = 0.00000060 # $0.60 / 1M tokens
205
+ input_price_per_token = 0.000001 # $1.00 / 1M tokens
206
+ output_price_per_token = 0.000001 # $1.00 / 1M tokens
207
207
 
208
208
  # Compute the number of characters (double if the question involves piping)
209
209
  user_prompt_chars = len(str(user_prompt)) * get_piping_multiplier(
@@ -7,6 +7,8 @@ from edsl.data_transfer_models import EDSLResultObjectInput
7
7
 
8
8
  from edsl.results.Result import Result
9
9
  from edsl.jobs.interviews.Interview import Interview
10
+ from edsl.config import Config
11
+ config = Config()
10
12
 
11
13
  if TYPE_CHECKING:
12
14
  from edsl.jobs.Jobs import Jobs
@@ -23,7 +25,7 @@ from edsl.jobs.data_structures import RunConfig
23
25
 
24
26
 
25
27
  class AsyncInterviewRunner:
26
- MAX_CONCURRENT = 5
28
+ MAX_CONCURRENT = int(config.EDSL_MAX_CONCURRENT_TASKS)
27
29
 
28
30
  def __init__(self, jobs: "Jobs", run_config: RunConfig):
29
31
  self.jobs = jobs
@@ -72,11 +72,11 @@ class CheckSurveyScenarioCompatibility:
72
72
  if warn:
73
73
  warnings.warn(message)
74
74
 
75
- if self.scenarios.has_jinja_braces:
76
- warnings.warn(
77
- "The scenarios have Jinja braces ({{ and }}). Converting to '<<' and '>>'. If you want a different conversion, use the convert_jinja_braces method first to modify the scenario."
78
- )
79
- self.scenarios = self.scenarios._convert_jinja_braces()
75
+ # if self.scenarios.has_jinja_braces:
76
+ # warnings.warn(
77
+ # "The scenarios have Jinja braces ({{ and }}). Converting to '<<' and '>>'. If you want a different conversion, use the convert_jinja_braces method first to modify the scenario."
78
+ # )
79
+ # self.scenarios = self.scenarios._convert_jinja_braces()
80
80
 
81
81
 
82
82
  if __name__ == "__main__":
@@ -166,6 +166,9 @@ class InterviewExceptionEntry:
166
166
  >>> entry = InterviewExceptionEntry.example()
167
167
  >>> _ = entry.to_dict()
168
168
  """
169
+ import json
170
+ from edsl.exceptions.questions import QuestionAnswerValidationError
171
+
169
172
  invigilator = (
170
173
  self.invigilator.to_dict() if self.invigilator is not None else None
171
174
  )
@@ -174,7 +177,16 @@ class InterviewExceptionEntry:
174
177
  "time": self.time,
175
178
  "traceback": self.traceback,
176
179
  "invigilator": invigilator,
180
+ "additional_data": {},
177
181
  }
182
+
183
+ if isinstance(self.exception, QuestionAnswerValidationError):
184
+ d["additional_data"]["edsl_response"] = json.dumps(self.exception.data)
185
+ d["additional_data"]["validating_model"] = json.dumps(
186
+ self.exception.model.model_json_schema()
187
+ )
188
+ d["additional_data"]["error_message"] = str(self.exception.message)
189
+
178
190
  return d
179
191
 
180
192
  @classmethod
@@ -419,7 +419,7 @@ class TaskHistory(RepresentationMixin):
419
419
  filename: Optional[str] = None,
420
420
  return_link=False,
421
421
  css=None,
422
- cta="\nClick to open the report in a new tab\n",
422
+ cta="<br><span style='font-size: 18px; font-weight: medium-bold; text-decoration: underline;'>Click to open the report in a new tab</span><br><br>",
423
423
  open_in_browser=False,
424
424
  ):
425
425
  """Return an HTML report."""
@@ -394,7 +394,6 @@ class LanguageModel(
394
394
  from edsl.config import CONFIG
395
395
 
396
396
  TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
397
-
398
397
  response = await asyncio.wait_for(f(**params), timeout=TIMEOUT)
399
398
  new_cache_key = cache.store(
400
399
  **cache_call_params, response=response
@@ -518,7 +517,9 @@ class LanguageModel(
518
517
  """
519
518
  from edsl.language_models.model import get_model_class
520
519
 
521
- model_class = get_model_class(data["model"])
520
+ model_class = get_model_class(
521
+ data["model"], service_name=data.get("inference_service", None)
522
+ )
522
523
  return model_class(**data)
523
524
 
524
525
  def __repr__(self) -> str:
@@ -574,7 +575,6 @@ class LanguageModel(
574
575
  return Model(skip_api_key_check=True)
575
576
 
576
577
  def from_cache(self, cache: "Cache") -> LanguageModel:
577
-
578
578
  from copy import deepcopy
579
579
  from types import MethodType
580
580
  from edsl import Cache
@@ -30,19 +30,22 @@ class PriceManager:
30
30
  except Exception as e:
31
31
  print(f"Error fetching prices: {str(e)}")
32
32
 
33
- def get_price(self, inference_service: str, model: str) -> Optional[Dict]:
33
+ def get_price(self, inference_service: str, model: str) -> Dict:
34
34
  """
35
35
  Get the price information for a specific service and model combination.
36
+ If no specific price is found, returns a fallback price.
36
37
 
37
38
  Args:
38
39
  inference_service (str): The name of the inference service
39
40
  model (str): The model identifier
40
41
 
41
42
  Returns:
42
- Optional[Dict]: Price information if found, None otherwise
43
+ Dict: Price information (either actual or fallback prices)
43
44
  """
44
45
  key = (inference_service, model)
45
- return self._price_lookup.get(key)
46
+ return self._price_lookup.get(key) or self._get_fallback_price(
47
+ inference_service
48
+ )
46
49
 
47
50
  def get_all_prices(self) -> Dict[Tuple[str, str], Dict]:
48
51
  """
@@ -53,6 +56,45 @@ class PriceManager:
53
56
  """
54
57
  return self._price_lookup.copy()
55
58
 
59
+ def _get_fallback_price(self, inference_service: str) -> Dict:
60
+ """
61
+ Get fallback prices for a service.
62
+ - First fallback: The highest input and output prices for that service from the price lookup.
63
+ - Second fallback: $1.00 per million tokens (for both input and output).
64
+
65
+ Args:
66
+ inference_service (str): The inference service name
67
+
68
+ Returns:
69
+ Dict: Price information
70
+ """
71
+ service_prices = [
72
+ prices
73
+ for (service, _), prices in self._price_lookup.items()
74
+ if service == inference_service
75
+ ]
76
+
77
+ input_tokens_per_usd = [
78
+ float(p["input"]["one_usd_buys"]) for p in service_prices if "input" in p
79
+ ]
80
+ if input_tokens_per_usd:
81
+ min_input_tokens = min(input_tokens_per_usd)
82
+ else:
83
+ min_input_tokens = 1_000_000
84
+
85
+ output_tokens_per_usd = [
86
+ float(p["output"]["one_usd_buys"]) for p in service_prices if "output" in p
87
+ ]
88
+ if output_tokens_per_usd:
89
+ min_output_tokens = min(output_tokens_per_usd)
90
+ else:
91
+ min_output_tokens = 1_000_000
92
+
93
+ return {
94
+ "input": {"one_usd_buys": min_input_tokens},
95
+ "output": {"one_usd_buys": min_output_tokens},
96
+ }
97
+
56
98
  def calculate_cost(
57
99
  self,
58
100
  inference_service: str,
@@ -75,8 +117,6 @@ class PriceManager:
75
117
  Union[float, str]: Total cost if calculation successful, error message string if not
76
118
  """
77
119
  relevant_prices = self.get_price(inference_service, model)
78
- if relevant_prices is None:
79
- return f"Could not find price for model {model} in the price lookup."
80
120
 
81
121
  # Extract token counts
82
122
  try: