edsl 0.1.44__py3-none-any.whl → 0.1.46__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +7 -3
- edsl/__version__.py +1 -1
- edsl/agents/InvigilatorBase.py +3 -1
- edsl/agents/PromptConstructor.py +66 -91
- edsl/agents/QuestionInstructionPromptBuilder.py +160 -79
- edsl/agents/QuestionTemplateReplacementsBuilder.py +80 -17
- edsl/agents/question_option_processor.py +15 -6
- edsl/coop/CoopFunctionsMixin.py +3 -4
- edsl/coop/coop.py +171 -96
- edsl/data/RemoteCacheSync.py +10 -9
- edsl/enums.py +3 -3
- edsl/inference_services/AnthropicService.py +11 -9
- edsl/inference_services/AvailableModelFetcher.py +2 -0
- edsl/inference_services/AwsBedrock.py +1 -2
- edsl/inference_services/AzureAI.py +12 -9
- edsl/inference_services/GoogleService.py +9 -4
- edsl/inference_services/InferenceServicesCollection.py +2 -2
- edsl/inference_services/MistralAIService.py +1 -2
- edsl/inference_services/OpenAIService.py +9 -4
- edsl/inference_services/PerplexityService.py +2 -1
- edsl/inference_services/{GrokService.py → XAIService.py} +2 -2
- edsl/inference_services/registry.py +2 -2
- edsl/jobs/AnswerQuestionFunctionConstructor.py +12 -1
- edsl/jobs/Jobs.py +24 -17
- edsl/jobs/JobsChecks.py +10 -13
- edsl/jobs/JobsPrompts.py +49 -26
- edsl/jobs/JobsRemoteInferenceHandler.py +4 -5
- edsl/jobs/async_interview_runner.py +3 -1
- edsl/jobs/check_survey_scenario_compatibility.py +5 -5
- edsl/jobs/data_structures.py +3 -0
- edsl/jobs/interviews/Interview.py +6 -3
- edsl/jobs/interviews/InterviewExceptionEntry.py +12 -0
- edsl/jobs/tasks/TaskHistory.py +1 -1
- edsl/language_models/LanguageModel.py +6 -3
- edsl/language_models/PriceManager.py +45 -5
- edsl/language_models/model.py +47 -26
- edsl/questions/QuestionBase.py +21 -0
- edsl/questions/QuestionBasePromptsMixin.py +103 -0
- edsl/questions/QuestionFreeText.py +22 -5
- edsl/questions/descriptors.py +4 -0
- edsl/questions/question_base_gen_mixin.py +96 -29
- edsl/results/Dataset.py +65 -0
- edsl/results/DatasetExportMixin.py +320 -32
- edsl/results/Result.py +27 -0
- edsl/results/Results.py +22 -2
- edsl/results/ResultsGGMixin.py +7 -3
- edsl/scenarios/DocumentChunker.py +2 -0
- edsl/scenarios/FileStore.py +10 -0
- edsl/scenarios/PdfExtractor.py +21 -1
- edsl/scenarios/Scenario.py +25 -9
- edsl/scenarios/ScenarioList.py +226 -24
- edsl/scenarios/handlers/__init__.py +1 -0
- edsl/scenarios/handlers/docx.py +5 -1
- edsl/scenarios/handlers/jpeg.py +39 -0
- edsl/surveys/Survey.py +5 -4
- edsl/surveys/SurveyFlowVisualization.py +91 -43
- edsl/templates/error_reporting/exceptions_table.html +7 -8
- edsl/templates/error_reporting/interview_details.html +1 -1
- edsl/templates/error_reporting/interviews.html +0 -1
- edsl/templates/error_reporting/overview.html +2 -7
- edsl/templates/error_reporting/performance_plot.html +1 -1
- edsl/templates/error_reporting/report.css +1 -1
- edsl/utilities/PrettyList.py +14 -0
- edsl-0.1.46.dist-info/METADATA +246 -0
- {edsl-0.1.44.dist-info → edsl-0.1.46.dist-info}/RECORD +67 -66
- edsl-0.1.44.dist-info/METADATA +0 -110
- {edsl-0.1.44.dist-info → edsl-0.1.46.dist-info}/LICENSE +0 -0
- {edsl-0.1.44.dist-info → edsl-0.1.46.dist-info}/WHEEL +0 -0
edsl/data/RemoteCacheSync.py
CHANGED
@@ -100,7 +100,7 @@ class RemoteCacheSync(AbstractContextManager):
|
|
100
100
|
|
101
101
|
def _get_cache_difference(self) -> CacheDifference:
|
102
102
|
"""Retrieves differences between local and remote caches."""
|
103
|
-
diff = self.coop.
|
103
|
+
diff = self.coop.legacy_remote_cache_get_diff(self.cache.keys())
|
104
104
|
return CacheDifference(
|
105
105
|
client_missing_entries=diff.get("client_missing_cacheentries", []),
|
106
106
|
server_missing_keys=diff.get("server_missing_cacheentry_keys", []),
|
@@ -112,7 +112,7 @@ class RemoteCacheSync(AbstractContextManager):
|
|
112
112
|
missing_count = len(diff.client_missing_entries)
|
113
113
|
|
114
114
|
if missing_count == 0:
|
115
|
-
|
115
|
+
# self._output("No new entries to add to local cache.")
|
116
116
|
return
|
117
117
|
|
118
118
|
# self._output(
|
@@ -154,22 +154,23 @@ class RemoteCacheSync(AbstractContextManager):
|
|
154
154
|
upload_count = len(entries_to_upload)
|
155
155
|
|
156
156
|
if upload_count > 0:
|
157
|
+
pass
|
157
158
|
# self._output(
|
158
159
|
# f"Updating remote cache with {upload_count:,} new "
|
159
160
|
# f"{'entry' if upload_count == 1 else 'entries'}..."
|
160
161
|
# )
|
161
162
|
|
162
|
-
self.coop.remote_cache_create_many(
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
)
|
163
|
+
# self.coop.remote_cache_create_many(
|
164
|
+
# entries_to_upload,
|
165
|
+
# visibility="private",
|
166
|
+
# description=self.remote_cache_description,
|
167
|
+
# )
|
167
168
|
# self._output("Remote cache updated!")
|
168
169
|
# else:
|
169
|
-
|
170
|
+
# self._output("No new entries to add to remote cache.")
|
170
171
|
|
171
172
|
# self._output(
|
172
|
-
|
173
|
+
# f"There are {len(self.cache.keys()):,} entries in the local cache."
|
173
174
|
# )
|
174
175
|
|
175
176
|
|
edsl/enums.py
CHANGED
@@ -67,7 +67,7 @@ class InferenceServiceType(EnumWithChecks):
|
|
67
67
|
TOGETHER = "together"
|
68
68
|
PERPLEXITY = "perplexity"
|
69
69
|
DEEPSEEK = "deepseek"
|
70
|
-
|
70
|
+
XAI = "xai"
|
71
71
|
|
72
72
|
|
73
73
|
# unavoidable violation of the DRY principle but it is necessary
|
@@ -87,7 +87,7 @@ InferenceServiceLiteral = Literal[
|
|
87
87
|
"together",
|
88
88
|
"perplexity",
|
89
89
|
"deepseek",
|
90
|
-
"
|
90
|
+
"xai",
|
91
91
|
]
|
92
92
|
|
93
93
|
available_models_urls = {
|
@@ -111,7 +111,7 @@ service_to_api_keyname = {
|
|
111
111
|
InferenceServiceType.TOGETHER.value: "TOGETHER_API_KEY",
|
112
112
|
InferenceServiceType.PERPLEXITY.value: "PERPLEXITY_API_KEY",
|
113
113
|
InferenceServiceType.DEEPSEEK.value: "DEEPSEEK_API_KEY",
|
114
|
-
InferenceServiceType.
|
114
|
+
InferenceServiceType.XAI.value: "XAI_API_KEY",
|
115
115
|
}
|
116
116
|
|
117
117
|
|
@@ -17,11 +17,10 @@ class AnthropicService(InferenceServiceABC):
|
|
17
17
|
output_token_name = "output_tokens"
|
18
18
|
model_exclude_list = []
|
19
19
|
|
20
|
-
available_models_url =
|
20
|
+
available_models_url = "https://docs.anthropic.com/en/docs/about-claude/models"
|
21
21
|
|
22
22
|
@classmethod
|
23
23
|
def get_model_list(cls, api_key: str = None):
|
24
|
-
|
25
24
|
import requests
|
26
25
|
|
27
26
|
if api_key is None:
|
@@ -94,13 +93,16 @@ class AnthropicService(InferenceServiceABC):
|
|
94
93
|
# breakpoint()
|
95
94
|
client = AsyncAnthropic(api_key=self.api_token)
|
96
95
|
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
96
|
+
try:
|
97
|
+
response = await client.messages.create(
|
98
|
+
model=model_name,
|
99
|
+
max_tokens=self.max_tokens,
|
100
|
+
temperature=self.temperature,
|
101
|
+
system=system_prompt, # note that the Anthropic API uses "system" parameter rather than put it in the message
|
102
|
+
messages=messages,
|
103
|
+
)
|
104
|
+
except Exception as e:
|
105
|
+
return {"message": str(e)}
|
104
106
|
return response.model_dump()
|
105
107
|
|
106
108
|
LLM.__name__ = model_class_name
|
@@ -69,6 +69,8 @@ class AvailableModelFetcher:
|
|
69
69
|
|
70
70
|
Returns a list of [model, service_name, index] entries.
|
71
71
|
"""
|
72
|
+
if service == "azure" or service == "bedrock":
|
73
|
+
force_refresh = True # Azure models are listed inside the .env AZURE_ENDPOINT_URL_AND_KEY variable
|
72
74
|
|
73
75
|
if service: # they passed a specific service
|
74
76
|
matching_models, _ = self.get_available_models_by_service(
|
@@ -179,15 +179,18 @@ class AzureAIService(InferenceServiceABC):
|
|
179
179
|
api_version=api_version,
|
180
180
|
api_key=api_key,
|
181
181
|
)
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
182
|
+
try:
|
183
|
+
response = await client.chat.completions.create(
|
184
|
+
model=model_name,
|
185
|
+
messages=[
|
186
|
+
{
|
187
|
+
"role": "user",
|
188
|
+
"content": user_prompt, # Your question can go here
|
189
|
+
},
|
190
|
+
],
|
191
|
+
)
|
192
|
+
except Exception as e:
|
193
|
+
return {"message": str(e)}
|
191
194
|
return response.model_dump()
|
192
195
|
|
193
196
|
# @staticmethod
|
@@ -39,7 +39,9 @@ class GoogleService(InferenceServiceABC):
|
|
39
39
|
|
40
40
|
model_exclude_list = []
|
41
41
|
|
42
|
-
available_models_url =
|
42
|
+
available_models_url = (
|
43
|
+
"https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models"
|
44
|
+
)
|
43
45
|
|
44
46
|
@classmethod
|
45
47
|
def get_model_list(cls):
|
@@ -132,9 +134,12 @@ class GoogleService(InferenceServiceABC):
|
|
132
134
|
)
|
133
135
|
combined_prompt.append(gen_ai_file)
|
134
136
|
|
135
|
-
|
136
|
-
|
137
|
-
|
137
|
+
try:
|
138
|
+
response = await self.generative_model.generate_content_async(
|
139
|
+
combined_prompt, generation_config=generation_config
|
140
|
+
)
|
141
|
+
except Exception as e:
|
142
|
+
return {"message": str(e)}
|
138
143
|
return response.to_dict()
|
139
144
|
|
140
145
|
LLM.__name__ = model_name
|
@@ -104,8 +104,9 @@ class InferenceServicesCollection:
|
|
104
104
|
def available(
|
105
105
|
self,
|
106
106
|
service: Optional[str] = None,
|
107
|
+
force_refresh: bool = False,
|
107
108
|
) -> List[Tuple[str, str, int]]:
|
108
|
-
return self.availability_fetcher.available(service)
|
109
|
+
return self.availability_fetcher.available(service, force_refresh=force_refresh)
|
109
110
|
|
110
111
|
def reset_cache(self) -> None:
|
111
112
|
self.availability_fetcher.reset_cache()
|
@@ -120,7 +121,6 @@ class InferenceServicesCollection:
|
|
120
121
|
def create_model_factory(
|
121
122
|
self, model_name: str, service_name: Optional[InferenceServiceLiteral] = None
|
122
123
|
) -> "LanguageModel":
|
123
|
-
|
124
124
|
if service_name is None: # we try to find the right service
|
125
125
|
service = self.resolver.resolve_model(model_name, service_name)
|
126
126
|
else: # if they passed a service, we'll use that
|
@@ -111,8 +111,7 @@ class MistralAIService(InferenceServiceABC):
|
|
111
111
|
],
|
112
112
|
)
|
113
113
|
except Exception as e:
|
114
|
-
|
115
|
-
|
114
|
+
return {"message": str(e)}
|
116
115
|
return res.model_dump()
|
117
116
|
|
118
117
|
LLM.__name__ = model_class_name
|
@@ -207,8 +207,10 @@ class OpenAIService(InferenceServiceABC):
|
|
207
207
|
{"role": "user", "content": content},
|
208
208
|
]
|
209
209
|
if (
|
210
|
-
system_prompt == "" and self.omit_system_prompt_if_empty
|
211
|
-
|
210
|
+
(system_prompt == "" and self.omit_system_prompt_if_empty)
|
211
|
+
or "o1" in self.model
|
212
|
+
or "o3" in self.model
|
213
|
+
):
|
212
214
|
messages = messages[1:]
|
213
215
|
|
214
216
|
params = {
|
@@ -222,14 +224,17 @@ class OpenAIService(InferenceServiceABC):
|
|
222
224
|
"logprobs": self.logprobs,
|
223
225
|
"top_logprobs": self.top_logprobs if self.logprobs else None,
|
224
226
|
}
|
225
|
-
if "o1" in self.model:
|
227
|
+
if "o1" in self.model or "o3" in self.model:
|
226
228
|
params.pop("max_tokens")
|
227
229
|
params["max_completion_tokens"] = self.max_tokens
|
228
230
|
params["temperature"] = 1
|
229
231
|
try:
|
230
232
|
response = await client.chat.completions.create(**params)
|
231
233
|
except Exception as e:
|
232
|
-
|
234
|
+
#breakpoint()
|
235
|
+
#print(e)
|
236
|
+
#raise e
|
237
|
+
return {'message': str(e)}
|
233
238
|
return response.model_dump()
|
234
239
|
|
235
240
|
LLM.__name__ = "LanguageModel"
|
@@ -152,7 +152,8 @@ class PerplexityService(OpenAIService):
|
|
152
152
|
try:
|
153
153
|
response = await client.chat.completions.create(**params)
|
154
154
|
except Exception as e:
|
155
|
-
|
155
|
+
return {"message": str(e)}
|
156
|
+
|
156
157
|
return response.model_dump()
|
157
158
|
|
158
159
|
LLM.__name__ = "LanguageModel"
|
@@ -2,10 +2,10 @@ from typing import Any, List
|
|
2
2
|
from edsl.inference_services.OpenAIService import OpenAIService
|
3
3
|
|
4
4
|
|
5
|
-
class
|
5
|
+
class XAIService(OpenAIService):
|
6
6
|
"""Openai service class."""
|
7
7
|
|
8
|
-
_inference_service_ = "
|
8
|
+
_inference_service_ = "xai"
|
9
9
|
_env_key_name_ = "XAI_API_KEY"
|
10
10
|
_base_url_ = "https://api.x.ai/v1"
|
11
11
|
_models_list_cache: List[str] = []
|
@@ -14,7 +14,7 @@ from edsl.inference_services.TestService import TestService
|
|
14
14
|
from edsl.inference_services.TogetherAIService import TogetherAIService
|
15
15
|
from edsl.inference_services.PerplexityService import PerplexityService
|
16
16
|
from edsl.inference_services.DeepSeekService import DeepSeekService
|
17
|
-
from edsl.inference_services.
|
17
|
+
from edsl.inference_services.XAIService import XAIService
|
18
18
|
|
19
19
|
try:
|
20
20
|
from edsl.inference_services.MistralAIService import MistralAIService
|
@@ -36,7 +36,7 @@ services = [
|
|
36
36
|
TogetherAIService,
|
37
37
|
PerplexityService,
|
38
38
|
DeepSeekService,
|
39
|
-
|
39
|
+
XAIService,
|
40
40
|
]
|
41
41
|
|
42
42
|
if mistral_available:
|
@@ -66,10 +66,14 @@ class SkipHandler:
|
|
66
66
|
)
|
67
67
|
)
|
68
68
|
|
69
|
+
|
69
70
|
def cancel_between(start, end):
|
70
71
|
"""Cancel the tasks for questions between the start and end indices."""
|
71
72
|
for i in range(start, end):
|
72
|
-
|
73
|
+
#print(f"Cancelling task {i}")
|
74
|
+
#self.interview.tasks[i].cancel()
|
75
|
+
#self.interview.tasks[i].set_result("skipped")
|
76
|
+
self.interview.skip_flags[self.interview.survey.questions[i].question_name] = True
|
73
77
|
|
74
78
|
if (next_question_index := next_question.next_q) == EndOfSurvey:
|
75
79
|
cancel_between(
|
@@ -80,6 +84,8 @@ class SkipHandler:
|
|
80
84
|
if next_question_index > (current_question_index + 1):
|
81
85
|
cancel_between(current_question_index + 1, next_question_index)
|
82
86
|
|
87
|
+
|
88
|
+
|
83
89
|
|
84
90
|
class AnswerQuestionFunctionConstructor:
|
85
91
|
"""Constructs a function that answers a question and records the answer."""
|
@@ -161,6 +167,11 @@ class AnswerQuestionFunctionConstructor:
|
|
161
167
|
async def attempt_answer():
|
162
168
|
invigilator = self.invigilator_fetcher(question)
|
163
169
|
|
170
|
+
if self.interview.skip_flags.get(question.question_name, False):
|
171
|
+
return invigilator.get_failed_task_result(
|
172
|
+
failure_reason="Question skipped."
|
173
|
+
)
|
174
|
+
|
164
175
|
if self.skip_handler.should_skip(question):
|
165
176
|
return invigilator.get_failed_task_result(
|
166
177
|
failure_reason="Question skipped."
|
edsl/jobs/Jobs.py
CHANGED
@@ -277,7 +277,7 @@ class Jobs(Base):
|
|
277
277
|
|
278
278
|
return JobsComponentConstructor(self).by(*args)
|
279
279
|
|
280
|
-
def prompts(self) -> "Dataset":
|
280
|
+
def prompts(self, iterations=1) -> "Dataset":
|
281
281
|
"""Return a Dataset of prompts that will be used.
|
282
282
|
|
283
283
|
|
@@ -285,7 +285,7 @@ class Jobs(Base):
|
|
285
285
|
>>> Jobs.example().prompts()
|
286
286
|
Dataset(...)
|
287
287
|
"""
|
288
|
-
return JobsPrompts(self).prompts()
|
288
|
+
return JobsPrompts(self).prompts(iterations=iterations)
|
289
289
|
|
290
290
|
def show_prompts(self, all: bool = False) -> None:
|
291
291
|
"""Print the prompts."""
|
@@ -364,6 +364,15 @@ class Jobs(Base):
|
|
364
364
|
self, cache=self.run_config.environment.cache
|
365
365
|
).create_interviews()
|
366
366
|
|
367
|
+
def show_flow(self, filename: Optional[str] = None) -> None:
|
368
|
+
"""Show the flow of the survey."""
|
369
|
+
from edsl.surveys.SurveyFlowVisualization import SurveyFlowVisualization
|
370
|
+
if self.scenarios:
|
371
|
+
scenario = self.scenarios[0]
|
372
|
+
else:
|
373
|
+
scenario = None
|
374
|
+
SurveyFlowVisualization(self.survey, scenario=scenario, agent=None).show_flow(filename=filename)
|
375
|
+
|
367
376
|
def interviews(self) -> list[Interview]:
|
368
377
|
"""
|
369
378
|
Return a list of :class:`edsl.jobs.interviews.Interview` objects.
|
@@ -409,11 +418,9 @@ class Jobs(Base):
|
|
409
418
|
BucketCollection(...)
|
410
419
|
"""
|
411
420
|
bc = BucketCollection.from_models(self.models)
|
412
|
-
|
421
|
+
|
413
422
|
if self.run_config.environment.key_lookup is not None:
|
414
|
-
bc.update_from_key_lookup(
|
415
|
-
self.run_config.environment.key_lookup
|
416
|
-
)
|
423
|
+
bc.update_from_key_lookup(self.run_config.environment.key_lookup)
|
417
424
|
return bc
|
418
425
|
|
419
426
|
def html(self):
|
@@ -475,25 +482,24 @@ class Jobs(Base):
|
|
475
482
|
def _start_remote_inference_job(
|
476
483
|
self, job_handler: Optional[JobsRemoteInferenceHandler] = None
|
477
484
|
) -> Union["Results", None]:
|
478
|
-
|
479
485
|
if job_handler is None:
|
480
486
|
job_handler = self._create_remote_inference_handler()
|
481
|
-
|
487
|
+
|
482
488
|
job_info = job_handler.create_remote_inference_job(
|
483
|
-
|
484
|
-
|
485
|
-
|
489
|
+
iterations=self.run_config.parameters.n,
|
490
|
+
remote_inference_description=self.run_config.parameters.remote_inference_description,
|
491
|
+
remote_inference_results_visibility=self.run_config.parameters.remote_inference_results_visibility,
|
492
|
+
fresh=self.run_config.parameters.fresh,
|
486
493
|
)
|
487
494
|
return job_info
|
488
|
-
|
489
|
-
def _create_remote_inference_handler(self) -> JobsRemoteInferenceHandler:
|
490
495
|
|
496
|
+
def _create_remote_inference_handler(self) -> JobsRemoteInferenceHandler:
|
491
497
|
from edsl.jobs.JobsRemoteInferenceHandler import JobsRemoteInferenceHandler
|
492
|
-
|
498
|
+
|
493
499
|
return JobsRemoteInferenceHandler(
|
494
500
|
self, verbose=self.run_config.parameters.verbose
|
495
501
|
)
|
496
|
-
|
502
|
+
|
497
503
|
def _remote_results(
|
498
504
|
self,
|
499
505
|
config: RunConfig,
|
@@ -507,7 +513,8 @@ class Jobs(Base):
|
|
507
513
|
if jh.use_remote_inference(self.run_config.parameters.disable_remote_inference):
|
508
514
|
job_info: RemoteJobInfo = self._start_remote_inference_job(jh)
|
509
515
|
if background:
|
510
|
-
from edsl.results.Results import Results
|
516
|
+
from edsl.results.Results import Results
|
517
|
+
|
511
518
|
results = Results.from_job_info(job_info)
|
512
519
|
return results
|
513
520
|
else:
|
@@ -594,7 +601,7 @@ class Jobs(Base):
|
|
594
601
|
# first try to run the job remotely
|
595
602
|
if (results := self._remote_results(config)) is not None:
|
596
603
|
return results
|
597
|
-
|
604
|
+
|
598
605
|
self._check_if_local_keys_ok()
|
599
606
|
|
600
607
|
if config.environment.bucket_collection is None:
|
edsl/jobs/JobsChecks.py
CHANGED
@@ -24,7 +24,7 @@ class JobsChecks:
|
|
24
24
|
|
25
25
|
def get_missing_api_keys(self) -> set:
|
26
26
|
"""
|
27
|
-
Returns a list of the
|
27
|
+
Returns a list of the API keys that a user needs to run this job, but does not currently have in their .env file.
|
28
28
|
"""
|
29
29
|
missing_api_keys = set()
|
30
30
|
|
@@ -134,22 +134,20 @@ class JobsChecks:
|
|
134
134
|
|
135
135
|
edsl_auth_token = secrets.token_urlsafe(16)
|
136
136
|
|
137
|
-
print("
|
137
|
+
print("\nThe following keys are needed to run this survey: \n")
|
138
138
|
for api_key in missing_api_keys:
|
139
|
-
print(f"
|
139
|
+
print(f"🔑 {api_key}")
|
140
140
|
print(
|
141
|
-
"
|
141
|
+
"""
|
142
|
+
\nYou can provide your own keys for language models or use an Expected Parrot key to access all available models.
|
143
|
+
\nClick the link below to create an account and run your survey with your Expected Parrot key:
|
144
|
+
"""
|
142
145
|
)
|
143
|
-
|
144
|
-
|
146
|
+
|
145
147
|
coop = Coop()
|
146
148
|
coop._display_login_url(
|
147
149
|
edsl_auth_token=edsl_auth_token,
|
148
|
-
link_description="
|
149
|
-
)
|
150
|
-
|
151
|
-
print(
|
152
|
-
"\nOnce you log in, your key will be stored on your computer and your survey will start running at the Expected Parrot server."
|
150
|
+
# link_description="",
|
153
151
|
)
|
154
152
|
|
155
153
|
api_key = coop._poll_for_api_key(edsl_auth_token)
|
@@ -159,8 +157,7 @@ class JobsChecks:
|
|
159
157
|
return
|
160
158
|
|
161
159
|
path_to_env = write_api_key_to_env(api_key)
|
162
|
-
print("\n✨ Your key has been stored at the following path: ")
|
163
|
-
print(f" {path_to_env}")
|
160
|
+
print(f"\n✨ Your Expected Parrot key has been stored at the following path: {path_to_env}\n")
|
164
161
|
|
165
162
|
# Retrieve API key so we can continue running the job
|
166
163
|
load_dotenv()
|
edsl/jobs/JobsPrompts.py
CHANGED
@@ -18,6 +18,7 @@ from edsl.data.CacheEntry import CacheEntry
|
|
18
18
|
|
19
19
|
logger = logging.getLogger(__name__)
|
20
20
|
|
21
|
+
|
21
22
|
class JobsPrompts:
|
22
23
|
def __init__(self, jobs: "Jobs"):
|
23
24
|
self.interviews = jobs.interviews()
|
@@ -26,7 +27,9 @@ class JobsPrompts:
|
|
26
27
|
self.survey = jobs.survey
|
27
28
|
self._price_lookup = None
|
28
29
|
self._agent_lookup = {agent: idx for idx, agent in enumerate(self.agents)}
|
29
|
-
self._scenario_lookup = {
|
30
|
+
self._scenario_lookup = {
|
31
|
+
scenario: idx for idx, scenario in enumerate(self.scenarios)
|
32
|
+
}
|
30
33
|
|
31
34
|
@property
|
32
35
|
def price_lookup(self):
|
@@ -37,7 +40,7 @@ class JobsPrompts:
|
|
37
40
|
self._price_lookup = c.fetch_prices()
|
38
41
|
return self._price_lookup
|
39
42
|
|
40
|
-
def prompts(self) -> "Dataset":
|
43
|
+
def prompts(self, iterations=1) -> "Dataset":
|
41
44
|
"""Return a Dataset of prompts that will be used.
|
42
45
|
|
43
46
|
>>> from edsl.jobs import Jobs
|
@@ -54,11 +57,11 @@ class JobsPrompts:
|
|
54
57
|
models = []
|
55
58
|
costs = []
|
56
59
|
cache_keys = []
|
57
|
-
|
60
|
+
|
58
61
|
for interview_index, interview in enumerate(interviews):
|
59
62
|
logger.info(f"Processing interview {interview_index} of {len(interviews)}")
|
60
63
|
interview_start = time.time()
|
61
|
-
|
64
|
+
|
62
65
|
# Fetch invigilators timing
|
63
66
|
invig_start = time.time()
|
64
67
|
invigilators = [
|
@@ -66,8 +69,10 @@ class JobsPrompts:
|
|
66
69
|
for question in interview.survey.questions
|
67
70
|
]
|
68
71
|
invig_end = time.time()
|
69
|
-
logger.debug(
|
70
|
-
|
72
|
+
logger.debug(
|
73
|
+
f"Time taken to fetch invigilators: {invig_end - invig_start:.4f}s"
|
74
|
+
)
|
75
|
+
|
71
76
|
# Process prompts timing
|
72
77
|
prompts_start = time.time()
|
73
78
|
for _, invigilator in enumerate(invigilators):
|
@@ -75,13 +80,15 @@ class JobsPrompts:
|
|
75
80
|
get_prompts_start = time.time()
|
76
81
|
prompts = invigilator.get_prompts()
|
77
82
|
get_prompts_end = time.time()
|
78
|
-
logger.debug(
|
79
|
-
|
83
|
+
logger.debug(
|
84
|
+
f"Time taken to get prompts: {get_prompts_end - get_prompts_start:.4f}s"
|
85
|
+
)
|
86
|
+
|
80
87
|
user_prompt = prompts["user_prompt"]
|
81
88
|
system_prompt = prompts["system_prompt"]
|
82
89
|
user_prompts.append(user_prompt)
|
83
90
|
system_prompts.append(system_prompt)
|
84
|
-
|
91
|
+
|
85
92
|
# Index lookups timing
|
86
93
|
index_start = time.time()
|
87
94
|
agent_index = self._agent_lookup[invigilator.agent]
|
@@ -90,14 +97,18 @@ class JobsPrompts:
|
|
90
97
|
scenario_index = self._scenario_lookup[invigilator.scenario]
|
91
98
|
scenario_indices.append(scenario_index)
|
92
99
|
index_end = time.time()
|
93
|
-
logger.debug(
|
94
|
-
|
100
|
+
logger.debug(
|
101
|
+
f"Time taken for index lookups: {index_end - index_start:.4f}s"
|
102
|
+
)
|
103
|
+
|
95
104
|
# Model and question name assignment timing
|
96
105
|
assign_start = time.time()
|
97
106
|
models.append(invigilator.model.model)
|
98
107
|
question_names.append(invigilator.question.question_name)
|
99
108
|
assign_end = time.time()
|
100
|
-
logger.debug(
|
109
|
+
logger.debug(
|
110
|
+
f"Time taken for assignments: {assign_end - assign_start:.4f}s"
|
111
|
+
)
|
101
112
|
|
102
113
|
# Cost estimation timing
|
103
114
|
cost_start = time.time()
|
@@ -109,32 +120,44 @@ class JobsPrompts:
|
|
109
120
|
model=invigilator.model.model,
|
110
121
|
)
|
111
122
|
cost_end = time.time()
|
112
|
-
logger.debug(
|
123
|
+
logger.debug(
|
124
|
+
f"Time taken to estimate prompt cost: {cost_end - cost_start:.4f}s"
|
125
|
+
)
|
113
126
|
costs.append(prompt_cost["cost_usd"])
|
114
127
|
|
115
128
|
# Cache key generation timing
|
116
129
|
cache_key_gen_start = time.time()
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
130
|
+
for iteration in range(iterations):
|
131
|
+
cache_key = CacheEntry.gen_key(
|
132
|
+
model=invigilator.model.model,
|
133
|
+
parameters=invigilator.model.parameters,
|
134
|
+
system_prompt=system_prompt,
|
135
|
+
user_prompt=user_prompt,
|
136
|
+
iteration=iteration,
|
137
|
+
)
|
138
|
+
cache_keys.append(cache_key)
|
139
|
+
|
124
140
|
cache_key_gen_end = time.time()
|
125
|
-
|
126
|
-
|
141
|
+
logger.debug(
|
142
|
+
f"Time taken to generate cache key: {cache_key_gen_end - cache_key_gen_start:.4f}s"
|
143
|
+
)
|
127
144
|
logger.debug("-" * 50) # Separator between iterations
|
128
145
|
|
129
146
|
prompts_end = time.time()
|
130
|
-
logger.info(
|
131
|
-
|
147
|
+
logger.info(
|
148
|
+
f"Time taken to process prompts: {prompts_end - prompts_start:.4f}s"
|
149
|
+
)
|
150
|
+
|
132
151
|
interview_end = time.time()
|
133
|
-
logger.info(
|
152
|
+
logger.info(
|
153
|
+
f"Overall time taken for interview: {interview_end - interview_start:.4f}s"
|
154
|
+
)
|
134
155
|
logger.info("Time breakdown:")
|
135
156
|
logger.info(f" Invigilators: {invig_end - invig_start:.4f}s")
|
136
157
|
logger.info(f" Prompts processing: {prompts_end - prompts_start:.4f}s")
|
137
|
-
logger.info(
|
158
|
+
logger.info(
|
159
|
+
f" Other overhead: {(interview_end - interview_start) - ((invig_end - invig_start) + (prompts_end - prompts_start)):.4f}s"
|
160
|
+
)
|
138
161
|
|
139
162
|
d = Dataset(
|
140
163
|
[
|