edsl 0.1.38.dev4__py3-none-any.whl → 0.1.39__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +197 -116
- edsl/__init__.py +15 -7
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +351 -147
- edsl/agents/AgentList.py +211 -73
- edsl/agents/Invigilator.py +101 -50
- edsl/agents/InvigilatorBase.py +62 -70
- edsl/agents/PromptConstructor.py +143 -225
- edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
- edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
- edsl/agents/__init__.py +0 -1
- edsl/agents/prompt_helpers.py +3 -3
- edsl/agents/question_option_processor.py +172 -0
- edsl/auto/AutoStudy.py +18 -5
- edsl/auto/StageBase.py +53 -40
- edsl/auto/StageQuestions.py +2 -1
- edsl/auto/utilities.py +0 -6
- edsl/config.py +22 -2
- edsl/conversation/car_buying.py +2 -1
- edsl/coop/CoopFunctionsMixin.py +15 -0
- edsl/coop/ExpectedParrotKeyHandler.py +125 -0
- edsl/coop/PriceFetcher.py +1 -1
- edsl/coop/coop.py +125 -47
- edsl/coop/utils.py +14 -14
- edsl/data/Cache.py +45 -27
- edsl/data/CacheEntry.py +12 -15
- edsl/data/CacheHandler.py +31 -12
- edsl/data/RemoteCacheSync.py +154 -46
- edsl/data/__init__.py +4 -3
- edsl/data_transfer_models.py +2 -1
- edsl/enums.py +27 -0
- edsl/exceptions/__init__.py +50 -50
- edsl/exceptions/agents.py +12 -0
- edsl/exceptions/inference_services.py +5 -0
- edsl/exceptions/questions.py +24 -6
- edsl/exceptions/scenarios.py +7 -0
- edsl/inference_services/AnthropicService.py +38 -19
- edsl/inference_services/AvailableModelCacheHandler.py +184 -0
- edsl/inference_services/AvailableModelFetcher.py +215 -0
- edsl/inference_services/AwsBedrock.py +0 -2
- edsl/inference_services/AzureAI.py +0 -2
- edsl/inference_services/GoogleService.py +7 -12
- edsl/inference_services/InferenceServiceABC.py +18 -85
- edsl/inference_services/InferenceServicesCollection.py +120 -79
- edsl/inference_services/MistralAIService.py +0 -3
- edsl/inference_services/OpenAIService.py +47 -35
- edsl/inference_services/PerplexityService.py +0 -3
- edsl/inference_services/ServiceAvailability.py +135 -0
- edsl/inference_services/TestService.py +11 -10
- edsl/inference_services/TogetherAIService.py +5 -3
- edsl/inference_services/data_structures.py +134 -0
- edsl/jobs/AnswerQuestionFunctionConstructor.py +223 -0
- edsl/jobs/Answers.py +1 -14
- edsl/jobs/FetchInvigilator.py +47 -0
- edsl/jobs/InterviewTaskManager.py +98 -0
- edsl/jobs/InterviewsConstructor.py +50 -0
- edsl/jobs/Jobs.py +356 -431
- edsl/jobs/JobsChecks.py +35 -10
- edsl/jobs/JobsComponentConstructor.py +189 -0
- edsl/jobs/JobsPrompts.py +6 -4
- edsl/jobs/JobsRemoteInferenceHandler.py +205 -133
- edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
- edsl/jobs/RequestTokenEstimator.py +30 -0
- edsl/jobs/async_interview_runner.py +138 -0
- edsl/jobs/buckets/BucketCollection.py +44 -3
- edsl/jobs/buckets/TokenBucket.py +53 -21
- edsl/jobs/buckets/TokenBucketAPI.py +211 -0
- edsl/jobs/buckets/TokenBucketClient.py +191 -0
- edsl/jobs/check_survey_scenario_compatibility.py +85 -0
- edsl/jobs/data_structures.py +120 -0
- edsl/jobs/decorators.py +35 -0
- edsl/jobs/interviews/Interview.py +143 -408
- edsl/jobs/jobs_status_enums.py +9 -0
- edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
- edsl/jobs/results_exceptions_handler.py +98 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +88 -403
- edsl/jobs/runners/JobsRunnerStatus.py +133 -165
- edsl/jobs/tasks/QuestionTaskCreator.py +21 -19
- edsl/jobs/tasks/TaskHistory.py +38 -18
- edsl/jobs/tasks/task_status_enum.py +0 -2
- edsl/language_models/ComputeCost.py +63 -0
- edsl/language_models/LanguageModel.py +194 -236
- edsl/language_models/ModelList.py +28 -19
- edsl/language_models/PriceManager.py +127 -0
- edsl/language_models/RawResponseHandler.py +106 -0
- edsl/language_models/ServiceDataSources.py +0 -0
- edsl/language_models/__init__.py +1 -2
- edsl/language_models/key_management/KeyLookup.py +63 -0
- edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
- edsl/language_models/key_management/KeyLookupCollection.py +38 -0
- edsl/language_models/key_management/__init__.py +0 -0
- edsl/language_models/key_management/models.py +131 -0
- edsl/language_models/model.py +256 -0
- edsl/language_models/repair.py +2 -2
- edsl/language_models/utilities.py +5 -4
- edsl/notebooks/Notebook.py +19 -14
- edsl/notebooks/NotebookToLaTeX.py +142 -0
- edsl/prompts/Prompt.py +29 -39
- edsl/questions/ExceptionExplainer.py +77 -0
- edsl/questions/HTMLQuestion.py +103 -0
- edsl/questions/QuestionBase.py +68 -214
- edsl/questions/QuestionBasePromptsMixin.py +7 -3
- edsl/questions/QuestionBudget.py +1 -1
- edsl/questions/QuestionCheckBox.py +3 -3
- edsl/questions/QuestionExtract.py +5 -7
- edsl/questions/QuestionFreeText.py +2 -3
- edsl/questions/QuestionList.py +10 -18
- edsl/questions/QuestionMatrix.py +265 -0
- edsl/questions/QuestionMultipleChoice.py +67 -23
- edsl/questions/QuestionNumerical.py +2 -4
- edsl/questions/QuestionRank.py +7 -17
- edsl/questions/SimpleAskMixin.py +4 -3
- edsl/questions/__init__.py +2 -1
- edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +47 -2
- edsl/questions/data_structures.py +20 -0
- edsl/questions/derived/QuestionLinearScale.py +6 -3
- edsl/questions/derived/QuestionTopK.py +1 -1
- edsl/questions/descriptors.py +17 -3
- edsl/questions/loop_processor.py +149 -0
- edsl/questions/{QuestionBaseGenMixin.py → question_base_gen_mixin.py} +57 -50
- edsl/questions/question_registry.py +1 -1
- edsl/questions/{ResponseValidatorABC.py → response_validator_abc.py} +40 -26
- edsl/questions/response_validator_factory.py +34 -0
- edsl/questions/templates/matrix/__init__.py +1 -0
- edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
- edsl/questions/templates/matrix/question_presentation.jinja +20 -0
- edsl/results/CSSParameterizer.py +1 -1
- edsl/results/Dataset.py +170 -7
- edsl/results/DatasetExportMixin.py +168 -305
- edsl/results/DatasetTree.py +28 -8
- edsl/results/MarkdownToDocx.py +122 -0
- edsl/results/MarkdownToPDF.py +111 -0
- edsl/results/Result.py +298 -206
- edsl/results/Results.py +149 -131
- edsl/results/ResultsExportMixin.py +2 -0
- edsl/results/TableDisplay.py +98 -171
- edsl/results/TextEditor.py +50 -0
- edsl/results/__init__.py +1 -1
- edsl/results/file_exports.py +252 -0
- edsl/results/{Selector.py → results_selector.py} +23 -13
- edsl/results/smart_objects.py +96 -0
- edsl/results/table_data_class.py +12 -0
- edsl/results/table_renderers.py +118 -0
- edsl/scenarios/ConstructDownloadLink.py +109 -0
- edsl/scenarios/DocumentChunker.py +102 -0
- edsl/scenarios/DocxScenario.py +16 -0
- edsl/scenarios/FileStore.py +150 -239
- edsl/scenarios/PdfExtractor.py +40 -0
- edsl/scenarios/Scenario.py +90 -193
- edsl/scenarios/ScenarioHtmlMixin.py +4 -3
- edsl/scenarios/ScenarioList.py +415 -244
- edsl/scenarios/ScenarioListExportMixin.py +0 -7
- edsl/scenarios/ScenarioListPdfMixin.py +15 -37
- edsl/scenarios/__init__.py +1 -2
- edsl/scenarios/directory_scanner.py +96 -0
- edsl/scenarios/file_methods.py +85 -0
- edsl/scenarios/handlers/__init__.py +13 -0
- edsl/scenarios/handlers/csv.py +49 -0
- edsl/scenarios/handlers/docx.py +76 -0
- edsl/scenarios/handlers/html.py +37 -0
- edsl/scenarios/handlers/json.py +111 -0
- edsl/scenarios/handlers/latex.py +5 -0
- edsl/scenarios/handlers/md.py +51 -0
- edsl/scenarios/handlers/pdf.py +68 -0
- edsl/scenarios/handlers/png.py +39 -0
- edsl/scenarios/handlers/pptx.py +105 -0
- edsl/scenarios/handlers/py.py +294 -0
- edsl/scenarios/handlers/sql.py +313 -0
- edsl/scenarios/handlers/sqlite.py +149 -0
- edsl/scenarios/handlers/txt.py +33 -0
- edsl/scenarios/{ScenarioJoin.py → scenario_join.py} +10 -6
- edsl/scenarios/scenario_selector.py +156 -0
- edsl/study/ObjectEntry.py +1 -1
- edsl/study/SnapShot.py +1 -1
- edsl/study/Study.py +5 -12
- edsl/surveys/ConstructDAG.py +92 -0
- edsl/surveys/EditSurvey.py +221 -0
- edsl/surveys/InstructionHandler.py +100 -0
- edsl/surveys/MemoryManagement.py +72 -0
- edsl/surveys/Rule.py +5 -4
- edsl/surveys/RuleCollection.py +25 -27
- edsl/surveys/RuleManager.py +172 -0
- edsl/surveys/Simulator.py +75 -0
- edsl/surveys/Survey.py +270 -791
- edsl/surveys/SurveyCSS.py +20 -8
- edsl/surveys/{SurveyFlowVisualizationMixin.py → SurveyFlowVisualization.py} +11 -9
- edsl/surveys/SurveyToApp.py +141 -0
- edsl/surveys/__init__.py +4 -2
- edsl/surveys/descriptors.py +6 -2
- edsl/surveys/instructions/ChangeInstruction.py +1 -2
- edsl/surveys/instructions/Instruction.py +4 -13
- edsl/surveys/instructions/InstructionCollection.py +11 -6
- edsl/templates/error_reporting/interview_details.html +1 -1
- edsl/templates/error_reporting/report.html +1 -1
- edsl/tools/plotting.py +1 -1
- edsl/utilities/PrettyList.py +56 -0
- edsl/utilities/is_notebook.py +18 -0
- edsl/utilities/is_valid_variable_name.py +11 -0
- edsl/utilities/remove_edsl_version.py +24 -0
- edsl/utilities/utilities.py +35 -23
- {edsl-0.1.38.dev4.dist-info → edsl-0.1.39.dist-info}/METADATA +12 -10
- edsl-0.1.39.dist-info/RECORD +358 -0
- {edsl-0.1.38.dev4.dist-info → edsl-0.1.39.dist-info}/WHEEL +1 -1
- edsl/language_models/KeyLookup.py +0 -30
- edsl/language_models/registry.py +0 -190
- edsl/language_models/unused/ReplicateBase.py +0 -83
- edsl/results/ResultsDBMixin.py +0 -238
- edsl-0.1.38.dev4.dist-info/RECORD +0 -277
- /edsl/questions/{RegisterQuestionsMeta.py → register_questions_meta.py} +0 -0
- /edsl/results/{ResultsFetchMixin.py → results_fetch_mixin.py} +0 -0
- /edsl/results/{ResultsToolsMixin.py → results_tools_mixin.py} +0 -0
- {edsl-0.1.38.dev4.dist-info → edsl-0.1.39.dist-info}/LICENSE +0 -0
edsl/jobs/Jobs.py
CHANGED
@@ -1,36 +1,106 @@
|
|
1
1
|
# """The Jobs class is a collection of agents, scenarios and models and one survey."""
|
2
2
|
from __future__ import annotations
|
3
|
-
import
|
4
|
-
import
|
5
|
-
from
|
6
|
-
|
3
|
+
import asyncio
|
4
|
+
from inspect import signature
|
5
|
+
from typing import (
|
6
|
+
Literal,
|
7
|
+
Optional,
|
8
|
+
Union,
|
9
|
+
Sequence,
|
10
|
+
Generator,
|
11
|
+
TYPE_CHECKING,
|
12
|
+
Callable,
|
13
|
+
Tuple,
|
14
|
+
)
|
7
15
|
|
8
16
|
from edsl.Base import Base
|
9
17
|
|
10
|
-
from edsl.exceptions import MissingAPIKeyError
|
11
18
|
from edsl.jobs.buckets.BucketCollection import BucketCollection
|
19
|
+
from edsl.jobs.JobsPrompts import JobsPrompts
|
12
20
|
from edsl.jobs.interviews.Interview import Interview
|
21
|
+
from edsl.utilities.remove_edsl_version import remove_edsl_version
|
13
22
|
from edsl.jobs.runners.JobsRunnerAsyncio import JobsRunnerAsyncio
|
14
|
-
from edsl.utilities.decorators import remove_edsl_version
|
15
|
-
|
16
23
|
from edsl.data.RemoteCacheSync import RemoteCacheSync
|
17
24
|
from edsl.exceptions.coop import CoopServerResponseError
|
18
25
|
|
26
|
+
from edsl.jobs.JobsChecks import JobsChecks
|
27
|
+
from edsl.jobs.data_structures import RunEnvironment, RunParameters, RunConfig
|
28
|
+
|
19
29
|
if TYPE_CHECKING:
|
20
30
|
from edsl.agents.Agent import Agent
|
21
31
|
from edsl.agents.AgentList import AgentList
|
22
32
|
from edsl.language_models.LanguageModel import LanguageModel
|
23
33
|
from edsl.scenarios.Scenario import Scenario
|
34
|
+
from edsl.scenarios.ScenarioList import ScenarioList
|
24
35
|
from edsl.surveys.Survey import Survey
|
25
36
|
from edsl.results.Results import Results
|
26
37
|
from edsl.results.Dataset import Dataset
|
38
|
+
from edsl.language_models.ModelList import ModelList
|
39
|
+
from edsl.data.Cache import Cache
|
40
|
+
from edsl.language_models.key_management.KeyLookup import KeyLookup
|
41
|
+
|
42
|
+
VisibilityType = Literal["private", "public", "unlisted"]
|
43
|
+
|
44
|
+
from dataclasses import dataclass
|
45
|
+
from typing import Optional, Union, TypeVar, Callable, cast
|
46
|
+
from functools import wraps
|
47
|
+
|
48
|
+
try:
|
49
|
+
from typing import ParamSpec
|
50
|
+
except ImportError:
|
51
|
+
from typing_extensions import ParamSpec
|
52
|
+
|
53
|
+
|
54
|
+
P = ParamSpec("P")
|
55
|
+
T = TypeVar("T")
|
56
|
+
|
57
|
+
|
58
|
+
from edsl.jobs.check_survey_scenario_compatibility import (
|
59
|
+
CheckSurveyScenarioCompatibility,
|
60
|
+
)
|
61
|
+
|
62
|
+
|
63
|
+
def with_config(f: Callable[P, T]) -> Callable[P, T]:
|
64
|
+
"This decorator make it so that the run function parameters match the RunConfig dataclass."
|
65
|
+
parameter_fields = {
|
66
|
+
name: field.default
|
67
|
+
for name, field in RunParameters.__dataclass_fields__.items()
|
68
|
+
}
|
69
|
+
environment_fields = {
|
70
|
+
name: field.default
|
71
|
+
for name, field in RunEnvironment.__dataclass_fields__.items()
|
72
|
+
}
|
73
|
+
combined = {**parameter_fields, **environment_fields}
|
74
|
+
|
75
|
+
@wraps(f)
|
76
|
+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
77
|
+
environment = RunEnvironment(
|
78
|
+
**{k: v for k, v in kwargs.items() if k in environment_fields}
|
79
|
+
)
|
80
|
+
parameters = RunParameters(
|
81
|
+
**{k: v for k, v in kwargs.items() if k in parameter_fields}
|
82
|
+
)
|
83
|
+
config = RunConfig(environment=environment, parameters=parameters)
|
84
|
+
return f(*args, config=config)
|
85
|
+
|
86
|
+
# Update the wrapper's signature to include all RunConfig parameters
|
87
|
+
# old_sig = signature(f)
|
88
|
+
# wrapper.__signature__ = old_sig.replace(
|
89
|
+
# parameters=list(old_sig.parameters.values())[:-1]
|
90
|
+
# + [
|
91
|
+
# old_sig.parameters["config"].replace(
|
92
|
+
# default=parameter_fields[name], name=name
|
93
|
+
# )
|
94
|
+
# for name in combined
|
95
|
+
# ]
|
96
|
+
# )
|
97
|
+
|
98
|
+
return cast(Callable[P, T], wrapper)
|
27
99
|
|
28
100
|
|
29
101
|
class Jobs(Base):
|
30
102
|
"""
|
31
|
-
A collection of agents, scenarios and models and one survey
|
32
|
-
The actual running of a job is done by a `JobsRunner`, which is a subclass of `JobsRunner`.
|
33
|
-
The `JobsRunner` is chosen by the user, and is stored in the `jobs_runner_name` attribute.
|
103
|
+
A collection of agents, scenarios and models and one survey that creates 'interviews'
|
34
104
|
"""
|
35
105
|
|
36
106
|
__documentation__ = "https://docs.expectedparrot.com/en/latest/jobs.html"
|
@@ -38,9 +108,9 @@ class Jobs(Base):
|
|
38
108
|
def __init__(
|
39
109
|
self,
|
40
110
|
survey: "Survey",
|
41
|
-
agents: Optional[list[
|
42
|
-
models: Optional[list[
|
43
|
-
scenarios: Optional[list[
|
111
|
+
agents: Optional[Union[list[Agent], AgentList]] = None,
|
112
|
+
models: Optional[Union[ModelList, list[LanguageModel]]] = None,
|
113
|
+
scenarios: Optional[Union[ScenarioList, list[Scenario]]] = None,
|
44
114
|
):
|
45
115
|
"""Initialize a Jobs instance.
|
46
116
|
|
@@ -49,14 +119,62 @@ class Jobs(Base):
|
|
49
119
|
:param models: a list of models
|
50
120
|
:param scenarios: a list of scenarios
|
51
121
|
"""
|
122
|
+
self.run_config = RunConfig(
|
123
|
+
environment=RunEnvironment(), parameters=RunParameters()
|
124
|
+
)
|
125
|
+
|
52
126
|
self.survey = survey
|
53
|
-
self.agents:
|
54
|
-
self.scenarios:
|
55
|
-
self.models = models
|
127
|
+
self.agents: AgentList = agents
|
128
|
+
self.scenarios: ScenarioList = scenarios
|
129
|
+
self.models: ModelList = models
|
130
|
+
|
131
|
+
def add_running_env(self, running_env: RunEnvironment):
|
132
|
+
self.run_config.add_environment(running_env)
|
133
|
+
return self
|
134
|
+
|
135
|
+
def using_cache(self, cache: "Cache") -> Jobs:
|
136
|
+
"""
|
137
|
+
Add a Cache to the job.
|
138
|
+
|
139
|
+
:param cache: the cache to add
|
140
|
+
"""
|
141
|
+
self.run_config.add_cache(cache)
|
142
|
+
return self
|
143
|
+
|
144
|
+
def using_bucket_collection(self, bucket_collection: BucketCollection) -> Jobs:
|
145
|
+
"""
|
146
|
+
Add a BucketCollection to the job.
|
147
|
+
|
148
|
+
:param bucket_collection: the bucket collection to add
|
149
|
+
"""
|
150
|
+
self.run_config.add_bucket_collection(bucket_collection)
|
151
|
+
return self
|
56
152
|
|
57
|
-
|
153
|
+
def using_key_lookup(self, key_lookup: KeyLookup) -> Jobs:
|
154
|
+
"""
|
155
|
+
Add a KeyLookup to the job.
|
156
|
+
|
157
|
+
:param key_lookup: the key lookup to add
|
158
|
+
"""
|
159
|
+
self.run_config.add_key_lookup(key_lookup)
|
160
|
+
return self
|
58
161
|
|
59
|
-
|
162
|
+
def using(self, obj: Union[Cache, BucketCollection, KeyLookup]) -> Jobs:
|
163
|
+
"""
|
164
|
+
Add a Cache, BucketCollection, or KeyLookup to the job.
|
165
|
+
|
166
|
+
:param obj: the object to add
|
167
|
+
"""
|
168
|
+
from edsl.data.Cache import Cache
|
169
|
+
from edsl.language_models.key_management.KeyLookup import KeyLookup
|
170
|
+
|
171
|
+
if isinstance(obj, Cache):
|
172
|
+
self.using_cache(obj)
|
173
|
+
elif isinstance(obj, BucketCollection):
|
174
|
+
self.using_bucket_collection(obj)
|
175
|
+
elif isinstance(obj, KeyLookup):
|
176
|
+
self.using_key_lookup(obj)
|
177
|
+
return self
|
60
178
|
|
61
179
|
@property
|
62
180
|
def models(self):
|
@@ -64,7 +182,7 @@ class Jobs(Base):
|
|
64
182
|
|
65
183
|
@models.setter
|
66
184
|
def models(self, value):
|
67
|
-
from edsl import ModelList
|
185
|
+
from edsl.language_models.ModelList import ModelList
|
68
186
|
|
69
187
|
if value:
|
70
188
|
if not isinstance(value, ModelList):
|
@@ -74,13 +192,19 @@ class Jobs(Base):
|
|
74
192
|
else:
|
75
193
|
self._models = ModelList([])
|
76
194
|
|
195
|
+
# update the bucket collection if it exists
|
196
|
+
if self.run_config.environment.bucket_collection is None:
|
197
|
+
self.run_config.environment.bucket_collection = (
|
198
|
+
self.create_bucket_collection()
|
199
|
+
)
|
200
|
+
|
77
201
|
@property
|
78
202
|
def agents(self):
|
79
203
|
return self._agents
|
80
204
|
|
81
205
|
@agents.setter
|
82
206
|
def agents(self, value):
|
83
|
-
from edsl import AgentList
|
207
|
+
from edsl.agents.AgentList import AgentList
|
84
208
|
|
85
209
|
if value:
|
86
210
|
if not isinstance(value, AgentList):
|
@@ -96,7 +220,7 @@ class Jobs(Base):
|
|
96
220
|
|
97
221
|
@scenarios.setter
|
98
222
|
def scenarios(self, value):
|
99
|
-
from edsl import ScenarioList
|
223
|
+
from edsl.scenarios.ScenarioList import ScenarioList
|
100
224
|
from edsl.results.Dataset import Dataset
|
101
225
|
|
102
226
|
if value:
|
@@ -115,28 +239,32 @@ class Jobs(Base):
|
|
115
239
|
def by(
|
116
240
|
self,
|
117
241
|
*args: Union[
|
118
|
-
|
119
|
-
|
120
|
-
|
242
|
+
Agent,
|
243
|
+
Scenario,
|
244
|
+
LanguageModel,
|
121
245
|
Sequence[Union["Agent", "Scenario", "LanguageModel"]],
|
122
246
|
],
|
123
247
|
) -> Jobs:
|
124
248
|
"""
|
125
|
-
Add Agents, Scenarios and LanguageModels to a job.
|
249
|
+
Add Agents, Scenarios and LanguageModels to a job.
|
250
|
+
|
251
|
+
:param args: objects or a sequence (list, tuple, ...) of objects of the same type
|
252
|
+
|
253
|
+
If no objects of this type exist in the Jobs instance, it stores the new objects as a list in the corresponding attribute.
|
254
|
+
Otherwise, it combines the new objects with existing objects using the object's `__add__` method.
|
126
255
|
|
127
256
|
This 'by' is intended to create a fluent interface.
|
128
257
|
|
129
|
-
>>> from edsl import Survey
|
130
|
-
>>> from edsl import QuestionFreeText
|
258
|
+
>>> from edsl.surveys.Survey import Survey
|
259
|
+
>>> from edsl.questions.QuestionFreeText import QuestionFreeText
|
131
260
|
>>> q = QuestionFreeText(question_name="name", question_text="What is your name?")
|
132
261
|
>>> j = Jobs(survey = Survey(questions=[q]))
|
133
262
|
>>> j
|
134
263
|
Jobs(survey=Survey(...), agents=AgentList([]), models=ModelList([]), scenarios=ScenarioList([]))
|
135
|
-
>>> from edsl import Agent; a = Agent(traits = {"status": "Sad"})
|
264
|
+
>>> from edsl.agents.Agent import Agent; a = Agent(traits = {"status": "Sad"})
|
136
265
|
>>> j.by(a).agents
|
137
266
|
AgentList([Agent(traits = {'status': 'Sad'})])
|
138
267
|
|
139
|
-
:param args: objects or a sequence (list, tuple, ...) of objects of the same type
|
140
268
|
|
141
269
|
Notes:
|
142
270
|
- all objects must implement the 'get_value', 'set_value', and `__add__` methods
|
@@ -144,28 +272,9 @@ class Jobs(Base):
|
|
144
272
|
- scenarios: traits of new scenarios are combined with traits of old existing. New scenarios will overwrite overlapping traits, and do not increase the number of scenarios in the instance
|
145
273
|
- models: new models overwrite old models.
|
146
274
|
"""
|
147
|
-
from edsl.
|
148
|
-
|
149
|
-
if isinstance(
|
150
|
-
args[0], Dataset
|
151
|
-
): # let the user user a Dataset as if it were a ScenarioList
|
152
|
-
args = args[0].to_scenario_list()
|
153
|
-
|
154
|
-
passed_objects = self._turn_args_to_list(
|
155
|
-
args
|
156
|
-
) # objects can also be passed comma-separated
|
275
|
+
from edsl.jobs.JobsComponentConstructor import JobsComponentConstructor
|
157
276
|
|
158
|
-
|
159
|
-
passed_objects[0]
|
160
|
-
)
|
161
|
-
|
162
|
-
if not current_objects:
|
163
|
-
new_objects = passed_objects
|
164
|
-
else:
|
165
|
-
new_objects = self._merge_objects(passed_objects, current_objects)
|
166
|
-
|
167
|
-
setattr(self, objects_key, new_objects) # update the job
|
168
|
-
return self
|
277
|
+
return JobsComponentConstructor(self).by(*args)
|
169
278
|
|
170
279
|
def prompts(self) -> "Dataset":
|
171
280
|
"""Return a Dataset of prompts that will be used.
|
@@ -175,12 +284,9 @@ class Jobs(Base):
|
|
175
284
|
>>> Jobs.example().prompts()
|
176
285
|
Dataset(...)
|
177
286
|
"""
|
178
|
-
|
179
|
-
|
180
|
-
j = JobsPrompts(self)
|
181
|
-
return j.prompts()
|
287
|
+
return JobsPrompts(self).prompts()
|
182
288
|
|
183
|
-
def show_prompts(self, all=False) -> None:
|
289
|
+
def show_prompts(self, all: bool = False) -> None:
|
184
290
|
"""Print the prompts."""
|
185
291
|
if all:
|
186
292
|
return self.prompts().to_scenario_list().table()
|
@@ -200,9 +306,12 @@ class Jobs(Base):
|
|
200
306
|
"""
|
201
307
|
Estimate the cost of running the prompts.
|
202
308
|
:param iterations: the number of iterations to run
|
309
|
+
:param system_prompt: the system prompt
|
310
|
+
:param user_prompt: the user prompt
|
311
|
+
:param price_lookup: the price lookup
|
312
|
+
:param inference_service: the inference service
|
313
|
+
:param model: the model name
|
203
314
|
"""
|
204
|
-
from edsl.jobs.JobsPrompts import JobsPrompts
|
205
|
-
|
206
315
|
return JobsPrompts.estimate_prompt_cost(
|
207
316
|
system_prompt, user_prompt, price_lookup, inference_service, model
|
208
317
|
)
|
@@ -213,18 +322,14 @@ class Jobs(Base):
|
|
213
322
|
|
214
323
|
:param iterations: the number of iterations to run
|
215
324
|
"""
|
216
|
-
|
217
|
-
|
218
|
-
j = JobsPrompts(self)
|
219
|
-
return j.estimate_job_cost(iterations)
|
325
|
+
return JobsPrompts(self).estimate_job_cost(iterations)
|
220
326
|
|
221
327
|
def estimate_job_cost_from_external_prices(
|
222
328
|
self, price_lookup: dict, iterations: int = 1
|
223
329
|
) -> dict:
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
return j.estimate_job_cost_from_external_prices(price_lookup, iterations)
|
330
|
+
return JobsPrompts(self).estimate_job_cost_from_external_prices(
|
331
|
+
price_lookup, iterations
|
332
|
+
)
|
228
333
|
|
229
334
|
@staticmethod
|
230
335
|
def compute_job_cost(job_results: Results) -> float:
|
@@ -233,111 +338,30 @@ class Jobs(Base):
|
|
233
338
|
"""
|
234
339
|
return job_results.compute_job_cost()
|
235
340
|
|
236
|
-
|
237
|
-
def _get_container_class(object):
|
238
|
-
from edsl.agents.AgentList import AgentList
|
239
|
-
from edsl.agents.Agent import Agent
|
240
|
-
from edsl.scenarios.Scenario import Scenario
|
241
|
-
from edsl.scenarios.ScenarioList import ScenarioList
|
242
|
-
from edsl.language_models.ModelList import ModelList
|
243
|
-
|
244
|
-
if isinstance(object, Agent):
|
245
|
-
return AgentList
|
246
|
-
elif isinstance(object, Scenario):
|
247
|
-
return ScenarioList
|
248
|
-
elif isinstance(object, ModelList):
|
249
|
-
return ModelList
|
250
|
-
else:
|
251
|
-
return list
|
252
|
-
|
253
|
-
@staticmethod
|
254
|
-
def _turn_args_to_list(args):
|
255
|
-
"""Return a list of the first argument if it is a sequence, otherwise returns a list of all the arguments.
|
256
|
-
|
257
|
-
Example:
|
258
|
-
|
259
|
-
>>> Jobs._turn_args_to_list([1,2,3])
|
260
|
-
[1, 2, 3]
|
261
|
-
|
262
|
-
"""
|
263
|
-
|
264
|
-
def did_user_pass_a_sequence(args):
|
265
|
-
"""Return True if the user passed a sequence, False otherwise.
|
266
|
-
|
267
|
-
Example:
|
268
|
-
|
269
|
-
>>> did_user_pass_a_sequence([1,2,3])
|
270
|
-
True
|
271
|
-
|
272
|
-
>>> did_user_pass_a_sequence(1)
|
273
|
-
False
|
274
|
-
"""
|
275
|
-
return len(args) == 1 and isinstance(args[0], Sequence)
|
276
|
-
|
277
|
-
if did_user_pass_a_sequence(args):
|
278
|
-
container_class = Jobs._get_container_class(args[0][0])
|
279
|
-
return container_class(args[0])
|
280
|
-
else:
|
281
|
-
container_class = Jobs._get_container_class(args[0])
|
282
|
-
return container_class(args)
|
283
|
-
|
284
|
-
def _get_current_objects_of_this_type(
|
285
|
-
self, object: Union["Agent", "Scenario", "LanguageModel"]
|
286
|
-
) -> tuple[list, str]:
|
341
|
+
def replace_missing_objects(self) -> None:
|
287
342
|
from edsl.agents.Agent import Agent
|
343
|
+
from edsl.language_models.model import Model
|
288
344
|
from edsl.scenarios.Scenario import Scenario
|
289
|
-
from edsl.language_models.LanguageModel import LanguageModel
|
290
345
|
|
291
|
-
|
346
|
+
self.agents = self.agents or [Agent()]
|
347
|
+
self.models = self.models or [Model()]
|
348
|
+
self.scenarios = self.scenarios or [Scenario()]
|
292
349
|
|
293
|
-
|
294
|
-
>>> j = Jobs.example()
|
295
|
-
>>> j._get_current_objects_of_this_type(j.agents[0])
|
296
|
-
(AgentList([Agent(traits = {'status': 'Joyful'}), Agent(traits = {'status': 'Sad'})]), 'agents')
|
350
|
+
def generate_interviews(self) -> Generator[Interview, None, None]:
|
297
351
|
"""
|
298
|
-
|
299
|
-
Agent: "agents",
|
300
|
-
Scenario: "scenarios",
|
301
|
-
LanguageModel: "models",
|
302
|
-
}
|
303
|
-
for class_type in class_to_key:
|
304
|
-
if isinstance(object, class_type) or issubclass(
|
305
|
-
object.__class__, class_type
|
306
|
-
):
|
307
|
-
key = class_to_key[class_type]
|
308
|
-
break
|
309
|
-
else:
|
310
|
-
raise ValueError(
|
311
|
-
f"First argument must be an Agent, Scenario, or LanguageModel, not {object}"
|
312
|
-
)
|
313
|
-
current_objects = getattr(self, key, None)
|
314
|
-
return current_objects, key
|
315
|
-
|
316
|
-
@staticmethod
|
317
|
-
def _get_empty_container_object(object):
|
318
|
-
from edsl.agents.AgentList import AgentList
|
319
|
-
from edsl.scenarios.ScenarioList import ScenarioList
|
352
|
+
Generate interviews.
|
320
353
|
|
321
|
-
|
322
|
-
|
323
|
-
|
354
|
+
Note that this sets the agents, model and scenarios if they have not been set. This is a side effect of the method.
|
355
|
+
This is useful because a user can create a job without setting the agents, models, or scenarios, and the job will still run,
|
356
|
+
with us filling in defaults.
|
324
357
|
|
325
|
-
@staticmethod
|
326
|
-
def _merge_objects(passed_objects, current_objects) -> list:
|
327
358
|
"""
|
328
|
-
|
359
|
+
from edsl.jobs.InterviewsConstructor import InterviewsConstructor
|
329
360
|
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
[5, 6, 7, 6, 7, 8, 7, 8, 9]
|
335
|
-
"""
|
336
|
-
new_objects = Jobs._get_empty_container_object(passed_objects[0])
|
337
|
-
for current_object in current_objects:
|
338
|
-
for new_object in passed_objects:
|
339
|
-
new_objects.append(current_object + new_object)
|
340
|
-
return new_objects
|
361
|
+
self.replace_missing_objects()
|
362
|
+
yield from InterviewsConstructor(
|
363
|
+
self, cache=self.run_config.environment.cache
|
364
|
+
).create_interviews()
|
341
365
|
|
342
366
|
def interviews(self) -> list[Interview]:
|
343
367
|
"""
|
@@ -353,13 +377,10 @@ class Jobs(Base):
|
|
353
377
|
>>> j.interviews()[0]
|
354
378
|
Interview(agent = Agent(traits = {'status': 'Joyful'}), survey = Survey(...), scenario = Scenario({'period': 'morning'}), model = Model(...))
|
355
379
|
"""
|
356
|
-
|
357
|
-
return self._interviews
|
358
|
-
else:
|
359
|
-
return list(self._create_interviews())
|
380
|
+
return list(self.generate_interviews())
|
360
381
|
|
361
382
|
@classmethod
|
362
|
-
def from_interviews(cls, interview_list):
|
383
|
+
def from_interviews(cls, interview_list) -> "Jobs":
|
363
384
|
"""Return a Jobs instance from a list of interviews.
|
364
385
|
|
365
386
|
This is useful when you have, say, a list of failed interviews and you want to create
|
@@ -373,34 +394,6 @@ class Jobs(Base):
|
|
373
394
|
jobs._interviews = interview_list
|
374
395
|
return jobs
|
375
396
|
|
376
|
-
def _create_interviews(self) -> Generator[Interview, None, None]:
|
377
|
-
"""
|
378
|
-
Generate interviews.
|
379
|
-
|
380
|
-
Note that this sets the agents, model and scenarios if they have not been set. This is a side effect of the method.
|
381
|
-
This is useful because a user can create a job without setting the agents, models, or scenarios, and the job will still run,
|
382
|
-
with us filling in defaults.
|
383
|
-
|
384
|
-
|
385
|
-
"""
|
386
|
-
# if no agents, models, or scenarios are set, set them to defaults
|
387
|
-
from edsl.agents.Agent import Agent
|
388
|
-
from edsl.language_models.registry import Model
|
389
|
-
from edsl.scenarios.Scenario import Scenario
|
390
|
-
|
391
|
-
self.agents = self.agents or [Agent()]
|
392
|
-
self.models = self.models or [Model()]
|
393
|
-
self.scenarios = self.scenarios or [Scenario()]
|
394
|
-
for agent, scenario, model in product(self.agents, self.scenarios, self.models):
|
395
|
-
yield Interview(
|
396
|
-
survey=self.survey,
|
397
|
-
agent=agent,
|
398
|
-
scenario=scenario,
|
399
|
-
model=model,
|
400
|
-
skip_retry=self.skip_retry,
|
401
|
-
raise_validation_errors=self.raise_validation_errors,
|
402
|
-
)
|
403
|
-
|
404
397
|
def create_bucket_collection(self) -> BucketCollection:
|
405
398
|
"""
|
406
399
|
Create a collection of buckets for each model.
|
@@ -414,17 +407,7 @@ class Jobs(Base):
|
|
414
407
|
>>> bc
|
415
408
|
BucketCollection(...)
|
416
409
|
"""
|
417
|
-
|
418
|
-
for model in self.models:
|
419
|
-
bucket_collection.add_model(model)
|
420
|
-
return bucket_collection
|
421
|
-
|
422
|
-
@property
|
423
|
-
def bucket_collection(self) -> BucketCollection:
|
424
|
-
"""Return the bucket collection. If it does not exist, create it."""
|
425
|
-
if self.__bucket_collection is None:
|
426
|
-
self.__bucket_collection = self.create_bucket_collection()
|
427
|
-
return self.__bucket_collection
|
410
|
+
return BucketCollection.from_models(self.models)
|
428
411
|
|
429
412
|
def html(self):
|
430
413
|
"""Return the HTML representations for each scenario"""
|
@@ -451,89 +434,27 @@ class Jobs(Base):
|
|
451
434
|
|
452
435
|
def _output(self, message) -> None:
|
453
436
|
"""Check if a Job is verbose. If so, print the message."""
|
454
|
-
if
|
437
|
+
if self.run_config.parameters.verbose:
|
455
438
|
print(message)
|
439
|
+
# if hasattr(self, "verbose") and self.verbose:
|
440
|
+
# print(message)
|
456
441
|
|
457
|
-
def
|
458
|
-
"""
|
459
|
-
|
460
|
-
>>>
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
>>> j = Jobs(survey = Survey(questions=[q]))
|
465
|
-
>>> with warnings.catch_warnings(record=True) as w:
|
466
|
-
... j._check_parameters(warn = True)
|
467
|
-
... assert len(w) == 1
|
468
|
-
... assert issubclass(w[-1].category, UserWarning)
|
469
|
-
... assert "The following parameters are in the survey but not in the scenarios" in str(w[-1].message)
|
470
|
-
|
471
|
-
>>> q = QuestionFreeText(question_text = "{{poo}}", question_name = "ugly_question")
|
472
|
-
>>> s = Scenario({'plop': "A", 'poo': "B"})
|
473
|
-
>>> j = Jobs(survey = Survey(questions=[q])).by(s)
|
474
|
-
>>> j._check_parameters(strict = True)
|
475
|
-
Traceback (most recent call last):
|
476
|
-
...
|
477
|
-
ValueError: The following parameters are in the scenarios but not in the survey: {'plop'}
|
478
|
-
|
479
|
-
>>> q = QuestionFreeText(question_text = "Hello", question_name = "ugly_question")
|
480
|
-
>>> s = Scenario({'ugly_question': "B"})
|
481
|
-
>>> j = Jobs(survey = Survey(questions=[q])).by(s)
|
482
|
-
>>> j._check_parameters()
|
483
|
-
Traceback (most recent call last):
|
484
|
-
...
|
485
|
-
ValueError: The following names are in both the survey question_names and the scenario keys: {'ugly_question'}. This will create issues.
|
486
|
-
"""
|
487
|
-
survey_parameters: set = self.survey.parameters
|
488
|
-
scenario_parameters: set = self.scenarios.parameters
|
489
|
-
|
490
|
-
msg0, msg1, msg2 = None, None, None
|
491
|
-
|
492
|
-
# look for key issues
|
493
|
-
if intersection := set(self.scenarios.parameters) & set(
|
494
|
-
self.survey.question_names
|
495
|
-
):
|
496
|
-
msg0 = f"The following names are in both the survey question_names and the scenario keys: {intersection}. This will create issues."
|
497
|
-
|
498
|
-
raise ValueError(msg0)
|
499
|
-
|
500
|
-
if in_survey_but_not_in_scenarios := survey_parameters - scenario_parameters:
|
501
|
-
msg1 = f"The following parameters are in the survey but not in the scenarios: {in_survey_but_not_in_scenarios}"
|
502
|
-
if in_scenarios_but_not_in_survey := scenario_parameters - survey_parameters:
|
503
|
-
msg2 = f"The following parameters are in the scenarios but not in the survey: {in_scenarios_but_not_in_survey}"
|
504
|
-
|
505
|
-
if msg1 or msg2:
|
506
|
-
message = "\n".join(filter(None, [msg1, msg2]))
|
507
|
-
if strict:
|
508
|
-
raise ValueError(message)
|
509
|
-
else:
|
510
|
-
if warn:
|
511
|
-
warnings.warn(message)
|
512
|
-
|
513
|
-
if self.scenarios.has_jinja_braces:
|
514
|
-
warnings.warn(
|
515
|
-
"The scenarios have Jinja braces ({{ and }}). Converting to '<<' and '>>'. If you want a different conversion, use the convert_jinja_braces method first to modify the scenario."
|
516
|
-
)
|
517
|
-
self.scenarios = self.scenarios.convert_jinja_braces()
|
518
|
-
|
519
|
-
@property
|
520
|
-
def skip_retry(self):
|
521
|
-
if not hasattr(self, "_skip_retry"):
|
522
|
-
return False
|
523
|
-
return self._skip_retry
|
442
|
+
def all_question_parameters(self) -> set:
|
443
|
+
"""Return all the fields in the questions in the survey.
|
444
|
+
>>> from edsl.jobs import Jobs
|
445
|
+
>>> Jobs.example().all_question_parameters()
|
446
|
+
{'period'}
|
447
|
+
"""
|
448
|
+
return set.union(*[question.parameters for question in self.survey.questions])
|
524
449
|
|
525
|
-
|
526
|
-
|
527
|
-
if not hasattr(self, "_raise_validation_errors"):
|
528
|
-
return False
|
529
|
-
return self._raise_validation_errors
|
450
|
+
def use_remote_cache(self) -> bool:
|
451
|
+
import requests
|
530
452
|
|
531
|
-
|
532
|
-
if disable_remote_cache:
|
453
|
+
if self.run_config.parameters.disable_remote_cache:
|
533
454
|
return False
|
534
|
-
if not disable_remote_cache:
|
455
|
+
if not self.run_config.parameters.disable_remote_cache:
|
535
456
|
try:
|
536
|
-
from edsl import Coop
|
457
|
+
from edsl.coop.coop import Coop
|
537
458
|
|
538
459
|
user_edsl_settings = Coop().edsl_settings
|
539
460
|
return user_edsl_settings.get("remote_caching", False)
|
@@ -544,152 +465,173 @@ class Jobs(Base):
|
|
544
465
|
|
545
466
|
return False
|
546
467
|
|
547
|
-
def
|
468
|
+
def _remote_results(
|
548
469
|
self,
|
549
|
-
|
550
|
-
|
551
|
-
stop_on_exception: bool = False,
|
552
|
-
cache: Union[Cache, bool] = None,
|
553
|
-
check_api_keys: bool = False,
|
554
|
-
sidecar_model: Optional[LanguageModel] = None,
|
555
|
-
verbose: bool = True,
|
556
|
-
print_exceptions=True,
|
557
|
-
remote_cache_description: Optional[str] = None,
|
558
|
-
remote_inference_description: Optional[str] = None,
|
559
|
-
remote_inference_results_visibility: Optional[
|
560
|
-
Literal["private", "public", "unlisted"]
|
561
|
-
] = "unlisted",
|
562
|
-
skip_retry: bool = False,
|
563
|
-
raise_validation_errors: bool = False,
|
564
|
-
disable_remote_cache: bool = False,
|
565
|
-
disable_remote_inference: bool = False,
|
566
|
-
) -> Results:
|
567
|
-
"""
|
568
|
-
Runs the Job: conducts Interviews and returns their results.
|
569
|
-
|
570
|
-
:param n: How many times to run each interview
|
571
|
-
:param progress_bar: Whether to show a progress bar
|
572
|
-
:param stop_on_exception: Stops the job if an exception is raised
|
573
|
-
:param cache: A Cache object to store results
|
574
|
-
:param check_api_keys: Raises an error if API keys are invalid
|
575
|
-
:param verbose: Prints extra messages
|
576
|
-
:param remote_cache_description: Specifies a description for this group of entries in the remote cache
|
577
|
-
:param remote_inference_description: Specifies a description for the remote inference job
|
578
|
-
:param remote_inference_results_visibility: The initial visibility of the Results object on Coop. This will only be used for remote jobs!
|
579
|
-
:param disable_remote_cache: If True, the job will not use remote cache. This only works for local jobs!
|
580
|
-
:param disable_remote_inference: If True, the job will not use remote inference
|
581
|
-
"""
|
582
|
-
from edsl.coop.coop import Coop
|
583
|
-
|
584
|
-
self._check_parameters()
|
585
|
-
self._skip_retry = skip_retry
|
586
|
-
self._raise_validation_errors = raise_validation_errors
|
470
|
+
) -> Union["Results", None]:
|
471
|
+
from edsl.jobs.JobsRemoteInferenceHandler import JobsRemoteInferenceHandler
|
587
472
|
|
588
|
-
|
473
|
+
jh = JobsRemoteInferenceHandler(
|
474
|
+
self, verbose=self.run_config.parameters.verbose
|
475
|
+
)
|
476
|
+
if jh.use_remote_inference(self.run_config.parameters.disable_remote_inference):
|
477
|
+
job_info = jh.create_remote_inference_job(
|
478
|
+
iterations=self.run_config.parameters.n,
|
479
|
+
remote_inference_description=self.run_config.parameters.remote_inference_description,
|
480
|
+
remote_inference_results_visibility=self.run_config.parameters.remote_inference_results_visibility,
|
481
|
+
)
|
482
|
+
results = jh.poll_remote_inference_job(job_info)
|
483
|
+
return results
|
484
|
+
else:
|
485
|
+
return None
|
589
486
|
|
590
|
-
|
487
|
+
def _prepare_to_run(self) -> None:
|
488
|
+
"This makes sure that the job is ready to run and that keys are in place for a remote job."
|
489
|
+
CheckSurveyScenarioCompatibility(self.survey, self.scenarios).check()
|
591
490
|
|
491
|
+
def _check_if_remote_keys_ok(self):
|
592
492
|
jc = JobsChecks(self)
|
593
|
-
|
594
|
-
# check if the user has all the keys they need
|
595
493
|
if jc.needs_key_process():
|
596
494
|
jc.key_process()
|
597
495
|
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
if jh.use_remote_inference(disable_remote_inference):
|
602
|
-
jh.create_remote_inference_job(
|
603
|
-
iterations=n,
|
604
|
-
remote_inference_description=remote_inference_description,
|
605
|
-
remote_inference_results_visibility=remote_inference_results_visibility,
|
606
|
-
)
|
607
|
-
results = jh.poll_remote_inference_job()
|
608
|
-
return results
|
609
|
-
|
610
|
-
if check_api_keys:
|
496
|
+
def _check_if_local_keys_ok(self):
|
497
|
+
jc = JobsChecks(self)
|
498
|
+
if self.run_config.parameters.check_api_keys:
|
611
499
|
jc.check_api_keys()
|
612
500
|
|
613
|
-
|
614
|
-
if cache is None or cache is True:
|
615
|
-
from edsl.data.CacheHandler import CacheHandler
|
501
|
+
async def _execute_with_remote_cache(self, run_job_async: bool) -> Results:
|
616
502
|
|
617
|
-
|
618
|
-
|
619
|
-
|
503
|
+
use_remote_cache = self.use_remote_cache()
|
504
|
+
|
505
|
+
from edsl.coop.coop import Coop
|
506
|
+
from edsl.jobs.runners.JobsRunnerAsyncio import JobsRunnerAsyncio
|
507
|
+
from edsl.data.Cache import Cache
|
620
508
|
|
621
|
-
|
509
|
+
assert isinstance(self.run_config.environment.cache, Cache)
|
622
510
|
|
623
|
-
remote_cache = self.use_remote_cache(disable_remote_cache)
|
624
511
|
with RemoteCacheSync(
|
625
512
|
coop=Coop(),
|
626
|
-
cache=cache,
|
513
|
+
cache=self.run_config.environment.cache,
|
627
514
|
output_func=self._output,
|
628
|
-
remote_cache=
|
629
|
-
remote_cache_description=remote_cache_description,
|
630
|
-
)
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
sidecar_model=sidecar_model,
|
637
|
-
print_exceptions=print_exceptions,
|
638
|
-
raise_validation_errors=raise_validation_errors,
|
639
|
-
)
|
640
|
-
|
641
|
-
# results.cache = cache.new_entries_cache()
|
515
|
+
remote_cache=use_remote_cache,
|
516
|
+
remote_cache_description=self.run_config.parameters.remote_cache_description,
|
517
|
+
):
|
518
|
+
runner = JobsRunnerAsyncio(self, environment=self.run_config.environment)
|
519
|
+
if run_job_async:
|
520
|
+
results = await runner.run_async(self.run_config.parameters)
|
521
|
+
else:
|
522
|
+
results = runner.run(self.run_config.parameters)
|
642
523
|
return results
|
643
524
|
|
644
|
-
|
645
|
-
self,
|
646
|
-
cache=None,
|
647
|
-
n=1,
|
648
|
-
disable_remote_inference: bool = False,
|
649
|
-
remote_inference_description: Optional[str] = None,
|
650
|
-
remote_inference_results_visibility: Optional[
|
651
|
-
Literal["private", "public", "unlisted"]
|
652
|
-
] = "unlisted",
|
653
|
-
**kwargs,
|
654
|
-
):
|
655
|
-
"""Run the job asynchronously, either locally or remotely.
|
525
|
+
def _setup_and_check(self) -> Tuple[RunConfig, Optional[Results]]:
|
656
526
|
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
:return: Results object
|
664
|
-
"""
|
665
|
-
# Check if we should use remote inference
|
666
|
-
from edsl.jobs.JobsRemoteInferenceHandler import JobsRemoteInferenceHandler
|
527
|
+
self._prepare_to_run()
|
528
|
+
self._check_if_remote_keys_ok()
|
529
|
+
|
530
|
+
# first try to run the job remotely
|
531
|
+
if results := self._remote_results():
|
532
|
+
return results
|
667
533
|
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
534
|
+
self._check_if_local_keys_ok()
|
535
|
+
return None
|
536
|
+
|
537
|
+
@property
|
538
|
+
def num_interviews(self):
|
539
|
+
if self.run_config.parameters.n is None:
|
540
|
+
return len(self)
|
541
|
+
else:
|
542
|
+
len(self) * self.run_config.parameters.n
|
543
|
+
|
544
|
+
def _run(self, config: RunConfig):
|
545
|
+
"Shared code for run and run_async"
|
546
|
+
if config.environment.cache is not None:
|
547
|
+
self.run_config.environment.cache = config.environment.cache
|
548
|
+
|
549
|
+
if config.environment.bucket_collection is not None:
|
550
|
+
self.run_config.environment.bucket_collection = (
|
551
|
+
config.environment.bucket_collection
|
674
552
|
)
|
553
|
+
|
554
|
+
if config.environment.key_lookup is not None:
|
555
|
+
self.run_config.environment.key_lookup = config.environment.key_lookup
|
556
|
+
|
557
|
+
# replace the parameters with the ones from the config
|
558
|
+
self.run_config.parameters = config.parameters
|
559
|
+
|
560
|
+
self.replace_missing_objects()
|
561
|
+
|
562
|
+
# try to run remotely first
|
563
|
+
self._prepare_to_run()
|
564
|
+
self._check_if_remote_keys_ok()
|
565
|
+
|
566
|
+
if (
|
567
|
+
self.run_config.environment.cache is None
|
568
|
+
or self.run_config.environment.cache is True
|
569
|
+
):
|
570
|
+
from edsl.data.CacheHandler import CacheHandler
|
571
|
+
|
572
|
+
self.run_config.environment.cache = CacheHandler().get_cache()
|
573
|
+
|
574
|
+
if self.run_config.environment.cache is False:
|
575
|
+
from edsl.data.Cache import Cache
|
576
|
+
|
577
|
+
self.run_config.environment.cache = Cache(immediate_write=False)
|
578
|
+
|
579
|
+
# first try to run the job remotely
|
580
|
+
if results := self._remote_results():
|
675
581
|
return results
|
676
582
|
|
677
|
-
|
678
|
-
return await JobsRunnerAsyncio(self).run_async(cache=cache, n=n, **kwargs)
|
583
|
+
self._check_if_local_keys_ok()
|
679
584
|
|
680
|
-
|
681
|
-
|
585
|
+
if config.environment.bucket_collection is None:
|
586
|
+
self.run_config.environment.bucket_collection = (
|
587
|
+
self.create_bucket_collection()
|
588
|
+
)
|
682
589
|
|
683
|
-
|
684
|
-
|
590
|
+
@with_config
|
591
|
+
def run(self, *, config: RunConfig) -> "Results":
|
592
|
+
"""
|
593
|
+
Runs the Job: conducts Interviews and returns their results.
|
685
594
|
|
686
|
-
|
687
|
-
|
688
|
-
|
689
|
-
|
690
|
-
|
595
|
+
:param n: How many times to run each interview
|
596
|
+
:param progress_bar: Whether to show a progress bar
|
597
|
+
:param stop_on_exception: Stops the job if an exception is raised
|
598
|
+
:param check_api_keys: Raises an error if API keys are invalid
|
599
|
+
:param verbose: Prints extra messages
|
600
|
+
:param remote_cache_description: Specifies a description for this group of entries in the remote cache
|
601
|
+
:param remote_inference_description: Specifies a description for the remote inference job
|
602
|
+
:param remote_inference_results_visibility: The initial visibility of the Results object on Coop. This will only be used for remote jobs!
|
603
|
+
:param disable_remote_cache: If True, the job will not use remote cache. This only works for local jobs!
|
604
|
+
:param disable_remote_inference: If True, the job will not use remote inference
|
605
|
+
:param cache: A Cache object to store results
|
606
|
+
:param bucket_collection: A BucketCollection object to track API calls
|
607
|
+
:param key_lookup: A KeyLookup object to manage API keys
|
691
608
|
"""
|
692
|
-
|
609
|
+
self._run(config)
|
610
|
+
|
611
|
+
return asyncio.run(self._execute_with_remote_cache(run_job_async=False))
|
612
|
+
|
613
|
+
@with_config
|
614
|
+
async def run_async(self, *, config: RunConfig) -> "Results":
|
615
|
+
"""
|
616
|
+
Runs the Job: conducts Interviews and returns their results.
|
617
|
+
|
618
|
+
:param n: How many times to run each interview
|
619
|
+
:param progress_bar: Whether to show a progress bar
|
620
|
+
:param stop_on_exception: Stops the job if an exception is raised
|
621
|
+
:param check_api_keys: Raises an error if API keys are invalid
|
622
|
+
:param verbose: Prints extra messages
|
623
|
+
:param remote_cache_description: Specifies a description for this group of entries in the remote cache
|
624
|
+
:param remote_inference_description: Specifies a description for the remote inference job
|
625
|
+
:param remote_inference_results_visibility: The initial visibility of the Results object on Coop. This will only be used for remote jobs!
|
626
|
+
:param disable_remote_cache: If True, the job will not use remote cache. This only works for local jobs!
|
627
|
+
:param disable_remote_inference: If True, the job will not use remote inference
|
628
|
+
:param cache: A Cache object to store results
|
629
|
+
:param bucket_collection: A BucketCollection object to track API calls
|
630
|
+
:param key_lookup: A KeyLookup object to manage API keys
|
631
|
+
"""
|
632
|
+
self._run(config)
|
633
|
+
|
634
|
+
return await self._execute_with_remote_cache(run_job_async=True)
|
693
635
|
|
694
636
|
def __repr__(self) -> str:
|
695
637
|
"""Return an eval-able string representation of the Jobs instance."""
|
@@ -697,17 +639,12 @@ class Jobs(Base):
|
|
697
639
|
|
698
640
|
def _summary(self):
|
699
641
|
return {
|
700
|
-
"
|
701
|
-
"
|
702
|
-
"
|
703
|
-
"
|
704
|
-
"Number of scenarios": len(self.scenarios),
|
642
|
+
"questions": len(self.survey),
|
643
|
+
"agents": len(self.agents or [1]),
|
644
|
+
"models": len(self.models or [1]),
|
645
|
+
"scenarios": len(self.scenarios or [1]),
|
705
646
|
}
|
706
647
|
|
707
|
-
def _repr_html_(self) -> str:
|
708
|
-
footer = f"<a href={self.__documentation__}>(docs)</a>"
|
709
|
-
return str(self.summary(format="html")) + footer
|
710
|
-
|
711
648
|
def __len__(self) -> int:
|
712
649
|
"""Return the maximum number of questions that will be asked while running this job.
|
713
650
|
Note that this is the maximum number of questions, not the actual number of questions that will be asked, as some questions may be skipped.
|
@@ -724,10 +661,6 @@ class Jobs(Base):
|
|
724
661
|
)
|
725
662
|
return number_of_questions
|
726
663
|
|
727
|
-
#######################
|
728
|
-
# Serialization methods
|
729
|
-
#######################
|
730
|
-
|
731
664
|
def to_dict(self, add_edsl_version=True):
|
732
665
|
d = {
|
733
666
|
"survey": self.survey.to_dict(add_edsl_version=add_edsl_version),
|
@@ -752,11 +685,14 @@ class Jobs(Base):
|
|
752
685
|
|
753
686
|
return d
|
754
687
|
|
688
|
+
def table(self):
|
689
|
+
return self.prompts().to_scenario_list().table()
|
690
|
+
|
755
691
|
@classmethod
|
756
692
|
@remove_edsl_version
|
757
693
|
def from_dict(cls, data: dict) -> Jobs:
|
758
694
|
"""Creates a Jobs instance from a dictionary."""
|
759
|
-
from edsl import Survey
|
695
|
+
from edsl.surveys.Survey import Survey
|
760
696
|
from edsl.agents.Agent import Agent
|
761
697
|
from edsl.language_models.LanguageModel import LanguageModel
|
762
698
|
from edsl.scenarios.Scenario import Scenario
|
@@ -778,9 +714,6 @@ class Jobs(Base):
|
|
778
714
|
"""
|
779
715
|
return hash(self) == hash(other)
|
780
716
|
|
781
|
-
#######################
|
782
|
-
# Example methods
|
783
|
-
#######################
|
784
717
|
@classmethod
|
785
718
|
def example(
|
786
719
|
cls,
|
@@ -800,14 +733,14 @@ class Jobs(Base):
|
|
800
733
|
"""
|
801
734
|
import random
|
802
735
|
from uuid import uuid4
|
803
|
-
from edsl.questions import QuestionMultipleChoice
|
736
|
+
from edsl.questions.QuestionMultipleChoice import QuestionMultipleChoice
|
804
737
|
from edsl.agents.Agent import Agent
|
805
738
|
from edsl.scenarios.Scenario import Scenario
|
806
739
|
|
807
740
|
addition = "" if not randomize else str(uuid4())
|
808
741
|
|
809
742
|
if test_model:
|
810
|
-
from edsl.language_models import LanguageModel
|
743
|
+
from edsl.language_models.LanguageModel import LanguageModel
|
811
744
|
|
812
745
|
m = LanguageModel.example(test_model=True)
|
813
746
|
|
@@ -848,7 +781,8 @@ class Jobs(Base):
|
|
848
781
|
question_options=["Good", "Great", "OK", "Terrible"],
|
849
782
|
question_name="how_feeling_yesterday",
|
850
783
|
)
|
851
|
-
from edsl import Survey
|
784
|
+
from edsl.surveys.Survey import Survey
|
785
|
+
from edsl.scenarios.ScenarioList import ScenarioList
|
852
786
|
|
853
787
|
base_survey = Survey(questions=[q1, q2])
|
854
788
|
|
@@ -865,15 +799,6 @@ class Jobs(Base):
|
|
865
799
|
|
866
800
|
return job
|
867
801
|
|
868
|
-
def rich_print(self):
|
869
|
-
"""Print a rich representation of the Jobs instance."""
|
870
|
-
from rich.table import Table
|
871
|
-
|
872
|
-
table = Table(title="Jobs")
|
873
|
-
table.add_column("Jobs")
|
874
|
-
table.add_row(self.survey.rich_print())
|
875
|
-
return table
|
876
|
-
|
877
802
|
def code(self):
|
878
803
|
"""Return the code to create this instance."""
|
879
804
|
raise NotImplementedError
|
@@ -881,7 +806,7 @@ class Jobs(Base):
|
|
881
806
|
|
882
807
|
def main():
|
883
808
|
"""Run the module's doctests."""
|
884
|
-
from edsl.jobs import Jobs
|
809
|
+
from edsl.jobs.Jobs import Jobs
|
885
810
|
from edsl.data.Cache import Cache
|
886
811
|
|
887
812
|
job = Jobs.example()
|