edsl 0.1.38.dev2__py3-none-any.whl → 0.1.38.dev4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +60 -31
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +18 -9
- edsl/agents/AgentList.py +59 -8
- edsl/agents/Invigilator.py +18 -7
- edsl/agents/InvigilatorBase.py +0 -19
- edsl/agents/PromptConstructor.py +5 -4
- edsl/config.py +8 -0
- edsl/coop/coop.py +74 -7
- edsl/data/Cache.py +27 -2
- edsl/data/CacheEntry.py +8 -3
- edsl/data/RemoteCacheSync.py +0 -19
- edsl/enums.py +2 -0
- edsl/inference_services/GoogleService.py +7 -15
- edsl/inference_services/PerplexityService.py +163 -0
- edsl/inference_services/registry.py +2 -0
- edsl/jobs/Jobs.py +88 -548
- edsl/jobs/JobsChecks.py +147 -0
- edsl/jobs/JobsPrompts.py +268 -0
- edsl/jobs/JobsRemoteInferenceHandler.py +239 -0
- edsl/jobs/interviews/Interview.py +11 -11
- edsl/jobs/runners/JobsRunnerAsyncio.py +140 -35
- edsl/jobs/runners/JobsRunnerStatus.py +0 -2
- edsl/jobs/tasks/TaskHistory.py +15 -16
- edsl/language_models/LanguageModel.py +44 -84
- edsl/language_models/ModelList.py +47 -1
- edsl/language_models/registry.py +57 -4
- edsl/prompts/Prompt.py +8 -3
- edsl/questions/QuestionBase.py +20 -16
- edsl/questions/QuestionExtract.py +3 -4
- edsl/questions/question_registry.py +36 -6
- edsl/results/CSSParameterizer.py +108 -0
- edsl/results/Dataset.py +146 -15
- edsl/results/DatasetExportMixin.py +231 -217
- edsl/results/DatasetTree.py +134 -4
- edsl/results/Result.py +18 -9
- edsl/results/Results.py +145 -51
- edsl/results/TableDisplay.py +198 -0
- edsl/results/table_display.css +78 -0
- edsl/scenarios/FileStore.py +187 -13
- edsl/scenarios/Scenario.py +61 -4
- edsl/scenarios/ScenarioJoin.py +127 -0
- edsl/scenarios/ScenarioList.py +237 -62
- edsl/surveys/Survey.py +16 -2
- edsl/surveys/SurveyFlowVisualizationMixin.py +67 -9
- edsl/surveys/instructions/Instruction.py +12 -0
- edsl/templates/error_reporting/interview_details.html +3 -3
- edsl/templates/error_reporting/interviews.html +18 -9
- edsl/utilities/utilities.py +15 -0
- {edsl-0.1.38.dev2.dist-info → edsl-0.1.38.dev4.dist-info}/METADATA +2 -1
- {edsl-0.1.38.dev2.dist-info → edsl-0.1.38.dev4.dist-info}/RECORD +53 -45
- {edsl-0.1.38.dev2.dist-info → edsl-0.1.38.dev4.dist-info}/LICENSE +0 -0
- {edsl-0.1.38.dev2.dist-info → edsl-0.1.38.dev4.dist-info}/WHEEL +0 -0
edsl/jobs/JobsChecks.py
ADDED
@@ -0,0 +1,147 @@
|
|
1
|
+
import os
|
2
|
+
from edsl.exceptions import MissingAPIKeyError
|
3
|
+
|
4
|
+
|
5
|
+
class JobsChecks:
|
6
|
+
def __init__(self, jobs):
|
7
|
+
""" """
|
8
|
+
self.jobs = jobs
|
9
|
+
|
10
|
+
def check_api_keys(self) -> None:
|
11
|
+
from edsl import Model
|
12
|
+
|
13
|
+
for model in self.jobs.models + [Model()]:
|
14
|
+
if not model.has_valid_api_key():
|
15
|
+
raise MissingAPIKeyError(
|
16
|
+
model_name=str(model.model),
|
17
|
+
inference_service=model._inference_service_,
|
18
|
+
)
|
19
|
+
|
20
|
+
def get_missing_api_keys(self) -> set:
|
21
|
+
"""
|
22
|
+
Returns a list of the api keys that a user needs to run this job, but does not currently have in their .env file.
|
23
|
+
"""
|
24
|
+
missing_api_keys = set()
|
25
|
+
|
26
|
+
from edsl import Model
|
27
|
+
from edsl.enums import service_to_api_keyname
|
28
|
+
|
29
|
+
for model in self.jobs.models + [Model()]:
|
30
|
+
if not model.has_valid_api_key():
|
31
|
+
key_name = service_to_api_keyname.get(
|
32
|
+
model._inference_service_, "NOT FOUND"
|
33
|
+
)
|
34
|
+
missing_api_keys.add(key_name)
|
35
|
+
|
36
|
+
return missing_api_keys
|
37
|
+
|
38
|
+
def user_has_ep_api_key(self) -> bool:
|
39
|
+
"""
|
40
|
+
Returns True if the user has an EXPECTED_PARROT_API_KEY in their env.
|
41
|
+
|
42
|
+
Otherwise, returns False.
|
43
|
+
"""
|
44
|
+
|
45
|
+
coop_api_key = os.getenv("EXPECTED_PARROT_API_KEY")
|
46
|
+
|
47
|
+
if coop_api_key is not None:
|
48
|
+
return True
|
49
|
+
else:
|
50
|
+
return False
|
51
|
+
|
52
|
+
def user_has_all_model_keys(self):
|
53
|
+
"""
|
54
|
+
Returns True if the user has all model keys required to run their job.
|
55
|
+
|
56
|
+
Otherwise, returns False.
|
57
|
+
"""
|
58
|
+
|
59
|
+
try:
|
60
|
+
self.check_api_keys()
|
61
|
+
return True
|
62
|
+
except MissingAPIKeyError:
|
63
|
+
return False
|
64
|
+
except Exception:
|
65
|
+
raise
|
66
|
+
|
67
|
+
def needs_external_llms(self) -> bool:
|
68
|
+
"""
|
69
|
+
Returns True if the job needs external LLMs to run.
|
70
|
+
|
71
|
+
Otherwise, returns False.
|
72
|
+
"""
|
73
|
+
# These cases are necessary to skip the API key check during doctests
|
74
|
+
|
75
|
+
# Accounts for Results.example()
|
76
|
+
all_agents_answer_questions_directly = len(self.jobs.agents) > 0 and all(
|
77
|
+
[hasattr(a, "answer_question_directly") for a in self.jobs.agents]
|
78
|
+
)
|
79
|
+
|
80
|
+
# Accounts for InterviewExceptionEntry.example()
|
81
|
+
only_model_is_test = set([m.model for m in self.jobs.models]) == set(["test"])
|
82
|
+
|
83
|
+
# Accounts for Survey.__call__
|
84
|
+
all_questions_are_functional = set(
|
85
|
+
[q.question_type for q in self.jobs.survey.questions]
|
86
|
+
) == set(["functional"])
|
87
|
+
|
88
|
+
if (
|
89
|
+
all_agents_answer_questions_directly
|
90
|
+
or only_model_is_test
|
91
|
+
or all_questions_are_functional
|
92
|
+
):
|
93
|
+
return False
|
94
|
+
else:
|
95
|
+
return True
|
96
|
+
|
97
|
+
def needs_key_process(self):
|
98
|
+
return (
|
99
|
+
not self.user_has_all_model_keys()
|
100
|
+
and not self.user_has_ep_api_key()
|
101
|
+
and self.needs_external_llms()
|
102
|
+
)
|
103
|
+
|
104
|
+
def key_process(self):
|
105
|
+
import secrets
|
106
|
+
from dotenv import load_dotenv
|
107
|
+
from edsl import CONFIG
|
108
|
+
from edsl.coop.coop import Coop
|
109
|
+
from edsl.utilities.utilities import write_api_key_to_env
|
110
|
+
|
111
|
+
missing_api_keys = self.get_missing_api_keys()
|
112
|
+
|
113
|
+
edsl_auth_token = secrets.token_urlsafe(16)
|
114
|
+
|
115
|
+
print("You're missing some of the API keys needed to run this job:")
|
116
|
+
for api_key in missing_api_keys:
|
117
|
+
print(f" 🔑 {api_key}")
|
118
|
+
print(
|
119
|
+
"\nYou can either add the missing keys to your .env file, or use remote inference."
|
120
|
+
)
|
121
|
+
print("Remote inference allows you to run jobs on our server.")
|
122
|
+
print("\n🚀 To use remote inference, sign up at the following link:")
|
123
|
+
|
124
|
+
coop = Coop()
|
125
|
+
coop._display_login_url(edsl_auth_token=edsl_auth_token)
|
126
|
+
|
127
|
+
print(
|
128
|
+
"\nOnce you log in, we will automatically retrieve your Expected Parrot API key and continue your job remotely."
|
129
|
+
)
|
130
|
+
|
131
|
+
api_key = coop._poll_for_api_key(edsl_auth_token)
|
132
|
+
|
133
|
+
if api_key is None:
|
134
|
+
print("\nTimed out waiting for login. Please try again.")
|
135
|
+
return
|
136
|
+
|
137
|
+
write_api_key_to_env(api_key)
|
138
|
+
print("✨ API key retrieved and written to .env file.\n")
|
139
|
+
|
140
|
+
# Retrieve API key so we can continue running the job
|
141
|
+
load_dotenv()
|
142
|
+
|
143
|
+
|
144
|
+
if __name__ == "__main__":
|
145
|
+
import doctest
|
146
|
+
|
147
|
+
doctest.testmod()
|
edsl/jobs/JobsPrompts.py
ADDED
@@ -0,0 +1,268 @@
|
|
1
|
+
from typing import List, TYPE_CHECKING
|
2
|
+
|
3
|
+
from edsl.results.Dataset import Dataset
|
4
|
+
|
5
|
+
if TYPE_CHECKING:
|
6
|
+
from edsl.jobs import Jobs
|
7
|
+
|
8
|
+
# from edsl.jobs.interviews.Interview import Interview
|
9
|
+
# from edsl.results.Dataset import Dataset
|
10
|
+
# from edsl.agents.AgentList import AgentList
|
11
|
+
# from edsl.scenarios.ScenarioList import ScenarioList
|
12
|
+
# from edsl.surveys.Survey import Survey
|
13
|
+
|
14
|
+
|
15
|
+
class JobsPrompts:
|
16
|
+
def __init__(self, jobs: "Jobs"):
|
17
|
+
self.interviews = jobs.interviews()
|
18
|
+
self.agents = jobs.agents
|
19
|
+
self.scenarios = jobs.scenarios
|
20
|
+
self.survey = jobs.survey
|
21
|
+
self._price_lookup = None
|
22
|
+
|
23
|
+
@property
|
24
|
+
def price_lookup(self):
|
25
|
+
if self._price_lookup is None:
|
26
|
+
from edsl import Coop
|
27
|
+
|
28
|
+
c = Coop()
|
29
|
+
self._price_lookup = c.fetch_prices()
|
30
|
+
return self._price_lookup
|
31
|
+
|
32
|
+
def prompts(self) -> "Dataset":
|
33
|
+
"""Return a Dataset of prompts that will be used.
|
34
|
+
|
35
|
+
>>> from edsl.jobs import Jobs
|
36
|
+
>>> Jobs.example().prompts()
|
37
|
+
Dataset(...)
|
38
|
+
"""
|
39
|
+
interviews = self.interviews
|
40
|
+
interview_indices = []
|
41
|
+
question_names = []
|
42
|
+
user_prompts = []
|
43
|
+
system_prompts = []
|
44
|
+
scenario_indices = []
|
45
|
+
agent_indices = []
|
46
|
+
models = []
|
47
|
+
costs = []
|
48
|
+
|
49
|
+
for interview_index, interview in enumerate(interviews):
|
50
|
+
invigilators = [
|
51
|
+
interview._get_invigilator(question)
|
52
|
+
for question in self.survey.questions
|
53
|
+
]
|
54
|
+
for _, invigilator in enumerate(invigilators):
|
55
|
+
prompts = invigilator.get_prompts()
|
56
|
+
user_prompt = prompts["user_prompt"]
|
57
|
+
system_prompt = prompts["system_prompt"]
|
58
|
+
user_prompts.append(user_prompt)
|
59
|
+
system_prompts.append(system_prompt)
|
60
|
+
agent_index = self.agents.index(invigilator.agent)
|
61
|
+
agent_indices.append(agent_index)
|
62
|
+
interview_indices.append(interview_index)
|
63
|
+
scenario_index = self.scenarios.index(invigilator.scenario)
|
64
|
+
scenario_indices.append(scenario_index)
|
65
|
+
models.append(invigilator.model.model)
|
66
|
+
question_names.append(invigilator.question.question_name)
|
67
|
+
|
68
|
+
prompt_cost = self.estimate_prompt_cost(
|
69
|
+
system_prompt=system_prompt,
|
70
|
+
user_prompt=user_prompt,
|
71
|
+
price_lookup=self.price_lookup,
|
72
|
+
inference_service=invigilator.model._inference_service_,
|
73
|
+
model=invigilator.model.model,
|
74
|
+
)
|
75
|
+
costs.append(prompt_cost["cost_usd"])
|
76
|
+
|
77
|
+
d = Dataset(
|
78
|
+
[
|
79
|
+
{"user_prompt": user_prompts},
|
80
|
+
{"system_prompt": system_prompts},
|
81
|
+
{"interview_index": interview_indices},
|
82
|
+
{"question_name": question_names},
|
83
|
+
{"scenario_index": scenario_indices},
|
84
|
+
{"agent_index": agent_indices},
|
85
|
+
{"model": models},
|
86
|
+
{"estimated_cost": costs},
|
87
|
+
]
|
88
|
+
)
|
89
|
+
return d
|
90
|
+
|
91
|
+
@staticmethod
|
92
|
+
def estimate_prompt_cost(
|
93
|
+
system_prompt: str,
|
94
|
+
user_prompt: str,
|
95
|
+
price_lookup: dict,
|
96
|
+
inference_service: str,
|
97
|
+
model: str,
|
98
|
+
) -> dict:
|
99
|
+
"""Estimates the cost of a prompt. Takes piping into account."""
|
100
|
+
import math
|
101
|
+
|
102
|
+
def get_piping_multiplier(prompt: str):
|
103
|
+
"""Returns 2 if a prompt includes Jinja braces, and 1 otherwise."""
|
104
|
+
|
105
|
+
if "{{" in prompt and "}}" in prompt:
|
106
|
+
return 2
|
107
|
+
return 1
|
108
|
+
|
109
|
+
# Look up prices per token
|
110
|
+
key = (inference_service, model)
|
111
|
+
|
112
|
+
try:
|
113
|
+
relevant_prices = price_lookup[key]
|
114
|
+
|
115
|
+
service_input_token_price = float(
|
116
|
+
relevant_prices["input"]["service_stated_token_price"]
|
117
|
+
)
|
118
|
+
service_input_token_qty = float(
|
119
|
+
relevant_prices["input"]["service_stated_token_qty"]
|
120
|
+
)
|
121
|
+
input_price_per_token = service_input_token_price / service_input_token_qty
|
122
|
+
|
123
|
+
service_output_token_price = float(
|
124
|
+
relevant_prices["output"]["service_stated_token_price"]
|
125
|
+
)
|
126
|
+
service_output_token_qty = float(
|
127
|
+
relevant_prices["output"]["service_stated_token_qty"]
|
128
|
+
)
|
129
|
+
output_price_per_token = (
|
130
|
+
service_output_token_price / service_output_token_qty
|
131
|
+
)
|
132
|
+
|
133
|
+
except KeyError:
|
134
|
+
# A KeyError is likely to occur if we cannot retrieve prices (the price_lookup dict is empty)
|
135
|
+
# Use a sensible default
|
136
|
+
|
137
|
+
import warnings
|
138
|
+
|
139
|
+
warnings.warn(
|
140
|
+
"Price data could not be retrieved. Using default estimates for input and output token prices. Input: $0.15 / 1M tokens; Output: $0.60 / 1M tokens"
|
141
|
+
)
|
142
|
+
input_price_per_token = 0.00000015 # $0.15 / 1M tokens
|
143
|
+
output_price_per_token = 0.00000060 # $0.60 / 1M tokens
|
144
|
+
|
145
|
+
# Compute the number of characters (double if the question involves piping)
|
146
|
+
user_prompt_chars = len(str(user_prompt)) * get_piping_multiplier(
|
147
|
+
str(user_prompt)
|
148
|
+
)
|
149
|
+
system_prompt_chars = len(str(system_prompt)) * get_piping_multiplier(
|
150
|
+
str(system_prompt)
|
151
|
+
)
|
152
|
+
|
153
|
+
# Convert into tokens (1 token approx. equals 4 characters)
|
154
|
+
input_tokens = (user_prompt_chars + system_prompt_chars) // 4
|
155
|
+
|
156
|
+
output_tokens = math.ceil(0.75 * input_tokens)
|
157
|
+
|
158
|
+
cost = (
|
159
|
+
input_tokens * input_price_per_token
|
160
|
+
+ output_tokens * output_price_per_token
|
161
|
+
)
|
162
|
+
|
163
|
+
return {
|
164
|
+
"input_tokens": input_tokens,
|
165
|
+
"output_tokens": output_tokens,
|
166
|
+
"cost_usd": cost,
|
167
|
+
}
|
168
|
+
|
169
|
+
def estimate_job_cost_from_external_prices(
|
170
|
+
self, price_lookup: dict, iterations: int = 1
|
171
|
+
) -> dict:
|
172
|
+
"""
|
173
|
+
Estimates the cost of a job according to the following assumptions:
|
174
|
+
|
175
|
+
- 1 token = 4 characters.
|
176
|
+
- For each prompt, output tokens = input tokens * 0.75, rounded up to the nearest integer.
|
177
|
+
|
178
|
+
price_lookup is an external pricing dictionary.
|
179
|
+
"""
|
180
|
+
|
181
|
+
import pandas as pd
|
182
|
+
|
183
|
+
interviews = self.interviews
|
184
|
+
data = []
|
185
|
+
for interview in interviews:
|
186
|
+
invigilators = [
|
187
|
+
interview._get_invigilator(question)
|
188
|
+
for question in self.survey.questions
|
189
|
+
]
|
190
|
+
for invigilator in invigilators:
|
191
|
+
prompts = invigilator.get_prompts()
|
192
|
+
|
193
|
+
# By this point, agent and scenario data has already been added to the prompts
|
194
|
+
user_prompt = prompts["user_prompt"]
|
195
|
+
system_prompt = prompts["system_prompt"]
|
196
|
+
inference_service = invigilator.model._inference_service_
|
197
|
+
model = invigilator.model.model
|
198
|
+
|
199
|
+
prompt_cost = self.estimate_prompt_cost(
|
200
|
+
system_prompt=system_prompt,
|
201
|
+
user_prompt=user_prompt,
|
202
|
+
price_lookup=price_lookup,
|
203
|
+
inference_service=inference_service,
|
204
|
+
model=model,
|
205
|
+
)
|
206
|
+
|
207
|
+
data.append(
|
208
|
+
{
|
209
|
+
"user_prompt": user_prompt,
|
210
|
+
"system_prompt": system_prompt,
|
211
|
+
"estimated_input_tokens": prompt_cost["input_tokens"],
|
212
|
+
"estimated_output_tokens": prompt_cost["output_tokens"],
|
213
|
+
"estimated_cost_usd": prompt_cost["cost_usd"],
|
214
|
+
"inference_service": inference_service,
|
215
|
+
"model": model,
|
216
|
+
}
|
217
|
+
)
|
218
|
+
|
219
|
+
df = pd.DataFrame.from_records(data)
|
220
|
+
|
221
|
+
df = (
|
222
|
+
df.groupby(["inference_service", "model"])
|
223
|
+
.agg(
|
224
|
+
{
|
225
|
+
"estimated_cost_usd": "sum",
|
226
|
+
"estimated_input_tokens": "sum",
|
227
|
+
"estimated_output_tokens": "sum",
|
228
|
+
}
|
229
|
+
)
|
230
|
+
.reset_index()
|
231
|
+
)
|
232
|
+
df["estimated_cost_usd"] = df["estimated_cost_usd"] * iterations
|
233
|
+
df["estimated_input_tokens"] = df["estimated_input_tokens"] * iterations
|
234
|
+
df["estimated_output_tokens"] = df["estimated_output_tokens"] * iterations
|
235
|
+
|
236
|
+
estimated_costs_by_model = df.to_dict("records")
|
237
|
+
|
238
|
+
estimated_total_cost = sum(
|
239
|
+
model["estimated_cost_usd"] for model in estimated_costs_by_model
|
240
|
+
)
|
241
|
+
estimated_total_input_tokens = sum(
|
242
|
+
model["estimated_input_tokens"] for model in estimated_costs_by_model
|
243
|
+
)
|
244
|
+
estimated_total_output_tokens = sum(
|
245
|
+
model["estimated_output_tokens"] for model in estimated_costs_by_model
|
246
|
+
)
|
247
|
+
|
248
|
+
output = {
|
249
|
+
"estimated_total_cost_usd": estimated_total_cost,
|
250
|
+
"estimated_total_input_tokens": estimated_total_input_tokens,
|
251
|
+
"estimated_total_output_tokens": estimated_total_output_tokens,
|
252
|
+
"model_costs": estimated_costs_by_model,
|
253
|
+
}
|
254
|
+
|
255
|
+
return output
|
256
|
+
|
257
|
+
def estimate_job_cost(self, iterations: int = 1) -> dict:
|
258
|
+
"""
|
259
|
+
Estimates the cost of a job according to the following assumptions:
|
260
|
+
|
261
|
+
- 1 token = 4 characters.
|
262
|
+
- For each prompt, output tokens = input tokens * 0.75, rounded up to the nearest integer.
|
263
|
+
|
264
|
+
Fetches prices from Coop.
|
265
|
+
"""
|
266
|
+
return self.estimate_job_cost_from_external_prices(
|
267
|
+
price_lookup=self.price_lookup, iterations=iterations
|
268
|
+
)
|
@@ -0,0 +1,239 @@
|
|
1
|
+
from typing import Optional, Union, Literal
|
2
|
+
import requests
|
3
|
+
import sys
|
4
|
+
from edsl.exceptions.coop import CoopServerResponseError
|
5
|
+
|
6
|
+
# from edsl.enums import VisibilityType
|
7
|
+
from edsl.results import Results
|
8
|
+
|
9
|
+
|
10
|
+
class JobsRemoteInferenceHandler:
|
11
|
+
def __init__(self, jobs, verbose=False, poll_interval=3):
|
12
|
+
"""
|
13
|
+
>>> from edsl.jobs import Jobs
|
14
|
+
>>> jh = JobsRemoteInferenceHandler(Jobs.example(), verbose=True)
|
15
|
+
>>> jh.use_remote_inference(True)
|
16
|
+
False
|
17
|
+
>>> jh._poll_remote_inference_job({'uuid':1234}, testing_simulated_response={"status": "failed"}) # doctest: +NORMALIZE_WHITESPACE
|
18
|
+
Job failed.
|
19
|
+
...
|
20
|
+
>>> jh._poll_remote_inference_job({'uuid':1234}, testing_simulated_response={"status": "completed"}) # doctest: +NORMALIZE_WHITESPACE
|
21
|
+
Job completed and Results stored on Coop: None.
|
22
|
+
Results(...)
|
23
|
+
"""
|
24
|
+
self.jobs = jobs
|
25
|
+
self.verbose = verbose
|
26
|
+
self.poll_interval = poll_interval
|
27
|
+
|
28
|
+
self._remote_job_creation_data = None
|
29
|
+
self._job_uuid = None
|
30
|
+
|
31
|
+
@property
|
32
|
+
def remote_job_creation_data(self):
|
33
|
+
return self._remote_job_creation_data
|
34
|
+
|
35
|
+
@property
|
36
|
+
def job_uuid(self):
|
37
|
+
return self._job_uuid
|
38
|
+
|
39
|
+
def use_remote_inference(self, disable_remote_inference: bool) -> bool:
|
40
|
+
if disable_remote_inference:
|
41
|
+
return False
|
42
|
+
if not disable_remote_inference:
|
43
|
+
try:
|
44
|
+
from edsl import Coop
|
45
|
+
|
46
|
+
user_edsl_settings = Coop().edsl_settings
|
47
|
+
return user_edsl_settings.get("remote_inference", False)
|
48
|
+
except requests.ConnectionError:
|
49
|
+
pass
|
50
|
+
except CoopServerResponseError as e:
|
51
|
+
pass
|
52
|
+
|
53
|
+
return False
|
54
|
+
|
55
|
+
def create_remote_inference_job(
|
56
|
+
self,
|
57
|
+
iterations: int = 1,
|
58
|
+
remote_inference_description: Optional[str] = None,
|
59
|
+
remote_inference_results_visibility: Optional["VisibilityType"] = "unlisted",
|
60
|
+
verbose=False,
|
61
|
+
):
|
62
|
+
""" """
|
63
|
+
from edsl.config import CONFIG
|
64
|
+
from edsl.coop.coop import Coop
|
65
|
+
from rich import print as rich_print
|
66
|
+
|
67
|
+
coop = Coop()
|
68
|
+
print("Remote inference activated. Sending job to server...")
|
69
|
+
remote_job_creation_data = coop.remote_inference_create(
|
70
|
+
self.jobs,
|
71
|
+
description=remote_inference_description,
|
72
|
+
status="queued",
|
73
|
+
iterations=iterations,
|
74
|
+
initial_results_visibility=remote_inference_results_visibility,
|
75
|
+
)
|
76
|
+
job_uuid = remote_job_creation_data.get("uuid")
|
77
|
+
print(f"Job sent to server. (Job uuid={job_uuid}).")
|
78
|
+
|
79
|
+
expected_parrot_url = CONFIG.get("EXPECTED_PARROT_URL")
|
80
|
+
progress_bar_url = f"{expected_parrot_url}/home/remote-job-progress/{job_uuid}"
|
81
|
+
|
82
|
+
rich_print(
|
83
|
+
f"View job progress here: [#38bdf8][link={progress_bar_url}]{progress_bar_url}[/link][/#38bdf8]"
|
84
|
+
)
|
85
|
+
|
86
|
+
self._remote_job_creation_data = remote_job_creation_data
|
87
|
+
self._job_uuid = job_uuid
|
88
|
+
# return remote_job_creation_data
|
89
|
+
|
90
|
+
@staticmethod
|
91
|
+
def check_status(job_uuid):
|
92
|
+
from edsl.coop.coop import Coop
|
93
|
+
|
94
|
+
coop = Coop()
|
95
|
+
return coop.remote_inference_get(job_uuid)
|
96
|
+
|
97
|
+
def poll_remote_inference_job(self):
|
98
|
+
return self._poll_remote_inference_job(
|
99
|
+
self.remote_job_creation_data, verbose=self.verbose
|
100
|
+
)
|
101
|
+
|
102
|
+
def _poll_remote_inference_job(
|
103
|
+
self,
|
104
|
+
remote_job_creation_data: dict,
|
105
|
+
verbose=False,
|
106
|
+
poll_interval: Optional[float] = None,
|
107
|
+
testing_simulated_response: Optional[dict] = None,
|
108
|
+
) -> Union[Results, None]:
|
109
|
+
import time
|
110
|
+
from datetime import datetime
|
111
|
+
from edsl.config import CONFIG
|
112
|
+
from edsl.coop.coop import Coop
|
113
|
+
|
114
|
+
if poll_interval is None:
|
115
|
+
poll_interval = self.poll_interval
|
116
|
+
|
117
|
+
expected_parrot_url = CONFIG.get("EXPECTED_PARROT_URL")
|
118
|
+
|
119
|
+
job_uuid = remote_job_creation_data.get("uuid")
|
120
|
+
coop = Coop()
|
121
|
+
|
122
|
+
if testing_simulated_response is not None:
|
123
|
+
remote_job_data_fetcher = lambda job_uuid: testing_simulated_response
|
124
|
+
object_fetcher = (
|
125
|
+
lambda results_uuid, expected_object_type: Results.example()
|
126
|
+
)
|
127
|
+
else:
|
128
|
+
remote_job_data_fetcher = coop.remote_inference_get
|
129
|
+
object_fetcher = coop.get
|
130
|
+
|
131
|
+
job_in_queue = True
|
132
|
+
while job_in_queue:
|
133
|
+
remote_job_data = remote_job_data_fetcher(job_uuid)
|
134
|
+
status = remote_job_data.get("status")
|
135
|
+
if status == "cancelled":
|
136
|
+
print("\r" + " " * 80 + "\r", end="")
|
137
|
+
print("Job cancelled by the user.")
|
138
|
+
print(
|
139
|
+
f"See {expected_parrot_url}/home/remote-inference for more details."
|
140
|
+
)
|
141
|
+
return None
|
142
|
+
elif status == "failed":
|
143
|
+
print("\r" + " " * 80 + "\r", end="")
|
144
|
+
# write to stderr
|
145
|
+
latest_error_report_url = remote_job_data.get("latest_error_report_url")
|
146
|
+
if latest_error_report_url:
|
147
|
+
print("Job failed.")
|
148
|
+
print(
|
149
|
+
f"Your job generated exceptions. Details on these exceptions can be found in the following report: {latest_error_report_url}"
|
150
|
+
)
|
151
|
+
print(
|
152
|
+
f"Need support? Post a message at the Expected Parrot Discord channel (https://discord.com/invite/mxAYkjfy9m) or send an email to info@expectedparrot.com."
|
153
|
+
)
|
154
|
+
else:
|
155
|
+
print("Job failed.")
|
156
|
+
print(
|
157
|
+
f"See {expected_parrot_url}/home/remote-inference for more details."
|
158
|
+
)
|
159
|
+
return None
|
160
|
+
elif status == "completed":
|
161
|
+
results_uuid = remote_job_data.get("results_uuid")
|
162
|
+
results_url = remote_job_data.get("results_url")
|
163
|
+
results = object_fetcher(results_uuid, expected_object_type="results")
|
164
|
+
print("\r" + " " * 80 + "\r", end="")
|
165
|
+
print(f"Job completed and Results stored on Coop: {results_url}.")
|
166
|
+
return results
|
167
|
+
else:
|
168
|
+
duration = poll_interval
|
169
|
+
time_checked = datetime.now().strftime("%Y-%m-%d %I:%M:%S %p")
|
170
|
+
frames = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
|
171
|
+
start_time = time.time()
|
172
|
+
i = 0
|
173
|
+
while time.time() - start_time < duration:
|
174
|
+
print(
|
175
|
+
f"\r{frames[i % len(frames)]} Job status: {status} - last update: {time_checked}",
|
176
|
+
end="",
|
177
|
+
flush=True,
|
178
|
+
)
|
179
|
+
time.sleep(0.1)
|
180
|
+
i += 1
|
181
|
+
|
182
|
+
def use_remote_inference(self, disable_remote_inference: bool) -> bool:
|
183
|
+
if disable_remote_inference:
|
184
|
+
return False
|
185
|
+
if not disable_remote_inference:
|
186
|
+
try:
|
187
|
+
from edsl import Coop
|
188
|
+
|
189
|
+
user_edsl_settings = Coop().edsl_settings
|
190
|
+
return user_edsl_settings.get("remote_inference", False)
|
191
|
+
except requests.ConnectionError:
|
192
|
+
pass
|
193
|
+
except CoopServerResponseError as e:
|
194
|
+
pass
|
195
|
+
|
196
|
+
return False
|
197
|
+
|
198
|
+
async def create_and_poll_remote_job(
|
199
|
+
self,
|
200
|
+
iterations: int = 1,
|
201
|
+
remote_inference_description: Optional[str] = None,
|
202
|
+
remote_inference_results_visibility: Optional[
|
203
|
+
Literal["private", "public", "unlisted"]
|
204
|
+
] = "unlisted",
|
205
|
+
) -> Union[Results, None]:
|
206
|
+
"""
|
207
|
+
Creates and polls a remote inference job asynchronously.
|
208
|
+
Reuses existing synchronous methods but runs them in an async context.
|
209
|
+
|
210
|
+
:param iterations: Number of times to run each interview
|
211
|
+
:param remote_inference_description: Optional description for the remote job
|
212
|
+
:param remote_inference_results_visibility: Visibility setting for results
|
213
|
+
:return: Results object if successful, None if job fails or is cancelled
|
214
|
+
"""
|
215
|
+
import asyncio
|
216
|
+
from functools import partial
|
217
|
+
|
218
|
+
# Create job using existing method
|
219
|
+
loop = asyncio.get_event_loop()
|
220
|
+
remote_job_creation_data = await loop.run_in_executor(
|
221
|
+
None,
|
222
|
+
partial(
|
223
|
+
self.create_remote_inference_job,
|
224
|
+
iterations=iterations,
|
225
|
+
remote_inference_description=remote_inference_description,
|
226
|
+
remote_inference_results_visibility=remote_inference_results_visibility,
|
227
|
+
),
|
228
|
+
)
|
229
|
+
|
230
|
+
# Poll using existing method but with async sleep
|
231
|
+
return await loop.run_in_executor(
|
232
|
+
None, partial(self.poll_remote_inference_job, remote_job_creation_data)
|
233
|
+
)
|
234
|
+
|
235
|
+
|
236
|
+
if __name__ == "__main__":
|
237
|
+
import doctest
|
238
|
+
|
239
|
+
doctest.testmod(optionflags=doctest.ELLIPSIS)
|
@@ -110,9 +110,9 @@ class Interview:
|
|
110
110
|
self.debug = debug
|
111
111
|
self.iteration = iteration
|
112
112
|
self.cache = cache
|
113
|
-
self.answers: dict[
|
114
|
-
|
115
|
-
) # will get filled in as interview progresses
|
113
|
+
self.answers: dict[
|
114
|
+
str, str
|
115
|
+
] = Answers() # will get filled in as interview progresses
|
116
116
|
self.sidecar_model = sidecar_model
|
117
117
|
|
118
118
|
# Trackers
|
@@ -143,9 +143,9 @@ class Interview:
|
|
143
143
|
The keys are the question names; the values are the lists of status log changes for each task.
|
144
144
|
"""
|
145
145
|
for task_creator in self.task_creators.values():
|
146
|
-
self._task_status_log_dict[
|
147
|
-
task_creator.
|
148
|
-
|
146
|
+
self._task_status_log_dict[
|
147
|
+
task_creator.question.question_name
|
148
|
+
] = task_creator.status_log
|
149
149
|
return self._task_status_log_dict
|
150
150
|
|
151
151
|
@property
|
@@ -486,11 +486,11 @@ class Interview:
|
|
486
486
|
"""
|
487
487
|
current_question_index: int = self.to_index[current_question.question_name]
|
488
488
|
|
489
|
-
next_question: Union[
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
489
|
+
next_question: Union[
|
490
|
+
int, EndOfSurvey
|
491
|
+
] = self.survey.rule_collection.next_question(
|
492
|
+
q_now=current_question_index,
|
493
|
+
answers=self.answers | self.scenario | self.agent["traits"],
|
494
494
|
)
|
495
495
|
|
496
496
|
next_question_index = next_question.next_q
|