edsl 0.1.37.dev3__py3-none-any.whl → 0.1.37.dev5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +303 -303
- edsl/BaseDiff.py +260 -260
- edsl/TemplateLoader.py +24 -24
- edsl/__init__.py +48 -48
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +855 -804
- edsl/agents/AgentList.py +350 -345
- edsl/agents/Invigilator.py +222 -222
- edsl/agents/InvigilatorBase.py +284 -305
- edsl/agents/PromptConstructor.py +353 -312
- edsl/agents/__init__.py +3 -3
- edsl/agents/descriptors.py +99 -86
- edsl/agents/prompt_helpers.py +129 -129
- edsl/auto/AutoStudy.py +117 -117
- edsl/auto/StageBase.py +230 -230
- edsl/auto/StageGenerateSurvey.py +178 -178
- edsl/auto/StageLabelQuestions.py +125 -125
- edsl/auto/StagePersona.py +61 -61
- edsl/auto/StagePersonaDimensionValueRanges.py +88 -88
- edsl/auto/StagePersonaDimensionValues.py +74 -74
- edsl/auto/StagePersonaDimensions.py +69 -69
- edsl/auto/StageQuestions.py +73 -73
- edsl/auto/SurveyCreatorPipeline.py +21 -21
- edsl/auto/utilities.py +224 -224
- edsl/base/Base.py +289 -289
- edsl/config.py +149 -149
- edsl/conjure/AgentConstructionMixin.py +160 -152
- edsl/conjure/Conjure.py +62 -62
- edsl/conjure/InputData.py +659 -659
- edsl/conjure/InputDataCSV.py +48 -48
- edsl/conjure/InputDataMixinQuestionStats.py +182 -182
- edsl/conjure/InputDataPyRead.py +91 -91
- edsl/conjure/InputDataSPSS.py +8 -8
- edsl/conjure/InputDataStata.py +8 -8
- edsl/conjure/QuestionOptionMixin.py +76 -76
- edsl/conjure/QuestionTypeMixin.py +23 -23
- edsl/conjure/RawQuestion.py +65 -65
- edsl/conjure/SurveyResponses.py +7 -7
- edsl/conjure/__init__.py +9 -9
- edsl/conjure/naming_utilities.py +263 -263
- edsl/conjure/utilities.py +201 -201
- edsl/conversation/Conversation.py +290 -238
- edsl/conversation/car_buying.py +58 -58
- edsl/conversation/chips.py +95 -0
- edsl/conversation/mug_negotiation.py +81 -81
- edsl/conversation/next_speaker_utilities.py +93 -93
- edsl/coop/PriceFetcher.py +54 -54
- edsl/coop/__init__.py +2 -2
- edsl/coop/coop.py +958 -824
- edsl/coop/utils.py +131 -131
- edsl/data/Cache.py +527 -527
- edsl/data/CacheEntry.py +228 -228
- edsl/data/CacheHandler.py +149 -149
- edsl/data/RemoteCacheSync.py +97 -97
- edsl/data/SQLiteDict.py +292 -292
- edsl/data/__init__.py +4 -4
- edsl/data/orm.py +10 -10
- edsl/data_transfer_models.py +73 -73
- edsl/enums.py +173 -173
- edsl/exceptions/BaseException.py +21 -0
- edsl/exceptions/__init__.py +54 -50
- edsl/exceptions/agents.py +38 -40
- edsl/exceptions/configuration.py +16 -16
- edsl/exceptions/coop.py +10 -10
- edsl/exceptions/data.py +14 -14
- edsl/exceptions/general.py +34 -34
- edsl/exceptions/jobs.py +33 -33
- edsl/exceptions/language_models.py +63 -63
- edsl/exceptions/prompts.py +15 -15
- edsl/exceptions/questions.py +91 -91
- edsl/exceptions/results.py +29 -26
- edsl/exceptions/scenarios.py +22 -0
- edsl/exceptions/surveys.py +37 -34
- edsl/inference_services/AnthropicService.py +87 -87
- edsl/inference_services/AwsBedrock.py +120 -115
- edsl/inference_services/AzureAI.py +217 -217
- edsl/inference_services/DeepInfraService.py +18 -18
- edsl/inference_services/GoogleService.py +156 -156
- edsl/inference_services/GroqService.py +20 -20
- edsl/inference_services/InferenceServiceABC.py +147 -147
- edsl/inference_services/InferenceServicesCollection.py +97 -74
- edsl/inference_services/MistralAIService.py +123 -123
- edsl/inference_services/OllamaService.py +18 -18
- edsl/inference_services/OpenAIService.py +224 -224
- edsl/inference_services/TestService.py +89 -89
- edsl/inference_services/TogetherAIService.py +170 -170
- edsl/inference_services/models_available_cache.py +118 -118
- edsl/inference_services/rate_limits_cache.py +25 -25
- edsl/inference_services/registry.py +39 -39
- edsl/inference_services/write_available.py +10 -10
- edsl/jobs/Answers.py +56 -56
- edsl/jobs/Jobs.py +1347 -1121
- edsl/jobs/__init__.py +1 -1
- edsl/jobs/buckets/BucketCollection.py +63 -63
- edsl/jobs/buckets/ModelBuckets.py +65 -65
- edsl/jobs/buckets/TokenBucket.py +248 -248
- edsl/jobs/interviews/Interview.py +661 -661
- edsl/jobs/interviews/InterviewExceptionCollection.py +99 -99
- edsl/jobs/interviews/InterviewExceptionEntry.py +186 -182
- edsl/jobs/interviews/InterviewStatistic.py +63 -63
- edsl/jobs/interviews/InterviewStatisticsCollection.py +25 -25
- edsl/jobs/interviews/InterviewStatusDictionary.py +78 -78
- edsl/jobs/interviews/InterviewStatusLog.py +92 -92
- edsl/jobs/interviews/ReportErrors.py +66 -66
- edsl/jobs/interviews/interview_status_enum.py +9 -9
- edsl/jobs/runners/JobsRunnerAsyncio.py +338 -338
- edsl/jobs/runners/JobsRunnerStatus.py +332 -332
- edsl/jobs/tasks/QuestionTaskCreator.py +242 -242
- edsl/jobs/tasks/TaskCreators.py +64 -64
- edsl/jobs/tasks/TaskHistory.py +442 -441
- edsl/jobs/tasks/TaskStatusLog.py +23 -23
- edsl/jobs/tasks/task_status_enum.py +163 -163
- edsl/jobs/tokens/InterviewTokenUsage.py +27 -27
- edsl/jobs/tokens/TokenUsage.py +34 -34
- edsl/language_models/KeyLookup.py +30 -0
- edsl/language_models/LanguageModel.py +706 -718
- edsl/language_models/ModelList.py +102 -102
- edsl/language_models/RegisterLanguageModelsMeta.py +184 -184
- edsl/language_models/__init__.py +3 -2
- edsl/language_models/fake_openai_call.py +15 -15
- edsl/language_models/fake_openai_service.py +61 -61
- edsl/language_models/registry.py +137 -137
- edsl/language_models/repair.py +156 -156
- edsl/language_models/unused/ReplicateBase.py +83 -83
- edsl/language_models/utilities.py +64 -64
- edsl/notebooks/Notebook.py +259 -259
- edsl/notebooks/__init__.py +1 -1
- edsl/prompts/Prompt.py +357 -353
- edsl/prompts/__init__.py +2 -2
- edsl/questions/AnswerValidatorMixin.py +289 -289
- edsl/questions/QuestionBase.py +656 -616
- edsl/questions/QuestionBaseGenMixin.py +161 -161
- edsl/questions/QuestionBasePromptsMixin.py +234 -266
- edsl/questions/QuestionBudget.py +227 -227
- edsl/questions/QuestionCheckBox.py +359 -359
- edsl/questions/QuestionExtract.py +183 -183
- edsl/questions/QuestionFreeText.py +114 -114
- edsl/questions/QuestionFunctional.py +159 -159
- edsl/questions/QuestionList.py +231 -231
- edsl/questions/QuestionMultipleChoice.py +286 -286
- edsl/questions/QuestionNumerical.py +153 -153
- edsl/questions/QuestionRank.py +324 -324
- edsl/questions/Quick.py +41 -41
- edsl/questions/RegisterQuestionsMeta.py +71 -71
- edsl/questions/ResponseValidatorABC.py +174 -174
- edsl/questions/SimpleAskMixin.py +73 -73
- edsl/questions/__init__.py +26 -26
- edsl/questions/compose_questions.py +98 -98
- edsl/questions/decorators.py +21 -21
- edsl/questions/derived/QuestionLikertFive.py +76 -76
- edsl/questions/derived/QuestionLinearScale.py +87 -87
- edsl/questions/derived/QuestionTopK.py +91 -91
- edsl/questions/derived/QuestionYesNo.py +82 -82
- edsl/questions/descriptors.py +413 -418
- edsl/questions/prompt_templates/question_budget.jinja +13 -13
- edsl/questions/prompt_templates/question_checkbox.jinja +32 -32
- edsl/questions/prompt_templates/question_extract.jinja +11 -11
- edsl/questions/prompt_templates/question_free_text.jinja +3 -3
- edsl/questions/prompt_templates/question_linear_scale.jinja +11 -11
- edsl/questions/prompt_templates/question_list.jinja +17 -17
- edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -33
- edsl/questions/prompt_templates/question_numerical.jinja +36 -36
- edsl/questions/question_registry.py +147 -147
- edsl/questions/settings.py +12 -12
- edsl/questions/templates/budget/answering_instructions.jinja +7 -7
- edsl/questions/templates/budget/question_presentation.jinja +7 -7
- edsl/questions/templates/checkbox/answering_instructions.jinja +10 -10
- edsl/questions/templates/checkbox/question_presentation.jinja +22 -22
- edsl/questions/templates/extract/answering_instructions.jinja +7 -7
- edsl/questions/templates/likert_five/answering_instructions.jinja +10 -10
- edsl/questions/templates/likert_five/question_presentation.jinja +11 -11
- edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -5
- edsl/questions/templates/linear_scale/question_presentation.jinja +5 -5
- edsl/questions/templates/list/answering_instructions.jinja +3 -3
- edsl/questions/templates/list/question_presentation.jinja +5 -5
- edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -9
- edsl/questions/templates/multiple_choice/question_presentation.jinja +11 -11
- edsl/questions/templates/numerical/answering_instructions.jinja +6 -6
- edsl/questions/templates/numerical/question_presentation.jinja +6 -6
- edsl/questions/templates/rank/answering_instructions.jinja +11 -11
- edsl/questions/templates/rank/question_presentation.jinja +15 -15
- edsl/questions/templates/top_k/answering_instructions.jinja +8 -8
- edsl/questions/templates/top_k/question_presentation.jinja +22 -22
- edsl/questions/templates/yes_no/answering_instructions.jinja +6 -6
- edsl/questions/templates/yes_no/question_presentation.jinja +11 -11
- edsl/results/Dataset.py +293 -293
- edsl/results/DatasetExportMixin.py +717 -693
- edsl/results/DatasetTree.py +145 -145
- edsl/results/Result.py +450 -435
- edsl/results/Results.py +1071 -1160
- edsl/results/ResultsDBMixin.py +238 -238
- edsl/results/ResultsExportMixin.py +43 -43
- edsl/results/ResultsFetchMixin.py +33 -33
- edsl/results/ResultsGGMixin.py +121 -121
- edsl/results/ResultsToolsMixin.py +98 -98
- edsl/results/Selector.py +135 -118
- edsl/results/__init__.py +2 -2
- edsl/results/tree_explore.py +115 -115
- edsl/scenarios/FileStore.py +458 -458
- edsl/scenarios/Scenario.py +546 -510
- edsl/scenarios/ScenarioHtmlMixin.py +64 -59
- edsl/scenarios/ScenarioList.py +1112 -1101
- edsl/scenarios/ScenarioListExportMixin.py +52 -52
- edsl/scenarios/ScenarioListPdfMixin.py +261 -261
- edsl/scenarios/__init__.py +4 -4
- edsl/shared.py +1 -1
- edsl/study/ObjectEntry.py +173 -173
- edsl/study/ProofOfWork.py +113 -113
- edsl/study/SnapShot.py +80 -80
- edsl/study/Study.py +528 -528
- edsl/study/__init__.py +4 -4
- edsl/surveys/DAG.py +148 -148
- edsl/surveys/Memory.py +31 -31
- edsl/surveys/MemoryPlan.py +244 -244
- edsl/surveys/Rule.py +330 -324
- edsl/surveys/RuleCollection.py +387 -387
- edsl/surveys/Survey.py +1795 -1772
- edsl/surveys/SurveyCSS.py +261 -261
- edsl/surveys/SurveyExportMixin.py +259 -259
- edsl/surveys/SurveyFlowVisualizationMixin.py +121 -121
- edsl/surveys/SurveyQualtricsImport.py +284 -284
- edsl/surveys/__init__.py +3 -3
- edsl/surveys/base.py +53 -53
- edsl/surveys/descriptors.py +56 -56
- edsl/surveys/instructions/ChangeInstruction.py +47 -47
- edsl/surveys/instructions/Instruction.py +51 -51
- edsl/surveys/instructions/InstructionCollection.py +77 -77
- edsl/templates/error_reporting/base.html +23 -23
- edsl/templates/error_reporting/exceptions_by_model.html +34 -34
- edsl/templates/error_reporting/exceptions_by_question_name.html +16 -16
- edsl/templates/error_reporting/exceptions_by_type.html +16 -16
- edsl/templates/error_reporting/interview_details.html +115 -115
- edsl/templates/error_reporting/interviews.html +9 -9
- edsl/templates/error_reporting/overview.html +4 -4
- edsl/templates/error_reporting/performance_plot.html +1 -1
- edsl/templates/error_reporting/report.css +73 -73
- edsl/templates/error_reporting/report.html +117 -117
- edsl/templates/error_reporting/report.js +25 -25
- edsl/tools/__init__.py +1 -1
- edsl/tools/clusters.py +192 -192
- edsl/tools/embeddings.py +27 -27
- edsl/tools/embeddings_plotting.py +118 -118
- edsl/tools/plotting.py +112 -112
- edsl/tools/summarize.py +18 -18
- edsl/utilities/SystemInfo.py +28 -28
- edsl/utilities/__init__.py +22 -22
- edsl/utilities/ast_utilities.py +25 -25
- edsl/utilities/data/Registry.py +6 -6
- edsl/utilities/data/__init__.py +1 -1
- edsl/utilities/data/scooter_results.json +1 -1
- edsl/utilities/decorators.py +77 -77
- edsl/utilities/gcp_bucket/cloud_storage.py +96 -96
- edsl/utilities/interface.py +627 -627
- edsl/utilities/repair_functions.py +28 -28
- edsl/utilities/restricted_python.py +70 -70
- edsl/utilities/utilities.py +409 -391
- {edsl-0.1.37.dev3.dist-info → edsl-0.1.37.dev5.dist-info}/LICENSE +21 -21
- {edsl-0.1.37.dev3.dist-info → edsl-0.1.37.dev5.dist-info}/METADATA +1 -1
- edsl-0.1.37.dev5.dist-info/RECORD +283 -0
- edsl-0.1.37.dev3.dist-info/RECORD +0 -279
- {edsl-0.1.37.dev3.dist-info → edsl-0.1.37.dev5.dist-info}/WHEEL +0 -0
edsl/jobs/Jobs.py
CHANGED
@@ -1,1121 +1,1347 @@
|
|
1
|
-
# """The Jobs class is a collection of agents, scenarios and models and one survey."""
|
2
|
-
from __future__ import annotations
|
3
|
-
import warnings
|
4
|
-
import requests
|
5
|
-
from itertools import product
|
6
|
-
from typing import Optional, Union, Sequence, Generator
|
7
|
-
|
8
|
-
from edsl.Base import Base
|
9
|
-
|
10
|
-
from edsl.
|
11
|
-
from edsl.jobs.
|
12
|
-
from edsl.jobs.
|
13
|
-
from edsl.
|
14
|
-
|
15
|
-
|
16
|
-
from edsl.
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
The
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
:param
|
37
|
-
:param
|
38
|
-
:param
|
39
|
-
|
40
|
-
|
41
|
-
self.
|
42
|
-
self.
|
43
|
-
self.
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
"
|
102
|
-
"
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
>>> from edsl import
|
113
|
-
>>>
|
114
|
-
>>>
|
115
|
-
>>> j
|
116
|
-
|
117
|
-
|
118
|
-
>>>
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
-
|
126
|
-
-
|
127
|
-
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
>>> Jobs
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
{"
|
202
|
-
{"
|
203
|
-
{"
|
204
|
-
{"
|
205
|
-
{"
|
206
|
-
{"
|
207
|
-
{"
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
"
|
256
|
-
)
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
)
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
df
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
)
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
|
656
|
-
|
657
|
-
|
658
|
-
>>>
|
659
|
-
>>>
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
...
|
664
|
-
|
665
|
-
"""
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
)
|
675
|
-
|
676
|
-
|
677
|
-
|
678
|
-
|
679
|
-
|
680
|
-
|
681
|
-
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
|
687
|
-
|
688
|
-
|
689
|
-
|
690
|
-
|
691
|
-
|
692
|
-
|
693
|
-
|
694
|
-
|
695
|
-
|
696
|
-
|
697
|
-
|
698
|
-
|
699
|
-
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
|
704
|
-
|
705
|
-
|
706
|
-
|
707
|
-
|
708
|
-
|
709
|
-
|
710
|
-
|
711
|
-
|
712
|
-
|
713
|
-
|
714
|
-
|
715
|
-
|
716
|
-
|
717
|
-
self.
|
718
|
-
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
|
724
|
-
|
725
|
-
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
self
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
|
744
|
-
|
745
|
-
|
746
|
-
|
747
|
-
|
748
|
-
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
755
|
-
|
756
|
-
|
757
|
-
|
758
|
-
|
759
|
-
|
760
|
-
|
761
|
-
|
762
|
-
|
763
|
-
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
|
770
|
-
|
771
|
-
|
772
|
-
|
773
|
-
|
774
|
-
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
|
779
|
-
|
780
|
-
|
781
|
-
|
782
|
-
|
783
|
-
|
784
|
-
)
|
785
|
-
|
786
|
-
|
787
|
-
|
788
|
-
|
789
|
-
|
790
|
-
|
791
|
-
|
792
|
-
|
793
|
-
|
794
|
-
|
795
|
-
|
796
|
-
return
|
797
|
-
|
798
|
-
|
799
|
-
|
800
|
-
|
801
|
-
|
802
|
-
|
803
|
-
|
804
|
-
|
805
|
-
|
806
|
-
|
807
|
-
|
808
|
-
|
809
|
-
|
810
|
-
|
811
|
-
|
812
|
-
|
813
|
-
|
814
|
-
|
815
|
-
|
816
|
-
|
817
|
-
|
818
|
-
|
819
|
-
|
820
|
-
|
821
|
-
|
822
|
-
|
823
|
-
|
824
|
-
|
825
|
-
|
826
|
-
|
827
|
-
|
828
|
-
|
829
|
-
|
830
|
-
|
831
|
-
|
832
|
-
|
833
|
-
|
834
|
-
|
835
|
-
|
836
|
-
|
837
|
-
|
838
|
-
|
839
|
-
|
840
|
-
|
841
|
-
|
842
|
-
|
843
|
-
|
844
|
-
|
845
|
-
|
846
|
-
|
847
|
-
|
848
|
-
|
849
|
-
|
850
|
-
|
851
|
-
|
852
|
-
|
853
|
-
|
854
|
-
|
855
|
-
|
856
|
-
|
857
|
-
|
858
|
-
|
859
|
-
|
860
|
-
|
861
|
-
|
862
|
-
|
863
|
-
|
864
|
-
|
865
|
-
|
866
|
-
|
867
|
-
|
868
|
-
|
869
|
-
|
870
|
-
|
871
|
-
|
872
|
-
|
873
|
-
|
874
|
-
|
875
|
-
|
876
|
-
|
877
|
-
|
878
|
-
|
879
|
-
|
880
|
-
|
881
|
-
|
882
|
-
|
883
|
-
|
884
|
-
|
885
|
-
|
886
|
-
|
887
|
-
|
888
|
-
|
889
|
-
|
890
|
-
|
891
|
-
|
892
|
-
|
893
|
-
|
894
|
-
|
895
|
-
|
896
|
-
|
897
|
-
|
898
|
-
|
899
|
-
|
900
|
-
|
901
|
-
|
902
|
-
|
903
|
-
|
904
|
-
|
905
|
-
|
906
|
-
|
907
|
-
|
908
|
-
|
909
|
-
|
910
|
-
|
911
|
-
|
912
|
-
|
913
|
-
|
914
|
-
|
915
|
-
|
916
|
-
|
917
|
-
|
918
|
-
|
919
|
-
|
920
|
-
|
921
|
-
|
922
|
-
|
923
|
-
|
924
|
-
|
925
|
-
|
926
|
-
|
927
|
-
|
928
|
-
|
929
|
-
|
930
|
-
|
931
|
-
|
932
|
-
|
933
|
-
|
934
|
-
|
935
|
-
def
|
936
|
-
|
937
|
-
|
938
|
-
|
939
|
-
|
940
|
-
|
941
|
-
|
942
|
-
|
943
|
-
|
944
|
-
|
945
|
-
|
946
|
-
|
947
|
-
|
948
|
-
|
949
|
-
|
950
|
-
|
951
|
-
|
952
|
-
|
953
|
-
|
954
|
-
|
955
|
-
|
956
|
-
|
957
|
-
|
958
|
-
|
959
|
-
|
960
|
-
|
961
|
-
|
962
|
-
|
963
|
-
|
964
|
-
|
965
|
-
|
966
|
-
|
967
|
-
|
968
|
-
|
969
|
-
|
970
|
-
|
971
|
-
|
972
|
-
|
973
|
-
|
974
|
-
|
975
|
-
|
976
|
-
|
977
|
-
|
978
|
-
|
979
|
-
|
980
|
-
|
981
|
-
|
982
|
-
|
983
|
-
|
984
|
-
|
985
|
-
|
986
|
-
|
987
|
-
|
988
|
-
|
989
|
-
|
990
|
-
|
991
|
-
|
992
|
-
|
993
|
-
|
994
|
-
|
995
|
-
|
996
|
-
|
997
|
-
|
998
|
-
|
999
|
-
|
1000
|
-
|
1001
|
-
|
1002
|
-
|
1003
|
-
|
1004
|
-
|
1005
|
-
|
1006
|
-
|
1007
|
-
|
1008
|
-
|
1009
|
-
|
1010
|
-
|
1011
|
-
|
1012
|
-
|
1013
|
-
|
1014
|
-
|
1015
|
-
|
1016
|
-
|
1017
|
-
|
1018
|
-
|
1019
|
-
|
1020
|
-
|
1021
|
-
|
1022
|
-
|
1023
|
-
|
1024
|
-
|
1025
|
-
|
1026
|
-
|
1027
|
-
|
1028
|
-
|
1029
|
-
|
1030
|
-
|
1031
|
-
|
1032
|
-
if
|
1033
|
-
|
1034
|
-
|
1035
|
-
|
1036
|
-
|
1037
|
-
|
1038
|
-
|
1039
|
-
|
1040
|
-
|
1041
|
-
|
1042
|
-
|
1043
|
-
|
1044
|
-
|
1045
|
-
|
1046
|
-
|
1047
|
-
|
1048
|
-
|
1049
|
-
|
1050
|
-
|
1051
|
-
|
1052
|
-
|
1053
|
-
|
1054
|
-
|
1055
|
-
|
1056
|
-
|
1057
|
-
|
1058
|
-
|
1059
|
-
|
1060
|
-
|
1061
|
-
|
1062
|
-
|
1063
|
-
|
1064
|
-
|
1065
|
-
|
1066
|
-
|
1067
|
-
|
1068
|
-
|
1069
|
-
|
1070
|
-
|
1071
|
-
|
1072
|
-
|
1073
|
-
|
1074
|
-
|
1075
|
-
|
1076
|
-
|
1077
|
-
|
1078
|
-
|
1079
|
-
|
1080
|
-
|
1081
|
-
|
1082
|
-
|
1083
|
-
|
1084
|
-
|
1085
|
-
|
1086
|
-
|
1087
|
-
|
1088
|
-
|
1089
|
-
|
1090
|
-
|
1091
|
-
|
1092
|
-
|
1093
|
-
|
1094
|
-
|
1095
|
-
|
1096
|
-
|
1097
|
-
|
1098
|
-
|
1099
|
-
|
1100
|
-
|
1101
|
-
|
1102
|
-
|
1103
|
-
|
1104
|
-
|
1105
|
-
|
1106
|
-
|
1107
|
-
|
1108
|
-
|
1109
|
-
|
1110
|
-
|
1111
|
-
|
1112
|
-
|
1113
|
-
|
1114
|
-
|
1115
|
-
|
1116
|
-
|
1117
|
-
|
1118
|
-
|
1119
|
-
|
1120
|
-
|
1121
|
-
|
1
|
+
# """The Jobs class is a collection of agents, scenarios and models and one survey."""
|
2
|
+
from __future__ import annotations
|
3
|
+
import warnings
|
4
|
+
import requests
|
5
|
+
from itertools import product
|
6
|
+
from typing import Literal, Optional, Union, Sequence, Generator
|
7
|
+
|
8
|
+
from edsl.Base import Base
|
9
|
+
|
10
|
+
from edsl.exceptions import MissingAPIKeyError
|
11
|
+
from edsl.jobs.buckets.BucketCollection import BucketCollection
|
12
|
+
from edsl.jobs.interviews.Interview import Interview
|
13
|
+
from edsl.jobs.runners.JobsRunnerAsyncio import JobsRunnerAsyncio
|
14
|
+
from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
|
15
|
+
|
16
|
+
from edsl.data.RemoteCacheSync import RemoteCacheSync
|
17
|
+
from edsl.exceptions.coop import CoopServerResponseError
|
18
|
+
|
19
|
+
|
20
|
+
class Jobs(Base):
|
21
|
+
"""
|
22
|
+
A collection of agents, scenarios and models and one survey.
|
23
|
+
The actual running of a job is done by a `JobsRunner`, which is a subclass of `JobsRunner`.
|
24
|
+
The `JobsRunner` is chosen by the user, and is stored in the `jobs_runner_name` attribute.
|
25
|
+
"""
|
26
|
+
|
27
|
+
def __init__(
|
28
|
+
self,
|
29
|
+
survey: "Survey",
|
30
|
+
agents: Optional[list["Agent"]] = None,
|
31
|
+
models: Optional[list["LanguageModel"]] = None,
|
32
|
+
scenarios: Optional[list["Scenario"]] = None,
|
33
|
+
):
|
34
|
+
"""Initialize a Jobs instance.
|
35
|
+
|
36
|
+
:param survey: the survey to be used in the job
|
37
|
+
:param agents: a list of agents
|
38
|
+
:param models: a list of models
|
39
|
+
:param scenarios: a list of scenarios
|
40
|
+
"""
|
41
|
+
self.survey = survey
|
42
|
+
self.agents: "AgentList" = agents
|
43
|
+
self.scenarios: "ScenarioList" = scenarios
|
44
|
+
self.models = models
|
45
|
+
|
46
|
+
self.__bucket_collection = None
|
47
|
+
|
48
|
+
# these setters and getters are used to ensure that the agents, models, and scenarios are stored as AgentList, ModelList, and ScenarioList objects
|
49
|
+
|
50
|
+
@property
|
51
|
+
def models(self):
|
52
|
+
return self._models
|
53
|
+
|
54
|
+
@models.setter
|
55
|
+
def models(self, value):
|
56
|
+
from edsl import ModelList
|
57
|
+
|
58
|
+
if value:
|
59
|
+
if not isinstance(value, ModelList):
|
60
|
+
self._models = ModelList(value)
|
61
|
+
else:
|
62
|
+
self._models = value
|
63
|
+
else:
|
64
|
+
self._models = ModelList([])
|
65
|
+
|
66
|
+
@property
|
67
|
+
def agents(self):
|
68
|
+
return self._agents
|
69
|
+
|
70
|
+
@agents.setter
|
71
|
+
def agents(self, value):
|
72
|
+
from edsl import AgentList
|
73
|
+
|
74
|
+
if value:
|
75
|
+
if not isinstance(value, AgentList):
|
76
|
+
self._agents = AgentList(value)
|
77
|
+
else:
|
78
|
+
self._agents = value
|
79
|
+
else:
|
80
|
+
self._agents = AgentList([])
|
81
|
+
|
82
|
+
@property
|
83
|
+
def scenarios(self):
|
84
|
+
return self._scenarios
|
85
|
+
|
86
|
+
@scenarios.setter
|
87
|
+
def scenarios(self, value):
|
88
|
+
from edsl import ScenarioList
|
89
|
+
|
90
|
+
if value:
|
91
|
+
if not isinstance(value, ScenarioList):
|
92
|
+
self._scenarios = ScenarioList(value)
|
93
|
+
else:
|
94
|
+
self._scenarios = value
|
95
|
+
else:
|
96
|
+
self._scenarios = ScenarioList([])
|
97
|
+
|
98
|
+
def by(
|
99
|
+
self,
|
100
|
+
*args: Union[
|
101
|
+
"Agent",
|
102
|
+
"Scenario",
|
103
|
+
"LanguageModel",
|
104
|
+
Sequence[Union["Agent", "Scenario", "LanguageModel"]],
|
105
|
+
],
|
106
|
+
) -> Jobs:
|
107
|
+
"""
|
108
|
+
Add Agents, Scenarios and LanguageModels to a job. If no objects of this type exist in the Jobs instance, it stores the new objects as a list in the corresponding attribute. Otherwise, it combines the new objects with existing objects using the object's `__add__` method.
|
109
|
+
|
110
|
+
This 'by' is intended to create a fluent interface.
|
111
|
+
|
112
|
+
>>> from edsl import Survey
|
113
|
+
>>> from edsl import QuestionFreeText
|
114
|
+
>>> q = QuestionFreeText(question_name="name", question_text="What is your name?")
|
115
|
+
>>> j = Jobs(survey = Survey(questions=[q]))
|
116
|
+
>>> j
|
117
|
+
Jobs(survey=Survey(...), agents=AgentList([]), models=ModelList([]), scenarios=ScenarioList([]))
|
118
|
+
>>> from edsl import Agent; a = Agent(traits = {"status": "Sad"})
|
119
|
+
>>> j.by(a).agents
|
120
|
+
AgentList([Agent(traits = {'status': 'Sad'})])
|
121
|
+
|
122
|
+
:param args: objects or a sequence (list, tuple, ...) of objects of the same type
|
123
|
+
|
124
|
+
Notes:
|
125
|
+
- all objects must implement the 'get_value', 'set_value', and `__add__` methods
|
126
|
+
- agents: traits of new agents are combined with traits of existing agents. New and existing agents should not have overlapping traits, and do not increase the # agents in the instance
|
127
|
+
- scenarios: traits of new scenarios are combined with traits of old existing. New scenarios will overwrite overlapping traits, and do not increase the number of scenarios in the instance
|
128
|
+
- models: new models overwrite old models.
|
129
|
+
"""
|
130
|
+
passed_objects = self._turn_args_to_list(
|
131
|
+
args
|
132
|
+
) # objects can also be passed comma-separated
|
133
|
+
|
134
|
+
current_objects, objects_key = self._get_current_objects_of_this_type(
|
135
|
+
passed_objects[0]
|
136
|
+
)
|
137
|
+
|
138
|
+
if not current_objects:
|
139
|
+
new_objects = passed_objects
|
140
|
+
else:
|
141
|
+
new_objects = self._merge_objects(passed_objects, current_objects)
|
142
|
+
|
143
|
+
setattr(self, objects_key, new_objects) # update the job
|
144
|
+
return self
|
145
|
+
|
146
|
+
def prompts(self) -> "Dataset":
|
147
|
+
"""Return a Dataset of prompts that will be used.
|
148
|
+
|
149
|
+
|
150
|
+
>>> from edsl.jobs import Jobs
|
151
|
+
>>> Jobs.example().prompts()
|
152
|
+
Dataset(...)
|
153
|
+
"""
|
154
|
+
from edsl import Coop
|
155
|
+
|
156
|
+
c = Coop()
|
157
|
+
price_lookup = c.fetch_prices()
|
158
|
+
|
159
|
+
interviews = self.interviews()
|
160
|
+
# data = []
|
161
|
+
interview_indices = []
|
162
|
+
question_names = []
|
163
|
+
user_prompts = []
|
164
|
+
system_prompts = []
|
165
|
+
scenario_indices = []
|
166
|
+
agent_indices = []
|
167
|
+
models = []
|
168
|
+
costs = []
|
169
|
+
from edsl.results.Dataset import Dataset
|
170
|
+
|
171
|
+
for interview_index, interview in enumerate(interviews):
|
172
|
+
invigilators = [
|
173
|
+
interview._get_invigilator(question)
|
174
|
+
for question in self.survey.questions
|
175
|
+
]
|
176
|
+
for _, invigilator in enumerate(invigilators):
|
177
|
+
prompts = invigilator.get_prompts()
|
178
|
+
user_prompt = prompts["user_prompt"]
|
179
|
+
system_prompt = prompts["system_prompt"]
|
180
|
+
user_prompts.append(user_prompt)
|
181
|
+
system_prompts.append(system_prompt)
|
182
|
+
agent_index = self.agents.index(invigilator.agent)
|
183
|
+
agent_indices.append(agent_index)
|
184
|
+
interview_indices.append(interview_index)
|
185
|
+
scenario_index = self.scenarios.index(invigilator.scenario)
|
186
|
+
scenario_indices.append(scenario_index)
|
187
|
+
models.append(invigilator.model.model)
|
188
|
+
question_names.append(invigilator.question.question_name)
|
189
|
+
|
190
|
+
prompt_cost = self.estimate_prompt_cost(
|
191
|
+
system_prompt=system_prompt,
|
192
|
+
user_prompt=user_prompt,
|
193
|
+
price_lookup=price_lookup,
|
194
|
+
inference_service=invigilator.model._inference_service_,
|
195
|
+
model=invigilator.model.model,
|
196
|
+
)
|
197
|
+
costs.append(prompt_cost["cost_usd"])
|
198
|
+
|
199
|
+
d = Dataset(
|
200
|
+
[
|
201
|
+
{"user_prompt": user_prompts},
|
202
|
+
{"system_prompt": system_prompts},
|
203
|
+
{"interview_index": interview_indices},
|
204
|
+
{"question_name": question_names},
|
205
|
+
{"scenario_index": scenario_indices},
|
206
|
+
{"agent_index": agent_indices},
|
207
|
+
{"model": models},
|
208
|
+
{"estimated_cost": costs},
|
209
|
+
]
|
210
|
+
)
|
211
|
+
return d
|
212
|
+
|
213
|
+
def show_prompts(self, all=False, max_rows: Optional[int] = None) -> None:
|
214
|
+
"""Print the prompts."""
|
215
|
+
if all:
|
216
|
+
self.prompts().to_scenario_list().print(format="rich", max_rows=max_rows)
|
217
|
+
else:
|
218
|
+
self.prompts().select(
|
219
|
+
"user_prompt", "system_prompt"
|
220
|
+
).to_scenario_list().print(format="rich", max_rows=max_rows)
|
221
|
+
|
222
|
+
@staticmethod
|
223
|
+
def estimate_prompt_cost(
|
224
|
+
system_prompt: str,
|
225
|
+
user_prompt: str,
|
226
|
+
price_lookup: dict,
|
227
|
+
inference_service: str,
|
228
|
+
model: str,
|
229
|
+
) -> dict:
|
230
|
+
"""Estimates the cost of a prompt. Takes piping into account."""
|
231
|
+
import math
|
232
|
+
|
233
|
+
def get_piping_multiplier(prompt: str):
|
234
|
+
"""Returns 2 if a prompt includes Jinja braces, and 1 otherwise."""
|
235
|
+
|
236
|
+
if "{{" in prompt and "}}" in prompt:
|
237
|
+
return 2
|
238
|
+
return 1
|
239
|
+
|
240
|
+
# Look up prices per token
|
241
|
+
key = (inference_service, model)
|
242
|
+
|
243
|
+
try:
|
244
|
+
relevant_prices = price_lookup[key]
|
245
|
+
|
246
|
+
service_input_token_price = float(
|
247
|
+
relevant_prices["input"]["service_stated_token_price"]
|
248
|
+
)
|
249
|
+
service_input_token_qty = float(
|
250
|
+
relevant_prices["input"]["service_stated_token_qty"]
|
251
|
+
)
|
252
|
+
input_price_per_token = service_input_token_price / service_input_token_qty
|
253
|
+
|
254
|
+
service_output_token_price = float(
|
255
|
+
relevant_prices["output"]["service_stated_token_price"]
|
256
|
+
)
|
257
|
+
service_output_token_qty = float(
|
258
|
+
relevant_prices["output"]["service_stated_token_qty"]
|
259
|
+
)
|
260
|
+
output_price_per_token = (
|
261
|
+
service_output_token_price / service_output_token_qty
|
262
|
+
)
|
263
|
+
|
264
|
+
except KeyError:
|
265
|
+
# A KeyError is likely to occur if we cannot retrieve prices (the price_lookup dict is empty)
|
266
|
+
# Use a sensible default
|
267
|
+
|
268
|
+
import warnings
|
269
|
+
|
270
|
+
warnings.warn(
|
271
|
+
"Price data could not be retrieved. Using default estimates for input and output token prices. Input: $0.15 / 1M tokens; Output: $0.60 / 1M tokens"
|
272
|
+
)
|
273
|
+
input_price_per_token = 0.00000015 # $0.15 / 1M tokens
|
274
|
+
output_price_per_token = 0.00000060 # $0.60 / 1M tokens
|
275
|
+
|
276
|
+
# Compute the number of characters (double if the question involves piping)
|
277
|
+
user_prompt_chars = len(str(user_prompt)) * get_piping_multiplier(
|
278
|
+
str(user_prompt)
|
279
|
+
)
|
280
|
+
system_prompt_chars = len(str(system_prompt)) * get_piping_multiplier(
|
281
|
+
str(system_prompt)
|
282
|
+
)
|
283
|
+
|
284
|
+
# Convert into tokens (1 token approx. equals 4 characters)
|
285
|
+
input_tokens = (user_prompt_chars + system_prompt_chars) // 4
|
286
|
+
|
287
|
+
output_tokens = math.ceil(0.75 * input_tokens)
|
288
|
+
|
289
|
+
cost = (
|
290
|
+
input_tokens * input_price_per_token
|
291
|
+
+ output_tokens * output_price_per_token
|
292
|
+
)
|
293
|
+
|
294
|
+
return {
|
295
|
+
"input_tokens": input_tokens,
|
296
|
+
"output_tokens": output_tokens,
|
297
|
+
"cost_usd": cost,
|
298
|
+
}
|
299
|
+
|
300
|
+
def estimate_job_cost_from_external_prices(
|
301
|
+
self, price_lookup: dict, iterations: int = 1
|
302
|
+
) -> dict:
|
303
|
+
"""
|
304
|
+
Estimates the cost of a job according to the following assumptions:
|
305
|
+
|
306
|
+
- 1 token = 4 characters.
|
307
|
+
- For each prompt, output tokens = input tokens * 0.75, rounded up to the nearest integer.
|
308
|
+
|
309
|
+
price_lookup is an external pricing dictionary.
|
310
|
+
"""
|
311
|
+
|
312
|
+
import pandas as pd
|
313
|
+
|
314
|
+
interviews = self.interviews()
|
315
|
+
data = []
|
316
|
+
for interview in interviews:
|
317
|
+
invigilators = [
|
318
|
+
interview._get_invigilator(question)
|
319
|
+
for question in self.survey.questions
|
320
|
+
]
|
321
|
+
for invigilator in invigilators:
|
322
|
+
prompts = invigilator.get_prompts()
|
323
|
+
|
324
|
+
# By this point, agent and scenario data has already been added to the prompts
|
325
|
+
user_prompt = prompts["user_prompt"]
|
326
|
+
system_prompt = prompts["system_prompt"]
|
327
|
+
inference_service = invigilator.model._inference_service_
|
328
|
+
model = invigilator.model.model
|
329
|
+
|
330
|
+
prompt_cost = self.estimate_prompt_cost(
|
331
|
+
system_prompt=system_prompt,
|
332
|
+
user_prompt=user_prompt,
|
333
|
+
price_lookup=price_lookup,
|
334
|
+
inference_service=inference_service,
|
335
|
+
model=model,
|
336
|
+
)
|
337
|
+
|
338
|
+
data.append(
|
339
|
+
{
|
340
|
+
"user_prompt": user_prompt,
|
341
|
+
"system_prompt": system_prompt,
|
342
|
+
"estimated_input_tokens": prompt_cost["input_tokens"],
|
343
|
+
"estimated_output_tokens": prompt_cost["output_tokens"],
|
344
|
+
"estimated_cost_usd": prompt_cost["cost_usd"],
|
345
|
+
"inference_service": inference_service,
|
346
|
+
"model": model,
|
347
|
+
}
|
348
|
+
)
|
349
|
+
|
350
|
+
df = pd.DataFrame.from_records(data)
|
351
|
+
|
352
|
+
df = (
|
353
|
+
df.groupby(["inference_service", "model"])
|
354
|
+
.agg(
|
355
|
+
{
|
356
|
+
"estimated_cost_usd": "sum",
|
357
|
+
"estimated_input_tokens": "sum",
|
358
|
+
"estimated_output_tokens": "sum",
|
359
|
+
}
|
360
|
+
)
|
361
|
+
.reset_index()
|
362
|
+
)
|
363
|
+
df["estimated_cost_usd"] = df["estimated_cost_usd"] * iterations
|
364
|
+
df["estimated_input_tokens"] = df["estimated_input_tokens"] * iterations
|
365
|
+
df["estimated_output_tokens"] = df["estimated_output_tokens"] * iterations
|
366
|
+
|
367
|
+
estimated_costs_by_model = df.to_dict("records")
|
368
|
+
|
369
|
+
estimated_total_cost = sum(
|
370
|
+
model["estimated_cost_usd"] for model in estimated_costs_by_model
|
371
|
+
)
|
372
|
+
estimated_total_input_tokens = sum(
|
373
|
+
model["estimated_input_tokens"] for model in estimated_costs_by_model
|
374
|
+
)
|
375
|
+
estimated_total_output_tokens = sum(
|
376
|
+
model["estimated_output_tokens"] for model in estimated_costs_by_model
|
377
|
+
)
|
378
|
+
|
379
|
+
output = {
|
380
|
+
"estimated_total_cost_usd": estimated_total_cost,
|
381
|
+
"estimated_total_input_tokens": estimated_total_input_tokens,
|
382
|
+
"estimated_total_output_tokens": estimated_total_output_tokens,
|
383
|
+
"model_costs": estimated_costs_by_model,
|
384
|
+
}
|
385
|
+
|
386
|
+
return output
|
387
|
+
|
388
|
+
def estimate_job_cost(self, iterations: int = 1) -> dict:
|
389
|
+
"""
|
390
|
+
Estimates the cost of a job according to the following assumptions:
|
391
|
+
|
392
|
+
- 1 token = 4 characters.
|
393
|
+
- For each prompt, output tokens = input tokens * 0.75, rounded up to the nearest integer.
|
394
|
+
|
395
|
+
Fetches prices from Coop.
|
396
|
+
"""
|
397
|
+
from edsl import Coop
|
398
|
+
|
399
|
+
c = Coop()
|
400
|
+
price_lookup = c.fetch_prices()
|
401
|
+
|
402
|
+
return self.estimate_job_cost_from_external_prices(
|
403
|
+
price_lookup=price_lookup, iterations=iterations
|
404
|
+
)
|
405
|
+
|
406
|
+
@staticmethod
|
407
|
+
def compute_job_cost(job_results: "Results") -> float:
|
408
|
+
"""
|
409
|
+
Computes the cost of a completed job in USD.
|
410
|
+
"""
|
411
|
+
total_cost = 0
|
412
|
+
for result in job_results:
|
413
|
+
for key in result.raw_model_response:
|
414
|
+
if key.endswith("_cost"):
|
415
|
+
result_cost = result.raw_model_response[key]
|
416
|
+
|
417
|
+
question_name = key.removesuffix("_cost")
|
418
|
+
cache_used = result.cache_used_dict[question_name]
|
419
|
+
|
420
|
+
if isinstance(result_cost, (int, float)) and not cache_used:
|
421
|
+
total_cost += result_cost
|
422
|
+
|
423
|
+
return total_cost
|
424
|
+
|
425
|
+
@staticmethod
|
426
|
+
def _get_container_class(object):
|
427
|
+
from edsl.agents.AgentList import AgentList
|
428
|
+
from edsl.agents.Agent import Agent
|
429
|
+
from edsl.scenarios.Scenario import Scenario
|
430
|
+
from edsl.scenarios.ScenarioList import ScenarioList
|
431
|
+
from edsl.language_models.ModelList import ModelList
|
432
|
+
|
433
|
+
if isinstance(object, Agent):
|
434
|
+
return AgentList
|
435
|
+
elif isinstance(object, Scenario):
|
436
|
+
return ScenarioList
|
437
|
+
elif isinstance(object, ModelList):
|
438
|
+
return ModelList
|
439
|
+
else:
|
440
|
+
return list
|
441
|
+
|
442
|
+
@staticmethod
|
443
|
+
def _turn_args_to_list(args):
|
444
|
+
"""Return a list of the first argument if it is a sequence, otherwise returns a list of all the arguments.
|
445
|
+
|
446
|
+
Example:
|
447
|
+
|
448
|
+
>>> Jobs._turn_args_to_list([1,2,3])
|
449
|
+
[1, 2, 3]
|
450
|
+
|
451
|
+
"""
|
452
|
+
|
453
|
+
def did_user_pass_a_sequence(args):
|
454
|
+
"""Return True if the user passed a sequence, False otherwise.
|
455
|
+
|
456
|
+
Example:
|
457
|
+
|
458
|
+
>>> did_user_pass_a_sequence([1,2,3])
|
459
|
+
True
|
460
|
+
|
461
|
+
>>> did_user_pass_a_sequence(1)
|
462
|
+
False
|
463
|
+
"""
|
464
|
+
return len(args) == 1 and isinstance(args[0], Sequence)
|
465
|
+
|
466
|
+
if did_user_pass_a_sequence(args):
|
467
|
+
container_class = Jobs._get_container_class(args[0][0])
|
468
|
+
return container_class(args[0])
|
469
|
+
else:
|
470
|
+
container_class = Jobs._get_container_class(args[0])
|
471
|
+
return container_class(args)
|
472
|
+
|
473
|
+
def _get_current_objects_of_this_type(
|
474
|
+
self, object: Union["Agent", "Scenario", "LanguageModel"]
|
475
|
+
) -> tuple[list, str]:
|
476
|
+
from edsl.agents.Agent import Agent
|
477
|
+
from edsl.scenarios.Scenario import Scenario
|
478
|
+
from edsl.language_models.LanguageModel import LanguageModel
|
479
|
+
|
480
|
+
"""Return the current objects of the same type as the first argument.
|
481
|
+
|
482
|
+
>>> from edsl.jobs import Jobs
|
483
|
+
>>> j = Jobs.example()
|
484
|
+
>>> j._get_current_objects_of_this_type(j.agents[0])
|
485
|
+
(AgentList([Agent(traits = {'status': 'Joyful'}), Agent(traits = {'status': 'Sad'})]), 'agents')
|
486
|
+
"""
|
487
|
+
class_to_key = {
|
488
|
+
Agent: "agents",
|
489
|
+
Scenario: "scenarios",
|
490
|
+
LanguageModel: "models",
|
491
|
+
}
|
492
|
+
for class_type in class_to_key:
|
493
|
+
if isinstance(object, class_type) or issubclass(
|
494
|
+
object.__class__, class_type
|
495
|
+
):
|
496
|
+
key = class_to_key[class_type]
|
497
|
+
break
|
498
|
+
else:
|
499
|
+
raise ValueError(
|
500
|
+
f"First argument must be an Agent, Scenario, or LanguageModel, not {object}"
|
501
|
+
)
|
502
|
+
current_objects = getattr(self, key, None)
|
503
|
+
return current_objects, key
|
504
|
+
|
505
|
+
@staticmethod
|
506
|
+
def _get_empty_container_object(object):
|
507
|
+
from edsl import AgentList
|
508
|
+
from edsl import Agent
|
509
|
+
from edsl import Scenario
|
510
|
+
from edsl import ScenarioList
|
511
|
+
|
512
|
+
if isinstance(object, Agent):
|
513
|
+
return AgentList([])
|
514
|
+
elif isinstance(object, Scenario):
|
515
|
+
return ScenarioList([])
|
516
|
+
else:
|
517
|
+
return []
|
518
|
+
|
519
|
+
@staticmethod
|
520
|
+
def _merge_objects(passed_objects, current_objects) -> list:
|
521
|
+
"""
|
522
|
+
Combine all the existing objects with the new objects.
|
523
|
+
|
524
|
+
For example, if the user passes in 3 agents,
|
525
|
+
and there are 2 existing agents, this will create 6 new agents
|
526
|
+
|
527
|
+
>>> Jobs(survey = [])._merge_objects([1,2,3], [4,5,6])
|
528
|
+
[5, 6, 7, 6, 7, 8, 7, 8, 9]
|
529
|
+
"""
|
530
|
+
new_objects = Jobs._get_empty_container_object(passed_objects[0])
|
531
|
+
for current_object in current_objects:
|
532
|
+
for new_object in passed_objects:
|
533
|
+
new_objects.append(current_object + new_object)
|
534
|
+
return new_objects
|
535
|
+
|
536
|
+
def interviews(self) -> list[Interview]:
|
537
|
+
"""
|
538
|
+
Return a list of :class:`edsl.jobs.interviews.Interview` objects.
|
539
|
+
|
540
|
+
It returns one Interview for each combination of Agent, Scenario, and LanguageModel.
|
541
|
+
If any of Agents, Scenarios, or LanguageModels are missing, it fills in with defaults.
|
542
|
+
|
543
|
+
>>> from edsl.jobs import Jobs
|
544
|
+
>>> j = Jobs.example()
|
545
|
+
>>> len(j.interviews())
|
546
|
+
4
|
547
|
+
>>> j.interviews()[0]
|
548
|
+
Interview(agent = Agent(traits = {'status': 'Joyful'}), survey = Survey(...), scenario = Scenario({'period': 'morning'}), model = Model(...))
|
549
|
+
"""
|
550
|
+
if hasattr(self, "_interviews"):
|
551
|
+
return self._interviews
|
552
|
+
else:
|
553
|
+
return list(self._create_interviews())
|
554
|
+
|
555
|
+
@classmethod
|
556
|
+
def from_interviews(cls, interview_list):
|
557
|
+
"""Return a Jobs instance from a list of interviews.
|
558
|
+
|
559
|
+
This is useful when you have, say, a list of failed interviews and you want to create
|
560
|
+
a new job with only those interviews.
|
561
|
+
"""
|
562
|
+
survey = interview_list[0].survey
|
563
|
+
# get all the models
|
564
|
+
models = list(set([interview.model for interview in interview_list]))
|
565
|
+
jobs = cls(survey)
|
566
|
+
jobs.models = models
|
567
|
+
jobs._interviews = interview_list
|
568
|
+
return jobs
|
569
|
+
|
570
|
+
def _create_interviews(self) -> Generator[Interview, None, None]:
|
571
|
+
"""
|
572
|
+
Generate interviews.
|
573
|
+
|
574
|
+
Note that this sets the agents, model and scenarios if they have not been set. This is a side effect of the method.
|
575
|
+
This is useful because a user can create a job without setting the agents, models, or scenarios, and the job will still run,
|
576
|
+
with us filling in defaults.
|
577
|
+
|
578
|
+
|
579
|
+
"""
|
580
|
+
# if no agents, models, or scenarios are set, set them to defaults
|
581
|
+
from edsl.agents.Agent import Agent
|
582
|
+
from edsl.language_models.registry import Model
|
583
|
+
from edsl.scenarios.Scenario import Scenario
|
584
|
+
|
585
|
+
self.agents = self.agents or [Agent()]
|
586
|
+
self.models = self.models or [Model()]
|
587
|
+
self.scenarios = self.scenarios or [Scenario()]
|
588
|
+
for agent, scenario, model in product(self.agents, self.scenarios, self.models):
|
589
|
+
yield Interview(
|
590
|
+
survey=self.survey,
|
591
|
+
agent=agent,
|
592
|
+
scenario=scenario,
|
593
|
+
model=model,
|
594
|
+
skip_retry=self.skip_retry,
|
595
|
+
raise_validation_errors=self.raise_validation_errors,
|
596
|
+
)
|
597
|
+
|
598
|
+
def create_bucket_collection(self) -> BucketCollection:
|
599
|
+
"""
|
600
|
+
Create a collection of buckets for each model.
|
601
|
+
|
602
|
+
These buckets are used to track API calls and token usage.
|
603
|
+
|
604
|
+
>>> from edsl.jobs import Jobs
|
605
|
+
>>> from edsl import Model
|
606
|
+
>>> j = Jobs.example().by(Model(temperature = 1), Model(temperature = 0.5))
|
607
|
+
>>> bc = j.create_bucket_collection()
|
608
|
+
>>> bc
|
609
|
+
BucketCollection(...)
|
610
|
+
"""
|
611
|
+
bucket_collection = BucketCollection()
|
612
|
+
for model in self.models:
|
613
|
+
bucket_collection.add_model(model)
|
614
|
+
return bucket_collection
|
615
|
+
|
616
|
+
@property
|
617
|
+
def bucket_collection(self) -> BucketCollection:
|
618
|
+
"""Return the bucket collection. If it does not exist, create it."""
|
619
|
+
if self.__bucket_collection is None:
|
620
|
+
self.__bucket_collection = self.create_bucket_collection()
|
621
|
+
return self.__bucket_collection
|
622
|
+
|
623
|
+
def html(self):
|
624
|
+
"""Return the HTML representations for each scenario"""
|
625
|
+
links = []
|
626
|
+
for index, scenario in enumerate(self.scenarios):
|
627
|
+
links.append(
|
628
|
+
self.survey.html(
|
629
|
+
scenario=scenario, return_link=True, cta=f"Scenario {index}"
|
630
|
+
)
|
631
|
+
)
|
632
|
+
return links
|
633
|
+
|
634
|
+
def __hash__(self):
|
635
|
+
"""Allow the model to be used as a key in a dictionary.
|
636
|
+
|
637
|
+
>>> from edsl.jobs import Jobs
|
638
|
+
>>> hash(Jobs.example())
|
639
|
+
846655441787442972
|
640
|
+
|
641
|
+
"""
|
642
|
+
from edsl.utilities.utilities import dict_hash
|
643
|
+
|
644
|
+
return dict_hash(self._to_dict())
|
645
|
+
|
646
|
+
def _output(self, message) -> None:
|
647
|
+
"""Check if a Job is verbose. If so, print the message."""
|
648
|
+
if hasattr(self, "verbose") and self.verbose:
|
649
|
+
print(message)
|
650
|
+
|
651
|
+
def _check_parameters(self, strict=False, warn=False) -> None:
|
652
|
+
"""Check if the parameters in the survey and scenarios are consistent.
|
653
|
+
|
654
|
+
>>> from edsl import QuestionFreeText
|
655
|
+
>>> from edsl import Survey
|
656
|
+
>>> from edsl import Scenario
|
657
|
+
>>> q = QuestionFreeText(question_text = "{{poo}}", question_name = "ugly_question")
|
658
|
+
>>> j = Jobs(survey = Survey(questions=[q]))
|
659
|
+
>>> with warnings.catch_warnings(record=True) as w:
|
660
|
+
... j._check_parameters(warn = True)
|
661
|
+
... assert len(w) == 1
|
662
|
+
... assert issubclass(w[-1].category, UserWarning)
|
663
|
+
... assert "The following parameters are in the survey but not in the scenarios" in str(w[-1].message)
|
664
|
+
|
665
|
+
>>> q = QuestionFreeText(question_text = "{{poo}}", question_name = "ugly_question")
|
666
|
+
>>> s = Scenario({'plop': "A", 'poo': "B"})
|
667
|
+
>>> j = Jobs(survey = Survey(questions=[q])).by(s)
|
668
|
+
>>> j._check_parameters(strict = True)
|
669
|
+
Traceback (most recent call last):
|
670
|
+
...
|
671
|
+
ValueError: The following parameters are in the scenarios but not in the survey: {'plop'}
|
672
|
+
|
673
|
+
>>> q = QuestionFreeText(question_text = "Hello", question_name = "ugly_question")
|
674
|
+
>>> s = Scenario({'ugly_question': "B"})
|
675
|
+
>>> j = Jobs(survey = Survey(questions=[q])).by(s)
|
676
|
+
>>> j._check_parameters()
|
677
|
+
Traceback (most recent call last):
|
678
|
+
...
|
679
|
+
ValueError: The following names are in both the survey question_names and the scenario keys: {'ugly_question'}. This will create issues.
|
680
|
+
"""
|
681
|
+
survey_parameters: set = self.survey.parameters
|
682
|
+
scenario_parameters: set = self.scenarios.parameters
|
683
|
+
|
684
|
+
msg0, msg1, msg2 = None, None, None
|
685
|
+
|
686
|
+
# look for key issues
|
687
|
+
if intersection := set(self.scenarios.parameters) & set(
|
688
|
+
self.survey.question_names
|
689
|
+
):
|
690
|
+
msg0 = f"The following names are in both the survey question_names and the scenario keys: {intersection}. This will create issues."
|
691
|
+
|
692
|
+
raise ValueError(msg0)
|
693
|
+
|
694
|
+
if in_survey_but_not_in_scenarios := survey_parameters - scenario_parameters:
|
695
|
+
msg1 = f"The following parameters are in the survey but not in the scenarios: {in_survey_but_not_in_scenarios}"
|
696
|
+
if in_scenarios_but_not_in_survey := scenario_parameters - survey_parameters:
|
697
|
+
msg2 = f"The following parameters are in the scenarios but not in the survey: {in_scenarios_but_not_in_survey}"
|
698
|
+
|
699
|
+
if msg1 or msg2:
|
700
|
+
message = "\n".join(filter(None, [msg1, msg2]))
|
701
|
+
if strict:
|
702
|
+
raise ValueError(message)
|
703
|
+
else:
|
704
|
+
if warn:
|
705
|
+
warnings.warn(message)
|
706
|
+
|
707
|
+
if self.scenarios.has_jinja_braces:
|
708
|
+
warnings.warn(
|
709
|
+
"The scenarios have Jinja braces ({{ and }}). Converting to '<<' and '>>'. If you want a different conversion, use the convert_jinja_braces method first to modify the scenario."
|
710
|
+
)
|
711
|
+
self.scenarios = self.scenarios.convert_jinja_braces()
|
712
|
+
|
713
|
+
@property
|
714
|
+
def skip_retry(self):
|
715
|
+
if not hasattr(self, "_skip_retry"):
|
716
|
+
return False
|
717
|
+
return self._skip_retry
|
718
|
+
|
719
|
+
@property
|
720
|
+
def raise_validation_errors(self):
|
721
|
+
if not hasattr(self, "_raise_validation_errors"):
|
722
|
+
return False
|
723
|
+
return self._raise_validation_errors
|
724
|
+
|
725
|
+
def create_remote_inference_job(
|
726
|
+
self,
|
727
|
+
iterations: int = 1,
|
728
|
+
remote_inference_description: Optional[str] = None,
|
729
|
+
remote_inference_results_visibility: Optional[VisibilityType] = "unlisted",
|
730
|
+
verbose=False,
|
731
|
+
):
|
732
|
+
""" """
|
733
|
+
from edsl.coop.coop import Coop
|
734
|
+
|
735
|
+
coop = Coop()
|
736
|
+
self._output("Remote inference activated. Sending job to server...")
|
737
|
+
remote_job_creation_data = coop.remote_inference_create(
|
738
|
+
self,
|
739
|
+
description=remote_inference_description,
|
740
|
+
status="queued",
|
741
|
+
iterations=iterations,
|
742
|
+
initial_results_visibility=remote_inference_results_visibility,
|
743
|
+
)
|
744
|
+
job_uuid = remote_job_creation_data.get("uuid")
|
745
|
+
if self.verbose:
|
746
|
+
print(f"Job sent to server. (Job uuid={job_uuid}).")
|
747
|
+
return remote_job_creation_data
|
748
|
+
|
749
|
+
@staticmethod
|
750
|
+
def check_status(job_uuid):
|
751
|
+
from edsl.coop.coop import Coop
|
752
|
+
|
753
|
+
coop = Coop()
|
754
|
+
return coop.remote_inference_get(job_uuid)
|
755
|
+
|
756
|
+
def poll_remote_inference_job(
|
757
|
+
self, remote_job_creation_data: dict, verbose=False, poll_interval=5
|
758
|
+
) -> Union[Results, None]:
|
759
|
+
from edsl.coop.coop import Coop
|
760
|
+
import time
|
761
|
+
from datetime import datetime
|
762
|
+
from edsl.config import CONFIG
|
763
|
+
|
764
|
+
expected_parrot_url = CONFIG.get("EXPECTED_PARROT_URL")
|
765
|
+
|
766
|
+
job_uuid = remote_job_creation_data.get("uuid")
|
767
|
+
|
768
|
+
coop = Coop()
|
769
|
+
job_in_queue = True
|
770
|
+
while job_in_queue:
|
771
|
+
remote_job_data = coop.remote_inference_get(job_uuid)
|
772
|
+
status = remote_job_data.get("status")
|
773
|
+
if status == "cancelled":
|
774
|
+
if self.verbose:
|
775
|
+
print("\r" + " " * 80 + "\r", end="")
|
776
|
+
print("Job cancelled by the user.")
|
777
|
+
print(
|
778
|
+
f"See {expected_parrot_url}/home/remote-inference for more details."
|
779
|
+
)
|
780
|
+
return None
|
781
|
+
elif status == "failed":
|
782
|
+
if self.verbose:
|
783
|
+
print("\r" + " " * 80 + "\r", end="")
|
784
|
+
print("Job failed.")
|
785
|
+
print(
|
786
|
+
f"See {expected_parrot_url}/home/remote-inference for more details."
|
787
|
+
)
|
788
|
+
return None
|
789
|
+
elif status == "completed":
|
790
|
+
results_uuid = remote_job_data.get("results_uuid")
|
791
|
+
results = coop.get(results_uuid, expected_object_type="results")
|
792
|
+
if self.verbose:
|
793
|
+
print("\r" + " " * 80 + "\r", end="")
|
794
|
+
url = f"{expected_parrot_url}/content/{results_uuid}"
|
795
|
+
print(f"Job completed and Results stored on Coop: {url}.")
|
796
|
+
return results
|
797
|
+
else:
|
798
|
+
duration = poll_interval
|
799
|
+
time_checked = datetime.now().strftime("%Y-%m-%d %I:%M:%S %p")
|
800
|
+
frames = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
|
801
|
+
start_time = time.time()
|
802
|
+
i = 0
|
803
|
+
while time.time() - start_time < duration:
|
804
|
+
if self.verbose:
|
805
|
+
print(
|
806
|
+
f"\r{frames[i % len(frames)]} Job status: {status} - last update: {time_checked}",
|
807
|
+
end="",
|
808
|
+
flush=True,
|
809
|
+
)
|
810
|
+
time.sleep(0.1)
|
811
|
+
i += 1
|
812
|
+
|
813
|
+
def use_remote_inference(self, disable_remote_inference: bool) -> bool:
|
814
|
+
if disable_remote_inference:
|
815
|
+
return False
|
816
|
+
if not disable_remote_inference:
|
817
|
+
try:
|
818
|
+
from edsl import Coop
|
819
|
+
|
820
|
+
user_edsl_settings = Coop().edsl_settings
|
821
|
+
return user_edsl_settings.get("remote_inference", False)
|
822
|
+
except requests.ConnectionError:
|
823
|
+
pass
|
824
|
+
except CoopServerResponseError as e:
|
825
|
+
pass
|
826
|
+
|
827
|
+
return False
|
828
|
+
|
829
|
+
def use_remote_cache(self, disable_remote_cache: bool) -> bool:
|
830
|
+
if disable_remote_cache:
|
831
|
+
return False
|
832
|
+
if not disable_remote_cache:
|
833
|
+
try:
|
834
|
+
from edsl import Coop
|
835
|
+
|
836
|
+
user_edsl_settings = Coop().edsl_settings
|
837
|
+
return user_edsl_settings.get("remote_caching", False)
|
838
|
+
except requests.ConnectionError:
|
839
|
+
pass
|
840
|
+
except CoopServerResponseError as e:
|
841
|
+
pass
|
842
|
+
|
843
|
+
return False
|
844
|
+
|
845
|
+
def check_api_keys(self) -> None:
|
846
|
+
from edsl import Model
|
847
|
+
|
848
|
+
for model in self.models + [Model()]:
|
849
|
+
if not model.has_valid_api_key():
|
850
|
+
raise MissingAPIKeyError(
|
851
|
+
model_name=str(model.model),
|
852
|
+
inference_service=model._inference_service_,
|
853
|
+
)
|
854
|
+
|
855
|
+
def get_missing_api_keys(self) -> set:
|
856
|
+
"""
|
857
|
+
Returns a list of the api keys that a user needs to run this job, but does not currently have in their .env file.
|
858
|
+
"""
|
859
|
+
|
860
|
+
missing_api_keys = set()
|
861
|
+
|
862
|
+
from edsl import Model
|
863
|
+
from edsl.enums import service_to_api_keyname
|
864
|
+
|
865
|
+
for model in self.models + [Model()]:
|
866
|
+
if not model.has_valid_api_key():
|
867
|
+
key_name = service_to_api_keyname.get(
|
868
|
+
model._inference_service_, "NOT FOUND"
|
869
|
+
)
|
870
|
+
missing_api_keys.add(key_name)
|
871
|
+
|
872
|
+
return missing_api_keys
|
873
|
+
|
874
|
+
def user_has_all_model_keys(self):
|
875
|
+
"""
|
876
|
+
Returns True if the user has all model keys required to run their job.
|
877
|
+
|
878
|
+
Otherwise, returns False.
|
879
|
+
"""
|
880
|
+
|
881
|
+
try:
|
882
|
+
self.check_api_keys()
|
883
|
+
return True
|
884
|
+
except MissingAPIKeyError:
|
885
|
+
return False
|
886
|
+
except Exception:
|
887
|
+
raise
|
888
|
+
|
889
|
+
def user_has_ep_api_key(self) -> bool:
|
890
|
+
"""
|
891
|
+
Returns True if the user has an EXPECTED_PARROT_API_KEY in their env.
|
892
|
+
|
893
|
+
Otherwise, returns False.
|
894
|
+
"""
|
895
|
+
|
896
|
+
import os
|
897
|
+
|
898
|
+
coop_api_key = os.getenv("EXPECTED_PARROT_API_KEY")
|
899
|
+
|
900
|
+
if coop_api_key is not None:
|
901
|
+
return True
|
902
|
+
else:
|
903
|
+
return False
|
904
|
+
|
905
|
+
def needs_external_llms(self) -> bool:
|
906
|
+
"""
|
907
|
+
Returns True if the job needs external LLMs to run.
|
908
|
+
|
909
|
+
Otherwise, returns False.
|
910
|
+
"""
|
911
|
+
# These cases are necessary to skip the API key check during doctests
|
912
|
+
|
913
|
+
# Accounts for Results.example()
|
914
|
+
all_agents_answer_questions_directly = len(self.agents) > 0 and all(
|
915
|
+
[hasattr(a, "answer_question_directly") for a in self.agents]
|
916
|
+
)
|
917
|
+
|
918
|
+
# Accounts for InterviewExceptionEntry.example()
|
919
|
+
only_model_is_test = set([m.model for m in self.models]) == set(["test"])
|
920
|
+
|
921
|
+
# Accounts for Survey.__call__
|
922
|
+
all_questions_are_functional = set(
|
923
|
+
[q.question_type for q in self.survey.questions]
|
924
|
+
) == set(["functional"])
|
925
|
+
|
926
|
+
if (
|
927
|
+
all_agents_answer_questions_directly
|
928
|
+
or only_model_is_test
|
929
|
+
or all_questions_are_functional
|
930
|
+
):
|
931
|
+
return False
|
932
|
+
else:
|
933
|
+
return True
|
934
|
+
|
935
|
+
def run(
|
936
|
+
self,
|
937
|
+
n: int = 1,
|
938
|
+
progress_bar: bool = False,
|
939
|
+
stop_on_exception: bool = False,
|
940
|
+
cache: Union[Cache, bool] = None,
|
941
|
+
check_api_keys: bool = False,
|
942
|
+
sidecar_model: Optional[LanguageModel] = None,
|
943
|
+
verbose: bool = False,
|
944
|
+
print_exceptions=True,
|
945
|
+
remote_cache_description: Optional[str] = None,
|
946
|
+
remote_inference_description: Optional[str] = None,
|
947
|
+
remote_inference_results_visibility: Optional[
|
948
|
+
Literal["private", "public", "unlisted"]
|
949
|
+
] = "unlisted",
|
950
|
+
skip_retry: bool = False,
|
951
|
+
raise_validation_errors: bool = False,
|
952
|
+
disable_remote_cache: bool = False,
|
953
|
+
disable_remote_inference: bool = False,
|
954
|
+
) -> Results:
|
955
|
+
"""
|
956
|
+
Runs the Job: conducts Interviews and returns their results.
|
957
|
+
|
958
|
+
:param n: How many times to run each interview
|
959
|
+
:param progress_bar: Whether to show a progress bar
|
960
|
+
:param stop_on_exception: Stops the job if an exception is raised
|
961
|
+
:param cache: A Cache object to store results
|
962
|
+
:param check_api_keys: Raises an error if API keys are invalid
|
963
|
+
:param verbose: Prints extra messages
|
964
|
+
:param remote_cache_description: Specifies a description for this group of entries in the remote cache
|
965
|
+
:param remote_inference_description: Specifies a description for the remote inference job
|
966
|
+
:param remote_inference_results_visibility: The initial visibility of the Results object on Coop. This will only be used for remote jobs!
|
967
|
+
:param disable_remote_cache: If True, the job will not use remote cache. This only works for local jobs!
|
968
|
+
:param disable_remote_inference: If True, the job will not use remote inference
|
969
|
+
"""
|
970
|
+
from edsl.coop.coop import Coop
|
971
|
+
|
972
|
+
self._check_parameters()
|
973
|
+
self._skip_retry = skip_retry
|
974
|
+
self._raise_validation_errors = raise_validation_errors
|
975
|
+
|
976
|
+
self.verbose = verbose
|
977
|
+
|
978
|
+
if (
|
979
|
+
not self.user_has_all_model_keys()
|
980
|
+
and not self.user_has_ep_api_key()
|
981
|
+
and self.needs_external_llms()
|
982
|
+
):
|
983
|
+
import secrets
|
984
|
+
from dotenv import load_dotenv
|
985
|
+
from edsl import CONFIG
|
986
|
+
from edsl.coop.coop import Coop
|
987
|
+
from edsl.utilities.utilities import write_api_key_to_env
|
988
|
+
|
989
|
+
missing_api_keys = self.get_missing_api_keys()
|
990
|
+
|
991
|
+
edsl_auth_token = secrets.token_urlsafe(16)
|
992
|
+
|
993
|
+
print("You're missing some of the API keys needed to run this job:")
|
994
|
+
for api_key in missing_api_keys:
|
995
|
+
print(f" 🔑 {api_key}")
|
996
|
+
print(
|
997
|
+
"\nYou can either add the missing keys to your .env file, or use remote inference."
|
998
|
+
)
|
999
|
+
print("Remote inference allows you to run jobs on our server.")
|
1000
|
+
print("\n🚀 To use remote inference, sign up at the following link:")
|
1001
|
+
|
1002
|
+
coop = Coop()
|
1003
|
+
coop._display_login_url(edsl_auth_token=edsl_auth_token)
|
1004
|
+
|
1005
|
+
print(
|
1006
|
+
"\nOnce you log in, we will automatically retrieve your Expected Parrot API key and continue your job remotely."
|
1007
|
+
)
|
1008
|
+
|
1009
|
+
api_key = coop._poll_for_api_key(edsl_auth_token)
|
1010
|
+
|
1011
|
+
if api_key is None:
|
1012
|
+
print("\nTimed out waiting for login. Please try again.")
|
1013
|
+
return
|
1014
|
+
|
1015
|
+
write_api_key_to_env(api_key)
|
1016
|
+
print("✨ API key retrieved and written to .env file.\n")
|
1017
|
+
|
1018
|
+
# Retrieve API key so we can continue running the job
|
1019
|
+
load_dotenv()
|
1020
|
+
|
1021
|
+
if remote_inference := self.use_remote_inference(disable_remote_inference):
|
1022
|
+
remote_job_creation_data = self.create_remote_inference_job(
|
1023
|
+
iterations=n,
|
1024
|
+
remote_inference_description=remote_inference_description,
|
1025
|
+
remote_inference_results_visibility=remote_inference_results_visibility,
|
1026
|
+
)
|
1027
|
+
results = self.poll_remote_inference_job(remote_job_creation_data)
|
1028
|
+
if results is None:
|
1029
|
+
self._output("Job failed.")
|
1030
|
+
return results
|
1031
|
+
|
1032
|
+
if check_api_keys:
|
1033
|
+
self.check_api_keys()
|
1034
|
+
|
1035
|
+
# handle cache
|
1036
|
+
if cache is None or cache is True:
|
1037
|
+
from edsl.data.CacheHandler import CacheHandler
|
1038
|
+
|
1039
|
+
cache = CacheHandler().get_cache()
|
1040
|
+
if cache is False:
|
1041
|
+
from edsl.data.Cache import Cache
|
1042
|
+
|
1043
|
+
cache = Cache()
|
1044
|
+
|
1045
|
+
remote_cache = self.use_remote_cache(disable_remote_cache)
|
1046
|
+
with RemoteCacheSync(
|
1047
|
+
coop=Coop(),
|
1048
|
+
cache=cache,
|
1049
|
+
output_func=self._output,
|
1050
|
+
remote_cache=remote_cache,
|
1051
|
+
remote_cache_description=remote_cache_description,
|
1052
|
+
) as r:
|
1053
|
+
results = self._run_local(
|
1054
|
+
n=n,
|
1055
|
+
progress_bar=progress_bar,
|
1056
|
+
cache=cache,
|
1057
|
+
stop_on_exception=stop_on_exception,
|
1058
|
+
sidecar_model=sidecar_model,
|
1059
|
+
print_exceptions=print_exceptions,
|
1060
|
+
raise_validation_errors=raise_validation_errors,
|
1061
|
+
)
|
1062
|
+
|
1063
|
+
results.cache = cache.new_entries_cache()
|
1064
|
+
return results
|
1065
|
+
|
1066
|
+
async def create_and_poll_remote_job(
|
1067
|
+
self,
|
1068
|
+
iterations: int = 1,
|
1069
|
+
remote_inference_description: Optional[str] = None,
|
1070
|
+
remote_inference_results_visibility: Optional[
|
1071
|
+
Literal["private", "public", "unlisted"]
|
1072
|
+
] = "unlisted",
|
1073
|
+
) -> Union[Results, None]:
|
1074
|
+
"""
|
1075
|
+
Creates and polls a remote inference job asynchronously.
|
1076
|
+
Reuses existing synchronous methods but runs them in an async context.
|
1077
|
+
|
1078
|
+
:param iterations: Number of times to run each interview
|
1079
|
+
:param remote_inference_description: Optional description for the remote job
|
1080
|
+
:param remote_inference_results_visibility: Visibility setting for results
|
1081
|
+
:return: Results object if successful, None if job fails or is cancelled
|
1082
|
+
"""
|
1083
|
+
import asyncio
|
1084
|
+
from functools import partial
|
1085
|
+
|
1086
|
+
# Create job using existing method
|
1087
|
+
loop = asyncio.get_event_loop()
|
1088
|
+
remote_job_creation_data = await loop.run_in_executor(
|
1089
|
+
None,
|
1090
|
+
partial(
|
1091
|
+
self.create_remote_inference_job,
|
1092
|
+
iterations=iterations,
|
1093
|
+
remote_inference_description=remote_inference_description,
|
1094
|
+
remote_inference_results_visibility=remote_inference_results_visibility,
|
1095
|
+
),
|
1096
|
+
)
|
1097
|
+
|
1098
|
+
# Poll using existing method but with async sleep
|
1099
|
+
return await loop.run_in_executor(
|
1100
|
+
None, partial(self.poll_remote_inference_job, remote_job_creation_data)
|
1101
|
+
)
|
1102
|
+
|
1103
|
+
async def run_async(
|
1104
|
+
self,
|
1105
|
+
cache=None,
|
1106
|
+
n=1,
|
1107
|
+
disable_remote_inference: bool = False,
|
1108
|
+
remote_inference_description: Optional[str] = None,
|
1109
|
+
remote_inference_results_visibility: Optional[
|
1110
|
+
Literal["private", "public", "unlisted"]
|
1111
|
+
] = "unlisted",
|
1112
|
+
**kwargs,
|
1113
|
+
):
|
1114
|
+
"""Run the job asynchronously, either locally or remotely.
|
1115
|
+
|
1116
|
+
:param cache: Cache object or boolean
|
1117
|
+
:param n: Number of iterations
|
1118
|
+
:param disable_remote_inference: If True, forces local execution
|
1119
|
+
:param remote_inference_description: Description for remote jobs
|
1120
|
+
:param remote_inference_results_visibility: Visibility setting for remote results
|
1121
|
+
:param kwargs: Additional arguments passed to local execution
|
1122
|
+
:return: Results object
|
1123
|
+
"""
|
1124
|
+
# Check if we should use remote inference
|
1125
|
+
if remote_inference := self.use_remote_inference(disable_remote_inference):
|
1126
|
+
results = await self.create_and_poll_remote_job(
|
1127
|
+
iterations=n,
|
1128
|
+
remote_inference_description=remote_inference_description,
|
1129
|
+
remote_inference_results_visibility=remote_inference_results_visibility,
|
1130
|
+
)
|
1131
|
+
if results is None:
|
1132
|
+
self._output("Job failed.")
|
1133
|
+
return results
|
1134
|
+
|
1135
|
+
# If not using remote inference, run locally with async
|
1136
|
+
return await JobsRunnerAsyncio(self).run_async(cache=cache, n=n, **kwargs)
|
1137
|
+
|
1138
|
+
def _run_local(self, *args, **kwargs):
|
1139
|
+
"""Run the job locally."""
|
1140
|
+
|
1141
|
+
results = JobsRunnerAsyncio(self).run(*args, **kwargs)
|
1142
|
+
return results
|
1143
|
+
|
1144
|
+
def all_question_parameters(self):
|
1145
|
+
"""Return all the fields in the questions in the survey.
|
1146
|
+
>>> from edsl.jobs import Jobs
|
1147
|
+
>>> Jobs.example().all_question_parameters()
|
1148
|
+
{'period'}
|
1149
|
+
"""
|
1150
|
+
return set.union(*[question.parameters for question in self.survey.questions])
|
1151
|
+
|
1152
|
+
#######################
|
1153
|
+
# Dunder methods
|
1154
|
+
#######################
|
1155
|
+
def print(self):
|
1156
|
+
from rich import print_json
|
1157
|
+
import json
|
1158
|
+
|
1159
|
+
print_json(json.dumps(self.to_dict()))
|
1160
|
+
|
1161
|
+
def __repr__(self) -> str:
|
1162
|
+
"""Return an eval-able string representation of the Jobs instance."""
|
1163
|
+
return f"Jobs(survey={repr(self.survey)}, agents={repr(self.agents)}, models={repr(self.models)}, scenarios={repr(self.scenarios)})"
|
1164
|
+
|
1165
|
+
def _repr_html_(self) -> str:
|
1166
|
+
from rich import print_json
|
1167
|
+
import json
|
1168
|
+
|
1169
|
+
print_json(json.dumps(self.to_dict()))
|
1170
|
+
|
1171
|
+
def __len__(self) -> int:
|
1172
|
+
"""Return the maximum number of questions that will be asked while running this job.
|
1173
|
+
Note that this is the maximum number of questions, not the actual number of questions that will be asked, as some questions may be skipped.
|
1174
|
+
|
1175
|
+
>>> from edsl.jobs import Jobs
|
1176
|
+
>>> len(Jobs.example())
|
1177
|
+
8
|
1178
|
+
"""
|
1179
|
+
number_of_questions = (
|
1180
|
+
len(self.agents or [1])
|
1181
|
+
* len(self.scenarios or [1])
|
1182
|
+
* len(self.models or [1])
|
1183
|
+
* len(self.survey)
|
1184
|
+
)
|
1185
|
+
return number_of_questions
|
1186
|
+
|
1187
|
+
#######################
|
1188
|
+
# Serialization methods
|
1189
|
+
#######################
|
1190
|
+
|
1191
|
+
def _to_dict(self):
|
1192
|
+
return {
|
1193
|
+
"survey": self.survey._to_dict(),
|
1194
|
+
"agents": [agent._to_dict() for agent in self.agents],
|
1195
|
+
"models": [model._to_dict() for model in self.models],
|
1196
|
+
"scenarios": [scenario._to_dict() for scenario in self.scenarios],
|
1197
|
+
}
|
1198
|
+
|
1199
|
+
@add_edsl_version
|
1200
|
+
def to_dict(self) -> dict:
|
1201
|
+
"""Convert the Jobs instance to a dictionary."""
|
1202
|
+
return self._to_dict()
|
1203
|
+
|
1204
|
+
@classmethod
|
1205
|
+
@remove_edsl_version
|
1206
|
+
def from_dict(cls, data: dict) -> Jobs:
|
1207
|
+
"""Creates a Jobs instance from a dictionary."""
|
1208
|
+
from edsl import Survey
|
1209
|
+
from edsl.agents.Agent import Agent
|
1210
|
+
from edsl.language_models.LanguageModel import LanguageModel
|
1211
|
+
from edsl.scenarios.Scenario import Scenario
|
1212
|
+
|
1213
|
+
return cls(
|
1214
|
+
survey=Survey.from_dict(data["survey"]),
|
1215
|
+
agents=[Agent.from_dict(agent) for agent in data["agents"]],
|
1216
|
+
models=[LanguageModel.from_dict(model) for model in data["models"]],
|
1217
|
+
scenarios=[Scenario.from_dict(scenario) for scenario in data["scenarios"]],
|
1218
|
+
)
|
1219
|
+
|
1220
|
+
def __eq__(self, other: Jobs) -> bool:
|
1221
|
+
"""Return True if the Jobs instance is equal to another Jobs instance.
|
1222
|
+
|
1223
|
+
>>> from edsl.jobs import Jobs
|
1224
|
+
>>> Jobs.example() == Jobs.example()
|
1225
|
+
True
|
1226
|
+
|
1227
|
+
"""
|
1228
|
+
return self.to_dict() == other.to_dict()
|
1229
|
+
|
1230
|
+
#######################
|
1231
|
+
# Example methods
|
1232
|
+
#######################
|
1233
|
+
@classmethod
|
1234
|
+
def example(
|
1235
|
+
cls,
|
1236
|
+
throw_exception_probability: float = 0.0,
|
1237
|
+
randomize: bool = False,
|
1238
|
+
test_model=False,
|
1239
|
+
) -> Jobs:
|
1240
|
+
"""Return an example Jobs instance.
|
1241
|
+
|
1242
|
+
:param throw_exception_probability: the probability that an exception will be thrown when answering a question. This is useful for testing error handling.
|
1243
|
+
:param randomize: whether to randomize the job by adding a random string to the period
|
1244
|
+
:param test_model: whether to use a test model
|
1245
|
+
|
1246
|
+
>>> Jobs.example()
|
1247
|
+
Jobs(...)
|
1248
|
+
|
1249
|
+
"""
|
1250
|
+
import random
|
1251
|
+
from uuid import uuid4
|
1252
|
+
from edsl.questions import QuestionMultipleChoice
|
1253
|
+
from edsl.agents.Agent import Agent
|
1254
|
+
from edsl.scenarios.Scenario import Scenario
|
1255
|
+
|
1256
|
+
addition = "" if not randomize else str(uuid4())
|
1257
|
+
|
1258
|
+
if test_model:
|
1259
|
+
from edsl.language_models import LanguageModel
|
1260
|
+
|
1261
|
+
m = LanguageModel.example(test_model=True)
|
1262
|
+
|
1263
|
+
# (status, question, period)
|
1264
|
+
agent_answers = {
|
1265
|
+
("Joyful", "how_feeling", "morning"): "OK",
|
1266
|
+
("Joyful", "how_feeling", "afternoon"): "Great",
|
1267
|
+
("Joyful", "how_feeling_yesterday", "morning"): "Great",
|
1268
|
+
("Joyful", "how_feeling_yesterday", "afternoon"): "Good",
|
1269
|
+
("Sad", "how_feeling", "morning"): "Terrible",
|
1270
|
+
("Sad", "how_feeling", "afternoon"): "OK",
|
1271
|
+
("Sad", "how_feeling_yesterday", "morning"): "OK",
|
1272
|
+
("Sad", "how_feeling_yesterday", "afternoon"): "Terrible",
|
1273
|
+
}
|
1274
|
+
|
1275
|
+
def answer_question_directly(self, question, scenario):
|
1276
|
+
"""Return the answer to a question. This is a method that can be added to an agent."""
|
1277
|
+
|
1278
|
+
if random.random() < throw_exception_probability:
|
1279
|
+
raise Exception("Error!")
|
1280
|
+
return agent_answers[
|
1281
|
+
(self.traits["status"], question.question_name, scenario["period"])
|
1282
|
+
]
|
1283
|
+
|
1284
|
+
sad_agent = Agent(traits={"status": "Sad"})
|
1285
|
+
joy_agent = Agent(traits={"status": "Joyful"})
|
1286
|
+
|
1287
|
+
sad_agent.add_direct_question_answering_method(answer_question_directly)
|
1288
|
+
joy_agent.add_direct_question_answering_method(answer_question_directly)
|
1289
|
+
|
1290
|
+
q1 = QuestionMultipleChoice(
|
1291
|
+
question_text="How are you this {{ period }}?",
|
1292
|
+
question_options=["Good", "Great", "OK", "Terrible"],
|
1293
|
+
question_name="how_feeling",
|
1294
|
+
)
|
1295
|
+
q2 = QuestionMultipleChoice(
|
1296
|
+
question_text="How were you feeling yesterday {{ period }}?",
|
1297
|
+
question_options=["Good", "Great", "OK", "Terrible"],
|
1298
|
+
question_name="how_feeling_yesterday",
|
1299
|
+
)
|
1300
|
+
from edsl import Survey, ScenarioList
|
1301
|
+
|
1302
|
+
base_survey = Survey(questions=[q1, q2])
|
1303
|
+
|
1304
|
+
scenario_list = ScenarioList(
|
1305
|
+
[
|
1306
|
+
Scenario({"period": f"morning{addition}"}),
|
1307
|
+
Scenario({"period": "afternoon"}),
|
1308
|
+
]
|
1309
|
+
)
|
1310
|
+
if test_model:
|
1311
|
+
job = base_survey.by(m).by(scenario_list).by(joy_agent, sad_agent)
|
1312
|
+
else:
|
1313
|
+
job = base_survey.by(scenario_list).by(joy_agent, sad_agent)
|
1314
|
+
|
1315
|
+
return job
|
1316
|
+
|
1317
|
+
def rich_print(self):
|
1318
|
+
"""Print a rich representation of the Jobs instance."""
|
1319
|
+
from rich.table import Table
|
1320
|
+
|
1321
|
+
table = Table(title="Jobs")
|
1322
|
+
table.add_column("Jobs")
|
1323
|
+
table.add_row(self.survey.rich_print())
|
1324
|
+
return table
|
1325
|
+
|
1326
|
+
def code(self):
|
1327
|
+
"""Return the code to create this instance."""
|
1328
|
+
raise NotImplementedError
|
1329
|
+
|
1330
|
+
|
1331
|
+
def main():
|
1332
|
+
"""Run the module's doctests."""
|
1333
|
+
from edsl.jobs import Jobs
|
1334
|
+
from edsl.data.Cache import Cache
|
1335
|
+
|
1336
|
+
job = Jobs.example()
|
1337
|
+
len(job) == 8
|
1338
|
+
results = job.run(cache=Cache())
|
1339
|
+
len(results) == 8
|
1340
|
+
results
|
1341
|
+
|
1342
|
+
|
1343
|
+
if __name__ == "__main__":
|
1344
|
+
"""Run the module's doctests."""
|
1345
|
+
import doctest
|
1346
|
+
|
1347
|
+
doctest.testmod(optionflags=doctest.ELLIPSIS)
|