edsl 0.1.42__py3-none-any.whl → 0.1.43__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__version__.py +1 -1
- edsl/agents/Invigilator.py +1 -1
- edsl/agents/PromptConstructor.py +92 -21
- edsl/agents/QuestionInstructionPromptBuilder.py +68 -9
- edsl/agents/prompt_helpers.py +2 -2
- edsl/coop/coop.py +65 -19
- edsl/enums.py +1 -2
- edsl/exceptions/coop.py +4 -0
- edsl/inference_services/AvailableModelFetcher.py +4 -1
- edsl/jobs/Jobs.py +54 -35
- edsl/jobs/JobsPrompts.py +54 -3
- edsl/jobs/JobsRemoteInferenceHandler.py +41 -25
- edsl/jobs/buckets/BucketCollection.py +30 -0
- edsl/jobs/data_structures.py +1 -0
- edsl/language_models/key_management/KeyLookupBuilder.py +47 -20
- edsl/language_models/key_management/models.py +10 -4
- edsl/prompts/Prompt.py +124 -61
- edsl/questions/descriptors.py +32 -18
- edsl/questions/question_base_gen_mixin.py +1 -0
- edsl/results/DatasetExportMixin.py +35 -6
- edsl/results/Results.py +179 -1
- edsl/results/ResultsGGMixin.py +117 -60
- edsl/scenarios/Scenario.py +33 -0
- edsl/scenarios/ScenarioList.py +22 -3
- edsl/scenarios/ScenarioListPdfMixin.py +9 -3
- {edsl-0.1.42.dist-info → edsl-0.1.43.dist-info}/METADATA +3 -4
- {edsl-0.1.42.dist-info → edsl-0.1.43.dist-info}/RECORD +29 -29
- {edsl-0.1.42.dist-info → edsl-0.1.43.dist-info}/LICENSE +0 -0
- {edsl-0.1.42.dist-info → edsl-0.1.43.dist-info}/WHEEL +0 -0
edsl/__version__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.1.
|
1
|
+
__version__ = "0.1.43"
|
edsl/agents/Invigilator.py
CHANGED
@@ -156,7 +156,7 @@ class InvigilatorAI(InvigilatorBase):
|
|
156
156
|
self.question.question_options = new_question_options
|
157
157
|
|
158
158
|
question_with_validators = self.question.render(
|
159
|
-
self.scenario | prior_answers_dict
|
159
|
+
self.scenario | prior_answers_dict | {'agent':self.agent.traits}
|
160
160
|
)
|
161
161
|
question_with_validators.use_code = self.question.use_code
|
162
162
|
else:
|
edsl/agents/PromptConstructor.py
CHANGED
@@ -1,6 +1,10 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
from typing import Dict, Any, Optional, Set, Union, TYPE_CHECKING
|
2
|
+
from typing import Dict, Any, Optional, Set, Union, TYPE_CHECKING, Literal
|
3
3
|
from functools import cached_property
|
4
|
+
from multiprocessing import Pool, freeze_support, get_context
|
5
|
+
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
|
6
|
+
import time
|
7
|
+
import logging
|
4
8
|
|
5
9
|
from edsl.prompts.Prompt import Prompt
|
6
10
|
|
@@ -22,6 +26,7 @@ if TYPE_CHECKING:
|
|
22
26
|
from edsl.questions.QuestionBase import QuestionBase
|
23
27
|
from edsl.scenarios.Scenario import Scenario
|
24
28
|
|
29
|
+
logger = logging.getLogger(__name__)
|
25
30
|
|
26
31
|
class BasePlaceholder:
|
27
32
|
"""Base class for placeholder values when a question is not yet answered."""
|
@@ -242,31 +247,97 @@ class PromptConstructor:
|
|
242
247
|
question_name, self.current_answers
|
243
248
|
)
|
244
249
|
|
245
|
-
def get_prompts(self) -> Dict[str,
|
246
|
-
"""Get
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
)
|
261
|
-
|
250
|
+
def get_prompts(self, parallel: Literal["thread", "process", None] = None) -> Dict[str, Any]:
|
251
|
+
"""Get the prompts for the question."""
|
252
|
+
start = time.time()
|
253
|
+
|
254
|
+
# Build all the components
|
255
|
+
instr_start = time.time()
|
256
|
+
agent_instructions = self.agent_instructions_prompt
|
257
|
+
instr_end = time.time()
|
258
|
+
logger.debug(f"Time taken for agent instructions: {instr_end - instr_start:.4f}s")
|
259
|
+
|
260
|
+
persona_start = time.time()
|
261
|
+
agent_persona = self.agent_persona_prompt
|
262
|
+
persona_end = time.time()
|
263
|
+
logger.debug(f"Time taken for agent persona: {persona_end - persona_start:.4f}s")
|
264
|
+
|
265
|
+
q_instr_start = time.time()
|
266
|
+
question_instructions = self.question_instructions_prompt
|
267
|
+
q_instr_end = time.time()
|
268
|
+
logger.debug(f"Time taken for question instructions: {q_instr_end - q_instr_start:.4f}s")
|
269
|
+
|
270
|
+
memory_start = time.time()
|
271
|
+
prior_question_memory = self.prior_question_memory_prompt
|
272
|
+
memory_end = time.time()
|
273
|
+
logger.debug(f"Time taken for prior question memory: {memory_end - memory_start:.4f}s")
|
274
|
+
|
275
|
+
# Get components dict
|
276
|
+
components = {
|
277
|
+
"agent_instructions": agent_instructions.text,
|
278
|
+
"agent_persona": agent_persona.text,
|
279
|
+
"question_instructions": question_instructions.text,
|
280
|
+
"prior_question_memory": prior_question_memory.text,
|
281
|
+
}
|
282
|
+
|
283
|
+
# Use PromptPlan's get_prompts method
|
284
|
+
plan_start = time.time()
|
285
|
+
|
286
|
+
# Get arranged components first
|
287
|
+
arranged = self.prompt_plan.arrange_components(**components)
|
288
|
+
|
289
|
+
if parallel == "process":
|
290
|
+
ctx = get_context('fork')
|
291
|
+
with ctx.Pool() as pool:
|
292
|
+
results = pool.map(_process_prompt, [
|
293
|
+
(arranged["user_prompt"], {}),
|
294
|
+
(arranged["system_prompt"], {})
|
295
|
+
])
|
296
|
+
prompts = {
|
297
|
+
"user_prompt": results[0],
|
298
|
+
"system_prompt": results[1]
|
299
|
+
}
|
300
|
+
|
301
|
+
elif parallel == "thread":
|
302
|
+
with ThreadPoolExecutor() as executor:
|
303
|
+
user_prompt_list = arranged["user_prompt"]
|
304
|
+
system_prompt_list = arranged["system_prompt"]
|
305
|
+
|
306
|
+
# Process both prompt lists in parallel
|
307
|
+
rendered_user = executor.submit(_process_prompt, (user_prompt_list, {}))
|
308
|
+
rendered_system = executor.submit(_process_prompt, (system_prompt_list, {}))
|
309
|
+
|
310
|
+
prompts = {
|
311
|
+
"user_prompt": rendered_user.result(),
|
312
|
+
"system_prompt": rendered_system.result()
|
313
|
+
}
|
314
|
+
|
315
|
+
else: # sequential processing
|
316
|
+
prompts = self.prompt_plan.get_prompts(**components)
|
317
|
+
|
318
|
+
plan_end = time.time()
|
319
|
+
logger.debug(f"Time taken for prompt processing: {plan_end - plan_start:.4f}s")
|
320
|
+
|
321
|
+
# Handle file keys if present
|
322
|
+
if hasattr(self, 'question_file_keys') and self.question_file_keys:
|
323
|
+
files_start = time.time()
|
262
324
|
files_list = []
|
263
325
|
for key in self.question_file_keys:
|
264
326
|
files_list.append(self.scenario[key])
|
265
327
|
prompts["files_list"] = files_list
|
328
|
+
files_end = time.time()
|
329
|
+
logger.debug(f"Time taken for file key processing: {files_end - files_start:.4f}s")
|
330
|
+
|
331
|
+
end = time.time()
|
332
|
+
logger.debug(f"Total time in get_prompts: {end - start:.4f}s")
|
266
333
|
return prompts
|
267
334
|
|
268
335
|
|
269
|
-
|
270
|
-
|
336
|
+
def _process_prompt(args):
|
337
|
+
"""Helper function to process a single prompt list with its replacements."""
|
338
|
+
prompt_list, replacements = args
|
339
|
+
return prompt_list.reduce()
|
340
|
+
|
271
341
|
|
272
|
-
|
342
|
+
if __name__ == '__main__':
|
343
|
+
freeze_support()
|
@@ -1,5 +1,6 @@
|
|
1
1
|
from typing import Dict, List, Set
|
2
2
|
from warnings import warn
|
3
|
+
import logging
|
3
4
|
from edsl.prompts.Prompt import Prompt
|
4
5
|
|
5
6
|
from edsl.agents.QuestionTemplateReplacementsBuilder import (
|
@@ -23,12 +24,44 @@ class QuestionInstructionPromptBuilder:
|
|
23
24
|
Returns:
|
24
25
|
Prompt: The fully rendered question instructions
|
25
26
|
"""
|
27
|
+
import time
|
28
|
+
|
29
|
+
start = time.time()
|
30
|
+
|
31
|
+
# Create base prompt
|
32
|
+
base_start = time.time()
|
26
33
|
base_prompt = self._create_base_prompt()
|
34
|
+
base_end = time.time()
|
35
|
+
logging.debug(f"Time for base prompt: {base_end - base_start}")
|
36
|
+
|
37
|
+
# Enrich with options
|
38
|
+
enrich_start = time.time()
|
27
39
|
enriched_prompt = self._enrich_with_question_options(base_prompt)
|
40
|
+
enrich_end = time.time()
|
41
|
+
logging.debug(f"Time for enriching with options: {enrich_end - enrich_start}")
|
42
|
+
|
43
|
+
# Render prompt
|
44
|
+
render_start = time.time()
|
28
45
|
rendered_prompt = self._render_prompt(enriched_prompt)
|
46
|
+
render_end = time.time()
|
47
|
+
logging.debug(f"Time for rendering prompt: {render_end - render_start}")
|
48
|
+
|
49
|
+
# Validate template variables
|
50
|
+
validate_start = time.time()
|
29
51
|
self._validate_template_variables(rendered_prompt)
|
30
|
-
|
31
|
-
|
52
|
+
validate_end = time.time()
|
53
|
+
logging.debug(f"Time for template validation: {validate_end - validate_start}")
|
54
|
+
|
55
|
+
# Append survey instructions
|
56
|
+
append_start = time.time()
|
57
|
+
final_prompt = self._append_survey_instructions(rendered_prompt)
|
58
|
+
append_end = time.time()
|
59
|
+
logging.debug(f"Time for appending survey instructions: {append_end - append_start}")
|
60
|
+
|
61
|
+
end = time.time()
|
62
|
+
logging.debug(f"Total time in build_question_instructions: {end - start}")
|
63
|
+
|
64
|
+
return final_prompt
|
32
65
|
|
33
66
|
def _create_base_prompt(self) -> Dict:
|
34
67
|
"""Creates the initial prompt with basic question data.
|
@@ -50,14 +83,25 @@ class QuestionInstructionPromptBuilder:
|
|
50
83
|
Returns:
|
51
84
|
Dict: Enriched prompt data
|
52
85
|
"""
|
86
|
+
import time
|
87
|
+
|
88
|
+
start = time.time()
|
89
|
+
|
53
90
|
if "question_options" in prompt_data["data"]:
|
54
91
|
from edsl.agents.question_option_processor import QuestionOptionProcessor
|
55
|
-
|
92
|
+
|
93
|
+
processor_start = time.time()
|
56
94
|
question_options = QuestionOptionProcessor(
|
57
95
|
self.prompt_constructor
|
58
96
|
).get_question_options(question_data=prompt_data["data"])
|
59
|
-
|
97
|
+
processor_end = time.time()
|
98
|
+
logging.debug(f"Time to process question options: {processor_end - processor_start}")
|
99
|
+
|
60
100
|
prompt_data["data"]["question_options"] = question_options
|
101
|
+
|
102
|
+
end = time.time()
|
103
|
+
logging.debug(f"Total time in _enrich_with_question_options: {end - start}")
|
104
|
+
|
61
105
|
return prompt_data
|
62
106
|
|
63
107
|
def _render_prompt(self, prompt_data: Dict) -> Prompt:
|
@@ -69,11 +113,28 @@ class QuestionInstructionPromptBuilder:
|
|
69
113
|
Returns:
|
70
114
|
Prompt: Rendered instructions
|
71
115
|
"""
|
72
|
-
|
116
|
+
import time
|
117
|
+
|
118
|
+
start = time.time()
|
119
|
+
|
120
|
+
# Build replacement dict
|
121
|
+
dict_start = time.time()
|
73
122
|
replacement_dict = QTRB(self.prompt_constructor).build_replacement_dict(
|
74
123
|
prompt_data["data"]
|
75
124
|
)
|
76
|
-
|
125
|
+
dict_end = time.time()
|
126
|
+
logging.debug(f"Time to build replacement dict: {dict_end - dict_start}")
|
127
|
+
|
128
|
+
# Render with dict
|
129
|
+
render_start = time.time()
|
130
|
+
result = prompt_data["prompt"].render(replacement_dict)
|
131
|
+
render_end = time.time()
|
132
|
+
logging.debug(f"Time to render with dict: {render_end - render_start}")
|
133
|
+
|
134
|
+
end = time.time()
|
135
|
+
logging.debug(f"Total time in _render_prompt: {end - start}")
|
136
|
+
|
137
|
+
return result
|
77
138
|
|
78
139
|
def _validate_template_variables(self, rendered_prompt: Prompt) -> None:
|
79
140
|
"""Validates that all template variables have been properly replaced.
|
@@ -101,9 +162,7 @@ class QuestionInstructionPromptBuilder:
|
|
101
162
|
"""
|
102
163
|
for question_name in self.survey.question_names:
|
103
164
|
if question_name in undefined_vars:
|
104
|
-
|
105
|
-
f"Question name found in undefined_template_variables: {question_name}"
|
106
|
-
)
|
165
|
+
logging.warning(f"Question name found in undefined_template_variables: {question_name}")
|
107
166
|
|
108
167
|
def _append_survey_instructions(self, rendered_prompt: Prompt) -> Prompt:
|
109
168
|
"""Appends any relevant survey instructions to the rendered prompt.
|
edsl/agents/prompt_helpers.py
CHANGED
@@ -124,6 +124,6 @@ class PromptPlan:
|
|
124
124
|
"""Get both prompts for the LLM call."""
|
125
125
|
prompts = self.arrange_components(**kwargs)
|
126
126
|
return {
|
127
|
-
"user_prompt": prompts["user_prompt"]
|
128
|
-
"system_prompt": prompts["system_prompt"]
|
127
|
+
"user_prompt": Prompt("".join(str(p) for p in prompts["user_prompt"])),
|
128
|
+
"system_prompt": Prompt("".join(str(p) for p in prompts["system_prompt"])),
|
129
129
|
}
|
edsl/coop/coop.py
CHANGED
@@ -14,7 +14,11 @@ from edsl.data.CacheEntry import CacheEntry
|
|
14
14
|
from edsl.jobs.Jobs import Jobs
|
15
15
|
from edsl.surveys.Survey import Survey
|
16
16
|
|
17
|
-
from edsl.exceptions.coop import
|
17
|
+
from edsl.exceptions.coop import (
|
18
|
+
CoopInvalidURLError,
|
19
|
+
CoopNoUUIDError,
|
20
|
+
CoopServerResponseError,
|
21
|
+
)
|
18
22
|
from edsl.coop.utils import (
|
19
23
|
EDSLObject,
|
20
24
|
ObjectRegistry,
|
@@ -285,17 +289,46 @@ class Coop(CoopFunctionsMixin):
|
|
285
289
|
if value is None:
|
286
290
|
return "null"
|
287
291
|
|
288
|
-
def
|
292
|
+
def _resolve_uuid_or_alias(
|
289
293
|
self, uuid: Union[str, UUID] = None, url: str = None
|
290
|
-
) ->
|
294
|
+
) -> tuple[Optional[str], Optional[str], Optional[str]]:
|
291
295
|
"""
|
292
|
-
Resolve the uuid from a uuid or a url.
|
296
|
+
Resolve the uuid or alias information from a uuid or a url.
|
297
|
+
Returns a tuple of (uuid, owner_username, alias)
|
298
|
+
- For content/<uuid> URLs: returns (uuid, None, None)
|
299
|
+
- For content/<username>/<alias> URLs: returns (None, username, alias)
|
293
300
|
"""
|
294
301
|
if not url and not uuid:
|
295
302
|
raise CoopNoUUIDError("No uuid or url provided for the object.")
|
303
|
+
|
296
304
|
if not uuid and url:
|
297
|
-
|
298
|
-
|
305
|
+
parts = (
|
306
|
+
url.replace("http://", "")
|
307
|
+
.replace("https://", "")
|
308
|
+
.rstrip("/")
|
309
|
+
.split("/")
|
310
|
+
)
|
311
|
+
|
312
|
+
# Remove domain
|
313
|
+
parts = parts[1:]
|
314
|
+
|
315
|
+
if len(parts) < 2 or parts[0] != "content":
|
316
|
+
raise CoopInvalidURLError(
|
317
|
+
f"Invalid URL format. The URL must end with /content/<uuid> or /content/<username>/<alias>: {url}"
|
318
|
+
)
|
319
|
+
|
320
|
+
if len(parts) == 2:
|
321
|
+
obj_uuid = parts[1]
|
322
|
+
return obj_uuid, None, None
|
323
|
+
elif len(parts) == 3:
|
324
|
+
username, alias = parts[1], parts[2]
|
325
|
+
return None, username, alias
|
326
|
+
else:
|
327
|
+
raise CoopInvalidURLError(
|
328
|
+
f"Invalid URL format. The URL must end with /content/<uuid> or /content/<username>/<alias>: {url}"
|
329
|
+
)
|
330
|
+
|
331
|
+
return str(uuid), None, None
|
299
332
|
|
300
333
|
@property
|
301
334
|
def edsl_settings(self) -> dict:
|
@@ -361,22 +394,31 @@ class Coop(CoopFunctionsMixin):
|
|
361
394
|
expected_object_type: Optional[ObjectType] = None,
|
362
395
|
) -> EDSLObject:
|
363
396
|
"""
|
364
|
-
Retrieve an EDSL object by its uuid or
|
397
|
+
Retrieve an EDSL object by its uuid/url or by owner username and alias.
|
365
398
|
- If the object's visibility is private, the user must be the owner.
|
366
399
|
- Optionally, check if the retrieved object is of a certain type.
|
367
400
|
|
368
401
|
:param uuid: the uuid of the object either in str or UUID format.
|
369
|
-
:param url: the url of the object.
|
402
|
+
:param url: the url of the object (can be content/uuid or content/username/alias format).
|
370
403
|
:param expected_object_type: the expected type of the object.
|
371
404
|
|
372
405
|
:return: the object instance.
|
373
406
|
"""
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
407
|
+
obj_uuid, owner_username, alias = self._resolve_uuid_or_alias(uuid, url)
|
408
|
+
|
409
|
+
if obj_uuid:
|
410
|
+
response = self._send_server_request(
|
411
|
+
uri=f"api/v0/object",
|
412
|
+
method="GET",
|
413
|
+
params={"uuid": obj_uuid},
|
414
|
+
)
|
415
|
+
else:
|
416
|
+
response = self._send_server_request(
|
417
|
+
uri=f"api/v0/object/alias",
|
418
|
+
method="GET",
|
419
|
+
params={"owner_username": owner_username, "alias": alias},
|
420
|
+
)
|
421
|
+
|
380
422
|
self._resolve_server_response(response)
|
381
423
|
json_string = response.json().get("json_string")
|
382
424
|
object_type = response.json().get("object_type")
|
@@ -414,12 +456,13 @@ class Coop(CoopFunctionsMixin):
|
|
414
456
|
"""
|
415
457
|
Delete an object from the server.
|
416
458
|
"""
|
417
|
-
|
459
|
+
obj_uuid, _, _ = self._resolve_uuid_or_alias(uuid, url)
|
418
460
|
response = self._send_server_request(
|
419
461
|
uri=f"api/v0/object",
|
420
462
|
method="DELETE",
|
421
|
-
params={"uuid":
|
463
|
+
params={"uuid": obj_uuid},
|
422
464
|
)
|
465
|
+
|
423
466
|
self._resolve_server_response(response)
|
424
467
|
return response.json()
|
425
468
|
|
@@ -438,11 +481,11 @@ class Coop(CoopFunctionsMixin):
|
|
438
481
|
"""
|
439
482
|
if description is None and visibility is None and value is None:
|
440
483
|
raise Exception("Nothing to patch.")
|
441
|
-
|
484
|
+
obj_uuid, _, _ = self._resolve_uuid_or_alias(uuid, url)
|
442
485
|
response = self._send_server_request(
|
443
486
|
uri=f"api/v0/object",
|
444
487
|
method="PATCH",
|
445
|
-
params={"uuid":
|
488
|
+
params={"uuid": obj_uuid},
|
446
489
|
payload={
|
447
490
|
"description": description,
|
448
491
|
"alias": alias,
|
@@ -549,6 +592,7 @@ class Coop(CoopFunctionsMixin):
|
|
549
592
|
def remote_cache_get(
|
550
593
|
self,
|
551
594
|
exclude_keys: Optional[list[str]] = None,
|
595
|
+
select_keys: Optional[list[str]] = None,
|
552
596
|
) -> list[CacheEntry]:
|
553
597
|
"""
|
554
598
|
Get all remote cache entries.
|
@@ -560,10 +604,12 @@ class Coop(CoopFunctionsMixin):
|
|
560
604
|
"""
|
561
605
|
if exclude_keys is None:
|
562
606
|
exclude_keys = []
|
607
|
+
if select_keys is None:
|
608
|
+
select_keys = []
|
563
609
|
response = self._send_server_request(
|
564
610
|
uri="api/v0/remote-cache/get-many",
|
565
611
|
method="POST",
|
566
|
-
payload={"keys": exclude_keys},
|
612
|
+
payload={"keys": exclude_keys, "selected_keys": select_keys},
|
567
613
|
timeout=40,
|
568
614
|
)
|
569
615
|
self._resolve_server_response(response)
|
edsl/enums.py
CHANGED
@@ -97,7 +97,6 @@ available_models_urls = {
|
|
97
97
|
|
98
98
|
|
99
99
|
service_to_api_keyname = {
|
100
|
-
InferenceServiceType.BEDROCK.value: "TBD",
|
101
100
|
InferenceServiceType.DEEP_INFRA.value: "DEEP_INFRA_API_KEY",
|
102
101
|
InferenceServiceType.REPLICATE.value: "TBD",
|
103
102
|
InferenceServiceType.OPENAI.value: "OPENAI_API_KEY",
|
@@ -109,7 +108,7 @@ service_to_api_keyname = {
|
|
109
108
|
InferenceServiceType.MISTRAL.value: "MISTRAL_API_KEY",
|
110
109
|
InferenceServiceType.TOGETHER.value: "TOGETHER_API_KEY",
|
111
110
|
InferenceServiceType.PERPLEXITY.value: "PERPLEXITY_API_KEY",
|
112
|
-
InferenceServiceType.DEEPSEEK.value: "DEEPSEEK_API_KEY"
|
111
|
+
InferenceServiceType.DEEPSEEK.value: "DEEPSEEK_API_KEY"
|
113
112
|
}
|
114
113
|
|
115
114
|
|
edsl/exceptions/coop.py
CHANGED
@@ -136,7 +136,10 @@ class AvailableModelFetcher:
|
|
136
136
|
if not service_models:
|
137
137
|
import warnings
|
138
138
|
|
139
|
-
warnings.
|
139
|
+
with warnings.catch_warnings():
|
140
|
+
warnings.simplefilter("ignore") # Ignores the warning
|
141
|
+
warnings.warn(f"No models found for service {service_name}")
|
142
|
+
|
140
143
|
return [], service_name
|
141
144
|
|
142
145
|
models_list = AvailableModels(
|
edsl/jobs/Jobs.py
CHANGED
@@ -38,6 +38,7 @@ if TYPE_CHECKING:
|
|
38
38
|
from edsl.language_models.ModelList import ModelList
|
39
39
|
from edsl.data.Cache import Cache
|
40
40
|
from edsl.language_models.key_management.KeyLookup import KeyLookup
|
41
|
+
from edsl.jobs.JobsRemoteInferenceHandler import JobsRemoteInferenceHandler
|
41
42
|
|
42
43
|
VisibilityType = Literal["private", "public", "unlisted"]
|
43
44
|
|
@@ -407,7 +408,13 @@ class Jobs(Base):
|
|
407
408
|
>>> bc
|
408
409
|
BucketCollection(...)
|
409
410
|
"""
|
410
|
-
|
411
|
+
bc = BucketCollection.from_models(self.models)
|
412
|
+
|
413
|
+
if self.run_config.environment.key_lookup is not None:
|
414
|
+
bc.update_from_key_lookup(
|
415
|
+
self.run_config.environment.key_lookup
|
416
|
+
)
|
417
|
+
return bc
|
411
418
|
|
412
419
|
def html(self):
|
413
420
|
"""Return the HTML representations for each scenario"""
|
@@ -465,22 +472,47 @@ class Jobs(Base):
|
|
465
472
|
|
466
473
|
return False
|
467
474
|
|
475
|
+
def _start_remote_inference_job(
|
476
|
+
self, job_handler: Optional[JobsRemoteInferenceHandler] = None
|
477
|
+
) -> Union["Results", None]:
|
478
|
+
|
479
|
+
if job_handler is None:
|
480
|
+
job_handler = self._create_remote_inference_handler()
|
481
|
+
|
482
|
+
job_info = job_handler.create_remote_inference_job(
|
483
|
+
iterations=self.run_config.parameters.n,
|
484
|
+
remote_inference_description=self.run_config.parameters.remote_inference_description,
|
485
|
+
remote_inference_results_visibility=self.run_config.parameters.remote_inference_results_visibility,
|
486
|
+
)
|
487
|
+
return job_info
|
488
|
+
|
489
|
+
def _create_remote_inference_handler(self) -> JobsRemoteInferenceHandler:
|
490
|
+
|
491
|
+
from edsl.jobs.JobsRemoteInferenceHandler import JobsRemoteInferenceHandler
|
492
|
+
|
493
|
+
return JobsRemoteInferenceHandler(
|
494
|
+
self, verbose=self.run_config.parameters.verbose
|
495
|
+
)
|
496
|
+
|
468
497
|
def _remote_results(
|
469
498
|
self,
|
499
|
+
config: RunConfig,
|
470
500
|
) -> Union["Results", None]:
|
471
501
|
from edsl.jobs.JobsRemoteInferenceHandler import JobsRemoteInferenceHandler
|
502
|
+
from edsl.jobs.JobsRemoteInferenceHandler import RemoteJobInfo
|
472
503
|
|
473
|
-
|
474
|
-
|
475
|
-
)
|
504
|
+
background = config.parameters.background
|
505
|
+
|
506
|
+
jh = self._create_remote_inference_handler()
|
476
507
|
if jh.use_remote_inference(self.run_config.parameters.disable_remote_inference):
|
477
|
-
job_info =
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
508
|
+
job_info: RemoteJobInfo = self._start_remote_inference_job(jh)
|
509
|
+
if background:
|
510
|
+
from edsl.results.Results import Results
|
511
|
+
results = Results.from_job_info(job_info)
|
512
|
+
return results
|
513
|
+
else:
|
514
|
+
results = jh.poll_remote_inference_job(job_info)
|
515
|
+
return results
|
484
516
|
else:
|
485
517
|
return None
|
486
518
|
|
@@ -507,13 +539,6 @@ class Jobs(Base):
|
|
507
539
|
|
508
540
|
assert isinstance(self.run_config.environment.cache, Cache)
|
509
541
|
|
510
|
-
# with RemoteCacheSync(
|
511
|
-
# coop=Coop(),
|
512
|
-
# cache=self.run_config.environment.cache,
|
513
|
-
# output_func=self._output,
|
514
|
-
# remote_cache=use_remote_cache,
|
515
|
-
# remote_cache_description=self.run_config.parameters.remote_cache_description,
|
516
|
-
# ):
|
517
542
|
runner = JobsRunnerAsyncio(self, environment=self.run_config.environment)
|
518
543
|
if run_job_async:
|
519
544
|
results = await runner.run_async(self.run_config.parameters)
|
@@ -521,19 +546,6 @@ class Jobs(Base):
|
|
521
546
|
results = runner.run(self.run_config.parameters)
|
522
547
|
return results
|
523
548
|
|
524
|
-
# def _setup_and_check(self) -> Tuple[RunConfig, Optional[Results]]:
|
525
|
-
# self._prepare_to_run()
|
526
|
-
# self._check_if_remote_keys_ok()
|
527
|
-
|
528
|
-
# # first try to run the job remotely
|
529
|
-
# results = self._remote_results()
|
530
|
-
# #breakpoint()
|
531
|
-
# if results is not None:
|
532
|
-
# return results
|
533
|
-
|
534
|
-
# self._check_if_local_keys_ok()
|
535
|
-
# return None
|
536
|
-
|
537
549
|
@property
|
538
550
|
def num_interviews(self):
|
539
551
|
if self.run_config.parameters.n is None:
|
@@ -563,7 +575,6 @@ class Jobs(Base):
|
|
563
575
|
|
564
576
|
self.replace_missing_objects()
|
565
577
|
|
566
|
-
# try to run remotely first
|
567
578
|
self._prepare_to_run()
|
568
579
|
self._check_if_remote_keys_ok()
|
569
580
|
|
@@ -581,9 +592,9 @@ class Jobs(Base):
|
|
581
592
|
self.run_config.environment.cache = Cache(immediate_write=False)
|
582
593
|
|
583
594
|
# first try to run the job remotely
|
584
|
-
if results := self._remote_results():
|
595
|
+
if (results := self._remote_results(config)) is not None:
|
585
596
|
return results
|
586
|
-
|
597
|
+
|
587
598
|
self._check_if_local_keys_ok()
|
588
599
|
|
589
600
|
if config.environment.bucket_collection is None:
|
@@ -591,6 +602,14 @@ class Jobs(Base):
|
|
591
602
|
self.create_bucket_collection()
|
592
603
|
)
|
593
604
|
|
605
|
+
if (
|
606
|
+
self.run_config.environment.key_lookup is not None
|
607
|
+
and self.run_config.environment.bucket_collection is not None
|
608
|
+
):
|
609
|
+
self.run_config.environment.bucket_collection.update_from_key_lookup(
|
610
|
+
self.run_config.environment.key_lookup
|
611
|
+
)
|
612
|
+
|
594
613
|
return None
|
595
614
|
|
596
615
|
@with_config
|
@@ -613,7 +632,7 @@ class Jobs(Base):
|
|
613
632
|
:param key_lookup: A KeyLookup object to manage API keys
|
614
633
|
"""
|
615
634
|
potentially_completed_results = self._run(config)
|
616
|
-
|
635
|
+
|
617
636
|
if potentially_completed_results is not None:
|
618
637
|
return potentially_completed_results
|
619
638
|
|