edsl 0.1.36.dev5__py3-none-any.whl → 0.1.37__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__init__.py +1 -0
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +92 -41
- edsl/agents/AgentList.py +15 -2
- edsl/agents/InvigilatorBase.py +15 -25
- edsl/agents/PromptConstructor.py +149 -108
- edsl/agents/descriptors.py +17 -4
- edsl/conjure/AgentConstructionMixin.py +11 -3
- edsl/conversation/Conversation.py +66 -14
- edsl/conversation/chips.py +95 -0
- edsl/coop/coop.py +148 -39
- edsl/data/Cache.py +1 -1
- edsl/data/RemoteCacheSync.py +25 -12
- edsl/exceptions/BaseException.py +21 -0
- edsl/exceptions/__init__.py +7 -3
- edsl/exceptions/agents.py +17 -19
- edsl/exceptions/results.py +11 -8
- edsl/exceptions/scenarios.py +22 -0
- edsl/exceptions/surveys.py +13 -10
- edsl/inference_services/AwsBedrock.py +7 -2
- edsl/inference_services/InferenceServicesCollection.py +42 -13
- edsl/inference_services/models_available_cache.py +25 -1
- edsl/jobs/Jobs.py +306 -71
- edsl/jobs/interviews/Interview.py +24 -14
- edsl/jobs/interviews/InterviewExceptionCollection.py +1 -1
- edsl/jobs/interviews/InterviewExceptionEntry.py +17 -13
- edsl/jobs/interviews/ReportErrors.py +2 -2
- edsl/jobs/runners/JobsRunnerAsyncio.py +10 -9
- edsl/jobs/tasks/TaskHistory.py +1 -0
- edsl/language_models/KeyLookup.py +30 -0
- edsl/language_models/LanguageModel.py +47 -59
- edsl/language_models/__init__.py +1 -0
- edsl/prompts/Prompt.py +11 -12
- edsl/questions/QuestionBase.py +53 -13
- edsl/questions/QuestionBasePromptsMixin.py +1 -33
- edsl/questions/QuestionFreeText.py +1 -0
- edsl/questions/QuestionFunctional.py +2 -2
- edsl/questions/descriptors.py +23 -28
- edsl/results/DatasetExportMixin.py +25 -1
- edsl/results/Result.py +27 -10
- edsl/results/Results.py +34 -121
- edsl/results/ResultsDBMixin.py +1 -1
- edsl/results/Selector.py +18 -1
- edsl/scenarios/FileStore.py +20 -5
- edsl/scenarios/Scenario.py +52 -13
- edsl/scenarios/ScenarioHtmlMixin.py +7 -2
- edsl/scenarios/ScenarioList.py +12 -1
- edsl/scenarios/__init__.py +2 -0
- edsl/surveys/Rule.py +10 -4
- edsl/surveys/Survey.py +100 -77
- edsl/utilities/utilities.py +18 -0
- {edsl-0.1.36.dev5.dist-info → edsl-0.1.37.dist-info}/METADATA +1 -1
- {edsl-0.1.36.dev5.dist-info → edsl-0.1.37.dist-info}/RECORD +55 -51
- {edsl-0.1.36.dev5.dist-info → edsl-0.1.37.dist-info}/LICENSE +0 -0
- {edsl-0.1.36.dev5.dist-info → edsl-0.1.37.dist-info}/WHEEL +0 -0
edsl/jobs/Jobs.py
CHANGED
@@ -3,9 +3,10 @@ from __future__ import annotations
|
|
3
3
|
import warnings
|
4
4
|
import requests
|
5
5
|
from itertools import product
|
6
|
-
from typing import Optional, Union, Sequence, Generator
|
6
|
+
from typing import Literal, Optional, Union, Sequence, Generator
|
7
7
|
|
8
8
|
from edsl.Base import Base
|
9
|
+
|
9
10
|
from edsl.exceptions import MissingAPIKeyError
|
10
11
|
from edsl.jobs.buckets.BucketCollection import BucketCollection
|
11
12
|
from edsl.jobs.interviews.Interview import Interview
|
@@ -193,7 +194,7 @@ class Jobs(Base):
|
|
193
194
|
inference_service=invigilator.model._inference_service_,
|
194
195
|
model=invigilator.model.model,
|
195
196
|
)
|
196
|
-
costs.append(prompt_cost["
|
197
|
+
costs.append(prompt_cost["cost_usd"])
|
197
198
|
|
198
199
|
d = Dataset(
|
199
200
|
[
|
@@ -209,14 +210,14 @@ class Jobs(Base):
|
|
209
210
|
)
|
210
211
|
return d
|
211
212
|
|
212
|
-
def show_prompts(self, all=False) -> None:
|
213
|
+
def show_prompts(self, all=False, max_rows: Optional[int] = None) -> None:
|
213
214
|
"""Print the prompts."""
|
214
215
|
if all:
|
215
|
-
self.prompts().to_scenario_list().print(format="rich")
|
216
|
+
self.prompts().to_scenario_list().print(format="rich", max_rows=max_rows)
|
216
217
|
else:
|
217
218
|
self.prompts().select(
|
218
219
|
"user_prompt", "system_prompt"
|
219
|
-
).to_scenario_list().print(format="rich")
|
220
|
+
).to_scenario_list().print(format="rich", max_rows=max_rows)
|
220
221
|
|
221
222
|
@staticmethod
|
222
223
|
def estimate_prompt_cost(
|
@@ -227,6 +228,7 @@ class Jobs(Base):
|
|
227
228
|
model: str,
|
228
229
|
) -> dict:
|
229
230
|
"""Estimates the cost of a prompt. Takes piping into account."""
|
231
|
+
import math
|
230
232
|
|
231
233
|
def get_piping_multiplier(prompt: str):
|
232
234
|
"""Returns 2 if a prompt includes Jinja braces, and 1 otherwise."""
|
@@ -240,10 +242,25 @@ class Jobs(Base):
|
|
240
242
|
|
241
243
|
try:
|
242
244
|
relevant_prices = price_lookup[key]
|
243
|
-
|
244
|
-
|
245
|
+
|
246
|
+
service_input_token_price = float(
|
247
|
+
relevant_prices["input"]["service_stated_token_price"]
|
248
|
+
)
|
249
|
+
service_input_token_qty = float(
|
250
|
+
relevant_prices["input"]["service_stated_token_qty"]
|
251
|
+
)
|
252
|
+
input_price_per_token = service_input_token_price / service_input_token_qty
|
253
|
+
|
254
|
+
service_output_token_price = float(
|
255
|
+
relevant_prices["output"]["service_stated_token_price"]
|
256
|
+
)
|
257
|
+
service_output_token_qty = float(
|
258
|
+
relevant_prices["output"]["service_stated_token_qty"]
|
259
|
+
)
|
260
|
+
output_price_per_token = (
|
261
|
+
service_output_token_price / service_output_token_qty
|
245
262
|
)
|
246
|
-
|
263
|
+
|
247
264
|
except KeyError:
|
248
265
|
# A KeyError is likely to occur if we cannot retrieve prices (the price_lookup dict is empty)
|
249
266
|
# Use a sensible default
|
@@ -253,9 +270,8 @@ class Jobs(Base):
|
|
253
270
|
warnings.warn(
|
254
271
|
"Price data could not be retrieved. Using default estimates for input and output token prices. Input: $0.15 / 1M tokens; Output: $0.60 / 1M tokens"
|
255
272
|
)
|
256
|
-
|
257
|
-
output_price_per_token = 0.
|
258
|
-
input_price_per_token = 0.00000060 # $0.60 / 1M tokens
|
273
|
+
input_price_per_token = 0.00000015 # $0.15 / 1M tokens
|
274
|
+
output_price_per_token = 0.00000060 # $0.60 / 1M tokens
|
259
275
|
|
260
276
|
# Compute the number of characters (double if the question involves piping)
|
261
277
|
user_prompt_chars = len(str(user_prompt)) * get_piping_multiplier(
|
@@ -267,7 +283,8 @@ class Jobs(Base):
|
|
267
283
|
|
268
284
|
# Convert into tokens (1 token approx. equals 4 characters)
|
269
285
|
input_tokens = (user_prompt_chars + system_prompt_chars) // 4
|
270
|
-
|
286
|
+
|
287
|
+
output_tokens = math.ceil(0.75 * input_tokens)
|
271
288
|
|
272
289
|
cost = (
|
273
290
|
input_tokens * input_price_per_token
|
@@ -277,15 +294,17 @@ class Jobs(Base):
|
|
277
294
|
return {
|
278
295
|
"input_tokens": input_tokens,
|
279
296
|
"output_tokens": output_tokens,
|
280
|
-
"
|
297
|
+
"cost_usd": cost,
|
281
298
|
}
|
282
299
|
|
283
|
-
def estimate_job_cost_from_external_prices(
|
300
|
+
def estimate_job_cost_from_external_prices(
|
301
|
+
self, price_lookup: dict, iterations: int = 1
|
302
|
+
) -> dict:
|
284
303
|
"""
|
285
304
|
Estimates the cost of a job according to the following assumptions:
|
286
305
|
|
287
306
|
- 1 token = 4 characters.
|
288
|
-
-
|
307
|
+
- For each prompt, output tokens = input tokens * 0.75, rounded up to the nearest integer.
|
289
308
|
|
290
309
|
price_lookup is an external pricing dictionary.
|
291
310
|
"""
|
@@ -322,7 +341,7 @@ class Jobs(Base):
|
|
322
341
|
"system_prompt": system_prompt,
|
323
342
|
"estimated_input_tokens": prompt_cost["input_tokens"],
|
324
343
|
"estimated_output_tokens": prompt_cost["output_tokens"],
|
325
|
-
"
|
344
|
+
"estimated_cost_usd": prompt_cost["cost_usd"],
|
326
345
|
"inference_service": inference_service,
|
327
346
|
"model": model,
|
328
347
|
}
|
@@ -334,18 +353,21 @@ class Jobs(Base):
|
|
334
353
|
df.groupby(["inference_service", "model"])
|
335
354
|
.agg(
|
336
355
|
{
|
337
|
-
"
|
356
|
+
"estimated_cost_usd": "sum",
|
338
357
|
"estimated_input_tokens": "sum",
|
339
358
|
"estimated_output_tokens": "sum",
|
340
359
|
}
|
341
360
|
)
|
342
361
|
.reset_index()
|
343
362
|
)
|
363
|
+
df["estimated_cost_usd"] = df["estimated_cost_usd"] * iterations
|
364
|
+
df["estimated_input_tokens"] = df["estimated_input_tokens"] * iterations
|
365
|
+
df["estimated_output_tokens"] = df["estimated_output_tokens"] * iterations
|
344
366
|
|
345
367
|
estimated_costs_by_model = df.to_dict("records")
|
346
368
|
|
347
369
|
estimated_total_cost = sum(
|
348
|
-
model["
|
370
|
+
model["estimated_cost_usd"] for model in estimated_costs_by_model
|
349
371
|
)
|
350
372
|
estimated_total_input_tokens = sum(
|
351
373
|
model["estimated_input_tokens"] for model in estimated_costs_by_model
|
@@ -355,7 +377,7 @@ class Jobs(Base):
|
|
355
377
|
)
|
356
378
|
|
357
379
|
output = {
|
358
|
-
"
|
380
|
+
"estimated_total_cost_usd": estimated_total_cost,
|
359
381
|
"estimated_total_input_tokens": estimated_total_input_tokens,
|
360
382
|
"estimated_total_output_tokens": estimated_total_output_tokens,
|
361
383
|
"model_costs": estimated_costs_by_model,
|
@@ -363,12 +385,12 @@ class Jobs(Base):
|
|
363
385
|
|
364
386
|
return output
|
365
387
|
|
366
|
-
def estimate_job_cost(self) -> dict:
|
388
|
+
def estimate_job_cost(self, iterations: int = 1) -> dict:
|
367
389
|
"""
|
368
390
|
Estimates the cost of a job according to the following assumptions:
|
369
391
|
|
370
392
|
- 1 token = 4 characters.
|
371
|
-
-
|
393
|
+
- For each prompt, output tokens = input tokens * 0.75, rounded up to the nearest integer.
|
372
394
|
|
373
395
|
Fetches prices from Coop.
|
374
396
|
"""
|
@@ -377,7 +399,9 @@ class Jobs(Base):
|
|
377
399
|
c = Coop()
|
378
400
|
price_lookup = c.fetch_prices()
|
379
401
|
|
380
|
-
return self.estimate_job_cost_from_external_prices(
|
402
|
+
return self.estimate_job_cost_from_external_prices(
|
403
|
+
price_lookup=price_lookup, iterations=iterations
|
404
|
+
)
|
381
405
|
|
382
406
|
@staticmethod
|
383
407
|
def compute_job_cost(job_results: "Results") -> float:
|
@@ -699,7 +723,11 @@ class Jobs(Base):
|
|
699
723
|
return self._raise_validation_errors
|
700
724
|
|
701
725
|
def create_remote_inference_job(
|
702
|
-
self,
|
726
|
+
self,
|
727
|
+
iterations: int = 1,
|
728
|
+
remote_inference_description: Optional[str] = None,
|
729
|
+
remote_inference_results_visibility: Optional[VisibilityType] = "unlisted",
|
730
|
+
verbose=False,
|
703
731
|
):
|
704
732
|
""" """
|
705
733
|
from edsl.coop.coop import Coop
|
@@ -711,9 +739,11 @@ class Jobs(Base):
|
|
711
739
|
description=remote_inference_description,
|
712
740
|
status="queued",
|
713
741
|
iterations=iterations,
|
742
|
+
initial_results_visibility=remote_inference_results_visibility,
|
714
743
|
)
|
715
744
|
job_uuid = remote_job_creation_data.get("uuid")
|
716
|
-
|
745
|
+
if self.verbose:
|
746
|
+
print(f"Job sent to server. (Job uuid={job_uuid}).")
|
717
747
|
return remote_job_creation_data
|
718
748
|
|
719
749
|
@staticmethod
|
@@ -724,7 +754,7 @@ class Jobs(Base):
|
|
724
754
|
return coop.remote_inference_get(job_uuid)
|
725
755
|
|
726
756
|
def poll_remote_inference_job(
|
727
|
-
self, remote_job_creation_data: dict
|
757
|
+
self, remote_job_creation_data: dict, verbose=False, poll_interval=5
|
728
758
|
) -> Union[Results, None]:
|
729
759
|
from edsl.coop.coop import Coop
|
730
760
|
import time
|
@@ -741,42 +771,46 @@ class Jobs(Base):
|
|
741
771
|
remote_job_data = coop.remote_inference_get(job_uuid)
|
742
772
|
status = remote_job_data.get("status")
|
743
773
|
if status == "cancelled":
|
744
|
-
|
745
|
-
|
746
|
-
|
747
|
-
|
748
|
-
|
774
|
+
if self.verbose:
|
775
|
+
print("\r" + " " * 80 + "\r", end="")
|
776
|
+
print("Job cancelled by the user.")
|
777
|
+
print(
|
778
|
+
f"See {expected_parrot_url}/home/remote-inference for more details."
|
779
|
+
)
|
749
780
|
return None
|
750
781
|
elif status == "failed":
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
755
|
-
|
782
|
+
if self.verbose:
|
783
|
+
print("\r" + " " * 80 + "\r", end="")
|
784
|
+
print("Job failed.")
|
785
|
+
print(
|
786
|
+
f"See {expected_parrot_url}/home/remote-inference for more details."
|
787
|
+
)
|
756
788
|
return None
|
757
789
|
elif status == "completed":
|
758
790
|
results_uuid = remote_job_data.get("results_uuid")
|
759
791
|
results = coop.get(results_uuid, expected_object_type="results")
|
760
|
-
|
761
|
-
|
762
|
-
|
792
|
+
if self.verbose:
|
793
|
+
print("\r" + " " * 80 + "\r", end="")
|
794
|
+
url = f"{expected_parrot_url}/content/{results_uuid}"
|
795
|
+
print(f"Job completed and Results stored on Coop: {url}.")
|
763
796
|
return results
|
764
797
|
else:
|
765
|
-
duration =
|
798
|
+
duration = poll_interval
|
766
799
|
time_checked = datetime.now().strftime("%Y-%m-%d %I:%M:%S %p")
|
767
800
|
frames = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
|
768
801
|
start_time = time.time()
|
769
802
|
i = 0
|
770
803
|
while time.time() - start_time < duration:
|
771
|
-
|
772
|
-
|
773
|
-
|
774
|
-
|
775
|
-
|
804
|
+
if self.verbose:
|
805
|
+
print(
|
806
|
+
f"\r{frames[i % len(frames)]} Job status: {status} - last update: {time_checked}",
|
807
|
+
end="",
|
808
|
+
flush=True,
|
809
|
+
)
|
776
810
|
time.sleep(0.1)
|
777
811
|
i += 1
|
778
812
|
|
779
|
-
def use_remote_inference(self, disable_remote_inference: bool):
|
813
|
+
def use_remote_inference(self, disable_remote_inference: bool) -> bool:
|
780
814
|
if disable_remote_inference:
|
781
815
|
return False
|
782
816
|
if not disable_remote_inference:
|
@@ -792,20 +826,23 @@ class Jobs(Base):
|
|
792
826
|
|
793
827
|
return False
|
794
828
|
|
795
|
-
def use_remote_cache(self):
|
796
|
-
|
797
|
-
|
829
|
+
def use_remote_cache(self, disable_remote_cache: bool) -> bool:
|
830
|
+
if disable_remote_cache:
|
831
|
+
return False
|
832
|
+
if not disable_remote_cache:
|
833
|
+
try:
|
834
|
+
from edsl import Coop
|
798
835
|
|
799
|
-
|
800
|
-
|
801
|
-
|
802
|
-
|
803
|
-
|
804
|
-
|
836
|
+
user_edsl_settings = Coop().edsl_settings
|
837
|
+
return user_edsl_settings.get("remote_caching", False)
|
838
|
+
except requests.ConnectionError:
|
839
|
+
pass
|
840
|
+
except CoopServerResponseError as e:
|
841
|
+
pass
|
805
842
|
|
806
843
|
return False
|
807
844
|
|
808
|
-
def check_api_keys(self):
|
845
|
+
def check_api_keys(self) -> None:
|
809
846
|
from edsl import Model
|
810
847
|
|
811
848
|
for model in self.models + [Model()]:
|
@@ -815,6 +852,86 @@ class Jobs(Base):
|
|
815
852
|
inference_service=model._inference_service_,
|
816
853
|
)
|
817
854
|
|
855
|
+
def get_missing_api_keys(self) -> set:
|
856
|
+
"""
|
857
|
+
Returns a list of the api keys that a user needs to run this job, but does not currently have in their .env file.
|
858
|
+
"""
|
859
|
+
|
860
|
+
missing_api_keys = set()
|
861
|
+
|
862
|
+
from edsl import Model
|
863
|
+
from edsl.enums import service_to_api_keyname
|
864
|
+
|
865
|
+
for model in self.models + [Model()]:
|
866
|
+
if not model.has_valid_api_key():
|
867
|
+
key_name = service_to_api_keyname.get(
|
868
|
+
model._inference_service_, "NOT FOUND"
|
869
|
+
)
|
870
|
+
missing_api_keys.add(key_name)
|
871
|
+
|
872
|
+
return missing_api_keys
|
873
|
+
|
874
|
+
def user_has_all_model_keys(self):
|
875
|
+
"""
|
876
|
+
Returns True if the user has all model keys required to run their job.
|
877
|
+
|
878
|
+
Otherwise, returns False.
|
879
|
+
"""
|
880
|
+
|
881
|
+
try:
|
882
|
+
self.check_api_keys()
|
883
|
+
return True
|
884
|
+
except MissingAPIKeyError:
|
885
|
+
return False
|
886
|
+
except Exception:
|
887
|
+
raise
|
888
|
+
|
889
|
+
def user_has_ep_api_key(self) -> bool:
|
890
|
+
"""
|
891
|
+
Returns True if the user has an EXPECTED_PARROT_API_KEY in their env.
|
892
|
+
|
893
|
+
Otherwise, returns False.
|
894
|
+
"""
|
895
|
+
|
896
|
+
import os
|
897
|
+
|
898
|
+
coop_api_key = os.getenv("EXPECTED_PARROT_API_KEY")
|
899
|
+
|
900
|
+
if coop_api_key is not None:
|
901
|
+
return True
|
902
|
+
else:
|
903
|
+
return False
|
904
|
+
|
905
|
+
def needs_external_llms(self) -> bool:
|
906
|
+
"""
|
907
|
+
Returns True if the job needs external LLMs to run.
|
908
|
+
|
909
|
+
Otherwise, returns False.
|
910
|
+
"""
|
911
|
+
# These cases are necessary to skip the API key check during doctests
|
912
|
+
|
913
|
+
# Accounts for Results.example()
|
914
|
+
all_agents_answer_questions_directly = len(self.agents) > 0 and all(
|
915
|
+
[hasattr(a, "answer_question_directly") for a in self.agents]
|
916
|
+
)
|
917
|
+
|
918
|
+
# Accounts for InterviewExceptionEntry.example()
|
919
|
+
only_model_is_test = set([m.model for m in self.models]) == set(["test"])
|
920
|
+
|
921
|
+
# Accounts for Survey.__call__
|
922
|
+
all_questions_are_functional = set(
|
923
|
+
[q.question_type for q in self.survey.questions]
|
924
|
+
) == set(["functional"])
|
925
|
+
|
926
|
+
if (
|
927
|
+
all_agents_answer_questions_directly
|
928
|
+
or only_model_is_test
|
929
|
+
or all_questions_are_functional
|
930
|
+
):
|
931
|
+
return False
|
932
|
+
else:
|
933
|
+
return True
|
934
|
+
|
818
935
|
def run(
|
819
936
|
self,
|
820
937
|
n: int = 1,
|
@@ -827,22 +944,28 @@ class Jobs(Base):
|
|
827
944
|
print_exceptions=True,
|
828
945
|
remote_cache_description: Optional[str] = None,
|
829
946
|
remote_inference_description: Optional[str] = None,
|
947
|
+
remote_inference_results_visibility: Optional[
|
948
|
+
Literal["private", "public", "unlisted"]
|
949
|
+
] = "unlisted",
|
830
950
|
skip_retry: bool = False,
|
831
951
|
raise_validation_errors: bool = False,
|
952
|
+
disable_remote_cache: bool = False,
|
832
953
|
disable_remote_inference: bool = False,
|
833
954
|
) -> Results:
|
834
955
|
"""
|
835
956
|
Runs the Job: conducts Interviews and returns their results.
|
836
957
|
|
837
|
-
:param n:
|
838
|
-
:param progress_bar:
|
839
|
-
:param stop_on_exception:
|
840
|
-
:param cache:
|
841
|
-
:param check_api_keys:
|
842
|
-
:param
|
843
|
-
:param
|
844
|
-
:param
|
845
|
-
:param
|
958
|
+
:param n: How many times to run each interview
|
959
|
+
:param progress_bar: Whether to show a progress bar
|
960
|
+
:param stop_on_exception: Stops the job if an exception is raised
|
961
|
+
:param cache: A Cache object to store results
|
962
|
+
:param check_api_keys: Raises an error if API keys are invalid
|
963
|
+
:param verbose: Prints extra messages
|
964
|
+
:param remote_cache_description: Specifies a description for this group of entries in the remote cache
|
965
|
+
:param remote_inference_description: Specifies a description for the remote inference job
|
966
|
+
:param remote_inference_results_visibility: The initial visibility of the Results object on Coop. This will only be used for remote jobs!
|
967
|
+
:param disable_remote_cache: If True, the job will not use remote cache. This only works for local jobs!
|
968
|
+
:param disable_remote_inference: If True, the job will not use remote inference
|
846
969
|
"""
|
847
970
|
from edsl.coop.coop import Coop
|
848
971
|
|
@@ -852,9 +975,54 @@ class Jobs(Base):
|
|
852
975
|
|
853
976
|
self.verbose = verbose
|
854
977
|
|
978
|
+
if (
|
979
|
+
not self.user_has_all_model_keys()
|
980
|
+
and not self.user_has_ep_api_key()
|
981
|
+
and self.needs_external_llms()
|
982
|
+
):
|
983
|
+
import secrets
|
984
|
+
from dotenv import load_dotenv
|
985
|
+
from edsl import CONFIG
|
986
|
+
from edsl.coop.coop import Coop
|
987
|
+
from edsl.utilities.utilities import write_api_key_to_env
|
988
|
+
|
989
|
+
missing_api_keys = self.get_missing_api_keys()
|
990
|
+
|
991
|
+
edsl_auth_token = secrets.token_urlsafe(16)
|
992
|
+
|
993
|
+
print("You're missing some of the API keys needed to run this job:")
|
994
|
+
for api_key in missing_api_keys:
|
995
|
+
print(f" 🔑 {api_key}")
|
996
|
+
print(
|
997
|
+
"\nYou can either add the missing keys to your .env file, or use remote inference."
|
998
|
+
)
|
999
|
+
print("Remote inference allows you to run jobs on our server.")
|
1000
|
+
print("\n🚀 To use remote inference, sign up at the following link:")
|
1001
|
+
|
1002
|
+
coop = Coop()
|
1003
|
+
coop._display_login_url(edsl_auth_token=edsl_auth_token)
|
1004
|
+
|
1005
|
+
print(
|
1006
|
+
"\nOnce you log in, we will automatically retrieve your Expected Parrot API key and continue your job remotely."
|
1007
|
+
)
|
1008
|
+
|
1009
|
+
api_key = coop._poll_for_api_key(edsl_auth_token)
|
1010
|
+
|
1011
|
+
if api_key is None:
|
1012
|
+
print("\nTimed out waiting for login. Please try again.")
|
1013
|
+
return
|
1014
|
+
|
1015
|
+
write_api_key_to_env(api_key)
|
1016
|
+
print("✨ API key retrieved and written to .env file.\n")
|
1017
|
+
|
1018
|
+
# Retrieve API key so we can continue running the job
|
1019
|
+
load_dotenv()
|
1020
|
+
|
855
1021
|
if remote_inference := self.use_remote_inference(disable_remote_inference):
|
856
1022
|
remote_job_creation_data = self.create_remote_inference_job(
|
857
|
-
iterations=n,
|
1023
|
+
iterations=n,
|
1024
|
+
remote_inference_description=remote_inference_description,
|
1025
|
+
remote_inference_results_visibility=remote_inference_results_visibility,
|
858
1026
|
)
|
859
1027
|
results = self.poll_remote_inference_job(remote_job_creation_data)
|
860
1028
|
if results is None:
|
@@ -874,7 +1042,7 @@ class Jobs(Base):
|
|
874
1042
|
|
875
1043
|
cache = Cache()
|
876
1044
|
|
877
|
-
remote_cache = self.use_remote_cache()
|
1045
|
+
remote_cache = self.use_remote_cache(disable_remote_cache)
|
878
1046
|
with RemoteCacheSync(
|
879
1047
|
coop=Coop(),
|
880
1048
|
cache=cache,
|
@@ -895,17 +1063,84 @@ class Jobs(Base):
|
|
895
1063
|
results.cache = cache.new_entries_cache()
|
896
1064
|
return results
|
897
1065
|
|
1066
|
+
async def create_and_poll_remote_job(
|
1067
|
+
self,
|
1068
|
+
iterations: int = 1,
|
1069
|
+
remote_inference_description: Optional[str] = None,
|
1070
|
+
remote_inference_results_visibility: Optional[
|
1071
|
+
Literal["private", "public", "unlisted"]
|
1072
|
+
] = "unlisted",
|
1073
|
+
) -> Union[Results, None]:
|
1074
|
+
"""
|
1075
|
+
Creates and polls a remote inference job asynchronously.
|
1076
|
+
Reuses existing synchronous methods but runs them in an async context.
|
1077
|
+
|
1078
|
+
:param iterations: Number of times to run each interview
|
1079
|
+
:param remote_inference_description: Optional description for the remote job
|
1080
|
+
:param remote_inference_results_visibility: Visibility setting for results
|
1081
|
+
:return: Results object if successful, None if job fails or is cancelled
|
1082
|
+
"""
|
1083
|
+
import asyncio
|
1084
|
+
from functools import partial
|
1085
|
+
|
1086
|
+
# Create job using existing method
|
1087
|
+
loop = asyncio.get_event_loop()
|
1088
|
+
remote_job_creation_data = await loop.run_in_executor(
|
1089
|
+
None,
|
1090
|
+
partial(
|
1091
|
+
self.create_remote_inference_job,
|
1092
|
+
iterations=iterations,
|
1093
|
+
remote_inference_description=remote_inference_description,
|
1094
|
+
remote_inference_results_visibility=remote_inference_results_visibility,
|
1095
|
+
),
|
1096
|
+
)
|
1097
|
+
|
1098
|
+
# Poll using existing method but with async sleep
|
1099
|
+
return await loop.run_in_executor(
|
1100
|
+
None, partial(self.poll_remote_inference_job, remote_job_creation_data)
|
1101
|
+
)
|
1102
|
+
|
1103
|
+
async def run_async(
|
1104
|
+
self,
|
1105
|
+
cache=None,
|
1106
|
+
n=1,
|
1107
|
+
disable_remote_inference: bool = False,
|
1108
|
+
remote_inference_description: Optional[str] = None,
|
1109
|
+
remote_inference_results_visibility: Optional[
|
1110
|
+
Literal["private", "public", "unlisted"]
|
1111
|
+
] = "unlisted",
|
1112
|
+
**kwargs,
|
1113
|
+
):
|
1114
|
+
"""Run the job asynchronously, either locally or remotely.
|
1115
|
+
|
1116
|
+
:param cache: Cache object or boolean
|
1117
|
+
:param n: Number of iterations
|
1118
|
+
:param disable_remote_inference: If True, forces local execution
|
1119
|
+
:param remote_inference_description: Description for remote jobs
|
1120
|
+
:param remote_inference_results_visibility: Visibility setting for remote results
|
1121
|
+
:param kwargs: Additional arguments passed to local execution
|
1122
|
+
:return: Results object
|
1123
|
+
"""
|
1124
|
+
# Check if we should use remote inference
|
1125
|
+
if remote_inference := self.use_remote_inference(disable_remote_inference):
|
1126
|
+
results = await self.create_and_poll_remote_job(
|
1127
|
+
iterations=n,
|
1128
|
+
remote_inference_description=remote_inference_description,
|
1129
|
+
remote_inference_results_visibility=remote_inference_results_visibility,
|
1130
|
+
)
|
1131
|
+
if results is None:
|
1132
|
+
self._output("Job failed.")
|
1133
|
+
return results
|
1134
|
+
|
1135
|
+
# If not using remote inference, run locally with async
|
1136
|
+
return await JobsRunnerAsyncio(self).run_async(cache=cache, n=n, **kwargs)
|
1137
|
+
|
898
1138
|
def _run_local(self, *args, **kwargs):
|
899
1139
|
"""Run the job locally."""
|
900
1140
|
|
901
1141
|
results = JobsRunnerAsyncio(self).run(*args, **kwargs)
|
902
1142
|
return results
|
903
1143
|
|
904
|
-
async def run_async(self, cache=None, n=1, **kwargs):
|
905
|
-
"""Run asynchronously."""
|
906
|
-
results = await JobsRunnerAsyncio(self).run_async(cache=cache, n=n, **kwargs)
|
907
|
-
return results
|
908
|
-
|
909
1144
|
def all_question_parameters(self):
|
910
1145
|
"""Return all the fields in the questions in the survey.
|
911
1146
|
>>> from edsl.jobs import Jobs
|
@@ -110,9 +110,9 @@ class Interview:
|
|
110
110
|
self.debug = debug
|
111
111
|
self.iteration = iteration
|
112
112
|
self.cache = cache
|
113
|
-
self.answers: dict[
|
114
|
-
|
115
|
-
) # will get filled in as interview progresses
|
113
|
+
self.answers: dict[
|
114
|
+
str, str
|
115
|
+
] = Answers() # will get filled in as interview progresses
|
116
116
|
self.sidecar_model = sidecar_model
|
117
117
|
|
118
118
|
# Trackers
|
@@ -143,9 +143,9 @@ class Interview:
|
|
143
143
|
The keys are the question names; the values are the lists of status log changes for each task.
|
144
144
|
"""
|
145
145
|
for task_creator in self.task_creators.values():
|
146
|
-
self._task_status_log_dict[
|
147
|
-
task_creator.
|
148
|
-
|
146
|
+
self._task_status_log_dict[
|
147
|
+
task_creator.question.question_name
|
148
|
+
] = task_creator.status_log
|
149
149
|
return self._task_status_log_dict
|
150
150
|
|
151
151
|
@property
|
@@ -178,7 +178,7 @@ class Interview:
|
|
178
178
|
if include_exceptions:
|
179
179
|
d["exceptions"] = self.exceptions.to_dict()
|
180
180
|
return d
|
181
|
-
|
181
|
+
|
182
182
|
@classmethod
|
183
183
|
def from_dict(cls, d: dict[str, Any]) -> "Interview":
|
184
184
|
"""Return an Interview instance from a dictionary."""
|
@@ -187,13 +187,23 @@ class Interview:
|
|
187
187
|
scenario = Scenario.from_dict(d["scenario"])
|
188
188
|
model = LanguageModel.from_dict(d["model"])
|
189
189
|
iteration = d["iteration"]
|
190
|
-
|
190
|
+
interview = cls(
|
191
|
+
agent=agent,
|
192
|
+
survey=survey,
|
193
|
+
scenario=scenario,
|
194
|
+
model=model,
|
195
|
+
iteration=iteration,
|
196
|
+
)
|
197
|
+
if "exceptions" in d:
|
198
|
+
exceptions = InterviewExceptionCollection.from_dict(d["exceptions"])
|
199
|
+
interview.exceptions = exceptions
|
200
|
+
return interview
|
191
201
|
|
192
202
|
def __hash__(self) -> int:
|
193
203
|
from edsl.utilities.utilities import dict_hash
|
194
204
|
|
195
205
|
return dict_hash(self._to_dict(include_exceptions=False))
|
196
|
-
|
206
|
+
|
197
207
|
def __eq__(self, other: "Interview") -> bool:
|
198
208
|
"""
|
199
209
|
>>> from edsl.jobs.interviews.Interview import Interview; i = Interview.example(); d = i._to_dict(); i2 = Interview.from_dict(d); i == i2
|
@@ -476,11 +486,11 @@ class Interview:
|
|
476
486
|
"""
|
477
487
|
current_question_index: int = self.to_index[current_question.question_name]
|
478
488
|
|
479
|
-
next_question: Union[
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
489
|
+
next_question: Union[
|
490
|
+
int, EndOfSurvey
|
491
|
+
] = self.survey.rule_collection.next_question(
|
492
|
+
q_now=current_question_index,
|
493
|
+
answers=self.answers | self.scenario | self.agent["traits"],
|
484
494
|
)
|
485
495
|
|
486
496
|
next_question_index = next_question.next_q
|
@@ -33,7 +33,7 @@ class InterviewExceptionCollection(UserDict):
|
|
33
33
|
"""Return the collection of exceptions as a dictionary."""
|
34
34
|
newdata = {k: [e.to_dict() for e in v] for k, v in self.data.items()}
|
35
35
|
return newdata
|
36
|
-
|
36
|
+
|
37
37
|
@classmethod
|
38
38
|
def from_dict(cls, data: dict) -> "InterviewExceptionCollection":
|
39
39
|
"""Create an InterviewExceptionCollection from a dictionary."""
|