edsl 0.1.44__py3-none-any.whl → 0.1.46__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. edsl/Base.py +7 -3
  2. edsl/__version__.py +1 -1
  3. edsl/agents/InvigilatorBase.py +3 -1
  4. edsl/agents/PromptConstructor.py +66 -91
  5. edsl/agents/QuestionInstructionPromptBuilder.py +160 -79
  6. edsl/agents/QuestionTemplateReplacementsBuilder.py +80 -17
  7. edsl/agents/question_option_processor.py +15 -6
  8. edsl/coop/CoopFunctionsMixin.py +3 -4
  9. edsl/coop/coop.py +171 -96
  10. edsl/data/RemoteCacheSync.py +10 -9
  11. edsl/enums.py +3 -3
  12. edsl/inference_services/AnthropicService.py +11 -9
  13. edsl/inference_services/AvailableModelFetcher.py +2 -0
  14. edsl/inference_services/AwsBedrock.py +1 -2
  15. edsl/inference_services/AzureAI.py +12 -9
  16. edsl/inference_services/GoogleService.py +9 -4
  17. edsl/inference_services/InferenceServicesCollection.py +2 -2
  18. edsl/inference_services/MistralAIService.py +1 -2
  19. edsl/inference_services/OpenAIService.py +9 -4
  20. edsl/inference_services/PerplexityService.py +2 -1
  21. edsl/inference_services/{GrokService.py → XAIService.py} +2 -2
  22. edsl/inference_services/registry.py +2 -2
  23. edsl/jobs/AnswerQuestionFunctionConstructor.py +12 -1
  24. edsl/jobs/Jobs.py +24 -17
  25. edsl/jobs/JobsChecks.py +10 -13
  26. edsl/jobs/JobsPrompts.py +49 -26
  27. edsl/jobs/JobsRemoteInferenceHandler.py +4 -5
  28. edsl/jobs/async_interview_runner.py +3 -1
  29. edsl/jobs/check_survey_scenario_compatibility.py +5 -5
  30. edsl/jobs/data_structures.py +3 -0
  31. edsl/jobs/interviews/Interview.py +6 -3
  32. edsl/jobs/interviews/InterviewExceptionEntry.py +12 -0
  33. edsl/jobs/tasks/TaskHistory.py +1 -1
  34. edsl/language_models/LanguageModel.py +6 -3
  35. edsl/language_models/PriceManager.py +45 -5
  36. edsl/language_models/model.py +47 -26
  37. edsl/questions/QuestionBase.py +21 -0
  38. edsl/questions/QuestionBasePromptsMixin.py +103 -0
  39. edsl/questions/QuestionFreeText.py +22 -5
  40. edsl/questions/descriptors.py +4 -0
  41. edsl/questions/question_base_gen_mixin.py +96 -29
  42. edsl/results/Dataset.py +65 -0
  43. edsl/results/DatasetExportMixin.py +320 -32
  44. edsl/results/Result.py +27 -0
  45. edsl/results/Results.py +22 -2
  46. edsl/results/ResultsGGMixin.py +7 -3
  47. edsl/scenarios/DocumentChunker.py +2 -0
  48. edsl/scenarios/FileStore.py +10 -0
  49. edsl/scenarios/PdfExtractor.py +21 -1
  50. edsl/scenarios/Scenario.py +25 -9
  51. edsl/scenarios/ScenarioList.py +226 -24
  52. edsl/scenarios/handlers/__init__.py +1 -0
  53. edsl/scenarios/handlers/docx.py +5 -1
  54. edsl/scenarios/handlers/jpeg.py +39 -0
  55. edsl/surveys/Survey.py +5 -4
  56. edsl/surveys/SurveyFlowVisualization.py +91 -43
  57. edsl/templates/error_reporting/exceptions_table.html +7 -8
  58. edsl/templates/error_reporting/interview_details.html +1 -1
  59. edsl/templates/error_reporting/interviews.html +0 -1
  60. edsl/templates/error_reporting/overview.html +2 -7
  61. edsl/templates/error_reporting/performance_plot.html +1 -1
  62. edsl/templates/error_reporting/report.css +1 -1
  63. edsl/utilities/PrettyList.py +14 -0
  64. edsl-0.1.46.dist-info/METADATA +246 -0
  65. {edsl-0.1.44.dist-info → edsl-0.1.46.dist-info}/RECORD +67 -66
  66. edsl-0.1.44.dist-info/METADATA +0 -110
  67. {edsl-0.1.44.dist-info → edsl-0.1.46.dist-info}/LICENSE +0 -0
  68. {edsl-0.1.44.dist-info → edsl-0.1.46.dist-info}/WHEEL +0 -0
