edsl 0.1.37__py3-none-any.whl → 0.1.37.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +35 -86
- edsl/agents/AgentList.py +0 -5
- edsl/agents/InvigilatorBase.py +23 -2
- edsl/agents/PromptConstructor.py +105 -148
- edsl/agents/descriptors.py +4 -17
- edsl/conjure/AgentConstructionMixin.py +3 -11
- edsl/conversation/Conversation.py +14 -66
- edsl/coop/coop.py +14 -148
- edsl/data/Cache.py +1 -1
- edsl/exceptions/__init__.py +3 -7
- edsl/exceptions/agents.py +19 -17
- edsl/exceptions/results.py +8 -11
- edsl/exceptions/surveys.py +10 -13
- edsl/inference_services/AwsBedrock.py +2 -7
- edsl/inference_services/InferenceServicesCollection.py +9 -32
- edsl/jobs/Jobs.py +71 -306
- edsl/jobs/interviews/InterviewExceptionEntry.py +1 -5
- edsl/jobs/tasks/TaskHistory.py +0 -1
- edsl/language_models/LanguageModel.py +59 -47
- edsl/language_models/__init__.py +0 -1
- edsl/prompts/Prompt.py +4 -11
- edsl/questions/QuestionBase.py +13 -53
- edsl/questions/QuestionBasePromptsMixin.py +33 -1
- edsl/questions/QuestionFreeText.py +0 -1
- edsl/questions/QuestionFunctional.py +2 -2
- edsl/questions/descriptors.py +28 -23
- edsl/results/DatasetExportMixin.py +1 -25
- edsl/results/Result.py +1 -16
- edsl/results/Results.py +120 -31
- edsl/results/ResultsDBMixin.py +1 -1
- edsl/results/Selector.py +1 -18
- edsl/scenarios/Scenario.py +12 -48
- edsl/scenarios/ScenarioHtmlMixin.py +2 -7
- edsl/scenarios/ScenarioList.py +1 -12
- edsl/surveys/Rule.py +4 -10
- edsl/surveys/Survey.py +77 -100
- edsl/utilities/utilities.py +0 -18
- {edsl-0.1.37.dist-info → edsl-0.1.37.dev1.dist-info}/METADATA +1 -1
- {edsl-0.1.37.dist-info → edsl-0.1.37.dev1.dist-info}/RECORD +42 -46
- edsl/conversation/chips.py +0 -95
- edsl/exceptions/BaseException.py +0 -21
- edsl/exceptions/scenarios.py +0 -22
- edsl/language_models/KeyLookup.py +0 -30
- {edsl-0.1.37.dist-info → edsl-0.1.37.dev1.dist-info}/LICENSE +0 -0
- {edsl-0.1.37.dist-info → edsl-0.1.37.dev1.dist-info}/WHEEL +0 -0
@@ -16,48 +16,25 @@ class InferenceServicesCollection:
|
|
16
16
|
|
17
17
|
@staticmethod
|
18
18
|
def _get_service_available(service, warn: bool = False) -> list[str]:
|
19
|
+
from_api = True
|
19
20
|
try:
|
20
21
|
service_models = service.available()
|
21
|
-
except Exception:
|
22
|
+
except Exception as e:
|
22
23
|
if warn:
|
23
24
|
warnings.warn(
|
24
25
|
f"""Error getting models for {service._inference_service_}.
|
25
26
|
Check that you have properly stored your Expected Parrot API key and activated remote inference, or stored your own API keys for the language models that you want to use.
|
26
27
|
See https://docs.expectedparrot.com/en/latest/api_keys.html for instructions on storing API keys.
|
27
|
-
Relying on
|
28
|
+
Relying on cache.""",
|
28
29
|
UserWarning,
|
29
30
|
)
|
31
|
+
from edsl.inference_services.models_available_cache import models_available
|
30
32
|
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
models_from_coop = c.fetch_models()
|
37
|
-
service_models = models_from_coop.get(service._inference_service_, [])
|
38
|
-
|
39
|
-
# cache results
|
40
|
-
service._models_list_cache = service_models
|
41
|
-
|
42
|
-
# Finally, use the available models cache from the Python file
|
43
|
-
except Exception:
|
44
|
-
if warn:
|
45
|
-
warnings.warn(
|
46
|
-
f"""Error getting models for {service._inference_service_}.
|
47
|
-
Relying on EDSL cache.""",
|
48
|
-
UserWarning,
|
49
|
-
)
|
50
|
-
|
51
|
-
from edsl.inference_services.models_available_cache import (
|
52
|
-
models_available,
|
53
|
-
)
|
54
|
-
|
55
|
-
service_models = models_available.get(service._inference_service_, [])
|
56
|
-
|
57
|
-
# cache results
|
58
|
-
service._models_list_cache = service_models
|
59
|
-
|
60
|
-
return service_models
|
33
|
+
service_models = models_available.get(service._inference_service_, [])
|
34
|
+
# cache results
|
35
|
+
service._models_list_cache = service_models
|
36
|
+
from_api = False
|
37
|
+
return service_models # , from_api
|
61
38
|
|
62
39
|
def available(self):
|
63
40
|
total_models = []
|
edsl/jobs/Jobs.py
CHANGED
@@ -3,10 +3,9 @@ from __future__ import annotations
|
|
3
3
|
import warnings
|
4
4
|
import requests
|
5
5
|
from itertools import product
|
6
|
-
from typing import
|
6
|
+
from typing import Optional, Union, Sequence, Generator
|
7
7
|
|
8
8
|
from edsl.Base import Base
|
9
|
-
|
10
9
|
from edsl.exceptions import MissingAPIKeyError
|
11
10
|
from edsl.jobs.buckets.BucketCollection import BucketCollection
|
12
11
|
from edsl.jobs.interviews.Interview import Interview
|
@@ -194,7 +193,7 @@ class Jobs(Base):
|
|
194
193
|
inference_service=invigilator.model._inference_service_,
|
195
194
|
model=invigilator.model.model,
|
196
195
|
)
|
197
|
-
costs.append(prompt_cost["
|
196
|
+
costs.append(prompt_cost["cost"])
|
198
197
|
|
199
198
|
d = Dataset(
|
200
199
|
[
|
@@ -210,14 +209,14 @@ class Jobs(Base):
|
|
210
209
|
)
|
211
210
|
return d
|
212
211
|
|
213
|
-
def show_prompts(self, all=False
|
212
|
+
def show_prompts(self, all=False) -> None:
|
214
213
|
"""Print the prompts."""
|
215
214
|
if all:
|
216
|
-
self.prompts().to_scenario_list().print(format="rich"
|
215
|
+
self.prompts().to_scenario_list().print(format="rich")
|
217
216
|
else:
|
218
217
|
self.prompts().select(
|
219
218
|
"user_prompt", "system_prompt"
|
220
|
-
).to_scenario_list().print(format="rich"
|
219
|
+
).to_scenario_list().print(format="rich")
|
221
220
|
|
222
221
|
@staticmethod
|
223
222
|
def estimate_prompt_cost(
|
@@ -228,7 +227,6 @@ class Jobs(Base):
|
|
228
227
|
model: str,
|
229
228
|
) -> dict:
|
230
229
|
"""Estimates the cost of a prompt. Takes piping into account."""
|
231
|
-
import math
|
232
230
|
|
233
231
|
def get_piping_multiplier(prompt: str):
|
234
232
|
"""Returns 2 if a prompt includes Jinja braces, and 1 otherwise."""
|
@@ -242,25 +240,10 @@ class Jobs(Base):
|
|
242
240
|
|
243
241
|
try:
|
244
242
|
relevant_prices = price_lookup[key]
|
245
|
-
|
246
|
-
|
247
|
-
relevant_prices["input"]["service_stated_token_price"]
|
248
|
-
)
|
249
|
-
service_input_token_qty = float(
|
250
|
-
relevant_prices["input"]["service_stated_token_qty"]
|
251
|
-
)
|
252
|
-
input_price_per_token = service_input_token_price / service_input_token_qty
|
253
|
-
|
254
|
-
service_output_token_price = float(
|
255
|
-
relevant_prices["output"]["service_stated_token_price"]
|
256
|
-
)
|
257
|
-
service_output_token_qty = float(
|
258
|
-
relevant_prices["output"]["service_stated_token_qty"]
|
259
|
-
)
|
260
|
-
output_price_per_token = (
|
261
|
-
service_output_token_price / service_output_token_qty
|
243
|
+
output_price_per_token = 1 / float(
|
244
|
+
relevant_prices["output"]["one_usd_buys"]
|
262
245
|
)
|
263
|
-
|
246
|
+
input_price_per_token = 1 / float(relevant_prices["input"]["one_usd_buys"])
|
264
247
|
except KeyError:
|
265
248
|
# A KeyError is likely to occur if we cannot retrieve prices (the price_lookup dict is empty)
|
266
249
|
# Use a sensible default
|
@@ -270,8 +253,9 @@ class Jobs(Base):
|
|
270
253
|
warnings.warn(
|
271
254
|
"Price data could not be retrieved. Using default estimates for input and output token prices. Input: $0.15 / 1M tokens; Output: $0.60 / 1M tokens"
|
272
255
|
)
|
273
|
-
|
274
|
-
output_price_per_token = 0.
|
256
|
+
|
257
|
+
output_price_per_token = 0.00000015 # $0.15 / 1M tokens
|
258
|
+
input_price_per_token = 0.00000060 # $0.60 / 1M tokens
|
275
259
|
|
276
260
|
# Compute the number of characters (double if the question involves piping)
|
277
261
|
user_prompt_chars = len(str(user_prompt)) * get_piping_multiplier(
|
@@ -283,8 +267,7 @@ class Jobs(Base):
|
|
283
267
|
|
284
268
|
# Convert into tokens (1 token approx. equals 4 characters)
|
285
269
|
input_tokens = (user_prompt_chars + system_prompt_chars) // 4
|
286
|
-
|
287
|
-
output_tokens = math.ceil(0.75 * input_tokens)
|
270
|
+
output_tokens = input_tokens
|
288
271
|
|
289
272
|
cost = (
|
290
273
|
input_tokens * input_price_per_token
|
@@ -294,17 +277,15 @@ class Jobs(Base):
|
|
294
277
|
return {
|
295
278
|
"input_tokens": input_tokens,
|
296
279
|
"output_tokens": output_tokens,
|
297
|
-
"
|
280
|
+
"cost": cost,
|
298
281
|
}
|
299
282
|
|
300
|
-
def estimate_job_cost_from_external_prices(
|
301
|
-
self, price_lookup: dict, iterations: int = 1
|
302
|
-
) -> dict:
|
283
|
+
def estimate_job_cost_from_external_prices(self, price_lookup: dict) -> dict:
|
303
284
|
"""
|
304
285
|
Estimates the cost of a job according to the following assumptions:
|
305
286
|
|
306
287
|
- 1 token = 4 characters.
|
307
|
-
-
|
288
|
+
- Input tokens = output tokens.
|
308
289
|
|
309
290
|
price_lookup is an external pricing dictionary.
|
310
291
|
"""
|
@@ -341,7 +322,7 @@ class Jobs(Base):
|
|
341
322
|
"system_prompt": system_prompt,
|
342
323
|
"estimated_input_tokens": prompt_cost["input_tokens"],
|
343
324
|
"estimated_output_tokens": prompt_cost["output_tokens"],
|
344
|
-
"
|
325
|
+
"estimated_cost": prompt_cost["cost"],
|
345
326
|
"inference_service": inference_service,
|
346
327
|
"model": model,
|
347
328
|
}
|
@@ -353,21 +334,18 @@ class Jobs(Base):
|
|
353
334
|
df.groupby(["inference_service", "model"])
|
354
335
|
.agg(
|
355
336
|
{
|
356
|
-
"
|
337
|
+
"estimated_cost": "sum",
|
357
338
|
"estimated_input_tokens": "sum",
|
358
339
|
"estimated_output_tokens": "sum",
|
359
340
|
}
|
360
341
|
)
|
361
342
|
.reset_index()
|
362
343
|
)
|
363
|
-
df["estimated_cost_usd"] = df["estimated_cost_usd"] * iterations
|
364
|
-
df["estimated_input_tokens"] = df["estimated_input_tokens"] * iterations
|
365
|
-
df["estimated_output_tokens"] = df["estimated_output_tokens"] * iterations
|
366
344
|
|
367
345
|
estimated_costs_by_model = df.to_dict("records")
|
368
346
|
|
369
347
|
estimated_total_cost = sum(
|
370
|
-
model["
|
348
|
+
model["estimated_cost"] for model in estimated_costs_by_model
|
371
349
|
)
|
372
350
|
estimated_total_input_tokens = sum(
|
373
351
|
model["estimated_input_tokens"] for model in estimated_costs_by_model
|
@@ -377,7 +355,7 @@ class Jobs(Base):
|
|
377
355
|
)
|
378
356
|
|
379
357
|
output = {
|
380
|
-
"
|
358
|
+
"estimated_total_cost": estimated_total_cost,
|
381
359
|
"estimated_total_input_tokens": estimated_total_input_tokens,
|
382
360
|
"estimated_total_output_tokens": estimated_total_output_tokens,
|
383
361
|
"model_costs": estimated_costs_by_model,
|
@@ -385,12 +363,12 @@ class Jobs(Base):
|
|
385
363
|
|
386
364
|
return output
|
387
365
|
|
388
|
-
def estimate_job_cost(self
|
366
|
+
def estimate_job_cost(self) -> dict:
|
389
367
|
"""
|
390
368
|
Estimates the cost of a job according to the following assumptions:
|
391
369
|
|
392
370
|
- 1 token = 4 characters.
|
393
|
-
-
|
371
|
+
- Input tokens = output tokens.
|
394
372
|
|
395
373
|
Fetches prices from Coop.
|
396
374
|
"""
|
@@ -399,9 +377,7 @@ class Jobs(Base):
|
|
399
377
|
c = Coop()
|
400
378
|
price_lookup = c.fetch_prices()
|
401
379
|
|
402
|
-
return self.estimate_job_cost_from_external_prices(
|
403
|
-
price_lookup=price_lookup, iterations=iterations
|
404
|
-
)
|
380
|
+
return self.estimate_job_cost_from_external_prices(price_lookup=price_lookup)
|
405
381
|
|
406
382
|
@staticmethod
|
407
383
|
def compute_job_cost(job_results: "Results") -> float:
|
@@ -723,11 +699,7 @@ class Jobs(Base):
|
|
723
699
|
return self._raise_validation_errors
|
724
700
|
|
725
701
|
def create_remote_inference_job(
|
726
|
-
self,
|
727
|
-
iterations: int = 1,
|
728
|
-
remote_inference_description: Optional[str] = None,
|
729
|
-
remote_inference_results_visibility: Optional[VisibilityType] = "unlisted",
|
730
|
-
verbose=False,
|
702
|
+
self, iterations: int = 1, remote_inference_description: Optional[str] = None
|
731
703
|
):
|
732
704
|
""" """
|
733
705
|
from edsl.coop.coop import Coop
|
@@ -739,11 +711,9 @@ class Jobs(Base):
|
|
739
711
|
description=remote_inference_description,
|
740
712
|
status="queued",
|
741
713
|
iterations=iterations,
|
742
|
-
initial_results_visibility=remote_inference_results_visibility,
|
743
714
|
)
|
744
715
|
job_uuid = remote_job_creation_data.get("uuid")
|
745
|
-
|
746
|
-
print(f"Job sent to server. (Job uuid={job_uuid}).")
|
716
|
+
print(f"Job sent to server. (Job uuid={job_uuid}).")
|
747
717
|
return remote_job_creation_data
|
748
718
|
|
749
719
|
@staticmethod
|
@@ -754,7 +724,7 @@ class Jobs(Base):
|
|
754
724
|
return coop.remote_inference_get(job_uuid)
|
755
725
|
|
756
726
|
def poll_remote_inference_job(
|
757
|
-
self, remote_job_creation_data: dict
|
727
|
+
self, remote_job_creation_data: dict
|
758
728
|
) -> Union[Results, None]:
|
759
729
|
from edsl.coop.coop import Coop
|
760
730
|
import time
|
@@ -771,46 +741,42 @@ class Jobs(Base):
|
|
771
741
|
remote_job_data = coop.remote_inference_get(job_uuid)
|
772
742
|
status = remote_job_data.get("status")
|
773
743
|
if status == "cancelled":
|
774
|
-
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
|
779
|
-
)
|
744
|
+
print("\r" + " " * 80 + "\r", end="")
|
745
|
+
print("Job cancelled by the user.")
|
746
|
+
print(
|
747
|
+
f"See {expected_parrot_url}/home/remote-inference for more details."
|
748
|
+
)
|
780
749
|
return None
|
781
750
|
elif status == "failed":
|
782
|
-
|
783
|
-
|
784
|
-
|
785
|
-
|
786
|
-
|
787
|
-
)
|
751
|
+
print("\r" + " " * 80 + "\r", end="")
|
752
|
+
print("Job failed.")
|
753
|
+
print(
|
754
|
+
f"See {expected_parrot_url}/home/remote-inference for more details."
|
755
|
+
)
|
788
756
|
return None
|
789
757
|
elif status == "completed":
|
790
758
|
results_uuid = remote_job_data.get("results_uuid")
|
791
759
|
results = coop.get(results_uuid, expected_object_type="results")
|
792
|
-
|
793
|
-
|
794
|
-
|
795
|
-
print(f"Job completed and Results stored on Coop: {url}.")
|
760
|
+
print("\r" + " " * 80 + "\r", end="")
|
761
|
+
url = f"{expected_parrot_url}/content/{results_uuid}"
|
762
|
+
print(f"Job completed and Results stored on Coop: {url}.")
|
796
763
|
return results
|
797
764
|
else:
|
798
|
-
duration =
|
765
|
+
duration = 5
|
799
766
|
time_checked = datetime.now().strftime("%Y-%m-%d %I:%M:%S %p")
|
800
767
|
frames = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
|
801
768
|
start_time = time.time()
|
802
769
|
i = 0
|
803
770
|
while time.time() - start_time < duration:
|
804
|
-
|
805
|
-
|
806
|
-
|
807
|
-
|
808
|
-
|
809
|
-
)
|
771
|
+
print(
|
772
|
+
f"\r{frames[i % len(frames)]} Job status: {status} - last update: {time_checked}",
|
773
|
+
end="",
|
774
|
+
flush=True,
|
775
|
+
)
|
810
776
|
time.sleep(0.1)
|
811
777
|
i += 1
|
812
778
|
|
813
|
-
def use_remote_inference(self, disable_remote_inference: bool)
|
779
|
+
def use_remote_inference(self, disable_remote_inference: bool):
|
814
780
|
if disable_remote_inference:
|
815
781
|
return False
|
816
782
|
if not disable_remote_inference:
|
@@ -826,23 +792,20 @@ class Jobs(Base):
|
|
826
792
|
|
827
793
|
return False
|
828
794
|
|
829
|
-
def use_remote_cache(self
|
830
|
-
|
831
|
-
|
832
|
-
if not disable_remote_cache:
|
833
|
-
try:
|
834
|
-
from edsl import Coop
|
795
|
+
def use_remote_cache(self):
|
796
|
+
try:
|
797
|
+
from edsl import Coop
|
835
798
|
|
836
|
-
|
837
|
-
|
838
|
-
|
839
|
-
|
840
|
-
|
841
|
-
|
799
|
+
user_edsl_settings = Coop().edsl_settings
|
800
|
+
return user_edsl_settings.get("remote_caching", False)
|
801
|
+
except requests.ConnectionError:
|
802
|
+
pass
|
803
|
+
except CoopServerResponseError as e:
|
804
|
+
pass
|
842
805
|
|
843
806
|
return False
|
844
807
|
|
845
|
-
def check_api_keys(self)
|
808
|
+
def check_api_keys(self):
|
846
809
|
from edsl import Model
|
847
810
|
|
848
811
|
for model in self.models + [Model()]:
|
@@ -852,86 +815,6 @@ class Jobs(Base):
|
|
852
815
|
inference_service=model._inference_service_,
|
853
816
|
)
|
854
817
|
|
855
|
-
def get_missing_api_keys(self) -> set:
|
856
|
-
"""
|
857
|
-
Returns a list of the api keys that a user needs to run this job, but does not currently have in their .env file.
|
858
|
-
"""
|
859
|
-
|
860
|
-
missing_api_keys = set()
|
861
|
-
|
862
|
-
from edsl import Model
|
863
|
-
from edsl.enums import service_to_api_keyname
|
864
|
-
|
865
|
-
for model in self.models + [Model()]:
|
866
|
-
if not model.has_valid_api_key():
|
867
|
-
key_name = service_to_api_keyname.get(
|
868
|
-
model._inference_service_, "NOT FOUND"
|
869
|
-
)
|
870
|
-
missing_api_keys.add(key_name)
|
871
|
-
|
872
|
-
return missing_api_keys
|
873
|
-
|
874
|
-
def user_has_all_model_keys(self):
|
875
|
-
"""
|
876
|
-
Returns True if the user has all model keys required to run their job.
|
877
|
-
|
878
|
-
Otherwise, returns False.
|
879
|
-
"""
|
880
|
-
|
881
|
-
try:
|
882
|
-
self.check_api_keys()
|
883
|
-
return True
|
884
|
-
except MissingAPIKeyError:
|
885
|
-
return False
|
886
|
-
except Exception:
|
887
|
-
raise
|
888
|
-
|
889
|
-
def user_has_ep_api_key(self) -> bool:
|
890
|
-
"""
|
891
|
-
Returns True if the user has an EXPECTED_PARROT_API_KEY in their env.
|
892
|
-
|
893
|
-
Otherwise, returns False.
|
894
|
-
"""
|
895
|
-
|
896
|
-
import os
|
897
|
-
|
898
|
-
coop_api_key = os.getenv("EXPECTED_PARROT_API_KEY")
|
899
|
-
|
900
|
-
if coop_api_key is not None:
|
901
|
-
return True
|
902
|
-
else:
|
903
|
-
return False
|
904
|
-
|
905
|
-
def needs_external_llms(self) -> bool:
|
906
|
-
"""
|
907
|
-
Returns True if the job needs external LLMs to run.
|
908
|
-
|
909
|
-
Otherwise, returns False.
|
910
|
-
"""
|
911
|
-
# These cases are necessary to skip the API key check during doctests
|
912
|
-
|
913
|
-
# Accounts for Results.example()
|
914
|
-
all_agents_answer_questions_directly = len(self.agents) > 0 and all(
|
915
|
-
[hasattr(a, "answer_question_directly") for a in self.agents]
|
916
|
-
)
|
917
|
-
|
918
|
-
# Accounts for InterviewExceptionEntry.example()
|
919
|
-
only_model_is_test = set([m.model for m in self.models]) == set(["test"])
|
920
|
-
|
921
|
-
# Accounts for Survey.__call__
|
922
|
-
all_questions_are_functional = set(
|
923
|
-
[q.question_type for q in self.survey.questions]
|
924
|
-
) == set(["functional"])
|
925
|
-
|
926
|
-
if (
|
927
|
-
all_agents_answer_questions_directly
|
928
|
-
or only_model_is_test
|
929
|
-
or all_questions_are_functional
|
930
|
-
):
|
931
|
-
return False
|
932
|
-
else:
|
933
|
-
return True
|
934
|
-
|
935
818
|
def run(
|
936
819
|
self,
|
937
820
|
n: int = 1,
|
@@ -944,28 +827,22 @@ class Jobs(Base):
|
|
944
827
|
print_exceptions=True,
|
945
828
|
remote_cache_description: Optional[str] = None,
|
946
829
|
remote_inference_description: Optional[str] = None,
|
947
|
-
remote_inference_results_visibility: Optional[
|
948
|
-
Literal["private", "public", "unlisted"]
|
949
|
-
] = "unlisted",
|
950
830
|
skip_retry: bool = False,
|
951
831
|
raise_validation_errors: bool = False,
|
952
|
-
disable_remote_cache: bool = False,
|
953
832
|
disable_remote_inference: bool = False,
|
954
833
|
) -> Results:
|
955
834
|
"""
|
956
835
|
Runs the Job: conducts Interviews and returns their results.
|
957
836
|
|
958
|
-
:param n:
|
959
|
-
:param progress_bar:
|
960
|
-
:param stop_on_exception:
|
961
|
-
:param cache:
|
962
|
-
:param check_api_keys:
|
963
|
-
:param
|
964
|
-
:param
|
965
|
-
:param
|
966
|
-
:param
|
967
|
-
:param disable_remote_cache: If True, the job will not use remote cache. This only works for local jobs!
|
968
|
-
:param disable_remote_inference: If True, the job will not use remote inference
|
837
|
+
:param n: how many times to run each interview
|
838
|
+
:param progress_bar: shows a progress bar
|
839
|
+
:param stop_on_exception: stops the job if an exception is raised
|
840
|
+
:param cache: a cache object to store results
|
841
|
+
:param check_api_keys: check if the API keys are valid
|
842
|
+
:param batch_mode: run the job in batch mode i.e., no expecation of interaction with the user
|
843
|
+
:param verbose: prints messages
|
844
|
+
:param remote_cache_description: specifies a description for this group of entries in the remote cache
|
845
|
+
:param remote_inference_description: specifies a description for the remote inference job
|
969
846
|
"""
|
970
847
|
from edsl.coop.coop import Coop
|
971
848
|
|
@@ -975,54 +852,9 @@ class Jobs(Base):
|
|
975
852
|
|
976
853
|
self.verbose = verbose
|
977
854
|
|
978
|
-
if (
|
979
|
-
not self.user_has_all_model_keys()
|
980
|
-
and not self.user_has_ep_api_key()
|
981
|
-
and self.needs_external_llms()
|
982
|
-
):
|
983
|
-
import secrets
|
984
|
-
from dotenv import load_dotenv
|
985
|
-
from edsl import CONFIG
|
986
|
-
from edsl.coop.coop import Coop
|
987
|
-
from edsl.utilities.utilities import write_api_key_to_env
|
988
|
-
|
989
|
-
missing_api_keys = self.get_missing_api_keys()
|
990
|
-
|
991
|
-
edsl_auth_token = secrets.token_urlsafe(16)
|
992
|
-
|
993
|
-
print("You're missing some of the API keys needed to run this job:")
|
994
|
-
for api_key in missing_api_keys:
|
995
|
-
print(f" 🔑 {api_key}")
|
996
|
-
print(
|
997
|
-
"\nYou can either add the missing keys to your .env file, or use remote inference."
|
998
|
-
)
|
999
|
-
print("Remote inference allows you to run jobs on our server.")
|
1000
|
-
print("\n🚀 To use remote inference, sign up at the following link:")
|
1001
|
-
|
1002
|
-
coop = Coop()
|
1003
|
-
coop._display_login_url(edsl_auth_token=edsl_auth_token)
|
1004
|
-
|
1005
|
-
print(
|
1006
|
-
"\nOnce you log in, we will automatically retrieve your Expected Parrot API key and continue your job remotely."
|
1007
|
-
)
|
1008
|
-
|
1009
|
-
api_key = coop._poll_for_api_key(edsl_auth_token)
|
1010
|
-
|
1011
|
-
if api_key is None:
|
1012
|
-
print("\nTimed out waiting for login. Please try again.")
|
1013
|
-
return
|
1014
|
-
|
1015
|
-
write_api_key_to_env(api_key)
|
1016
|
-
print("✨ API key retrieved and written to .env file.\n")
|
1017
|
-
|
1018
|
-
# Retrieve API key so we can continue running the job
|
1019
|
-
load_dotenv()
|
1020
|
-
|
1021
855
|
if remote_inference := self.use_remote_inference(disable_remote_inference):
|
1022
856
|
remote_job_creation_data = self.create_remote_inference_job(
|
1023
|
-
iterations=n,
|
1024
|
-
remote_inference_description=remote_inference_description,
|
1025
|
-
remote_inference_results_visibility=remote_inference_results_visibility,
|
857
|
+
iterations=n, remote_inference_description=remote_inference_description
|
1026
858
|
)
|
1027
859
|
results = self.poll_remote_inference_job(remote_job_creation_data)
|
1028
860
|
if results is None:
|
@@ -1042,7 +874,7 @@ class Jobs(Base):
|
|
1042
874
|
|
1043
875
|
cache = Cache()
|
1044
876
|
|
1045
|
-
remote_cache = self.use_remote_cache(
|
877
|
+
remote_cache = self.use_remote_cache()
|
1046
878
|
with RemoteCacheSync(
|
1047
879
|
coop=Coop(),
|
1048
880
|
cache=cache,
|
@@ -1063,84 +895,17 @@ class Jobs(Base):
|
|
1063
895
|
results.cache = cache.new_entries_cache()
|
1064
896
|
return results
|
1065
897
|
|
1066
|
-
async def create_and_poll_remote_job(
|
1067
|
-
self,
|
1068
|
-
iterations: int = 1,
|
1069
|
-
remote_inference_description: Optional[str] = None,
|
1070
|
-
remote_inference_results_visibility: Optional[
|
1071
|
-
Literal["private", "public", "unlisted"]
|
1072
|
-
] = "unlisted",
|
1073
|
-
) -> Union[Results, None]:
|
1074
|
-
"""
|
1075
|
-
Creates and polls a remote inference job asynchronously.
|
1076
|
-
Reuses existing synchronous methods but runs them in an async context.
|
1077
|
-
|
1078
|
-
:param iterations: Number of times to run each interview
|
1079
|
-
:param remote_inference_description: Optional description for the remote job
|
1080
|
-
:param remote_inference_results_visibility: Visibility setting for results
|
1081
|
-
:return: Results object if successful, None if job fails or is cancelled
|
1082
|
-
"""
|
1083
|
-
import asyncio
|
1084
|
-
from functools import partial
|
1085
|
-
|
1086
|
-
# Create job using existing method
|
1087
|
-
loop = asyncio.get_event_loop()
|
1088
|
-
remote_job_creation_data = await loop.run_in_executor(
|
1089
|
-
None,
|
1090
|
-
partial(
|
1091
|
-
self.create_remote_inference_job,
|
1092
|
-
iterations=iterations,
|
1093
|
-
remote_inference_description=remote_inference_description,
|
1094
|
-
remote_inference_results_visibility=remote_inference_results_visibility,
|
1095
|
-
),
|
1096
|
-
)
|
1097
|
-
|
1098
|
-
# Poll using existing method but with async sleep
|
1099
|
-
return await loop.run_in_executor(
|
1100
|
-
None, partial(self.poll_remote_inference_job, remote_job_creation_data)
|
1101
|
-
)
|
1102
|
-
|
1103
|
-
async def run_async(
|
1104
|
-
self,
|
1105
|
-
cache=None,
|
1106
|
-
n=1,
|
1107
|
-
disable_remote_inference: bool = False,
|
1108
|
-
remote_inference_description: Optional[str] = None,
|
1109
|
-
remote_inference_results_visibility: Optional[
|
1110
|
-
Literal["private", "public", "unlisted"]
|
1111
|
-
] = "unlisted",
|
1112
|
-
**kwargs,
|
1113
|
-
):
|
1114
|
-
"""Run the job asynchronously, either locally or remotely.
|
1115
|
-
|
1116
|
-
:param cache: Cache object or boolean
|
1117
|
-
:param n: Number of iterations
|
1118
|
-
:param disable_remote_inference: If True, forces local execution
|
1119
|
-
:param remote_inference_description: Description for remote jobs
|
1120
|
-
:param remote_inference_results_visibility: Visibility setting for remote results
|
1121
|
-
:param kwargs: Additional arguments passed to local execution
|
1122
|
-
:return: Results object
|
1123
|
-
"""
|
1124
|
-
# Check if we should use remote inference
|
1125
|
-
if remote_inference := self.use_remote_inference(disable_remote_inference):
|
1126
|
-
results = await self.create_and_poll_remote_job(
|
1127
|
-
iterations=n,
|
1128
|
-
remote_inference_description=remote_inference_description,
|
1129
|
-
remote_inference_results_visibility=remote_inference_results_visibility,
|
1130
|
-
)
|
1131
|
-
if results is None:
|
1132
|
-
self._output("Job failed.")
|
1133
|
-
return results
|
1134
|
-
|
1135
|
-
# If not using remote inference, run locally with async
|
1136
|
-
return await JobsRunnerAsyncio(self).run_async(cache=cache, n=n, **kwargs)
|
1137
|
-
|
1138
898
|
def _run_local(self, *args, **kwargs):
|
1139
899
|
"""Run the job locally."""
|
1140
900
|
|
1141
901
|
results = JobsRunnerAsyncio(self).run(*args, **kwargs)
|
1142
902
|
return results
|
1143
903
|
|
904
|
+
async def run_async(self, cache=None, n=1, **kwargs):
|
905
|
+
"""Run asynchronously."""
|
906
|
+
results = await JobsRunnerAsyncio(self).run_async(cache=cache, n=n, **kwargs)
|
907
|
+
return results
|
908
|
+
|
1144
909
|
def all_question_parameters(self):
|
1145
910
|
"""Return all the fields in the questions in the survey.
|
1146
911
|
>>> from edsl.jobs import Jobs
|
@@ -67,11 +67,7 @@ class InterviewExceptionEntry:
|
|
67
67
|
m = LanguageModel.example(test_model=True)
|
68
68
|
q = QuestionFreeText.example(exception_to_throw=ValueError)
|
69
69
|
results = q.by(m).run(
|
70
|
-
skip_retry=True,
|
71
|
-
print_exceptions=False,
|
72
|
-
raise_validation_errors=True,
|
73
|
-
disable_remote_cache=True,
|
74
|
-
disable_remote_inference=True,
|
70
|
+
skip_retry=True, print_exceptions=False, raise_validation_errors=True
|
75
71
|
)
|
76
72
|
return results.task_history.exceptions[0]["how_are_you"][0]
|
77
73
|
|