edsl 0.1.35__py3-none-any.whl → 0.1.36__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +5 -0
- edsl/__init__.py +1 -0
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +37 -9
- edsl/agents/Invigilator.py +2 -1
- edsl/agents/InvigilatorBase.py +5 -1
- edsl/agents/PromptConstructor.py +31 -67
- edsl/conversation/Conversation.py +1 -1
- edsl/coop/PriceFetcher.py +14 -18
- edsl/coop/coop.py +42 -8
- edsl/data/RemoteCacheSync.py +97 -0
- edsl/exceptions/coop.py +8 -0
- edsl/inference_services/InferenceServiceABC.py +28 -0
- edsl/inference_services/InferenceServicesCollection.py +10 -4
- edsl/inference_services/models_available_cache.py +25 -1
- edsl/inference_services/registry.py +24 -16
- edsl/jobs/Jobs.py +327 -206
- edsl/jobs/interviews/Interview.py +65 -10
- edsl/jobs/interviews/InterviewExceptionCollection.py +9 -0
- edsl/jobs/interviews/InterviewExceptionEntry.py +31 -9
- edsl/jobs/runners/JobsRunnerAsyncio.py +8 -13
- edsl/jobs/tasks/QuestionTaskCreator.py +1 -5
- edsl/jobs/tasks/TaskHistory.py +23 -7
- edsl/language_models/LanguageModel.py +3 -0
- edsl/prompts/Prompt.py +24 -38
- edsl/prompts/__init__.py +1 -1
- edsl/questions/QuestionBasePromptsMixin.py +18 -18
- edsl/questions/QuestionFunctional.py +7 -3
- edsl/questions/descriptors.py +24 -24
- edsl/results/Dataset.py +12 -0
- edsl/results/Result.py +2 -0
- edsl/results/Results.py +13 -1
- edsl/scenarios/FileStore.py +20 -5
- edsl/scenarios/Scenario.py +15 -1
- edsl/scenarios/__init__.py +2 -0
- edsl/surveys/Survey.py +3 -0
- edsl/surveys/instructions/Instruction.py +20 -3
- {edsl-0.1.35.dist-info → edsl-0.1.36.dist-info}/METADATA +1 -1
- {edsl-0.1.35.dist-info → edsl-0.1.36.dist-info}/RECORD +41 -57
- edsl/jobs/FailedQuestion.py +0 -78
- edsl/jobs/interviews/InterviewStatusMixin.py +0 -33
- edsl/jobs/tasks/task_management.py +0 -13
- edsl/prompts/QuestionInstructionsBase.py +0 -10
- edsl/prompts/library/agent_instructions.py +0 -38
- edsl/prompts/library/agent_persona.py +0 -21
- edsl/prompts/library/question_budget.py +0 -30
- edsl/prompts/library/question_checkbox.py +0 -38
- edsl/prompts/library/question_extract.py +0 -23
- edsl/prompts/library/question_freetext.py +0 -18
- edsl/prompts/library/question_linear_scale.py +0 -24
- edsl/prompts/library/question_list.py +0 -26
- edsl/prompts/library/question_multiple_choice.py +0 -54
- edsl/prompts/library/question_numerical.py +0 -35
- edsl/prompts/library/question_rank.py +0 -25
- edsl/prompts/prompt_config.py +0 -37
- edsl/prompts/registry.py +0 -202
- {edsl-0.1.35.dist-info → edsl-0.1.36.dist-info}/LICENSE +0 -0
- {edsl-0.1.35.dist-info → edsl-0.1.36.dist-info}/WHEEL +0 -0
@@ -65,7 +65,31 @@ models_available = {
|
|
65
65
|
"meta-llama/Meta-Llama-3-70B-Instruct",
|
66
66
|
"openchat/openchat_3.5",
|
67
67
|
],
|
68
|
-
"google": [
|
68
|
+
"google": [
|
69
|
+
"gemini-1.0-pro",
|
70
|
+
"gemini-1.0-pro-001",
|
71
|
+
"gemini-1.0-pro-latest",
|
72
|
+
"gemini-1.0-pro-vision-latest",
|
73
|
+
"gemini-1.5-flash",
|
74
|
+
"gemini-1.5-flash-001",
|
75
|
+
"gemini-1.5-flash-001-tuning",
|
76
|
+
"gemini-1.5-flash-002",
|
77
|
+
"gemini-1.5-flash-8b",
|
78
|
+
"gemini-1.5-flash-8b-001",
|
79
|
+
"gemini-1.5-flash-8b-exp-0827",
|
80
|
+
"gemini-1.5-flash-8b-exp-0924",
|
81
|
+
"gemini-1.5-flash-8b-latest",
|
82
|
+
"gemini-1.5-flash-exp-0827",
|
83
|
+
"gemini-1.5-flash-latest",
|
84
|
+
"gemini-1.5-pro",
|
85
|
+
"gemini-1.5-pro-001",
|
86
|
+
"gemini-1.5-pro-002",
|
87
|
+
"gemini-1.5-pro-exp-0801",
|
88
|
+
"gemini-1.5-pro-exp-0827",
|
89
|
+
"gemini-1.5-pro-latest",
|
90
|
+
"gemini-pro",
|
91
|
+
"gemini-pro-vision",
|
92
|
+
],
|
69
93
|
"bedrock": [
|
70
94
|
"amazon.titan-tg1-large",
|
71
95
|
"amazon.titan-text-lite-v1",
|
@@ -11,21 +11,29 @@ from edsl.inference_services.AwsBedrock import AwsBedrockService
|
|
11
11
|
from edsl.inference_services.AzureAI import AzureAIService
|
12
12
|
from edsl.inference_services.OllamaService import OllamaService
|
13
13
|
from edsl.inference_services.TestService import TestService
|
14
|
-
from edsl.inference_services.MistralAIService import MistralAIService
|
15
14
|
from edsl.inference_services.TogetherAIService import TogetherAIService
|
16
15
|
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
16
|
+
try:
|
17
|
+
from edsl.inference_services.MistralAIService import MistralAIService
|
18
|
+
|
19
|
+
mistral_available = True
|
20
|
+
except Exception as e:
|
21
|
+
mistral_available = False
|
22
|
+
|
23
|
+
services = [
|
24
|
+
OpenAIService,
|
25
|
+
AnthropicService,
|
26
|
+
DeepInfraService,
|
27
|
+
GoogleService,
|
28
|
+
GroqService,
|
29
|
+
AwsBedrockService,
|
30
|
+
AzureAIService,
|
31
|
+
OllamaService,
|
32
|
+
TestService,
|
33
|
+
TogetherAIService,
|
34
|
+
]
|
35
|
+
|
36
|
+
if mistral_available:
|
37
|
+
services.append(MistralAIService)
|
38
|
+
|
39
|
+
default = InferenceServicesCollection(services)
|
edsl/jobs/Jobs.py
CHANGED
@@ -1,8 +1,10 @@
|
|
1
1
|
# """The Jobs class is a collection of agents, scenarios and models and one survey."""
|
2
2
|
from __future__ import annotations
|
3
3
|
import warnings
|
4
|
+
import requests
|
4
5
|
from itertools import product
|
5
6
|
from typing import Optional, Union, Sequence, Generator
|
7
|
+
|
6
8
|
from edsl.Base import Base
|
7
9
|
from edsl.exceptions import MissingAPIKeyError
|
8
10
|
from edsl.jobs.buckets.BucketCollection import BucketCollection
|
@@ -10,6 +12,9 @@ from edsl.jobs.interviews.Interview import Interview
|
|
10
12
|
from edsl.jobs.runners.JobsRunnerAsyncio import JobsRunnerAsyncio
|
11
13
|
from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
|
12
14
|
|
15
|
+
from edsl.data.RemoteCacheSync import RemoteCacheSync
|
16
|
+
from edsl.exceptions.coop import CoopServerResponseError
|
17
|
+
|
13
18
|
|
14
19
|
class Jobs(Base):
|
15
20
|
"""
|
@@ -180,17 +185,15 @@ class Jobs(Base):
|
|
180
185
|
scenario_indices.append(scenario_index)
|
181
186
|
models.append(invigilator.model.model)
|
182
187
|
question_names.append(invigilator.question.question_name)
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
) + output_tokens / float(inverse_output_price)
|
193
|
-
costs.append(cost)
|
188
|
+
|
189
|
+
prompt_cost = self.estimate_prompt_cost(
|
190
|
+
system_prompt=system_prompt,
|
191
|
+
user_prompt=user_prompt,
|
192
|
+
price_lookup=price_lookup,
|
193
|
+
inference_service=invigilator.model._inference_service_,
|
194
|
+
model=invigilator.model.model,
|
195
|
+
)
|
196
|
+
costs.append(prompt_cost["cost"])
|
194
197
|
|
195
198
|
d = Dataset(
|
196
199
|
[
|
@@ -205,59 +208,195 @@ class Jobs(Base):
|
|
205
208
|
]
|
206
209
|
)
|
207
210
|
return d
|
208
|
-
# if table:
|
209
|
-
# d.to_scenario_list().print(format="rich")
|
210
|
-
# else:
|
211
|
-
# return d
|
212
211
|
|
213
|
-
def show_prompts(self) -> None:
|
212
|
+
def show_prompts(self, all=False) -> None:
|
214
213
|
"""Print the prompts."""
|
215
|
-
|
214
|
+
if all:
|
215
|
+
self.prompts().to_scenario_list().print(format="rich")
|
216
|
+
else:
|
217
|
+
self.prompts().select(
|
218
|
+
"user_prompt", "system_prompt"
|
219
|
+
).to_scenario_list().print(format="rich")
|
220
|
+
|
221
|
+
@staticmethod
|
222
|
+
def estimate_prompt_cost(
|
223
|
+
system_prompt: str,
|
224
|
+
user_prompt: str,
|
225
|
+
price_lookup: dict,
|
226
|
+
inference_service: str,
|
227
|
+
model: str,
|
228
|
+
) -> dict:
|
229
|
+
"""Estimates the cost of a prompt. Takes piping into account."""
|
230
|
+
|
231
|
+
def get_piping_multiplier(prompt: str):
|
232
|
+
"""Returns 2 if a prompt includes Jinja braces, and 1 otherwise."""
|
233
|
+
|
234
|
+
if "{{" in prompt and "}}" in prompt:
|
235
|
+
return 2
|
236
|
+
return 1
|
237
|
+
|
238
|
+
# Look up prices per token
|
239
|
+
key = (inference_service, model)
|
240
|
+
|
241
|
+
try:
|
242
|
+
relevant_prices = price_lookup[key]
|
243
|
+
output_price_per_token = 1 / float(
|
244
|
+
relevant_prices["output"]["one_usd_buys"]
|
245
|
+
)
|
246
|
+
input_price_per_token = 1 / float(relevant_prices["input"]["one_usd_buys"])
|
247
|
+
except KeyError:
|
248
|
+
# A KeyError is likely to occur if we cannot retrieve prices (the price_lookup dict is empty)
|
249
|
+
# Use a sensible default
|
250
|
+
|
251
|
+
import warnings
|
252
|
+
|
253
|
+
warnings.warn(
|
254
|
+
"Price data could not be retrieved. Using default estimates for input and output token prices. Input: $0.15 / 1M tokens; Output: $0.60 / 1M tokens"
|
255
|
+
)
|
256
|
+
|
257
|
+
output_price_per_token = 0.00000015 # $0.15 / 1M tokens
|
258
|
+
input_price_per_token = 0.00000060 # $0.60 / 1M tokens
|
259
|
+
|
260
|
+
# Compute the number of characters (double if the question involves piping)
|
261
|
+
user_prompt_chars = len(str(user_prompt)) * get_piping_multiplier(
|
262
|
+
str(user_prompt)
|
263
|
+
)
|
264
|
+
system_prompt_chars = len(str(system_prompt)) * get_piping_multiplier(
|
265
|
+
str(system_prompt)
|
266
|
+
)
|
267
|
+
|
268
|
+
# Convert into tokens (1 token approx. equals 4 characters)
|
269
|
+
input_tokens = (user_prompt_chars + system_prompt_chars) // 4
|
270
|
+
output_tokens = input_tokens
|
271
|
+
|
272
|
+
cost = (
|
273
|
+
input_tokens * input_price_per_token
|
274
|
+
+ output_tokens * output_price_per_token
|
275
|
+
)
|
276
|
+
|
277
|
+
return {
|
278
|
+
"input_tokens": input_tokens,
|
279
|
+
"output_tokens": output_tokens,
|
280
|
+
"cost": cost,
|
281
|
+
}
|
282
|
+
|
283
|
+
def estimate_job_cost_from_external_prices(self, price_lookup: dict) -> dict:
|
284
|
+
"""
|
285
|
+
Estimates the cost of a job according to the following assumptions:
|
286
|
+
|
287
|
+
- 1 token = 4 characters.
|
288
|
+
- Input tokens = output tokens.
|
289
|
+
|
290
|
+
price_lookup is an external pricing dictionary.
|
291
|
+
"""
|
292
|
+
|
293
|
+
import pandas as pd
|
294
|
+
|
295
|
+
interviews = self.interviews()
|
296
|
+
data = []
|
297
|
+
for interview in interviews:
|
298
|
+
invigilators = [
|
299
|
+
interview._get_invigilator(question)
|
300
|
+
for question in self.survey.questions
|
301
|
+
]
|
302
|
+
for invigilator in invigilators:
|
303
|
+
prompts = invigilator.get_prompts()
|
216
304
|
|
217
|
-
|
305
|
+
# By this point, agent and scenario data has already been added to the prompts
|
306
|
+
user_prompt = prompts["user_prompt"]
|
307
|
+
system_prompt = prompts["system_prompt"]
|
308
|
+
inference_service = invigilator.model._inference_service_
|
309
|
+
model = invigilator.model.model
|
310
|
+
|
311
|
+
prompt_cost = self.estimate_prompt_cost(
|
312
|
+
system_prompt=system_prompt,
|
313
|
+
user_prompt=user_prompt,
|
314
|
+
price_lookup=price_lookup,
|
315
|
+
inference_service=inference_service,
|
316
|
+
model=model,
|
317
|
+
)
|
318
|
+
|
319
|
+
data.append(
|
320
|
+
{
|
321
|
+
"user_prompt": user_prompt,
|
322
|
+
"system_prompt": system_prompt,
|
323
|
+
"estimated_input_tokens": prompt_cost["input_tokens"],
|
324
|
+
"estimated_output_tokens": prompt_cost["output_tokens"],
|
325
|
+
"estimated_cost": prompt_cost["cost"],
|
326
|
+
"inference_service": inference_service,
|
327
|
+
"model": model,
|
328
|
+
}
|
329
|
+
)
|
330
|
+
|
331
|
+
df = pd.DataFrame.from_records(data)
|
332
|
+
|
333
|
+
df = (
|
334
|
+
df.groupby(["inference_service", "model"])
|
335
|
+
.agg(
|
336
|
+
{
|
337
|
+
"estimated_cost": "sum",
|
338
|
+
"estimated_input_tokens": "sum",
|
339
|
+
"estimated_output_tokens": "sum",
|
340
|
+
}
|
341
|
+
)
|
342
|
+
.reset_index()
|
343
|
+
)
|
344
|
+
|
345
|
+
estimated_costs_by_model = df.to_dict("records")
|
346
|
+
|
347
|
+
estimated_total_cost = sum(
|
348
|
+
model["estimated_cost"] for model in estimated_costs_by_model
|
349
|
+
)
|
350
|
+
estimated_total_input_tokens = sum(
|
351
|
+
model["estimated_input_tokens"] for model in estimated_costs_by_model
|
352
|
+
)
|
353
|
+
estimated_total_output_tokens = sum(
|
354
|
+
model["estimated_output_tokens"] for model in estimated_costs_by_model
|
355
|
+
)
|
356
|
+
|
357
|
+
output = {
|
358
|
+
"estimated_total_cost": estimated_total_cost,
|
359
|
+
"estimated_total_input_tokens": estimated_total_input_tokens,
|
360
|
+
"estimated_total_output_tokens": estimated_total_output_tokens,
|
361
|
+
"model_costs": estimated_costs_by_model,
|
362
|
+
}
|
363
|
+
|
364
|
+
return output
|
365
|
+
|
366
|
+
def estimate_job_cost(self) -> dict:
|
367
|
+
"""
|
368
|
+
Estimates the cost of a job according to the following assumptions:
|
369
|
+
|
370
|
+
- 1 token = 4 characters.
|
371
|
+
- Input tokens = output tokens.
|
372
|
+
|
373
|
+
Fetches prices from Coop.
|
374
|
+
"""
|
218
375
|
from edsl import Coop
|
219
376
|
|
220
377
|
c = Coop()
|
221
378
|
price_lookup = c.fetch_prices()
|
222
379
|
|
223
|
-
|
380
|
+
return self.estimate_job_cost_from_external_prices(price_lookup=price_lookup)
|
224
381
|
|
225
|
-
|
226
|
-
|
227
|
-
|
382
|
+
@staticmethod
|
383
|
+
def compute_job_cost(job_results: "Results") -> float:
|
384
|
+
"""
|
385
|
+
Computes the cost of a completed job in USD.
|
386
|
+
"""
|
387
|
+
total_cost = 0
|
388
|
+
for result in job_results:
|
389
|
+
for key in result.raw_model_response:
|
390
|
+
if key.endswith("_cost"):
|
391
|
+
result_cost = result.raw_model_response[key]
|
228
392
|
|
229
|
-
|
393
|
+
question_name = key.removesuffix("_cost")
|
394
|
+
cache_used = result.cache_used_dict[question_name]
|
230
395
|
|
231
|
-
|
232
|
-
|
233
|
-
for model in self.models:
|
234
|
-
key = (model._inference_service_, model.model)
|
235
|
-
relevant_prices = price_lookup[key]
|
236
|
-
inverse_output_price = relevant_prices["output"]["one_usd_buys"]
|
237
|
-
inverse_input_price = relevant_prices["input"]["one_usd_buys"]
|
238
|
-
|
239
|
-
aproximation_cost[key] = {
|
240
|
-
"input": input_token_aproximations / float(inverse_input_price),
|
241
|
-
"output": input_token_aproximations / float(inverse_output_price),
|
242
|
-
}
|
243
|
-
##TODO curenlty we approximate the number of output tokens with the number
|
244
|
-
# of input tokens. A better solution will be to compute the quesiton answer options length and sum them
|
245
|
-
# to compute the output tokens
|
246
|
-
|
247
|
-
total_cost += input_token_aproximations / float(inverse_input_price)
|
248
|
-
total_cost += input_token_aproximations / float(inverse_output_price)
|
249
|
-
|
250
|
-
# multiply_factor = len(self.agents or [1]) * len(self.scenarios or [1])
|
251
|
-
multiply_factor = 1
|
252
|
-
out = {
|
253
|
-
"input_token_aproximations": input_token_aproximations,
|
254
|
-
"models_costs": aproximation_cost,
|
255
|
-
"estimated_total_cost": total_cost * multiply_factor,
|
256
|
-
"multiply_factor": multiply_factor,
|
257
|
-
"single_config_cost": total_cost,
|
258
|
-
}
|
396
|
+
if isinstance(result_cost, (int, float)) and not cache_used:
|
397
|
+
total_cost += result_cost
|
259
398
|
|
260
|
-
return
|
399
|
+
return total_cost
|
261
400
|
|
262
401
|
@staticmethod
|
263
402
|
def _get_container_class(object):
|
@@ -482,7 +621,7 @@ class Jobs(Base):
|
|
482
621
|
|
483
622
|
def _output(self, message) -> None:
|
484
623
|
"""Check if a Job is verbose. If so, print the message."""
|
485
|
-
if self.verbose:
|
624
|
+
if hasattr(self, "verbose") and self.verbose:
|
486
625
|
print(message)
|
487
626
|
|
488
627
|
def _check_parameters(self, strict=False, warn=False) -> None:
|
@@ -559,6 +698,123 @@ class Jobs(Base):
|
|
559
698
|
return False
|
560
699
|
return self._raise_validation_errors
|
561
700
|
|
701
|
+
def create_remote_inference_job(
|
702
|
+
self, iterations: int = 1, remote_inference_description: Optional[str] = None
|
703
|
+
):
|
704
|
+
""" """
|
705
|
+
from edsl.coop.coop import Coop
|
706
|
+
|
707
|
+
coop = Coop()
|
708
|
+
self._output("Remote inference activated. Sending job to server...")
|
709
|
+
remote_job_creation_data = coop.remote_inference_create(
|
710
|
+
self,
|
711
|
+
description=remote_inference_description,
|
712
|
+
status="queued",
|
713
|
+
iterations=iterations,
|
714
|
+
)
|
715
|
+
job_uuid = remote_job_creation_data.get("uuid")
|
716
|
+
print(f"Job sent to server. (Job uuid={job_uuid}).")
|
717
|
+
return remote_job_creation_data
|
718
|
+
|
719
|
+
@staticmethod
|
720
|
+
def check_status(job_uuid):
|
721
|
+
from edsl.coop.coop import Coop
|
722
|
+
|
723
|
+
coop = Coop()
|
724
|
+
return coop.remote_inference_get(job_uuid)
|
725
|
+
|
726
|
+
def poll_remote_inference_job(
|
727
|
+
self, remote_job_creation_data: dict
|
728
|
+
) -> Union[Results, None]:
|
729
|
+
from edsl.coop.coop import Coop
|
730
|
+
import time
|
731
|
+
from datetime import datetime
|
732
|
+
from edsl.config import CONFIG
|
733
|
+
|
734
|
+
expected_parrot_url = CONFIG.get("EXPECTED_PARROT_URL")
|
735
|
+
|
736
|
+
job_uuid = remote_job_creation_data.get("uuid")
|
737
|
+
|
738
|
+
coop = Coop()
|
739
|
+
job_in_queue = True
|
740
|
+
while job_in_queue:
|
741
|
+
remote_job_data = coop.remote_inference_get(job_uuid)
|
742
|
+
status = remote_job_data.get("status")
|
743
|
+
if status == "cancelled":
|
744
|
+
print("\r" + " " * 80 + "\r", end="")
|
745
|
+
print("Job cancelled by the user.")
|
746
|
+
print(
|
747
|
+
f"See {expected_parrot_url}/home/remote-inference for more details."
|
748
|
+
)
|
749
|
+
return None
|
750
|
+
elif status == "failed":
|
751
|
+
print("\r" + " " * 80 + "\r", end="")
|
752
|
+
print("Job failed.")
|
753
|
+
print(
|
754
|
+
f"See {expected_parrot_url}/home/remote-inference for more details."
|
755
|
+
)
|
756
|
+
return None
|
757
|
+
elif status == "completed":
|
758
|
+
results_uuid = remote_job_data.get("results_uuid")
|
759
|
+
results = coop.get(results_uuid, expected_object_type="results")
|
760
|
+
print("\r" + " " * 80 + "\r", end="")
|
761
|
+
url = f"{expected_parrot_url}/content/{results_uuid}"
|
762
|
+
print(f"Job completed and Results stored on Coop: {url}.")
|
763
|
+
return results
|
764
|
+
else:
|
765
|
+
duration = 5
|
766
|
+
time_checked = datetime.now().strftime("%Y-%m-%d %I:%M:%S %p")
|
767
|
+
frames = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
|
768
|
+
start_time = time.time()
|
769
|
+
i = 0
|
770
|
+
while time.time() - start_time < duration:
|
771
|
+
print(
|
772
|
+
f"\r{frames[i % len(frames)]} Job status: {status} - last update: {time_checked}",
|
773
|
+
end="",
|
774
|
+
flush=True,
|
775
|
+
)
|
776
|
+
time.sleep(0.1)
|
777
|
+
i += 1
|
778
|
+
|
779
|
+
def use_remote_inference(self, disable_remote_inference: bool):
|
780
|
+
if disable_remote_inference:
|
781
|
+
return False
|
782
|
+
if not disable_remote_inference:
|
783
|
+
try:
|
784
|
+
from edsl import Coop
|
785
|
+
|
786
|
+
user_edsl_settings = Coop().edsl_settings
|
787
|
+
return user_edsl_settings.get("remote_inference", False)
|
788
|
+
except requests.ConnectionError:
|
789
|
+
pass
|
790
|
+
except CoopServerResponseError as e:
|
791
|
+
pass
|
792
|
+
|
793
|
+
return False
|
794
|
+
|
795
|
+
def use_remote_cache(self):
|
796
|
+
try:
|
797
|
+
from edsl import Coop
|
798
|
+
|
799
|
+
user_edsl_settings = Coop().edsl_settings
|
800
|
+
return user_edsl_settings.get("remote_caching", False)
|
801
|
+
except requests.ConnectionError:
|
802
|
+
pass
|
803
|
+
except CoopServerResponseError as e:
|
804
|
+
pass
|
805
|
+
|
806
|
+
return False
|
807
|
+
|
808
|
+
def check_api_keys(self):
|
809
|
+
from edsl import Model
|
810
|
+
|
811
|
+
for model in self.models + [Model()]:
|
812
|
+
if not model.has_valid_api_key():
|
813
|
+
raise MissingAPIKeyError(
|
814
|
+
model_name=str(model.model),
|
815
|
+
inference_service=model._inference_service_,
|
816
|
+
)
|
817
|
+
|
562
818
|
def run(
|
563
819
|
self,
|
564
820
|
n: int = 1,
|
@@ -596,91 +852,17 @@ class Jobs(Base):
|
|
596
852
|
|
597
853
|
self.verbose = verbose
|
598
854
|
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
if not disable_remote_inference:
|
603
|
-
try:
|
604
|
-
coop = Coop()
|
605
|
-
user_edsl_settings = Coop().edsl_settings
|
606
|
-
remote_cache = user_edsl_settings.get("remote_caching", False)
|
607
|
-
remote_inference = user_edsl_settings.get("remote_inference", False)
|
608
|
-
except Exception:
|
609
|
-
pass
|
610
|
-
|
611
|
-
if remote_inference:
|
612
|
-
import time
|
613
|
-
from datetime import datetime
|
614
|
-
from edsl.config import CONFIG
|
615
|
-
|
616
|
-
expected_parrot_url = CONFIG.get("EXPECTED_PARROT_URL")
|
617
|
-
|
618
|
-
self._output("Remote inference activated. Sending job to server...")
|
619
|
-
if remote_cache:
|
620
|
-
self._output(
|
621
|
-
"Remote caching activated. The remote cache will be used for this job."
|
622
|
-
)
|
623
|
-
|
624
|
-
remote_job_creation_data = coop.remote_inference_create(
|
625
|
-
self,
|
626
|
-
description=remote_inference_description,
|
627
|
-
status="queued",
|
628
|
-
iterations=n,
|
855
|
+
if remote_inference := self.use_remote_inference(disable_remote_inference):
|
856
|
+
remote_job_creation_data = self.create_remote_inference_job(
|
857
|
+
iterations=n, remote_inference_description=remote_inference_description
|
629
858
|
)
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
job_in_queue = True
|
635
|
-
while job_in_queue:
|
636
|
-
remote_job_data = coop.remote_inference_get(job_uuid)
|
637
|
-
status = remote_job_data.get("status")
|
638
|
-
if status == "cancelled":
|
639
|
-
print("\r" + " " * 80 + "\r", end="")
|
640
|
-
print("Job cancelled by the user.")
|
641
|
-
print(
|
642
|
-
f"See {expected_parrot_url}/home/remote-inference for more details."
|
643
|
-
)
|
644
|
-
return None
|
645
|
-
elif status == "failed":
|
646
|
-
print("\r" + " " * 80 + "\r", end="")
|
647
|
-
print("Job failed.")
|
648
|
-
print(
|
649
|
-
f"See {expected_parrot_url}/home/remote-inference for more details."
|
650
|
-
)
|
651
|
-
return None
|
652
|
-
elif status == "completed":
|
653
|
-
results_uuid = remote_job_data.get("results_uuid")
|
654
|
-
results = coop.get(results_uuid, expected_object_type="results")
|
655
|
-
print("\r" + " " * 80 + "\r", end="")
|
656
|
-
print(
|
657
|
-
f"Job completed and Results stored on Coop (Results uuid={results_uuid})."
|
658
|
-
)
|
659
|
-
return results
|
660
|
-
else:
|
661
|
-
duration = 5
|
662
|
-
time_checked = datetime.now().strftime("%Y-%m-%d %I:%M:%S %p")
|
663
|
-
frames = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
|
664
|
-
start_time = time.time()
|
665
|
-
i = 0
|
666
|
-
while time.time() - start_time < duration:
|
667
|
-
print(
|
668
|
-
f"\r{frames[i % len(frames)]} Job status: {status} - last update: {time_checked}",
|
669
|
-
end="",
|
670
|
-
flush=True,
|
671
|
-
)
|
672
|
-
time.sleep(0.1)
|
673
|
-
i += 1
|
674
|
-
else:
|
675
|
-
if check_api_keys:
|
676
|
-
from edsl import Model
|
859
|
+
results = self.poll_remote_inference_job(remote_job_creation_data)
|
860
|
+
if results is None:
|
861
|
+
self._output("Job failed.")
|
862
|
+
return results
|
677
863
|
|
678
|
-
|
679
|
-
|
680
|
-
raise MissingAPIKeyError(
|
681
|
-
model_name=str(model.model),
|
682
|
-
inference_service=model._inference_service_,
|
683
|
-
)
|
864
|
+
if check_api_keys:
|
865
|
+
self.check_api_keys()
|
684
866
|
|
685
867
|
# handle cache
|
686
868
|
if cache is None or cache is True:
|
@@ -692,51 +874,14 @@ class Jobs(Base):
|
|
692
874
|
|
693
875
|
cache = Cache()
|
694
876
|
|
695
|
-
|
696
|
-
|
697
|
-
|
698
|
-
|
699
|
-
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
raise_validation_errors=raise_validation_errors,
|
704
|
-
)
|
705
|
-
|
706
|
-
results.cache = cache.new_entries_cache()
|
707
|
-
|
708
|
-
self._output(f"There are {len(cache.keys()):,} entries in the local cache.")
|
709
|
-
else:
|
710
|
-
cache_difference = coop.remote_cache_get_diff(cache.keys())
|
711
|
-
|
712
|
-
client_missing_cacheentries = cache_difference.get(
|
713
|
-
"client_missing_cacheentries", []
|
714
|
-
)
|
715
|
-
|
716
|
-
missing_entry_count = len(client_missing_cacheentries)
|
717
|
-
if missing_entry_count > 0:
|
718
|
-
self._output(
|
719
|
-
f"Updating local cache with {missing_entry_count:,} new "
|
720
|
-
f"{'entry' if missing_entry_count == 1 else 'entries'} from remote..."
|
721
|
-
)
|
722
|
-
cache.add_from_dict(
|
723
|
-
{entry.key: entry for entry in client_missing_cacheentries}
|
724
|
-
)
|
725
|
-
self._output("Local cache updated!")
|
726
|
-
else:
|
727
|
-
self._output("No new entries to add to local cache.")
|
728
|
-
|
729
|
-
server_missing_cacheentry_keys = cache_difference.get(
|
730
|
-
"server_missing_cacheentry_keys", []
|
731
|
-
)
|
732
|
-
server_missing_cacheentries = [
|
733
|
-
entry
|
734
|
-
for key in server_missing_cacheentry_keys
|
735
|
-
if (entry := cache.data.get(key)) is not None
|
736
|
-
]
|
737
|
-
old_entry_keys = [key for key in cache.keys()]
|
738
|
-
|
739
|
-
self._output("Running job...")
|
877
|
+
remote_cache = self.use_remote_cache()
|
878
|
+
with RemoteCacheSync(
|
879
|
+
coop=Coop(),
|
880
|
+
cache=cache,
|
881
|
+
output_func=self._output,
|
882
|
+
remote_cache=remote_cache,
|
883
|
+
remote_cache_description=remote_cache_description,
|
884
|
+
) as r:
|
740
885
|
results = self._run_local(
|
741
886
|
n=n,
|
742
887
|
progress_bar=progress_bar,
|
@@ -746,32 +891,8 @@ class Jobs(Base):
|
|
746
891
|
print_exceptions=print_exceptions,
|
747
892
|
raise_validation_errors=raise_validation_errors,
|
748
893
|
)
|
749
|
-
self._output("Job completed!")
|
750
|
-
|
751
|
-
new_cache_entries = list(
|
752
|
-
[entry for entry in cache.values() if entry.key not in old_entry_keys]
|
753
|
-
)
|
754
|
-
server_missing_cacheentries.extend(new_cache_entries)
|
755
|
-
|
756
|
-
new_entry_count = len(server_missing_cacheentries)
|
757
|
-
if new_entry_count > 0:
|
758
|
-
self._output(
|
759
|
-
f"Updating remote cache with {new_entry_count:,} new "
|
760
|
-
f"{'entry' if new_entry_count == 1 else 'entries'}..."
|
761
|
-
)
|
762
|
-
coop.remote_cache_create_many(
|
763
|
-
server_missing_cacheentries,
|
764
|
-
visibility="private",
|
765
|
-
description=remote_cache_description,
|
766
|
-
)
|
767
|
-
self._output("Remote cache updated!")
|
768
|
-
else:
|
769
|
-
self._output("No new entries to add to remote cache.")
|
770
|
-
|
771
|
-
results.cache = cache.new_entries_cache()
|
772
|
-
|
773
|
-
self._output(f"There are {len(cache.keys()):,} entries in the local cache.")
|
774
894
|
|
895
|
+
results.cache = cache.new_entries_cache()
|
775
896
|
return results
|
776
897
|
|
777
898
|
def _run_local(self, *args, **kwargs):
|