@@ -100,7 +100,7 @@ class RemoteCacheSync(AbstractContextManager):
100
100
 
101
101
  def _get_cache_difference(self) -> CacheDifference:
102
102
  """Retrieves differences between local and remote caches."""
103
- diff = self.coop.remote_cache_get_diff(self.cache.keys())
103
+ diff = self.coop.legacy_remote_cache_get_diff(self.cache.keys())
104
104
  return CacheDifference(
105
105
  client_missing_entries=diff.get("client_missing_cacheentries", []),
106
106
  server_missing_keys=diff.get("server_missing_cacheentry_keys", []),
@@ -112,7 +112,7 @@ class RemoteCacheSync(AbstractContextManager):
112
112
  missing_count = len(diff.client_missing_entries)
113
113
 
114
114
  if missing_count == 0:
115
- # self._output("No new entries to add to local cache.")
115
+ # self._output("No new entries to add to local cache.")
116
116
  return
117
117
 
118
118
  # self._output(
@@ -154,22 +154,23 @@ class RemoteCacheSync(AbstractContextManager):
154
154
  upload_count = len(entries_to_upload)
155
155
 
156
156
  if upload_count > 0:
157
+ pass
157
158
  # self._output(
158
159
  # f"Updating remote cache with {upload_count:,} new "
159
160
  # f"{'entry' if upload_count == 1 else 'entries'}..."
160
161
  # )
161
162
 
162
- self.coop.remote_cache_create_many(
163
- entries_to_upload,
164
- visibility="private",
165
- description=self.remote_cache_description,
166
- )
163
+ # self.coop.remote_cache_create_many(
164
+ # entries_to_upload,
165
+ # visibility="private",
166
+ # description=self.remote_cache_description,
167
+ # )
167
168
  # self._output("Remote cache updated!")
168
169
  # else:
169
- # self._output("No new entries to add to remote cache.")
170
+ # self._output("No new entries to add to remote cache.")
170
171
 
171
172
  # self._output(
172
- # f"There are {len(self.cache.keys()):,} entries in the local cache."
173
+ # f"There are {len(self.cache.keys()):,} entries in the local cache."
173
174
  # )
174
175
 
175
176
 
edsl/enums.py CHANGED
@@ -67,7 +67,7 @@ class InferenceServiceType(EnumWithChecks):
67
67
  TOGETHER = "together"
68
68
  PERPLEXITY = "perplexity"
69
69
  DEEPSEEK = "deepseek"
70
- GROK = "grok"
70
+ XAI = "xai"
71
71
 
72
72
 
73
73
  # unavoidable violation of the DRY principle but it is necessary
@@ -87,7 +87,7 @@ InferenceServiceLiteral = Literal[
87
87
  "together",
88
88
  "perplexity",
89
89
  "deepseek",
90
- "grok",
90
+ "xai",
91
91
  ]
92
92
 
93
93
  available_models_urls = {
@@ -111,7 +111,7 @@ service_to_api_keyname = {
111
111
  InferenceServiceType.TOGETHER.value: "TOGETHER_API_KEY",
112
112
  InferenceServiceType.PERPLEXITY.value: "PERPLEXITY_API_KEY",
113
113
  InferenceServiceType.DEEPSEEK.value: "DEEPSEEK_API_KEY",
114
- InferenceServiceType.GROK.value: "XAI_API_KEY",
114
+ InferenceServiceType.XAI.value: "XAI_API_KEY",
115
115
  }
116
116
 
117
117
 
@@ -17,11 +17,10 @@ class AnthropicService(InferenceServiceABC):
17
17
  output_token_name = "output_tokens"
18
18
  model_exclude_list = []
19
19
 
20
- available_models_url = 'https://docs.anthropic.com/en/docs/about-claude/models'
20
+ available_models_url = "https://docs.anthropic.com/en/docs/about-claude/models"
21
21
 
22
22
  @classmethod
23
23
  def get_model_list(cls, api_key: str = None):
24
-
25
24
  import requests
26
25
 
27
26
  if api_key is None:
@@ -94,13 +93,16 @@ class AnthropicService(InferenceServiceABC):
94
93
  # breakpoint()
95
94
  client = AsyncAnthropic(api_key=self.api_token)
96
95
 
97
- response = await client.messages.create(
98
- model=model_name,
99
- max_tokens=self.max_tokens,
100
- temperature=self.temperature,
101
- system=system_prompt, # note that the Anthropic API uses "system" parameter rather than put it in the message
102
- messages=messages,
103
- )
96
+ try:
97
+ response = await client.messages.create(
98
+ model=model_name,
99
+ max_tokens=self.max_tokens,
100
+ temperature=self.temperature,
101
+ system=system_prompt, # note that the Anthropic API uses "system" parameter rather than put it in the message
102
+ messages=messages,
103
+ )
104
+ except Exception as e:
105
+ return {"message": str(e)}
104
106
  return response.model_dump()
105
107
 
106
108
  LLM.__name__ = model_class_name
@@ -69,6 +69,8 @@ class AvailableModelFetcher:
69
69
 
70
70
  Returns a list of [model, service_name, index] entries.
71
71
  """
72
+ if service == "azure" or service == "bedrock":
73
+ force_refresh = True # Azure models are listed inside the .env AZURE_ENDPOINT_URL_AND_KEY variable
72
74
 
73
75
  if service: # they passed a specific service
74
76
  matching_models, _ = self.get_available_models_by_service(
@@ -110,8 +110,7 @@ class AwsBedrockService(InferenceServiceABC):
110
110
  )
111
111
  return response
112
112
  except (ClientError, Exception) as e:
113
- print(e)
114
- return {"error": str(e)}
113
+ return {"message": str(e)}
115
114
 
116
115
  LLM.__name__ = model_class_name
117
116
 
@@ -179,15 +179,18 @@ class AzureAIService(InferenceServiceABC):
179
179
  api_version=api_version,
180
180
  api_key=api_key,
181
181
  )
182
- response = await client.chat.completions.create(
183
- model=model_name,
184
- messages=[
185
- {
186
- "role": "user",
187
- "content": user_prompt, # Your question can go here
188
- },
189
- ],
190
- )
182
+ try:
183
+ response = await client.chat.completions.create(
184
+ model=model_name,
185
+ messages=[
186
+ {
187
+ "role": "user",
188
+ "content": user_prompt, # Your question can go here
189
+ },
190
+ ],
191
+ )
192
+ except Exception as e:
193
+ return {"message": str(e)}
191
194
  return response.model_dump()
192
195
 
193
196
  # @staticmethod
@@ -39,7 +39,9 @@ class GoogleService(InferenceServiceABC):
39
39
 
40
40
  model_exclude_list = []
41
41
 
42
- available_models_url = 'https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models'
42
+ available_models_url = (
43
+ "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models"
44
+ )
43
45
 
44
46
  @classmethod
45
47
  def get_model_list(cls):
@@ -132,9 +134,12 @@ class GoogleService(InferenceServiceABC):
132
134
  )
133
135
  combined_prompt.append(gen_ai_file)
134
136
 
135
- response = await self.generative_model.generate_content_async(
136
- combined_prompt, generation_config=generation_config
137
- )
137
+ try:
138
+ response = await self.generative_model.generate_content_async(
139
+ combined_prompt, generation_config=generation_config
140
+ )
141
+ except Exception as e:
142
+ return {"message": str(e)}
138
143
  return response.to_dict()
139
144
 
140
145
  LLM.__name__ = model_name
@@ -104,8 +104,9 @@ class InferenceServicesCollection:
104
104
  def available(
105
105
  self,
106
106
  service: Optional[str] = None,
107
+ force_refresh: bool = False,
107
108
  ) -> List[Tuple[str, str, int]]:
108
- return self.availability_fetcher.available(service)
109
+ return self.availability_fetcher.available(service, force_refresh=force_refresh)
109
110
 
110
111
  def reset_cache(self) -> None:
111
112
  self.availability_fetcher.reset_cache()
@@ -120,7 +121,6 @@ class InferenceServicesCollection:
120
121
  def create_model_factory(
121
122
  self, model_name: str, service_name: Optional[InferenceServiceLiteral] = None
122
123
  ) -> "LanguageModel":
123
-
124
124
  if service_name is None: # we try to find the right service
125
125
  service = self.resolver.resolve_model(model_name, service_name)
126
126
  else: # if they passed a service, we'll use that
@@ -111,8 +111,7 @@ class MistralAIService(InferenceServiceABC):
111
111
  ],
112
112
  )
113
113
  except Exception as e:
114
- raise LanguageModelBadResponseError(f"Error with Mistral API: {e}")
115
-
114
+ return {"message": str(e)}
116
115
  return res.model_dump()
117
116
 
118
117
  LLM.__name__ = model_class_name
@@ -207,8 +207,10 @@ class OpenAIService(InferenceServiceABC):
207
207
  {"role": "user", "content": content},
208
208
  ]
209
209
  if (
210
- system_prompt == "" and self.omit_system_prompt_if_empty
211
- ) or "o1" in self.model:
210
+ (system_prompt == "" and self.omit_system_prompt_if_empty)
211
+ or "o1" in self.model
212
+ or "o3" in self.model
213
+ ):
212
214
  messages = messages[1:]
213
215
 
214
216
  params = {
@@ -222,14 +224,17 @@ class OpenAIService(InferenceServiceABC):
222
224
  "logprobs": self.logprobs,
223
225
  "top_logprobs": self.top_logprobs if self.logprobs else None,
224
226
  }
225
- if "o1" in self.model:
227
+ if "o1" in self.model or "o3" in self.model:
226
228
  params.pop("max_tokens")
227
229
  params["max_completion_tokens"] = self.max_tokens
228
230
  params["temperature"] = 1
229
231
  try:
230
232
  response = await client.chat.completions.create(**params)
231
233
  except Exception as e:
232
- print(e)
234
+ #breakpoint()
235
+ #print(e)
236
+ #raise e
237
+ return {'message': str(e)}
233
238
  return response.model_dump()
234
239
 
235
240
  LLM.__name__ = "LanguageModel"
@@ -152,7 +152,8 @@ class PerplexityService(OpenAIService):
152
152
  try:
153
153
  response = await client.chat.completions.create(**params)
154
154
  except Exception as e:
155
- print(e, flush=True)
155
+ return {"message": str(e)}
156
+
156
157
  return response.model_dump()
157
158
 
158
159
  LLM.__name__ = "LanguageModel"
@@ -2,10 +2,10 @@ from typing import Any, List
2
2
  from edsl.inference_services.OpenAIService import OpenAIService
3
3
 
4
4
 
5
- class GrokService(OpenAIService):
5
+ class XAIService(OpenAIService):
6
6
  """Openai service class."""
7
7
 
8
- _inference_service_ = "grok"
8
+ _inference_service_ = "xai"
9
9
  _env_key_name_ = "XAI_API_KEY"
10
10
  _base_url_ = "https://api.x.ai/v1"
11
11
  _models_list_cache: List[str] = []
@@ -14,7 +14,7 @@ from edsl.inference_services.TestService import TestService
14
14
  from edsl.inference_services.TogetherAIService import TogetherAIService
15
15
  from edsl.inference_services.PerplexityService import PerplexityService
16
16
  from edsl.inference_services.DeepSeekService import DeepSeekService
17
- from edsl.inference_services.GrokService import GrokService
17
+ from edsl.inference_services.XAIService import XAIService
18
18
 
19
19
  try:
20
20
  from edsl.inference_services.MistralAIService import MistralAIService
@@ -36,7 +36,7 @@ services = [
36
36
  TogetherAIService,
37
37
  PerplexityService,
38
38
  DeepSeekService,
39
- GrokService,
39
+ XAIService,
40
40
  ]
41
41
 
42
42
  if mistral_available:
@@ -66,10 +66,14 @@ class SkipHandler:
66
66
  )
67
67
  )
68
68
 
69
+
69
70
  def cancel_between(start, end):
70
71
  """Cancel the tasks for questions between the start and end indices."""
71
72
  for i in range(start, end):
72
- self.interview.tasks[i].cancel()
73
+ #print(f"Cancelling task {i}")
74
+ #self.interview.tasks[i].cancel()
75
+ #self.interview.tasks[i].set_result("skipped")
76
+ self.interview.skip_flags[self.interview.survey.questions[i].question_name] = True
73
77
 
74
78
  if (next_question_index := next_question.next_q) == EndOfSurvey:
75
79
  cancel_between(
@@ -80,6 +84,8 @@ class SkipHandler:
80
84
  if next_question_index > (current_question_index + 1):
81
85
  cancel_between(current_question_index + 1, next_question_index)
82
86
 
87
+
88
+
83
89
 
84
90
  class AnswerQuestionFunctionConstructor:
85
91
  """Constructs a function that answers a question and records the answer."""
@@ -161,6 +167,11 @@ class AnswerQuestionFunctionConstructor:
161
167
  async def attempt_answer():
162
168
  invigilator = self.invigilator_fetcher(question)
163
169
 
170
+ if self.interview.skip_flags.get(question.question_name, False):
171
+ return invigilator.get_failed_task_result(
172
+ failure_reason="Question skipped."
173
+ )
174
+
164
175
  if self.skip_handler.should_skip(question):
165
176
  return invigilator.get_failed_task_result(
166
177
  failure_reason="Question skipped."
edsl/jobs/Jobs.py CHANGED
@@ -277,7 +277,7 @@ class Jobs(Base):
277
277
 
278
278
  return JobsComponentConstructor(self).by(*args)
279
279
 
280
- def prompts(self) -> "Dataset":
280
+ def prompts(self, iterations=1) -> "Dataset":
281
281
  """Return a Dataset of prompts that will be used.
282
282
 
283
283
 
@@ -285,7 +285,7 @@ class Jobs(Base):
285
285
  >>> Jobs.example().prompts()
286
286
  Dataset(...)
287
287
  """
288
- return JobsPrompts(self).prompts()
288
+ return JobsPrompts(self).prompts(iterations=iterations)
289
289
 
290
290
  def show_prompts(self, all: bool = False) -> None:
291
291
  """Print the prompts."""
@@ -364,6 +364,15 @@ class Jobs(Base):
364
364
  self, cache=self.run_config.environment.cache
365
365
  ).create_interviews()
366
366
 
367
+ def show_flow(self, filename: Optional[str] = None) -> None:
368
+ """Show the flow of the survey."""
369
+ from edsl.surveys.SurveyFlowVisualization import SurveyFlowVisualization
370
+ if self.scenarios:
371
+ scenario = self.scenarios[0]
372
+ else:
373
+ scenario = None
374
+ SurveyFlowVisualization(self.survey, scenario=scenario, agent=None).show_flow(filename=filename)
375
+
367
376
  def interviews(self) -> list[Interview]:
368
377
  """
369
378
  Return a list of :class:`edsl.jobs.interviews.Interview` objects.
@@ -409,11 +418,9 @@ class Jobs(Base):
409
418
  BucketCollection(...)
410
419
  """
411
420
  bc = BucketCollection.from_models(self.models)
412
-
421
+
413
422
  if self.run_config.environment.key_lookup is not None:
414
- bc.update_from_key_lookup(
415
- self.run_config.environment.key_lookup
416
- )
423
+ bc.update_from_key_lookup(self.run_config.environment.key_lookup)
417
424
  return bc
418
425
 
419
426
  def html(self):
@@ -475,25 +482,24 @@ class Jobs(Base):
475
482
  def _start_remote_inference_job(
476
483
  self, job_handler: Optional[JobsRemoteInferenceHandler] = None
477
484
  ) -> Union["Results", None]:
478
-
479
485
  if job_handler is None:
480
486
  job_handler = self._create_remote_inference_handler()
481
-
487
+
482
488
  job_info = job_handler.create_remote_inference_job(
483
- iterations=self.run_config.parameters.n,
484
- remote_inference_description=self.run_config.parameters.remote_inference_description,
485
- remote_inference_results_visibility=self.run_config.parameters.remote_inference_results_visibility,
489
+ iterations=self.run_config.parameters.n,
490
+ remote_inference_description=self.run_config.parameters.remote_inference_description,
491
+ remote_inference_results_visibility=self.run_config.parameters.remote_inference_results_visibility,
492
+ fresh=self.run_config.parameters.fresh,
486
493
  )
487
494
  return job_info
488
-
489
- def _create_remote_inference_handler(self) -> JobsRemoteInferenceHandler:
490
495
 
496
+ def _create_remote_inference_handler(self) -> JobsRemoteInferenceHandler:
491
497
  from edsl.jobs.JobsRemoteInferenceHandler import JobsRemoteInferenceHandler
492
-
498
+
493
499
  return JobsRemoteInferenceHandler(
494
500
  self, verbose=self.run_config.parameters.verbose
495
501
  )
496
-
502
+
497
503
  def _remote_results(
498
504
  self,
499
505
  config: RunConfig,
@@ -507,7 +513,8 @@ class Jobs(Base):
507
513
  if jh.use_remote_inference(self.run_config.parameters.disable_remote_inference):
508
514
  job_info: RemoteJobInfo = self._start_remote_inference_job(jh)
509
515
  if background:
510
- from edsl.results.Results import Results
516
+ from edsl.results.Results import Results
517
+
511
518
  results = Results.from_job_info(job_info)
512
519
  return results
513
520
  else:
@@ -594,7 +601,7 @@ class Jobs(Base):
594
601
  # first try to run the job remotely
595
602
  if (results := self._remote_results(config)) is not None:
596
603
  return results
597
-
604
+
598
605
  self._check_if_local_keys_ok()
599
606
 
600
607
  if config.environment.bucket_collection is None:
edsl/jobs/JobsChecks.py CHANGED
@@ -24,7 +24,7 @@ class JobsChecks:
24
24
 
25
25
  def get_missing_api_keys(self) -> set:
26
26
  """
27
- Returns a list of the api keys that a user needs to run this job, but does not currently have in their .env file.
27
+ Returns a list of the API keys that a user needs to run this job, but does not currently have in their .env file.
28
28
  """
29
29
  missing_api_keys = set()
30
30
 
@@ -134,22 +134,20 @@ class JobsChecks:
134
134
 
135
135
  edsl_auth_token = secrets.token_urlsafe(16)
136
136
 
137
- print("API keys are required to run surveys with language models. The following keys are needed to run this survey: ")
137
+ print("\nThe following keys are needed to run this survey: \n")
138
138
  for api_key in missing_api_keys:
139
- print(f" 🔑 {api_key}")
139
+ print(f"🔑 {api_key}")
140
140
  print(
141
- "\nYou can provide your own keys or use an Expected Parrot key to access all available models."
141
+ """
142
+ \nYou can provide your own keys for language models or use an Expected Parrot key to access all available models.
143
+ \nClick the link below to create an account and run your survey with your Expected Parrot key:
144
+ """
142
145
  )
143
- print("Please see the documentation page to learn about options for managing keys: https://docs.expectedparrot.com/en/latest/api_keys.html")
144
-
146
+
145
147
  coop = Coop()
146
148
  coop._display_login_url(
147
149
  edsl_auth_token=edsl_auth_token,
148
- link_description="\n➡️ Click the link below to create an account and get an Expected Parrot key:\n",
149
- )
150
-
151
- print(
152
- "\nOnce you log in, your key will be stored on your computer and your survey will start running at the Expected Parrot server."
150
+ # link_description="",
153
151
  )
154
152
 
155
153
  api_key = coop._poll_for_api_key(edsl_auth_token)
@@ -159,8 +157,7 @@ class JobsChecks:
159
157
  return
160
158
 
161
159
  path_to_env = write_api_key_to_env(api_key)
162
- print("\n✨ Your key has been stored at the following path: ")
163
- print(f" {path_to_env}")
160
+ print(f"\n✨ Your Expected Parrot key has been stored at the following path: {path_to_env}\n")
164
161
 
165
162
  # Retrieve API key so we can continue running the job
166
163
  load_dotenv()
edsl/jobs/JobsPrompts.py CHANGED
@@ -18,6 +18,7 @@ from edsl.data.CacheEntry import CacheEntry
18
18
 
19
19
  logger = logging.getLogger(__name__)
20
20
 
21
+
21
22
  class JobsPrompts:
22
23
  def __init__(self, jobs: "Jobs"):
23
24
  self.interviews = jobs.interviews()
@@ -26,7 +27,9 @@ class JobsPrompts:
26
27
  self.survey = jobs.survey
27
28
  self._price_lookup = None
28
29
  self._agent_lookup = {agent: idx for idx, agent in enumerate(self.agents)}
29
- self._scenario_lookup = {scenario: idx for idx, scenario in enumerate(self.scenarios)}
30
+ self._scenario_lookup = {
31
+ scenario: idx for idx, scenario in enumerate(self.scenarios)
32
+ }
30
33
 
31
34
  @property
32
35
  def price_lookup(self):
@@ -37,7 +40,7 @@ class JobsPrompts:
37
40
  self._price_lookup = c.fetch_prices()
38
41
  return self._price_lookup
39
42
 
40
- def prompts(self) -> "Dataset":
43
+ def prompts(self, iterations=1) -> "Dataset":
41
44
  """Return a Dataset of prompts that will be used.
42
45
 
43
46
  >>> from edsl.jobs import Jobs
@@ -54,11 +57,11 @@ class JobsPrompts:
54
57
  models = []
55
58
  costs = []
56
59
  cache_keys = []
57
-
60
+
58
61
  for interview_index, interview in enumerate(interviews):
59
62
  logger.info(f"Processing interview {interview_index} of {len(interviews)}")
60
63
  interview_start = time.time()
61
-
64
+
62
65
  # Fetch invigilators timing
63
66
  invig_start = time.time()
64
67
  invigilators = [
@@ -66,8 +69,10 @@ class JobsPrompts:
66
69
  for question in interview.survey.questions
67
70
  ]
68
71
  invig_end = time.time()
69
- logger.debug(f"Time taken to fetch invigilators: {invig_end - invig_start:.4f}s")
70
-
72
+ logger.debug(
73
+ f"Time taken to fetch invigilators: {invig_end - invig_start:.4f}s"
74
+ )
75
+
71
76
  # Process prompts timing
72
77
  prompts_start = time.time()
73
78
  for _, invigilator in enumerate(invigilators):
@@ -75,13 +80,15 @@ class JobsPrompts:
75
80
  get_prompts_start = time.time()
76
81
  prompts = invigilator.get_prompts()
77
82
  get_prompts_end = time.time()
78
- logger.debug(f"Time taken to get prompts: {get_prompts_end - get_prompts_start:.4f}s")
79
-
83
+ logger.debug(
84
+ f"Time taken to get prompts: {get_prompts_end - get_prompts_start:.4f}s"
85
+ )
86
+
80
87
  user_prompt = prompts["user_prompt"]
81
88
  system_prompt = prompts["system_prompt"]
82
89
  user_prompts.append(user_prompt)
83
90
  system_prompts.append(system_prompt)
84
-
91
+
85
92
  # Index lookups timing
86
93
  index_start = time.time()
87
94
  agent_index = self._agent_lookup[invigilator.agent]
@@ -90,14 +97,18 @@ class JobsPrompts:
90
97
  scenario_index = self._scenario_lookup[invigilator.scenario]
91
98
  scenario_indices.append(scenario_index)
92
99
  index_end = time.time()
93
- logger.debug(f"Time taken for index lookups: {index_end - index_start:.4f}s")
94
-
100
+ logger.debug(
101
+ f"Time taken for index lookups: {index_end - index_start:.4f}s"
102
+ )
103
+
95
104
  # Model and question name assignment timing
96
105
  assign_start = time.time()
97
106
  models.append(invigilator.model.model)
98
107
  question_names.append(invigilator.question.question_name)
99
108
  assign_end = time.time()
100
- logger.debug(f"Time taken for assignments: {assign_end - assign_start:.4f}s")
109
+ logger.debug(
110
+ f"Time taken for assignments: {assign_end - assign_start:.4f}s"
111
+ )
101
112
 
102
113
  # Cost estimation timing
103
114
  cost_start = time.time()
@@ -109,32 +120,44 @@ class JobsPrompts:
109
120
  model=invigilator.model.model,
110
121
  )
111
122
  cost_end = time.time()
112
- logger.debug(f"Time taken to estimate prompt cost: {cost_end - cost_start:.4f}s")
123
+ logger.debug(
124
+ f"Time taken to estimate prompt cost: {cost_end - cost_start:.4f}s"
125
+ )
113
126
  costs.append(prompt_cost["cost_usd"])
114
127
 
115
128
  # Cache key generation timing
116
129
  cache_key_gen_start = time.time()
117
- cache_key = CacheEntry.gen_key(
118
- model=invigilator.model.model,
119
- parameters=invigilator.model.parameters,
120
- system_prompt=system_prompt,
121
- user_prompt=user_prompt,
122
- iteration=0,
123
- )
130
+ for iteration in range(iterations):
131
+ cache_key = CacheEntry.gen_key(
132
+ model=invigilator.model.model,
133
+ parameters=invigilator.model.parameters,
134
+ system_prompt=system_prompt,
135
+ user_prompt=user_prompt,
136
+ iteration=iteration,
137
+ )
138
+ cache_keys.append(cache_key)
139
+
124
140
  cache_key_gen_end = time.time()
125
- cache_keys.append(cache_key)
126
- logger.debug(f"Time taken to generate cache key: {cache_key_gen_end - cache_key_gen_start:.4f}s")
141
+ logger.debug(
142
+ f"Time taken to generate cache key: {cache_key_gen_end - cache_key_gen_start:.4f}s"
143
+ )
127
144
  logger.debug("-" * 50) # Separator between iterations
128
145
 
129
146
  prompts_end = time.time()
130
- logger.info(f"Time taken to process prompts: {prompts_end - prompts_start:.4f}s")
131
-
147
+ logger.info(
148
+ f"Time taken to process prompts: {prompts_end - prompts_start:.4f}s"
149
+ )
150
+
132
151
  interview_end = time.time()
133
- logger.info(f"Overall time taken for interview: {interview_end - interview_start:.4f}s")
152
+ logger.info(
153
+ f"Overall time taken for interview: {interview_end - interview_start:.4f}s"
154
+ )
134
155
  logger.info("Time breakdown:")
135
156
  logger.info(f" Invigilators: {invig_end - invig_start:.4f}s")
136
157
  logger.info(f" Prompts processing: {prompts_end - prompts_start:.4f}s")
137
- logger.info(f" Other overhead: {(interview_end - interview_start) - ((invig_end - invig_start) + (prompts_end - prompts_start)):.4f}s")
158
+ logger.info(
159
+ f" Other overhead: {(interview_end - interview_start) - ((invig_end - invig_start) + (prompts_end - prompts_start)):.4f}s"
160
+ )
138
161
 
139
162
  d = Dataset(
140
163
  [