edsl 0.1.38__py3-none-any.whl → 0.1.38.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +31 -60
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +9 -18
- edsl/agents/AgentList.py +8 -59
- edsl/agents/Invigilator.py +7 -18
- edsl/agents/InvigilatorBase.py +19 -0
- edsl/agents/PromptConstructor.py +4 -5
- edsl/config.py +0 -8
- edsl/coop/coop.py +7 -74
- edsl/data/Cache.py +2 -27
- edsl/data/CacheEntry.py +3 -8
- edsl/data/RemoteCacheSync.py +19 -0
- edsl/enums.py +0 -2
- edsl/inference_services/GoogleService.py +15 -7
- edsl/inference_services/registry.py +0 -2
- edsl/jobs/Jobs.py +548 -88
- edsl/jobs/interviews/Interview.py +11 -11
- edsl/jobs/runners/JobsRunnerAsyncio.py +35 -140
- edsl/jobs/runners/JobsRunnerStatus.py +2 -0
- edsl/jobs/tasks/TaskHistory.py +16 -15
- edsl/language_models/LanguageModel.py +84 -44
- edsl/language_models/ModelList.py +1 -47
- edsl/language_models/registry.py +4 -57
- edsl/prompts/Prompt.py +3 -8
- edsl/questions/QuestionBase.py +16 -20
- edsl/questions/QuestionExtract.py +4 -3
- edsl/questions/question_registry.py +6 -36
- edsl/results/Dataset.py +15 -146
- edsl/results/DatasetExportMixin.py +217 -231
- edsl/results/DatasetTree.py +4 -134
- edsl/results/Result.py +9 -18
- edsl/results/Results.py +51 -145
- edsl/scenarios/FileStore.py +13 -187
- edsl/scenarios/Scenario.py +4 -61
- edsl/scenarios/ScenarioList.py +62 -237
- edsl/surveys/Survey.py +2 -16
- edsl/surveys/SurveyFlowVisualizationMixin.py +9 -67
- edsl/surveys/instructions/Instruction.py +0 -12
- edsl/templates/error_reporting/interview_details.html +3 -3
- edsl/templates/error_reporting/interviews.html +9 -18
- edsl/utilities/utilities.py +0 -15
- {edsl-0.1.38.dist-info → edsl-0.1.38.dev2.dist-info}/METADATA +1 -2
- {edsl-0.1.38.dist-info → edsl-0.1.38.dev2.dist-info}/RECORD +45 -53
- edsl/inference_services/PerplexityService.py +0 -163
- edsl/jobs/JobsChecks.py +0 -147
- edsl/jobs/JobsPrompts.py +0 -268
- edsl/jobs/JobsRemoteInferenceHandler.py +0 -239
- edsl/results/CSSParameterizer.py +0 -108
- edsl/results/TableDisplay.py +0 -198
- edsl/results/table_display.css +0 -78
- edsl/scenarios/ScenarioJoin.py +0 -127
- {edsl-0.1.38.dist-info → edsl-0.1.38.dev2.dist-info}/LICENSE +0 -0
- {edsl-0.1.38.dist-info → edsl-0.1.38.dev2.dist-info}/WHEEL +0 -0
edsl/jobs/Jobs.py
CHANGED
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
3
3
|
import warnings
|
4
4
|
import requests
|
5
5
|
from itertools import product
|
6
|
-
from typing import Literal, Optional, Union, Sequence, Generator
|
6
|
+
from typing import Literal, Optional, Union, Sequence, Generator
|
7
7
|
|
8
8
|
from edsl.Base import Base
|
9
9
|
|
@@ -11,20 +11,11 @@ from edsl.exceptions import MissingAPIKeyError
|
|
11
11
|
from edsl.jobs.buckets.BucketCollection import BucketCollection
|
12
12
|
from edsl.jobs.interviews.Interview import Interview
|
13
13
|
from edsl.jobs.runners.JobsRunnerAsyncio import JobsRunnerAsyncio
|
14
|
-
from edsl.utilities.decorators import remove_edsl_version
|
14
|
+
from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
|
15
15
|
|
16
16
|
from edsl.data.RemoteCacheSync import RemoteCacheSync
|
17
17
|
from edsl.exceptions.coop import CoopServerResponseError
|
18
18
|
|
19
|
-
if TYPE_CHECKING:
|
20
|
-
from edsl.agents.Agent import Agent
|
21
|
-
from edsl.agents.AgentList import AgentList
|
22
|
-
from edsl.language_models.LanguageModel import LanguageModel
|
23
|
-
from edsl.scenarios.Scenario import Scenario
|
24
|
-
from edsl.surveys.Survey import Survey
|
25
|
-
from edsl.results.Results import Results
|
26
|
-
from edsl.results.Dataset import Dataset
|
27
|
-
|
28
19
|
|
29
20
|
class Jobs(Base):
|
30
21
|
"""
|
@@ -33,8 +24,6 @@ class Jobs(Base):
|
|
33
24
|
The `JobsRunner` is chosen by the user, and is stored in the `jobs_runner_name` attribute.
|
34
25
|
"""
|
35
26
|
|
36
|
-
__documentation__ = "https://docs.expectedparrot.com/en/latest/jobs.html"
|
37
|
-
|
38
27
|
def __init__(
|
39
28
|
self,
|
40
29
|
survey: "Survey",
|
@@ -97,14 +86,8 @@ class Jobs(Base):
|
|
97
86
|
@scenarios.setter
|
98
87
|
def scenarios(self, value):
|
99
88
|
from edsl import ScenarioList
|
100
|
-
from edsl.results.Dataset import Dataset
|
101
89
|
|
102
90
|
if value:
|
103
|
-
if isinstance(
|
104
|
-
value, Dataset
|
105
|
-
): # if the user passes in a Dataset, convert it to a ScenarioList
|
106
|
-
value = value.to_scenario_list()
|
107
|
-
|
108
91
|
if not isinstance(value, ScenarioList):
|
109
92
|
self._scenarios = ScenarioList(value)
|
110
93
|
else:
|
@@ -144,13 +127,6 @@ class Jobs(Base):
|
|
144
127
|
- scenarios: traits of new scenarios are combined with traits of old existing. New scenarios will overwrite overlapping traits, and do not increase the number of scenarios in the instance
|
145
128
|
- models: new models overwrite old models.
|
146
129
|
"""
|
147
|
-
from edsl.results.Dataset import Dataset
|
148
|
-
|
149
|
-
if isinstance(
|
150
|
-
args[0], Dataset
|
151
|
-
): # let the user user a Dataset as if it were a ScenarioList
|
152
|
-
args = args[0].to_scenario_list()
|
153
|
-
|
154
130
|
passed_objects = self._turn_args_to_list(
|
155
131
|
args
|
156
132
|
) # objects can also be passed comma-separated
|
@@ -175,19 +151,73 @@ class Jobs(Base):
|
|
175
151
|
>>> Jobs.example().prompts()
|
176
152
|
Dataset(...)
|
177
153
|
"""
|
178
|
-
from edsl
|
154
|
+
from edsl import Coop
|
155
|
+
|
156
|
+
c = Coop()
|
157
|
+
price_lookup = c.fetch_prices()
|
158
|
+
|
159
|
+
interviews = self.interviews()
|
160
|
+
# data = []
|
161
|
+
interview_indices = []
|
162
|
+
question_names = []
|
163
|
+
user_prompts = []
|
164
|
+
system_prompts = []
|
165
|
+
scenario_indices = []
|
166
|
+
agent_indices = []
|
167
|
+
models = []
|
168
|
+
costs = []
|
169
|
+
from edsl.results.Dataset import Dataset
|
170
|
+
|
171
|
+
for interview_index, interview in enumerate(interviews):
|
172
|
+
invigilators = [
|
173
|
+
interview._get_invigilator(question)
|
174
|
+
for question in self.survey.questions
|
175
|
+
]
|
176
|
+
for _, invigilator in enumerate(invigilators):
|
177
|
+
prompts = invigilator.get_prompts()
|
178
|
+
user_prompt = prompts["user_prompt"]
|
179
|
+
system_prompt = prompts["system_prompt"]
|
180
|
+
user_prompts.append(user_prompt)
|
181
|
+
system_prompts.append(system_prompt)
|
182
|
+
agent_index = self.agents.index(invigilator.agent)
|
183
|
+
agent_indices.append(agent_index)
|
184
|
+
interview_indices.append(interview_index)
|
185
|
+
scenario_index = self.scenarios.index(invigilator.scenario)
|
186
|
+
scenario_indices.append(scenario_index)
|
187
|
+
models.append(invigilator.model.model)
|
188
|
+
question_names.append(invigilator.question.question_name)
|
189
|
+
|
190
|
+
prompt_cost = self.estimate_prompt_cost(
|
191
|
+
system_prompt=system_prompt,
|
192
|
+
user_prompt=user_prompt,
|
193
|
+
price_lookup=price_lookup,
|
194
|
+
inference_service=invigilator.model._inference_service_,
|
195
|
+
model=invigilator.model.model,
|
196
|
+
)
|
197
|
+
costs.append(prompt_cost["cost_usd"])
|
179
198
|
|
180
|
-
|
181
|
-
|
199
|
+
d = Dataset(
|
200
|
+
[
|
201
|
+
{"user_prompt": user_prompts},
|
202
|
+
{"system_prompt": system_prompts},
|
203
|
+
{"interview_index": interview_indices},
|
204
|
+
{"question_name": question_names},
|
205
|
+
{"scenario_index": scenario_indices},
|
206
|
+
{"agent_index": agent_indices},
|
207
|
+
{"model": models},
|
208
|
+
{"estimated_cost": costs},
|
209
|
+
]
|
210
|
+
)
|
211
|
+
return d
|
182
212
|
|
183
|
-
def show_prompts(self, all=False) -> None:
|
213
|
+
def show_prompts(self, all=False, max_rows: Optional[int] = None) -> None:
|
184
214
|
"""Print the prompts."""
|
185
215
|
if all:
|
186
|
-
|
216
|
+
self.prompts().to_scenario_list().print(format="rich", max_rows=max_rows)
|
187
217
|
else:
|
188
|
-
|
189
|
-
|
190
|
-
)
|
218
|
+
self.prompts().select(
|
219
|
+
"user_prompt", "system_prompt"
|
220
|
+
).to_scenario_list().print(format="rich", max_rows=max_rows)
|
191
221
|
|
192
222
|
@staticmethod
|
193
223
|
def estimate_prompt_cost(
|
@@ -196,42 +226,201 @@ class Jobs(Base):
|
|
196
226
|
price_lookup: dict,
|
197
227
|
inference_service: str,
|
198
228
|
model: str,
|
229
|
+
) -> dict:
|
230
|
+
"""Estimates the cost of a prompt. Takes piping into account."""
|
231
|
+
import math
|
232
|
+
|
233
|
+
def get_piping_multiplier(prompt: str):
|
234
|
+
"""Returns 2 if a prompt includes Jinja braces, and 1 otherwise."""
|
235
|
+
|
236
|
+
if "{{" in prompt and "}}" in prompt:
|
237
|
+
return 2
|
238
|
+
return 1
|
239
|
+
|
240
|
+
# Look up prices per token
|
241
|
+
key = (inference_service, model)
|
242
|
+
|
243
|
+
try:
|
244
|
+
relevant_prices = price_lookup[key]
|
245
|
+
|
246
|
+
service_input_token_price = float(
|
247
|
+
relevant_prices["input"]["service_stated_token_price"]
|
248
|
+
)
|
249
|
+
service_input_token_qty = float(
|
250
|
+
relevant_prices["input"]["service_stated_token_qty"]
|
251
|
+
)
|
252
|
+
input_price_per_token = service_input_token_price / service_input_token_qty
|
253
|
+
|
254
|
+
service_output_token_price = float(
|
255
|
+
relevant_prices["output"]["service_stated_token_price"]
|
256
|
+
)
|
257
|
+
service_output_token_qty = float(
|
258
|
+
relevant_prices["output"]["service_stated_token_qty"]
|
259
|
+
)
|
260
|
+
output_price_per_token = (
|
261
|
+
service_output_token_price / service_output_token_qty
|
262
|
+
)
|
263
|
+
|
264
|
+
except KeyError:
|
265
|
+
# A KeyError is likely to occur if we cannot retrieve prices (the price_lookup dict is empty)
|
266
|
+
# Use a sensible default
|
267
|
+
|
268
|
+
import warnings
|
269
|
+
|
270
|
+
warnings.warn(
|
271
|
+
"Price data could not be retrieved. Using default estimates for input and output token prices. Input: $0.15 / 1M tokens; Output: $0.60 / 1M tokens"
|
272
|
+
)
|
273
|
+
input_price_per_token = 0.00000015 # $0.15 / 1M tokens
|
274
|
+
output_price_per_token = 0.00000060 # $0.60 / 1M tokens
|
275
|
+
|
276
|
+
# Compute the number of characters (double if the question involves piping)
|
277
|
+
user_prompt_chars = len(str(user_prompt)) * get_piping_multiplier(
|
278
|
+
str(user_prompt)
|
279
|
+
)
|
280
|
+
system_prompt_chars = len(str(system_prompt)) * get_piping_multiplier(
|
281
|
+
str(system_prompt)
|
282
|
+
)
|
283
|
+
|
284
|
+
# Convert into tokens (1 token approx. equals 4 characters)
|
285
|
+
input_tokens = (user_prompt_chars + system_prompt_chars) // 4
|
286
|
+
|
287
|
+
output_tokens = math.ceil(0.75 * input_tokens)
|
288
|
+
|
289
|
+
cost = (
|
290
|
+
input_tokens * input_price_per_token
|
291
|
+
+ output_tokens * output_price_per_token
|
292
|
+
)
|
293
|
+
|
294
|
+
return {
|
295
|
+
"input_tokens": input_tokens,
|
296
|
+
"output_tokens": output_tokens,
|
297
|
+
"cost_usd": cost,
|
298
|
+
}
|
299
|
+
|
300
|
+
def estimate_job_cost_from_external_prices(
|
301
|
+
self, price_lookup: dict, iterations: int = 1
|
199
302
|
) -> dict:
|
200
303
|
"""
|
201
|
-
|
202
|
-
|
304
|
+
Estimates the cost of a job according to the following assumptions:
|
305
|
+
|
306
|
+
- 1 token = 4 characters.
|
307
|
+
- For each prompt, output tokens = input tokens * 0.75, rounded up to the nearest integer.
|
308
|
+
|
309
|
+
price_lookup is an external pricing dictionary.
|
203
310
|
"""
|
204
|
-
from edsl.jobs.JobsPrompts import JobsPrompts
|
205
311
|
|
206
|
-
|
207
|
-
|
312
|
+
import pandas as pd
|
313
|
+
|
314
|
+
interviews = self.interviews()
|
315
|
+
data = []
|
316
|
+
for interview in interviews:
|
317
|
+
invigilators = [
|
318
|
+
interview._get_invigilator(question)
|
319
|
+
for question in self.survey.questions
|
320
|
+
]
|
321
|
+
for invigilator in invigilators:
|
322
|
+
prompts = invigilator.get_prompts()
|
323
|
+
|
324
|
+
# By this point, agent and scenario data has already been added to the prompts
|
325
|
+
user_prompt = prompts["user_prompt"]
|
326
|
+
system_prompt = prompts["system_prompt"]
|
327
|
+
inference_service = invigilator.model._inference_service_
|
328
|
+
model = invigilator.model.model
|
329
|
+
|
330
|
+
prompt_cost = self.estimate_prompt_cost(
|
331
|
+
system_prompt=system_prompt,
|
332
|
+
user_prompt=user_prompt,
|
333
|
+
price_lookup=price_lookup,
|
334
|
+
inference_service=inference_service,
|
335
|
+
model=model,
|
336
|
+
)
|
337
|
+
|
338
|
+
data.append(
|
339
|
+
{
|
340
|
+
"user_prompt": user_prompt,
|
341
|
+
"system_prompt": system_prompt,
|
342
|
+
"estimated_input_tokens": prompt_cost["input_tokens"],
|
343
|
+
"estimated_output_tokens": prompt_cost["output_tokens"],
|
344
|
+
"estimated_cost_usd": prompt_cost["cost_usd"],
|
345
|
+
"inference_service": inference_service,
|
346
|
+
"model": model,
|
347
|
+
}
|
348
|
+
)
|
349
|
+
|
350
|
+
df = pd.DataFrame.from_records(data)
|
351
|
+
|
352
|
+
df = (
|
353
|
+
df.groupby(["inference_service", "model"])
|
354
|
+
.agg(
|
355
|
+
{
|
356
|
+
"estimated_cost_usd": "sum",
|
357
|
+
"estimated_input_tokens": "sum",
|
358
|
+
"estimated_output_tokens": "sum",
|
359
|
+
}
|
360
|
+
)
|
361
|
+
.reset_index()
|
362
|
+
)
|
363
|
+
df["estimated_cost_usd"] = df["estimated_cost_usd"] * iterations
|
364
|
+
df["estimated_input_tokens"] = df["estimated_input_tokens"] * iterations
|
365
|
+
df["estimated_output_tokens"] = df["estimated_output_tokens"] * iterations
|
366
|
+
|
367
|
+
estimated_costs_by_model = df.to_dict("records")
|
368
|
+
|
369
|
+
estimated_total_cost = sum(
|
370
|
+
model["estimated_cost_usd"] for model in estimated_costs_by_model
|
371
|
+
)
|
372
|
+
estimated_total_input_tokens = sum(
|
373
|
+
model["estimated_input_tokens"] for model in estimated_costs_by_model
|
374
|
+
)
|
375
|
+
estimated_total_output_tokens = sum(
|
376
|
+
model["estimated_output_tokens"] for model in estimated_costs_by_model
|
208
377
|
)
|
209
378
|
|
379
|
+
output = {
|
380
|
+
"estimated_total_cost_usd": estimated_total_cost,
|
381
|
+
"estimated_total_input_tokens": estimated_total_input_tokens,
|
382
|
+
"estimated_total_output_tokens": estimated_total_output_tokens,
|
383
|
+
"model_costs": estimated_costs_by_model,
|
384
|
+
}
|
385
|
+
|
386
|
+
return output
|
387
|
+
|
210
388
|
def estimate_job_cost(self, iterations: int = 1) -> dict:
|
211
389
|
"""
|
212
|
-
|
390
|
+
Estimates the cost of a job according to the following assumptions:
|
213
391
|
|
214
|
-
|
215
|
-
|
216
|
-
from edsl.jobs.JobsPrompts import JobsPrompts
|
392
|
+
- 1 token = 4 characters.
|
393
|
+
- For each prompt, output tokens = input tokens * 0.75, rounded up to the nearest integer.
|
217
394
|
|
218
|
-
|
219
|
-
|
395
|
+
Fetches prices from Coop.
|
396
|
+
"""
|
397
|
+
from edsl import Coop
|
220
398
|
|
221
|
-
|
222
|
-
|
223
|
-
) -> dict:
|
224
|
-
from edsl.jobs.JobsPrompts import JobsPrompts
|
399
|
+
c = Coop()
|
400
|
+
price_lookup = c.fetch_prices()
|
225
401
|
|
226
|
-
|
227
|
-
|
402
|
+
return self.estimate_job_cost_from_external_prices(
|
403
|
+
price_lookup=price_lookup, iterations=iterations
|
404
|
+
)
|
228
405
|
|
229
406
|
@staticmethod
|
230
|
-
def compute_job_cost(job_results: Results) -> float:
|
407
|
+
def compute_job_cost(job_results: "Results") -> float:
|
231
408
|
"""
|
232
409
|
Computes the cost of a completed job in USD.
|
233
410
|
"""
|
234
|
-
|
411
|
+
total_cost = 0
|
412
|
+
for result in job_results:
|
413
|
+
for key in result.raw_model_response:
|
414
|
+
if key.endswith("_cost"):
|
415
|
+
result_cost = result.raw_model_response[key]
|
416
|
+
|
417
|
+
question_name = key.removesuffix("_cost")
|
418
|
+
cache_used = result.cache_used_dict[question_name]
|
419
|
+
|
420
|
+
if isinstance(result_cost, (int, float)) and not cache_used:
|
421
|
+
total_cost += result_cost
|
422
|
+
|
423
|
+
return total_cost
|
235
424
|
|
236
425
|
@staticmethod
|
237
426
|
def _get_container_class(object):
|
@@ -315,12 +504,17 @@ class Jobs(Base):
|
|
315
504
|
|
316
505
|
@staticmethod
|
317
506
|
def _get_empty_container_object(object):
|
318
|
-
from edsl
|
319
|
-
from edsl
|
507
|
+
from edsl import AgentList
|
508
|
+
from edsl import Agent
|
509
|
+
from edsl import Scenario
|
510
|
+
from edsl import ScenarioList
|
320
511
|
|
321
|
-
|
322
|
-
|
323
|
-
)
|
512
|
+
if isinstance(object, Agent):
|
513
|
+
return AgentList([])
|
514
|
+
elif isinstance(object, Scenario):
|
515
|
+
return ScenarioList([])
|
516
|
+
else:
|
517
|
+
return []
|
324
518
|
|
325
519
|
@staticmethod
|
326
520
|
def _merge_objects(passed_objects, current_objects) -> list:
|
@@ -528,6 +722,110 @@ class Jobs(Base):
|
|
528
722
|
return False
|
529
723
|
return self._raise_validation_errors
|
530
724
|
|
725
|
+
def create_remote_inference_job(
|
726
|
+
self,
|
727
|
+
iterations: int = 1,
|
728
|
+
remote_inference_description: Optional[str] = None,
|
729
|
+
remote_inference_results_visibility: Optional[VisibilityType] = "unlisted",
|
730
|
+
verbose=False,
|
731
|
+
):
|
732
|
+
""" """
|
733
|
+
from edsl.coop.coop import Coop
|
734
|
+
|
735
|
+
coop = Coop()
|
736
|
+
self._output("Remote inference activated. Sending job to server...")
|
737
|
+
remote_job_creation_data = coop.remote_inference_create(
|
738
|
+
self,
|
739
|
+
description=remote_inference_description,
|
740
|
+
status="queued",
|
741
|
+
iterations=iterations,
|
742
|
+
initial_results_visibility=remote_inference_results_visibility,
|
743
|
+
)
|
744
|
+
job_uuid = remote_job_creation_data.get("uuid")
|
745
|
+
if self.verbose:
|
746
|
+
print(f"Job sent to server. (Job uuid={job_uuid}).")
|
747
|
+
return remote_job_creation_data
|
748
|
+
|
749
|
+
@staticmethod
|
750
|
+
def check_status(job_uuid):
|
751
|
+
from edsl.coop.coop import Coop
|
752
|
+
|
753
|
+
coop = Coop()
|
754
|
+
return coop.remote_inference_get(job_uuid)
|
755
|
+
|
756
|
+
def poll_remote_inference_job(
|
757
|
+
self, remote_job_creation_data: dict, verbose=False, poll_interval=5
|
758
|
+
) -> Union[Results, None]:
|
759
|
+
from edsl.coop.coop import Coop
|
760
|
+
import time
|
761
|
+
from datetime import datetime
|
762
|
+
from edsl.config import CONFIG
|
763
|
+
|
764
|
+
expected_parrot_url = CONFIG.get("EXPECTED_PARROT_URL")
|
765
|
+
|
766
|
+
job_uuid = remote_job_creation_data.get("uuid")
|
767
|
+
|
768
|
+
coop = Coop()
|
769
|
+
job_in_queue = True
|
770
|
+
while job_in_queue:
|
771
|
+
remote_job_data = coop.remote_inference_get(job_uuid)
|
772
|
+
status = remote_job_data.get("status")
|
773
|
+
if status == "cancelled":
|
774
|
+
if self.verbose:
|
775
|
+
print("\r" + " " * 80 + "\r", end="")
|
776
|
+
print("Job cancelled by the user.")
|
777
|
+
print(
|
778
|
+
f"See {expected_parrot_url}/home/remote-inference for more details."
|
779
|
+
)
|
780
|
+
return None
|
781
|
+
elif status == "failed":
|
782
|
+
if self.verbose:
|
783
|
+
print("\r" + " " * 80 + "\r", end="")
|
784
|
+
print("Job failed.")
|
785
|
+
print(
|
786
|
+
f"See {expected_parrot_url}/home/remote-inference for more details."
|
787
|
+
)
|
788
|
+
return None
|
789
|
+
elif status == "completed":
|
790
|
+
results_uuid = remote_job_data.get("results_uuid")
|
791
|
+
results = coop.get(results_uuid, expected_object_type="results")
|
792
|
+
if self.verbose:
|
793
|
+
print("\r" + " " * 80 + "\r", end="")
|
794
|
+
url = f"{expected_parrot_url}/content/{results_uuid}"
|
795
|
+
print(f"Job completed and Results stored on Coop: {url}.")
|
796
|
+
return results
|
797
|
+
else:
|
798
|
+
duration = poll_interval
|
799
|
+
time_checked = datetime.now().strftime("%Y-%m-%d %I:%M:%S %p")
|
800
|
+
frames = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
|
801
|
+
start_time = time.time()
|
802
|
+
i = 0
|
803
|
+
while time.time() - start_time < duration:
|
804
|
+
if self.verbose:
|
805
|
+
print(
|
806
|
+
f"\r{frames[i % len(frames)]} Job status: {status} - last update: {time_checked}",
|
807
|
+
end="",
|
808
|
+
flush=True,
|
809
|
+
)
|
810
|
+
time.sleep(0.1)
|
811
|
+
i += 1
|
812
|
+
|
813
|
+
def use_remote_inference(self, disable_remote_inference: bool) -> bool:
|
814
|
+
if disable_remote_inference:
|
815
|
+
return False
|
816
|
+
if not disable_remote_inference:
|
817
|
+
try:
|
818
|
+
from edsl import Coop
|
819
|
+
|
820
|
+
user_edsl_settings = Coop().edsl_settings
|
821
|
+
return user_edsl_settings.get("remote_inference", False)
|
822
|
+
except requests.ConnectionError:
|
823
|
+
pass
|
824
|
+
except CoopServerResponseError as e:
|
825
|
+
pass
|
826
|
+
|
827
|
+
return False
|
828
|
+
|
531
829
|
def use_remote_cache(self, disable_remote_cache: bool) -> bool:
|
532
830
|
if disable_remote_cache:
|
533
831
|
return False
|
@@ -544,6 +842,96 @@ class Jobs(Base):
|
|
544
842
|
|
545
843
|
return False
|
546
844
|
|
845
|
+
def check_api_keys(self) -> None:
|
846
|
+
from edsl import Model
|
847
|
+
|
848
|
+
for model in self.models + [Model()]:
|
849
|
+
if not model.has_valid_api_key():
|
850
|
+
raise MissingAPIKeyError(
|
851
|
+
model_name=str(model.model),
|
852
|
+
inference_service=model._inference_service_,
|
853
|
+
)
|
854
|
+
|
855
|
+
def get_missing_api_keys(self) -> set:
|
856
|
+
"""
|
857
|
+
Returns a list of the api keys that a user needs to run this job, but does not currently have in their .env file.
|
858
|
+
"""
|
859
|
+
|
860
|
+
missing_api_keys = set()
|
861
|
+
|
862
|
+
from edsl import Model
|
863
|
+
from edsl.enums import service_to_api_keyname
|
864
|
+
|
865
|
+
for model in self.models + [Model()]:
|
866
|
+
if not model.has_valid_api_key():
|
867
|
+
key_name = service_to_api_keyname.get(
|
868
|
+
model._inference_service_, "NOT FOUND"
|
869
|
+
)
|
870
|
+
missing_api_keys.add(key_name)
|
871
|
+
|
872
|
+
return missing_api_keys
|
873
|
+
|
874
|
+
def user_has_all_model_keys(self):
|
875
|
+
"""
|
876
|
+
Returns True if the user has all model keys required to run their job.
|
877
|
+
|
878
|
+
Otherwise, returns False.
|
879
|
+
"""
|
880
|
+
|
881
|
+
try:
|
882
|
+
self.check_api_keys()
|
883
|
+
return True
|
884
|
+
except MissingAPIKeyError:
|
885
|
+
return False
|
886
|
+
except Exception:
|
887
|
+
raise
|
888
|
+
|
889
|
+
def user_has_ep_api_key(self) -> bool:
|
890
|
+
"""
|
891
|
+
Returns True if the user has an EXPECTED_PARROT_API_KEY in their env.
|
892
|
+
|
893
|
+
Otherwise, returns False.
|
894
|
+
"""
|
895
|
+
|
896
|
+
import os
|
897
|
+
|
898
|
+
coop_api_key = os.getenv("EXPECTED_PARROT_API_KEY")
|
899
|
+
|
900
|
+
if coop_api_key is not None:
|
901
|
+
return True
|
902
|
+
else:
|
903
|
+
return False
|
904
|
+
|
905
|
+
def needs_external_llms(self) -> bool:
|
906
|
+
"""
|
907
|
+
Returns True if the job needs external LLMs to run.
|
908
|
+
|
909
|
+
Otherwise, returns False.
|
910
|
+
"""
|
911
|
+
# These cases are necessary to skip the API key check during doctests
|
912
|
+
|
913
|
+
# Accounts for Results.example()
|
914
|
+
all_agents_answer_questions_directly = len(self.agents) > 0 and all(
|
915
|
+
[hasattr(a, "answer_question_directly") for a in self.agents]
|
916
|
+
)
|
917
|
+
|
918
|
+
# Accounts for InterviewExceptionEntry.example()
|
919
|
+
only_model_is_test = set([m.model for m in self.models]) == set(["test"])
|
920
|
+
|
921
|
+
# Accounts for Survey.__call__
|
922
|
+
all_questions_are_functional = set(
|
923
|
+
[q.question_type for q in self.survey.questions]
|
924
|
+
) == set(["functional"])
|
925
|
+
|
926
|
+
if (
|
927
|
+
all_agents_answer_questions_directly
|
928
|
+
or only_model_is_test
|
929
|
+
or all_questions_are_functional
|
930
|
+
):
|
931
|
+
return False
|
932
|
+
else:
|
933
|
+
return True
|
934
|
+
|
547
935
|
def run(
|
548
936
|
self,
|
549
937
|
n: int = 1,
|
@@ -552,7 +940,7 @@ class Jobs(Base):
|
|
552
940
|
cache: Union[Cache, bool] = None,
|
553
941
|
check_api_keys: bool = False,
|
554
942
|
sidecar_model: Optional[LanguageModel] = None,
|
555
|
-
verbose: bool =
|
943
|
+
verbose: bool = False,
|
556
944
|
print_exceptions=True,
|
557
945
|
remote_cache_description: Optional[str] = None,
|
558
946
|
remote_inference_description: Optional[str] = None,
|
@@ -587,28 +975,62 @@ class Jobs(Base):
|
|
587
975
|
|
588
976
|
self.verbose = verbose
|
589
977
|
|
590
|
-
|
978
|
+
if (
|
979
|
+
not self.user_has_all_model_keys()
|
980
|
+
and not self.user_has_ep_api_key()
|
981
|
+
and self.needs_external_llms()
|
982
|
+
):
|
983
|
+
import secrets
|
984
|
+
from dotenv import load_dotenv
|
985
|
+
from edsl import CONFIG
|
986
|
+
from edsl.coop.coop import Coop
|
987
|
+
from edsl.utilities.utilities import write_api_key_to_env
|
988
|
+
|
989
|
+
missing_api_keys = self.get_missing_api_keys()
|
591
990
|
|
592
|
-
|
991
|
+
edsl_auth_token = secrets.token_urlsafe(16)
|
593
992
|
|
594
|
-
|
595
|
-
|
596
|
-
|
993
|
+
print("You're missing some of the API keys needed to run this job:")
|
994
|
+
for api_key in missing_api_keys:
|
995
|
+
print(f" 🔑 {api_key}")
|
996
|
+
print(
|
997
|
+
"\nYou can either add the missing keys to your .env file, or use remote inference."
|
998
|
+
)
|
999
|
+
print("Remote inference allows you to run jobs on our server.")
|
1000
|
+
print("\n🚀 To use remote inference, sign up at the following link:")
|
1001
|
+
|
1002
|
+
coop = Coop()
|
1003
|
+
coop._display_login_url(edsl_auth_token=edsl_auth_token)
|
597
1004
|
|
598
|
-
|
1005
|
+
print(
|
1006
|
+
"\nOnce you log in, we will automatically retrieve your Expected Parrot API key and continue your job remotely."
|
1007
|
+
)
|
599
1008
|
|
600
|
-
|
601
|
-
|
602
|
-
|
1009
|
+
api_key = coop._poll_for_api_key(edsl_auth_token)
|
1010
|
+
|
1011
|
+
if api_key is None:
|
1012
|
+
print("\nTimed out waiting for login. Please try again.")
|
1013
|
+
return
|
1014
|
+
|
1015
|
+
write_api_key_to_env(api_key)
|
1016
|
+
print("✨ API key retrieved and written to .env file.\n")
|
1017
|
+
|
1018
|
+
# Retrieve API key so we can continue running the job
|
1019
|
+
load_dotenv()
|
1020
|
+
|
1021
|
+
if remote_inference := self.use_remote_inference(disable_remote_inference):
|
1022
|
+
remote_job_creation_data = self.create_remote_inference_job(
|
603
1023
|
iterations=n,
|
604
1024
|
remote_inference_description=remote_inference_description,
|
605
1025
|
remote_inference_results_visibility=remote_inference_results_visibility,
|
606
1026
|
)
|
607
|
-
results =
|
1027
|
+
results = self.poll_remote_inference_job(remote_job_creation_data)
|
1028
|
+
if results is None:
|
1029
|
+
self._output("Job failed.")
|
608
1030
|
return results
|
609
1031
|
|
610
1032
|
if check_api_keys:
|
611
|
-
|
1033
|
+
self.check_api_keys()
|
612
1034
|
|
613
1035
|
# handle cache
|
614
1036
|
if cache is None or cache is True:
|
@@ -638,9 +1060,46 @@ class Jobs(Base):
|
|
638
1060
|
raise_validation_errors=raise_validation_errors,
|
639
1061
|
)
|
640
1062
|
|
641
|
-
|
1063
|
+
results.cache = cache.new_entries_cache()
|
642
1064
|
return results
|
643
1065
|
|
1066
|
+
async def create_and_poll_remote_job(
|
1067
|
+
self,
|
1068
|
+
iterations: int = 1,
|
1069
|
+
remote_inference_description: Optional[str] = None,
|
1070
|
+
remote_inference_results_visibility: Optional[
|
1071
|
+
Literal["private", "public", "unlisted"]
|
1072
|
+
] = "unlisted",
|
1073
|
+
) -> Union[Results, None]:
|
1074
|
+
"""
|
1075
|
+
Creates and polls a remote inference job asynchronously.
|
1076
|
+
Reuses existing synchronous methods but runs them in an async context.
|
1077
|
+
|
1078
|
+
:param iterations: Number of times to run each interview
|
1079
|
+
:param remote_inference_description: Optional description for the remote job
|
1080
|
+
:param remote_inference_results_visibility: Visibility setting for results
|
1081
|
+
:return: Results object if successful, None if job fails or is cancelled
|
1082
|
+
"""
|
1083
|
+
import asyncio
|
1084
|
+
from functools import partial
|
1085
|
+
|
1086
|
+
# Create job using existing method
|
1087
|
+
loop = asyncio.get_event_loop()
|
1088
|
+
remote_job_creation_data = await loop.run_in_executor(
|
1089
|
+
None,
|
1090
|
+
partial(
|
1091
|
+
self.create_remote_inference_job,
|
1092
|
+
iterations=iterations,
|
1093
|
+
remote_inference_description=remote_inference_description,
|
1094
|
+
remote_inference_results_visibility=remote_inference_results_visibility,
|
1095
|
+
),
|
1096
|
+
)
|
1097
|
+
|
1098
|
+
# Poll using existing method but with async sleep
|
1099
|
+
return await loop.run_in_executor(
|
1100
|
+
None, partial(self.poll_remote_inference_job, remote_job_creation_data)
|
1101
|
+
)
|
1102
|
+
|
644
1103
|
async def run_async(
|
645
1104
|
self,
|
646
1105
|
cache=None,
|
@@ -663,15 +1122,14 @@ class Jobs(Base):
|
|
663
1122
|
:return: Results object
|
664
1123
|
"""
|
665
1124
|
# Check if we should use remote inference
|
666
|
-
|
667
|
-
|
668
|
-
jh = JobsRemoteInferenceHandler(self, verbose=False)
|
669
|
-
if jh.use_remote_inference(disable_remote_inference):
|
670
|
-
results = await jh.create_and_poll_remote_job(
|
1125
|
+
if remote_inference := self.use_remote_inference(disable_remote_inference):
|
1126
|
+
results = await self.create_and_poll_remote_job(
|
671
1127
|
iterations=n,
|
672
1128
|
remote_inference_description=remote_inference_description,
|
673
1129
|
remote_inference_results_visibility=remote_inference_results_visibility,
|
674
1130
|
)
|
1131
|
+
if results is None:
|
1132
|
+
self._output("Job failed.")
|
675
1133
|
return results
|
676
1134
|
|
677
1135
|
# If not using remote inference, run locally with async
|
@@ -691,22 +1149,24 @@ class Jobs(Base):
|
|
691
1149
|
"""
|
692
1150
|
return set.union(*[question.parameters for question in self.survey.questions])
|
693
1151
|
|
1152
|
+
#######################
|
1153
|
+
# Dunder methods
|
1154
|
+
#######################
|
1155
|
+
def print(self):
|
1156
|
+
from rich import print_json
|
1157
|
+
import json
|
1158
|
+
|
1159
|
+
print_json(json.dumps(self.to_dict()))
|
1160
|
+
|
694
1161
|
def __repr__(self) -> str:
|
695
1162
|
"""Return an eval-able string representation of the Jobs instance."""
|
696
1163
|
return f"Jobs(survey={repr(self.survey)}, agents={repr(self.agents)}, models={repr(self.models)}, scenarios={repr(self.scenarios)})"
|
697
1164
|
|
698
|
-
def _summary(self):
|
699
|
-
return {
|
700
|
-
"EDSL Class": "Jobs",
|
701
|
-
"Number of questions": len(self.survey),
|
702
|
-
"Number of agents": len(self.agents),
|
703
|
-
"Number of models": len(self.models),
|
704
|
-
"Number of scenarios": len(self.scenarios),
|
705
|
-
}
|
706
|
-
|
707
1165
|
def _repr_html_(self) -> str:
|
708
|
-
|
709
|
-
|
1166
|
+
from rich import print_json
|
1167
|
+
import json
|
1168
|
+
|
1169
|
+
print_json(json.dumps(self.to_dict()))
|
710
1170
|
|
711
1171
|
def __len__(self) -> int:
|
712
1172
|
"""Return the maximum number of questions that will be asked while running this job.
|
@@ -776,7 +1236,7 @@ class Jobs(Base):
|
|
776
1236
|
True
|
777
1237
|
|
778
1238
|
"""
|
779
|
-
return
|
1239
|
+
return self.to_dict() == other.to_dict()
|
780
1240
|
|
781
1241
|
#######################
|
782
1242
|
# Example methods
|