edsl 0.1.37.dev5__py3-none-any.whl → 0.1.38__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +63 -34
- edsl/BaseDiff.py +7 -7
- edsl/__init__.py +2 -1
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +23 -11
- edsl/agents/AgentList.py +86 -23
- edsl/agents/Invigilator.py +18 -7
- edsl/agents/InvigilatorBase.py +0 -19
- edsl/agents/PromptConstructor.py +5 -4
- edsl/auto/SurveyCreatorPipeline.py +1 -1
- edsl/auto/utilities.py +1 -1
- edsl/base/Base.py +3 -13
- edsl/config.py +8 -0
- edsl/coop/coop.py +89 -19
- edsl/data/Cache.py +45 -17
- edsl/data/CacheEntry.py +8 -3
- edsl/data/RemoteCacheSync.py +0 -19
- edsl/enums.py +2 -0
- edsl/exceptions/agents.py +4 -0
- edsl/exceptions/cache.py +5 -0
- edsl/inference_services/GoogleService.py +7 -15
- edsl/inference_services/PerplexityService.py +163 -0
- edsl/inference_services/registry.py +2 -0
- edsl/jobs/Jobs.py +110 -559
- edsl/jobs/JobsChecks.py +147 -0
- edsl/jobs/JobsPrompts.py +268 -0
- edsl/jobs/JobsRemoteInferenceHandler.py +239 -0
- edsl/jobs/buckets/TokenBucket.py +3 -0
- edsl/jobs/interviews/Interview.py +7 -7
- edsl/jobs/runners/JobsRunnerAsyncio.py +156 -28
- edsl/jobs/runners/JobsRunnerStatus.py +194 -196
- edsl/jobs/tasks/TaskHistory.py +27 -19
- edsl/language_models/LanguageModel.py +52 -90
- edsl/language_models/ModelList.py +67 -14
- edsl/language_models/registry.py +57 -4
- edsl/notebooks/Notebook.py +7 -8
- edsl/prompts/Prompt.py +8 -3
- edsl/questions/QuestionBase.py +38 -30
- edsl/questions/QuestionBaseGenMixin.py +1 -1
- edsl/questions/QuestionBasePromptsMixin.py +0 -17
- edsl/questions/QuestionExtract.py +3 -4
- edsl/questions/QuestionFunctional.py +10 -3
- edsl/questions/derived/QuestionTopK.py +2 -0
- edsl/questions/question_registry.py +36 -6
- edsl/results/CSSParameterizer.py +108 -0
- edsl/results/Dataset.py +146 -15
- edsl/results/DatasetExportMixin.py +231 -217
- edsl/results/DatasetTree.py +134 -4
- edsl/results/Result.py +31 -16
- edsl/results/Results.py +159 -65
- edsl/results/TableDisplay.py +198 -0
- edsl/results/table_display.css +78 -0
- edsl/scenarios/FileStore.py +187 -13
- edsl/scenarios/Scenario.py +73 -18
- edsl/scenarios/ScenarioJoin.py +127 -0
- edsl/scenarios/ScenarioList.py +251 -76
- edsl/surveys/MemoryPlan.py +1 -1
- edsl/surveys/Rule.py +1 -5
- edsl/surveys/RuleCollection.py +1 -1
- edsl/surveys/Survey.py +25 -19
- edsl/surveys/SurveyFlowVisualizationMixin.py +67 -9
- edsl/surveys/instructions/ChangeInstruction.py +9 -7
- edsl/surveys/instructions/Instruction.py +21 -7
- edsl/templates/error_reporting/interview_details.html +3 -3
- edsl/templates/error_reporting/interviews.html +18 -9
- edsl/{conjure → utilities}/naming_utilities.py +1 -1
- edsl/utilities/utilities.py +15 -0
- {edsl-0.1.37.dev5.dist-info → edsl-0.1.38.dist-info}/METADATA +2 -1
- {edsl-0.1.37.dev5.dist-info → edsl-0.1.38.dist-info}/RECORD +71 -77
- edsl/conjure/AgentConstructionMixin.py +0 -160
- edsl/conjure/Conjure.py +0 -62
- edsl/conjure/InputData.py +0 -659
- edsl/conjure/InputDataCSV.py +0 -48
- edsl/conjure/InputDataMixinQuestionStats.py +0 -182
- edsl/conjure/InputDataPyRead.py +0 -91
- edsl/conjure/InputDataSPSS.py +0 -8
- edsl/conjure/InputDataStata.py +0 -8
- edsl/conjure/QuestionOptionMixin.py +0 -76
- edsl/conjure/QuestionTypeMixin.py +0 -23
- edsl/conjure/RawQuestion.py +0 -65
- edsl/conjure/SurveyResponses.py +0 -7
- edsl/conjure/__init__.py +0 -9
- edsl/conjure/examples/placeholder.txt +0 -0
- edsl/conjure/utilities.py +0 -201
- {edsl-0.1.37.dev5.dist-info → edsl-0.1.38.dist-info}/LICENSE +0 -0
- {edsl-0.1.37.dev5.dist-info → edsl-0.1.38.dist-info}/WHEEL +0 -0
edsl/jobs/Jobs.py
CHANGED
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
3
3
|
import warnings
|
4
4
|
import requests
|
5
5
|
from itertools import product
|
6
|
-
from typing import Literal, Optional, Union, Sequence, Generator
|
6
|
+
from typing import Literal, Optional, Union, Sequence, Generator, TYPE_CHECKING
|
7
7
|
|
8
8
|
from edsl.Base import Base
|
9
9
|
|
@@ -11,11 +11,20 @@ from edsl.exceptions import MissingAPIKeyError
|
|
11
11
|
from edsl.jobs.buckets.BucketCollection import BucketCollection
|
12
12
|
from edsl.jobs.interviews.Interview import Interview
|
13
13
|
from edsl.jobs.runners.JobsRunnerAsyncio import JobsRunnerAsyncio
|
14
|
-
from edsl.utilities.decorators import
|
14
|
+
from edsl.utilities.decorators import remove_edsl_version
|
15
15
|
|
16
16
|
from edsl.data.RemoteCacheSync import RemoteCacheSync
|
17
17
|
from edsl.exceptions.coop import CoopServerResponseError
|
18
18
|
|
19
|
+
if TYPE_CHECKING:
|
20
|
+
from edsl.agents.Agent import Agent
|
21
|
+
from edsl.agents.AgentList import AgentList
|
22
|
+
from edsl.language_models.LanguageModel import LanguageModel
|
23
|
+
from edsl.scenarios.Scenario import Scenario
|
24
|
+
from edsl.surveys.Survey import Survey
|
25
|
+
from edsl.results.Results import Results
|
26
|
+
from edsl.results.Dataset import Dataset
|
27
|
+
|
19
28
|
|
20
29
|
class Jobs(Base):
|
21
30
|
"""
|
@@ -24,6 +33,8 @@ class Jobs(Base):
|
|
24
33
|
The `JobsRunner` is chosen by the user, and is stored in the `jobs_runner_name` attribute.
|
25
34
|
"""
|
26
35
|
|
36
|
+
__documentation__ = "https://docs.expectedparrot.com/en/latest/jobs.html"
|
37
|
+
|
27
38
|
def __init__(
|
28
39
|
self,
|
29
40
|
survey: "Survey",
|
@@ -86,8 +97,14 @@ class Jobs(Base):
|
|
86
97
|
@scenarios.setter
|
87
98
|
def scenarios(self, value):
|
88
99
|
from edsl import ScenarioList
|
100
|
+
from edsl.results.Dataset import Dataset
|
89
101
|
|
90
102
|
if value:
|
103
|
+
if isinstance(
|
104
|
+
value, Dataset
|
105
|
+
): # if the user passes in a Dataset, convert it to a ScenarioList
|
106
|
+
value = value.to_scenario_list()
|
107
|
+
|
91
108
|
if not isinstance(value, ScenarioList):
|
92
109
|
self._scenarios = ScenarioList(value)
|
93
110
|
else:
|
@@ -127,6 +144,13 @@ class Jobs(Base):
|
|
127
144
|
- scenarios: traits of new scenarios are combined with traits of old existing. New scenarios will overwrite overlapping traits, and do not increase the number of scenarios in the instance
|
128
145
|
- models: new models overwrite old models.
|
129
146
|
"""
|
147
|
+
from edsl.results.Dataset import Dataset
|
148
|
+
|
149
|
+
if isinstance(
|
150
|
+
args[0], Dataset
|
151
|
+
): # let the user user a Dataset as if it were a ScenarioList
|
152
|
+
args = args[0].to_scenario_list()
|
153
|
+
|
130
154
|
passed_objects = self._turn_args_to_list(
|
131
155
|
args
|
132
156
|
) # objects can also be passed comma-separated
|
@@ -151,73 +175,19 @@ class Jobs(Base):
|
|
151
175
|
>>> Jobs.example().prompts()
|
152
176
|
Dataset(...)
|
153
177
|
"""
|
154
|
-
from edsl import
|
155
|
-
|
156
|
-
c = Coop()
|
157
|
-
price_lookup = c.fetch_prices()
|
158
|
-
|
159
|
-
interviews = self.interviews()
|
160
|
-
# data = []
|
161
|
-
interview_indices = []
|
162
|
-
question_names = []
|
163
|
-
user_prompts = []
|
164
|
-
system_prompts = []
|
165
|
-
scenario_indices = []
|
166
|
-
agent_indices = []
|
167
|
-
models = []
|
168
|
-
costs = []
|
169
|
-
from edsl.results.Dataset import Dataset
|
178
|
+
from edsl.jobs.JobsPrompts import JobsPrompts
|
170
179
|
|
171
|
-
|
172
|
-
|
173
|
-
interview._get_invigilator(question)
|
174
|
-
for question in self.survey.questions
|
175
|
-
]
|
176
|
-
for _, invigilator in enumerate(invigilators):
|
177
|
-
prompts = invigilator.get_prompts()
|
178
|
-
user_prompt = prompts["user_prompt"]
|
179
|
-
system_prompt = prompts["system_prompt"]
|
180
|
-
user_prompts.append(user_prompt)
|
181
|
-
system_prompts.append(system_prompt)
|
182
|
-
agent_index = self.agents.index(invigilator.agent)
|
183
|
-
agent_indices.append(agent_index)
|
184
|
-
interview_indices.append(interview_index)
|
185
|
-
scenario_index = self.scenarios.index(invigilator.scenario)
|
186
|
-
scenario_indices.append(scenario_index)
|
187
|
-
models.append(invigilator.model.model)
|
188
|
-
question_names.append(invigilator.question.question_name)
|
189
|
-
|
190
|
-
prompt_cost = self.estimate_prompt_cost(
|
191
|
-
system_prompt=system_prompt,
|
192
|
-
user_prompt=user_prompt,
|
193
|
-
price_lookup=price_lookup,
|
194
|
-
inference_service=invigilator.model._inference_service_,
|
195
|
-
model=invigilator.model.model,
|
196
|
-
)
|
197
|
-
costs.append(prompt_cost["cost_usd"])
|
180
|
+
j = JobsPrompts(self)
|
181
|
+
return j.prompts()
|
198
182
|
|
199
|
-
|
200
|
-
[
|
201
|
-
{"user_prompt": user_prompts},
|
202
|
-
{"system_prompt": system_prompts},
|
203
|
-
{"interview_index": interview_indices},
|
204
|
-
{"question_name": question_names},
|
205
|
-
{"scenario_index": scenario_indices},
|
206
|
-
{"agent_index": agent_indices},
|
207
|
-
{"model": models},
|
208
|
-
{"estimated_cost": costs},
|
209
|
-
]
|
210
|
-
)
|
211
|
-
return d
|
212
|
-
|
213
|
-
def show_prompts(self, all=False, max_rows: Optional[int] = None) -> None:
|
183
|
+
def show_prompts(self, all=False) -> None:
|
214
184
|
"""Print the prompts."""
|
215
185
|
if all:
|
216
|
-
self.prompts().to_scenario_list().
|
186
|
+
return self.prompts().to_scenario_list().table()
|
217
187
|
else:
|
218
|
-
|
219
|
-
"user_prompt", "system_prompt"
|
220
|
-
)
|
188
|
+
return (
|
189
|
+
self.prompts().to_scenario_list().table("user_prompt", "system_prompt")
|
190
|
+
)
|
221
191
|
|
222
192
|
@staticmethod
|
223
193
|
def estimate_prompt_cost(
|
@@ -226,201 +196,42 @@ class Jobs(Base):
|
|
226
196
|
price_lookup: dict,
|
227
197
|
inference_service: str,
|
228
198
|
model: str,
|
229
|
-
) -> dict:
|
230
|
-
"""Estimates the cost of a prompt. Takes piping into account."""
|
231
|
-
import math
|
232
|
-
|
233
|
-
def get_piping_multiplier(prompt: str):
|
234
|
-
"""Returns 2 if a prompt includes Jinja braces, and 1 otherwise."""
|
235
|
-
|
236
|
-
if "{{" in prompt and "}}" in prompt:
|
237
|
-
return 2
|
238
|
-
return 1
|
239
|
-
|
240
|
-
# Look up prices per token
|
241
|
-
key = (inference_service, model)
|
242
|
-
|
243
|
-
try:
|
244
|
-
relevant_prices = price_lookup[key]
|
245
|
-
|
246
|
-
service_input_token_price = float(
|
247
|
-
relevant_prices["input"]["service_stated_token_price"]
|
248
|
-
)
|
249
|
-
service_input_token_qty = float(
|
250
|
-
relevant_prices["input"]["service_stated_token_qty"]
|
251
|
-
)
|
252
|
-
input_price_per_token = service_input_token_price / service_input_token_qty
|
253
|
-
|
254
|
-
service_output_token_price = float(
|
255
|
-
relevant_prices["output"]["service_stated_token_price"]
|
256
|
-
)
|
257
|
-
service_output_token_qty = float(
|
258
|
-
relevant_prices["output"]["service_stated_token_qty"]
|
259
|
-
)
|
260
|
-
output_price_per_token = (
|
261
|
-
service_output_token_price / service_output_token_qty
|
262
|
-
)
|
263
|
-
|
264
|
-
except KeyError:
|
265
|
-
# A KeyError is likely to occur if we cannot retrieve prices (the price_lookup dict is empty)
|
266
|
-
# Use a sensible default
|
267
|
-
|
268
|
-
import warnings
|
269
|
-
|
270
|
-
warnings.warn(
|
271
|
-
"Price data could not be retrieved. Using default estimates for input and output token prices. Input: $0.15 / 1M tokens; Output: $0.60 / 1M tokens"
|
272
|
-
)
|
273
|
-
input_price_per_token = 0.00000015 # $0.15 / 1M tokens
|
274
|
-
output_price_per_token = 0.00000060 # $0.60 / 1M tokens
|
275
|
-
|
276
|
-
# Compute the number of characters (double if the question involves piping)
|
277
|
-
user_prompt_chars = len(str(user_prompt)) * get_piping_multiplier(
|
278
|
-
str(user_prompt)
|
279
|
-
)
|
280
|
-
system_prompt_chars = len(str(system_prompt)) * get_piping_multiplier(
|
281
|
-
str(system_prompt)
|
282
|
-
)
|
283
|
-
|
284
|
-
# Convert into tokens (1 token approx. equals 4 characters)
|
285
|
-
input_tokens = (user_prompt_chars + system_prompt_chars) // 4
|
286
|
-
|
287
|
-
output_tokens = math.ceil(0.75 * input_tokens)
|
288
|
-
|
289
|
-
cost = (
|
290
|
-
input_tokens * input_price_per_token
|
291
|
-
+ output_tokens * output_price_per_token
|
292
|
-
)
|
293
|
-
|
294
|
-
return {
|
295
|
-
"input_tokens": input_tokens,
|
296
|
-
"output_tokens": output_tokens,
|
297
|
-
"cost_usd": cost,
|
298
|
-
}
|
299
|
-
|
300
|
-
def estimate_job_cost_from_external_prices(
|
301
|
-
self, price_lookup: dict, iterations: int = 1
|
302
199
|
) -> dict:
|
303
200
|
"""
|
304
|
-
|
305
|
-
|
306
|
-
- 1 token = 4 characters.
|
307
|
-
- For each prompt, output tokens = input tokens * 0.75, rounded up to the nearest integer.
|
308
|
-
|
309
|
-
price_lookup is an external pricing dictionary.
|
201
|
+
Estimate the cost of running the prompts.
|
202
|
+
:param iterations: the number of iterations to run
|
310
203
|
"""
|
204
|
+
from edsl.jobs.JobsPrompts import JobsPrompts
|
311
205
|
|
312
|
-
|
313
|
-
|
314
|
-
interviews = self.interviews()
|
315
|
-
data = []
|
316
|
-
for interview in interviews:
|
317
|
-
invigilators = [
|
318
|
-
interview._get_invigilator(question)
|
319
|
-
for question in self.survey.questions
|
320
|
-
]
|
321
|
-
for invigilator in invigilators:
|
322
|
-
prompts = invigilator.get_prompts()
|
323
|
-
|
324
|
-
# By this point, agent and scenario data has already been added to the prompts
|
325
|
-
user_prompt = prompts["user_prompt"]
|
326
|
-
system_prompt = prompts["system_prompt"]
|
327
|
-
inference_service = invigilator.model._inference_service_
|
328
|
-
model = invigilator.model.model
|
329
|
-
|
330
|
-
prompt_cost = self.estimate_prompt_cost(
|
331
|
-
system_prompt=system_prompt,
|
332
|
-
user_prompt=user_prompt,
|
333
|
-
price_lookup=price_lookup,
|
334
|
-
inference_service=inference_service,
|
335
|
-
model=model,
|
336
|
-
)
|
337
|
-
|
338
|
-
data.append(
|
339
|
-
{
|
340
|
-
"user_prompt": user_prompt,
|
341
|
-
"system_prompt": system_prompt,
|
342
|
-
"estimated_input_tokens": prompt_cost["input_tokens"],
|
343
|
-
"estimated_output_tokens": prompt_cost["output_tokens"],
|
344
|
-
"estimated_cost_usd": prompt_cost["cost_usd"],
|
345
|
-
"inference_service": inference_service,
|
346
|
-
"model": model,
|
347
|
-
}
|
348
|
-
)
|
349
|
-
|
350
|
-
df = pd.DataFrame.from_records(data)
|
351
|
-
|
352
|
-
df = (
|
353
|
-
df.groupby(["inference_service", "model"])
|
354
|
-
.agg(
|
355
|
-
{
|
356
|
-
"estimated_cost_usd": "sum",
|
357
|
-
"estimated_input_tokens": "sum",
|
358
|
-
"estimated_output_tokens": "sum",
|
359
|
-
}
|
360
|
-
)
|
361
|
-
.reset_index()
|
206
|
+
return JobsPrompts.estimate_prompt_cost(
|
207
|
+
system_prompt, user_prompt, price_lookup, inference_service, model
|
362
208
|
)
|
363
|
-
df["estimated_cost_usd"] = df["estimated_cost_usd"] * iterations
|
364
|
-
df["estimated_input_tokens"] = df["estimated_input_tokens"] * iterations
|
365
|
-
df["estimated_output_tokens"] = df["estimated_output_tokens"] * iterations
|
366
|
-
|
367
|
-
estimated_costs_by_model = df.to_dict("records")
|
368
|
-
|
369
|
-
estimated_total_cost = sum(
|
370
|
-
model["estimated_cost_usd"] for model in estimated_costs_by_model
|
371
|
-
)
|
372
|
-
estimated_total_input_tokens = sum(
|
373
|
-
model["estimated_input_tokens"] for model in estimated_costs_by_model
|
374
|
-
)
|
375
|
-
estimated_total_output_tokens = sum(
|
376
|
-
model["estimated_output_tokens"] for model in estimated_costs_by_model
|
377
|
-
)
|
378
|
-
|
379
|
-
output = {
|
380
|
-
"estimated_total_cost_usd": estimated_total_cost,
|
381
|
-
"estimated_total_input_tokens": estimated_total_input_tokens,
|
382
|
-
"estimated_total_output_tokens": estimated_total_output_tokens,
|
383
|
-
"model_costs": estimated_costs_by_model,
|
384
|
-
}
|
385
|
-
|
386
|
-
return output
|
387
209
|
|
388
210
|
def estimate_job_cost(self, iterations: int = 1) -> dict:
|
389
211
|
"""
|
390
|
-
|
391
|
-
|
392
|
-
- 1 token = 4 characters.
|
393
|
-
- For each prompt, output tokens = input tokens * 0.75, rounded up to the nearest integer.
|
212
|
+
Estimate the cost of running the job.
|
394
213
|
|
395
|
-
|
214
|
+
:param iterations: the number of iterations to run
|
396
215
|
"""
|
397
|
-
from edsl import
|
216
|
+
from edsl.jobs.JobsPrompts import JobsPrompts
|
398
217
|
|
399
|
-
|
400
|
-
|
218
|
+
j = JobsPrompts(self)
|
219
|
+
return j.estimate_job_cost(iterations)
|
401
220
|
|
402
|
-
|
403
|
-
|
404
|
-
|
221
|
+
def estimate_job_cost_from_external_prices(
|
222
|
+
self, price_lookup: dict, iterations: int = 1
|
223
|
+
) -> dict:
|
224
|
+
from edsl.jobs.JobsPrompts import JobsPrompts
|
225
|
+
|
226
|
+
j = JobsPrompts(self)
|
227
|
+
return j.estimate_job_cost_from_external_prices(price_lookup, iterations)
|
405
228
|
|
406
229
|
@staticmethod
|
407
|
-
def compute_job_cost(job_results:
|
230
|
+
def compute_job_cost(job_results: Results) -> float:
|
408
231
|
"""
|
409
232
|
Computes the cost of a completed job in USD.
|
410
233
|
"""
|
411
|
-
|
412
|
-
for result in job_results:
|
413
|
-
for key in result.raw_model_response:
|
414
|
-
if key.endswith("_cost"):
|
415
|
-
result_cost = result.raw_model_response[key]
|
416
|
-
|
417
|
-
question_name = key.removesuffix("_cost")
|
418
|
-
cache_used = result.cache_used_dict[question_name]
|
419
|
-
|
420
|
-
if isinstance(result_cost, (int, float)) and not cache_used:
|
421
|
-
total_cost += result_cost
|
422
|
-
|
423
|
-
return total_cost
|
234
|
+
return job_results.compute_job_cost()
|
424
235
|
|
425
236
|
@staticmethod
|
426
237
|
def _get_container_class(object):
|
@@ -504,17 +315,12 @@ class Jobs(Base):
|
|
504
315
|
|
505
316
|
@staticmethod
|
506
317
|
def _get_empty_container_object(object):
|
507
|
-
from edsl import AgentList
|
508
|
-
from edsl import
|
509
|
-
from edsl import Scenario
|
510
|
-
from edsl import ScenarioList
|
318
|
+
from edsl.agents.AgentList import AgentList
|
319
|
+
from edsl.scenarios.ScenarioList import ScenarioList
|
511
320
|
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
return ScenarioList([])
|
516
|
-
else:
|
517
|
-
return []
|
321
|
+
return {"Agent": AgentList([]), "Scenario": ScenarioList([])}.get(
|
322
|
+
object.__class__.__name__, []
|
323
|
+
)
|
518
324
|
|
519
325
|
@staticmethod
|
520
326
|
def _merge_objects(passed_objects, current_objects) -> list:
|
@@ -641,7 +447,7 @@ class Jobs(Base):
|
|
641
447
|
"""
|
642
448
|
from edsl.utilities.utilities import dict_hash
|
643
449
|
|
644
|
-
return dict_hash(self.
|
450
|
+
return dict_hash(self.to_dict(add_edsl_version=False))
|
645
451
|
|
646
452
|
def _output(self, message) -> None:
|
647
453
|
"""Check if a Job is verbose. If so, print the message."""
|
@@ -722,110 +528,6 @@ class Jobs(Base):
|
|
722
528
|
return False
|
723
529
|
return self._raise_validation_errors
|
724
530
|
|
725
|
-
def create_remote_inference_job(
|
726
|
-
self,
|
727
|
-
iterations: int = 1,
|
728
|
-
remote_inference_description: Optional[str] = None,
|
729
|
-
remote_inference_results_visibility: Optional[VisibilityType] = "unlisted",
|
730
|
-
verbose=False,
|
731
|
-
):
|
732
|
-
""" """
|
733
|
-
from edsl.coop.coop import Coop
|
734
|
-
|
735
|
-
coop = Coop()
|
736
|
-
self._output("Remote inference activated. Sending job to server...")
|
737
|
-
remote_job_creation_data = coop.remote_inference_create(
|
738
|
-
self,
|
739
|
-
description=remote_inference_description,
|
740
|
-
status="queued",
|
741
|
-
iterations=iterations,
|
742
|
-
initial_results_visibility=remote_inference_results_visibility,
|
743
|
-
)
|
744
|
-
job_uuid = remote_job_creation_data.get("uuid")
|
745
|
-
if self.verbose:
|
746
|
-
print(f"Job sent to server. (Job uuid={job_uuid}).")
|
747
|
-
return remote_job_creation_data
|
748
|
-
|
749
|
-
@staticmethod
|
750
|
-
def check_status(job_uuid):
|
751
|
-
from edsl.coop.coop import Coop
|
752
|
-
|
753
|
-
coop = Coop()
|
754
|
-
return coop.remote_inference_get(job_uuid)
|
755
|
-
|
756
|
-
def poll_remote_inference_job(
|
757
|
-
self, remote_job_creation_data: dict, verbose=False, poll_interval=5
|
758
|
-
) -> Union[Results, None]:
|
759
|
-
from edsl.coop.coop import Coop
|
760
|
-
import time
|
761
|
-
from datetime import datetime
|
762
|
-
from edsl.config import CONFIG
|
763
|
-
|
764
|
-
expected_parrot_url = CONFIG.get("EXPECTED_PARROT_URL")
|
765
|
-
|
766
|
-
job_uuid = remote_job_creation_data.get("uuid")
|
767
|
-
|
768
|
-
coop = Coop()
|
769
|
-
job_in_queue = True
|
770
|
-
while job_in_queue:
|
771
|
-
remote_job_data = coop.remote_inference_get(job_uuid)
|
772
|
-
status = remote_job_data.get("status")
|
773
|
-
if status == "cancelled":
|
774
|
-
if self.verbose:
|
775
|
-
print("\r" + " " * 80 + "\r", end="")
|
776
|
-
print("Job cancelled by the user.")
|
777
|
-
print(
|
778
|
-
f"See {expected_parrot_url}/home/remote-inference for more details."
|
779
|
-
)
|
780
|
-
return None
|
781
|
-
elif status == "failed":
|
782
|
-
if self.verbose:
|
783
|
-
print("\r" + " " * 80 + "\r", end="")
|
784
|
-
print("Job failed.")
|
785
|
-
print(
|
786
|
-
f"See {expected_parrot_url}/home/remote-inference for more details."
|
787
|
-
)
|
788
|
-
return None
|
789
|
-
elif status == "completed":
|
790
|
-
results_uuid = remote_job_data.get("results_uuid")
|
791
|
-
results = coop.get(results_uuid, expected_object_type="results")
|
792
|
-
if self.verbose:
|
793
|
-
print("\r" + " " * 80 + "\r", end="")
|
794
|
-
url = f"{expected_parrot_url}/content/{results_uuid}"
|
795
|
-
print(f"Job completed and Results stored on Coop: {url}.")
|
796
|
-
return results
|
797
|
-
else:
|
798
|
-
duration = poll_interval
|
799
|
-
time_checked = datetime.now().strftime("%Y-%m-%d %I:%M:%S %p")
|
800
|
-
frames = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
|
801
|
-
start_time = time.time()
|
802
|
-
i = 0
|
803
|
-
while time.time() - start_time < duration:
|
804
|
-
if self.verbose:
|
805
|
-
print(
|
806
|
-
f"\r{frames[i % len(frames)]} Job status: {status} - last update: {time_checked}",
|
807
|
-
end="",
|
808
|
-
flush=True,
|
809
|
-
)
|
810
|
-
time.sleep(0.1)
|
811
|
-
i += 1
|
812
|
-
|
813
|
-
def use_remote_inference(self, disable_remote_inference: bool) -> bool:
|
814
|
-
if disable_remote_inference:
|
815
|
-
return False
|
816
|
-
if not disable_remote_inference:
|
817
|
-
try:
|
818
|
-
from edsl import Coop
|
819
|
-
|
820
|
-
user_edsl_settings = Coop().edsl_settings
|
821
|
-
return user_edsl_settings.get("remote_inference", False)
|
822
|
-
except requests.ConnectionError:
|
823
|
-
pass
|
824
|
-
except CoopServerResponseError as e:
|
825
|
-
pass
|
826
|
-
|
827
|
-
return False
|
828
|
-
|
829
531
|
def use_remote_cache(self, disable_remote_cache: bool) -> bool:
|
830
532
|
if disable_remote_cache:
|
831
533
|
return False
|
@@ -842,96 +544,6 @@ class Jobs(Base):
|
|
842
544
|
|
843
545
|
return False
|
844
546
|
|
845
|
-
def check_api_keys(self) -> None:
|
846
|
-
from edsl import Model
|
847
|
-
|
848
|
-
for model in self.models + [Model()]:
|
849
|
-
if not model.has_valid_api_key():
|
850
|
-
raise MissingAPIKeyError(
|
851
|
-
model_name=str(model.model),
|
852
|
-
inference_service=model._inference_service_,
|
853
|
-
)
|
854
|
-
|
855
|
-
def get_missing_api_keys(self) -> set:
|
856
|
-
"""
|
857
|
-
Returns a list of the api keys that a user needs to run this job, but does not currently have in their .env file.
|
858
|
-
"""
|
859
|
-
|
860
|
-
missing_api_keys = set()
|
861
|
-
|
862
|
-
from edsl import Model
|
863
|
-
from edsl.enums import service_to_api_keyname
|
864
|
-
|
865
|
-
for model in self.models + [Model()]:
|
866
|
-
if not model.has_valid_api_key():
|
867
|
-
key_name = service_to_api_keyname.get(
|
868
|
-
model._inference_service_, "NOT FOUND"
|
869
|
-
)
|
870
|
-
missing_api_keys.add(key_name)
|
871
|
-
|
872
|
-
return missing_api_keys
|
873
|
-
|
874
|
-
def user_has_all_model_keys(self):
|
875
|
-
"""
|
876
|
-
Returns True if the user has all model keys required to run their job.
|
877
|
-
|
878
|
-
Otherwise, returns False.
|
879
|
-
"""
|
880
|
-
|
881
|
-
try:
|
882
|
-
self.check_api_keys()
|
883
|
-
return True
|
884
|
-
except MissingAPIKeyError:
|
885
|
-
return False
|
886
|
-
except Exception:
|
887
|
-
raise
|
888
|
-
|
889
|
-
def user_has_ep_api_key(self) -> bool:
|
890
|
-
"""
|
891
|
-
Returns True if the user has an EXPECTED_PARROT_API_KEY in their env.
|
892
|
-
|
893
|
-
Otherwise, returns False.
|
894
|
-
"""
|
895
|
-
|
896
|
-
import os
|
897
|
-
|
898
|
-
coop_api_key = os.getenv("EXPECTED_PARROT_API_KEY")
|
899
|
-
|
900
|
-
if coop_api_key is not None:
|
901
|
-
return True
|
902
|
-
else:
|
903
|
-
return False
|
904
|
-
|
905
|
-
def needs_external_llms(self) -> bool:
|
906
|
-
"""
|
907
|
-
Returns True if the job needs external LLMs to run.
|
908
|
-
|
909
|
-
Otherwise, returns False.
|
910
|
-
"""
|
911
|
-
# These cases are necessary to skip the API key check during doctests
|
912
|
-
|
913
|
-
# Accounts for Results.example()
|
914
|
-
all_agents_answer_questions_directly = len(self.agents) > 0 and all(
|
915
|
-
[hasattr(a, "answer_question_directly") for a in self.agents]
|
916
|
-
)
|
917
|
-
|
918
|
-
# Accounts for InterviewExceptionEntry.example()
|
919
|
-
only_model_is_test = set([m.model for m in self.models]) == set(["test"])
|
920
|
-
|
921
|
-
# Accounts for Survey.__call__
|
922
|
-
all_questions_are_functional = set(
|
923
|
-
[q.question_type for q in self.survey.questions]
|
924
|
-
) == set(["functional"])
|
925
|
-
|
926
|
-
if (
|
927
|
-
all_agents_answer_questions_directly
|
928
|
-
or only_model_is_test
|
929
|
-
or all_questions_are_functional
|
930
|
-
):
|
931
|
-
return False
|
932
|
-
else:
|
933
|
-
return True
|
934
|
-
|
935
547
|
def run(
|
936
548
|
self,
|
937
549
|
n: int = 1,
|
@@ -940,7 +552,7 @@ class Jobs(Base):
|
|
940
552
|
cache: Union[Cache, bool] = None,
|
941
553
|
check_api_keys: bool = False,
|
942
554
|
sidecar_model: Optional[LanguageModel] = None,
|
943
|
-
verbose: bool =
|
555
|
+
verbose: bool = True,
|
944
556
|
print_exceptions=True,
|
945
557
|
remote_cache_description: Optional[str] = None,
|
946
558
|
remote_inference_description: Optional[str] = None,
|
@@ -975,62 +587,28 @@ class Jobs(Base):
|
|
975
587
|
|
976
588
|
self.verbose = verbose
|
977
589
|
|
978
|
-
|
979
|
-
not self.user_has_all_model_keys()
|
980
|
-
and not self.user_has_ep_api_key()
|
981
|
-
and self.needs_external_llms()
|
982
|
-
):
|
983
|
-
import secrets
|
984
|
-
from dotenv import load_dotenv
|
985
|
-
from edsl import CONFIG
|
986
|
-
from edsl.coop.coop import Coop
|
987
|
-
from edsl.utilities.utilities import write_api_key_to_env
|
988
|
-
|
989
|
-
missing_api_keys = self.get_missing_api_keys()
|
590
|
+
from edsl.jobs.JobsChecks import JobsChecks
|
990
591
|
|
991
|
-
|
592
|
+
jc = JobsChecks(self)
|
992
593
|
|
993
|
-
|
994
|
-
|
995
|
-
|
996
|
-
print(
|
997
|
-
"\nYou can either add the missing keys to your .env file, or use remote inference."
|
998
|
-
)
|
999
|
-
print("Remote inference allows you to run jobs on our server.")
|
1000
|
-
print("\n🚀 To use remote inference, sign up at the following link:")
|
1001
|
-
|
1002
|
-
coop = Coop()
|
1003
|
-
coop._display_login_url(edsl_auth_token=edsl_auth_token)
|
1004
|
-
|
1005
|
-
print(
|
1006
|
-
"\nOnce you log in, we will automatically retrieve your Expected Parrot API key and continue your job remotely."
|
1007
|
-
)
|
1008
|
-
|
1009
|
-
api_key = coop._poll_for_api_key(edsl_auth_token)
|
1010
|
-
|
1011
|
-
if api_key is None:
|
1012
|
-
print("\nTimed out waiting for login. Please try again.")
|
1013
|
-
return
|
594
|
+
# check if the user has all the keys they need
|
595
|
+
if jc.needs_key_process():
|
596
|
+
jc.key_process()
|
1014
597
|
|
1015
|
-
|
1016
|
-
print("✨ API key retrieved and written to .env file.\n")
|
598
|
+
from edsl.jobs.JobsRemoteInferenceHandler import JobsRemoteInferenceHandler
|
1017
599
|
|
1018
|
-
|
1019
|
-
|
1020
|
-
|
1021
|
-
if remote_inference := self.use_remote_inference(disable_remote_inference):
|
1022
|
-
remote_job_creation_data = self.create_remote_inference_job(
|
600
|
+
jh = JobsRemoteInferenceHandler(self, verbose=verbose)
|
601
|
+
if jh.use_remote_inference(disable_remote_inference):
|
602
|
+
jh.create_remote_inference_job(
|
1023
603
|
iterations=n,
|
1024
604
|
remote_inference_description=remote_inference_description,
|
1025
605
|
remote_inference_results_visibility=remote_inference_results_visibility,
|
1026
606
|
)
|
1027
|
-
results =
|
1028
|
-
if results is None:
|
1029
|
-
self._output("Job failed.")
|
607
|
+
results = jh.poll_remote_inference_job()
|
1030
608
|
return results
|
1031
609
|
|
1032
610
|
if check_api_keys:
|
1033
|
-
|
611
|
+
jc.check_api_keys()
|
1034
612
|
|
1035
613
|
# handle cache
|
1036
614
|
if cache is None or cache is True:
|
@@ -1060,46 +638,9 @@ class Jobs(Base):
|
|
1060
638
|
raise_validation_errors=raise_validation_errors,
|
1061
639
|
)
|
1062
640
|
|
1063
|
-
results.cache = cache.new_entries_cache()
|
641
|
+
# results.cache = cache.new_entries_cache()
|
1064
642
|
return results
|
1065
643
|
|
1066
|
-
async def create_and_poll_remote_job(
|
1067
|
-
self,
|
1068
|
-
iterations: int = 1,
|
1069
|
-
remote_inference_description: Optional[str] = None,
|
1070
|
-
remote_inference_results_visibility: Optional[
|
1071
|
-
Literal["private", "public", "unlisted"]
|
1072
|
-
] = "unlisted",
|
1073
|
-
) -> Union[Results, None]:
|
1074
|
-
"""
|
1075
|
-
Creates and polls a remote inference job asynchronously.
|
1076
|
-
Reuses existing synchronous methods but runs them in an async context.
|
1077
|
-
|
1078
|
-
:param iterations: Number of times to run each interview
|
1079
|
-
:param remote_inference_description: Optional description for the remote job
|
1080
|
-
:param remote_inference_results_visibility: Visibility setting for results
|
1081
|
-
:return: Results object if successful, None if job fails or is cancelled
|
1082
|
-
"""
|
1083
|
-
import asyncio
|
1084
|
-
from functools import partial
|
1085
|
-
|
1086
|
-
# Create job using existing method
|
1087
|
-
loop = asyncio.get_event_loop()
|
1088
|
-
remote_job_creation_data = await loop.run_in_executor(
|
1089
|
-
None,
|
1090
|
-
partial(
|
1091
|
-
self.create_remote_inference_job,
|
1092
|
-
iterations=iterations,
|
1093
|
-
remote_inference_description=remote_inference_description,
|
1094
|
-
remote_inference_results_visibility=remote_inference_results_visibility,
|
1095
|
-
),
|
1096
|
-
)
|
1097
|
-
|
1098
|
-
# Poll using existing method but with async sleep
|
1099
|
-
return await loop.run_in_executor(
|
1100
|
-
None, partial(self.poll_remote_inference_job, remote_job_creation_data)
|
1101
|
-
)
|
1102
|
-
|
1103
644
|
async def run_async(
|
1104
645
|
self,
|
1105
646
|
cache=None,
|
@@ -1122,14 +663,15 @@ class Jobs(Base):
|
|
1122
663
|
:return: Results object
|
1123
664
|
"""
|
1124
665
|
# Check if we should use remote inference
|
1125
|
-
|
1126
|
-
|
666
|
+
from edsl.jobs.JobsRemoteInferenceHandler import JobsRemoteInferenceHandler
|
667
|
+
|
668
|
+
jh = JobsRemoteInferenceHandler(self, verbose=False)
|
669
|
+
if jh.use_remote_inference(disable_remote_inference):
|
670
|
+
results = await jh.create_and_poll_remote_job(
|
1127
671
|
iterations=n,
|
1128
672
|
remote_inference_description=remote_inference_description,
|
1129
673
|
remote_inference_results_visibility=remote_inference_results_visibility,
|
1130
674
|
)
|
1131
|
-
if results is None:
|
1132
|
-
self._output("Job failed.")
|
1133
675
|
return results
|
1134
676
|
|
1135
677
|
# If not using remote inference, run locally with async
|
@@ -1149,24 +691,22 @@ class Jobs(Base):
|
|
1149
691
|
"""
|
1150
692
|
return set.union(*[question.parameters for question in self.survey.questions])
|
1151
693
|
|
1152
|
-
#######################
|
1153
|
-
# Dunder methods
|
1154
|
-
#######################
|
1155
|
-
def print(self):
|
1156
|
-
from rich import print_json
|
1157
|
-
import json
|
1158
|
-
|
1159
|
-
print_json(json.dumps(self.to_dict()))
|
1160
|
-
|
1161
694
|
def __repr__(self) -> str:
|
1162
695
|
"""Return an eval-able string representation of the Jobs instance."""
|
1163
696
|
return f"Jobs(survey={repr(self.survey)}, agents={repr(self.agents)}, models={repr(self.models)}, scenarios={repr(self.scenarios)})"
|
1164
697
|
|
1165
|
-
def
|
1166
|
-
|
1167
|
-
|
698
|
+
def _summary(self):
|
699
|
+
return {
|
700
|
+
"EDSL Class": "Jobs",
|
701
|
+
"Number of questions": len(self.survey),
|
702
|
+
"Number of agents": len(self.agents),
|
703
|
+
"Number of models": len(self.models),
|
704
|
+
"Number of scenarios": len(self.scenarios),
|
705
|
+
}
|
1168
706
|
|
1169
|
-
|
707
|
+
def _repr_html_(self) -> str:
|
708
|
+
footer = f"<a href={self.__documentation__}>(docs)</a>"
|
709
|
+
return str(self.summary(format="html")) + footer
|
1170
710
|
|
1171
711
|
def __len__(self) -> int:
|
1172
712
|
"""Return the maximum number of questions that will be asked while running this job.
|
@@ -1188,18 +728,29 @@ class Jobs(Base):
|
|
1188
728
|
# Serialization methods
|
1189
729
|
#######################
|
1190
730
|
|
1191
|
-
def
|
1192
|
-
|
1193
|
-
"survey": self.survey.
|
1194
|
-
"agents": [
|
1195
|
-
|
1196
|
-
|
731
|
+
def to_dict(self, add_edsl_version=True):
|
732
|
+
d = {
|
733
|
+
"survey": self.survey.to_dict(add_edsl_version=add_edsl_version),
|
734
|
+
"agents": [
|
735
|
+
agent.to_dict(add_edsl_version=add_edsl_version)
|
736
|
+
for agent in self.agents
|
737
|
+
],
|
738
|
+
"models": [
|
739
|
+
model.to_dict(add_edsl_version=add_edsl_version)
|
740
|
+
for model in self.models
|
741
|
+
],
|
742
|
+
"scenarios": [
|
743
|
+
scenario.to_dict(add_edsl_version=add_edsl_version)
|
744
|
+
for scenario in self.scenarios
|
745
|
+
],
|
1197
746
|
}
|
747
|
+
if add_edsl_version:
|
748
|
+
from edsl import __version__
|
1198
749
|
|
1199
|
-
|
1200
|
-
|
1201
|
-
|
1202
|
-
return
|
750
|
+
d["edsl_version"] = __version__
|
751
|
+
d["edsl_class_name"] = "Jobs"
|
752
|
+
|
753
|
+
return d
|
1203
754
|
|
1204
755
|
@classmethod
|
1205
756
|
@remove_edsl_version
|
@@ -1225,7 +776,7 @@ class Jobs(Base):
|
|
1225
776
|
True
|
1226
777
|
|
1227
778
|
"""
|
1228
|
-
return self
|
779
|
+
return hash(self) == hash(other)
|
1229
780
|
|
1230
781
|
#######################
|
1231
782
|
# Example methods
|