edsl 0.1.38.dev1__py3-none-any.whl → 0.1.38.dev3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (263) hide show
  1. edsl/Base.py +303 -303
  2. edsl/BaseDiff.py +260 -260
  3. edsl/TemplateLoader.py +24 -24
  4. edsl/__init__.py +49 -48
  5. edsl/__version__.py +1 -1
  6. edsl/agents/Agent.py +858 -855
  7. edsl/agents/AgentList.py +362 -350
  8. edsl/agents/Invigilator.py +222 -222
  9. edsl/agents/InvigilatorBase.py +284 -284
  10. edsl/agents/PromptConstructor.py +353 -353
  11. edsl/agents/__init__.py +3 -3
  12. edsl/agents/descriptors.py +99 -99
  13. edsl/agents/prompt_helpers.py +129 -129
  14. edsl/auto/AutoStudy.py +117 -117
  15. edsl/auto/StageBase.py +230 -230
  16. edsl/auto/StageGenerateSurvey.py +178 -178
  17. edsl/auto/StageLabelQuestions.py +125 -125
  18. edsl/auto/StagePersona.py +61 -61
  19. edsl/auto/StagePersonaDimensionValueRanges.py +88 -88
  20. edsl/auto/StagePersonaDimensionValues.py +74 -74
  21. edsl/auto/StagePersonaDimensions.py +69 -69
  22. edsl/auto/StageQuestions.py +73 -73
  23. edsl/auto/SurveyCreatorPipeline.py +21 -21
  24. edsl/auto/utilities.py +224 -224
  25. edsl/base/Base.py +279 -289
  26. edsl/config.py +149 -149
  27. edsl/conversation/Conversation.py +290 -290
  28. edsl/conversation/car_buying.py +58 -58
  29. edsl/conversation/chips.py +95 -95
  30. edsl/conversation/mug_negotiation.py +81 -81
  31. edsl/conversation/next_speaker_utilities.py +93 -93
  32. edsl/coop/PriceFetcher.py +54 -54
  33. edsl/coop/__init__.py +2 -2
  34. edsl/coop/coop.py +961 -958
  35. edsl/coop/utils.py +131 -131
  36. edsl/data/Cache.py +530 -527
  37. edsl/data/CacheEntry.py +228 -228
  38. edsl/data/CacheHandler.py +149 -149
  39. edsl/data/RemoteCacheSync.py +97 -97
  40. edsl/data/SQLiteDict.py +292 -292
  41. edsl/data/__init__.py +4 -4
  42. edsl/data/orm.py +10 -10
  43. edsl/data_transfer_models.py +73 -73
  44. edsl/enums.py +173 -173
  45. edsl/exceptions/BaseException.py +21 -21
  46. edsl/exceptions/__init__.py +54 -54
  47. edsl/exceptions/agents.py +42 -38
  48. edsl/exceptions/cache.py +5 -0
  49. edsl/exceptions/configuration.py +16 -16
  50. edsl/exceptions/coop.py +10 -10
  51. edsl/exceptions/data.py +14 -14
  52. edsl/exceptions/general.py +34 -34
  53. edsl/exceptions/jobs.py +33 -33
  54. edsl/exceptions/language_models.py +63 -63
  55. edsl/exceptions/prompts.py +15 -15
  56. edsl/exceptions/questions.py +91 -91
  57. edsl/exceptions/results.py +29 -29
  58. edsl/exceptions/scenarios.py +22 -22
  59. edsl/exceptions/surveys.py +37 -37
  60. edsl/inference_services/AnthropicService.py +87 -87
  61. edsl/inference_services/AwsBedrock.py +120 -120
  62. edsl/inference_services/AzureAI.py +217 -217
  63. edsl/inference_services/DeepInfraService.py +18 -18
  64. edsl/inference_services/GoogleService.py +156 -156
  65. edsl/inference_services/GroqService.py +20 -20
  66. edsl/inference_services/InferenceServiceABC.py +147 -147
  67. edsl/inference_services/InferenceServicesCollection.py +97 -97
  68. edsl/inference_services/MistralAIService.py +123 -123
  69. edsl/inference_services/OllamaService.py +18 -18
  70. edsl/inference_services/OpenAIService.py +224 -224
  71. edsl/inference_services/TestService.py +89 -89
  72. edsl/inference_services/TogetherAIService.py +170 -170
  73. edsl/inference_services/models_available_cache.py +118 -118
  74. edsl/inference_services/rate_limits_cache.py +25 -25
  75. edsl/inference_services/registry.py +39 -39
  76. edsl/inference_services/write_available.py +10 -10
  77. edsl/jobs/Answers.py +56 -56
  78. edsl/jobs/Jobs.py +1358 -1347
  79. edsl/jobs/__init__.py +1 -1
  80. edsl/jobs/buckets/BucketCollection.py +63 -63
  81. edsl/jobs/buckets/ModelBuckets.py +65 -65
  82. edsl/jobs/buckets/TokenBucket.py +251 -248
  83. edsl/jobs/interviews/Interview.py +661 -661
  84. edsl/jobs/interviews/InterviewExceptionCollection.py +99 -99
  85. edsl/jobs/interviews/InterviewExceptionEntry.py +186 -186
  86. edsl/jobs/interviews/InterviewStatistic.py +63 -63
  87. edsl/jobs/interviews/InterviewStatisticsCollection.py +25 -25
  88. edsl/jobs/interviews/InterviewStatusDictionary.py +78 -78
  89. edsl/jobs/interviews/InterviewStatusLog.py +92 -92
  90. edsl/jobs/interviews/ReportErrors.py +66 -66
  91. edsl/jobs/interviews/interview_status_enum.py +9 -9
  92. edsl/jobs/runners/JobsRunnerAsyncio.py +361 -338
  93. edsl/jobs/runners/JobsRunnerStatus.py +332 -332
  94. edsl/jobs/tasks/QuestionTaskCreator.py +242 -242
  95. edsl/jobs/tasks/TaskCreators.py +64 -64
  96. edsl/jobs/tasks/TaskHistory.py +451 -442
  97. edsl/jobs/tasks/TaskStatusLog.py +23 -23
  98. edsl/jobs/tasks/task_status_enum.py +163 -163
  99. edsl/jobs/tokens/InterviewTokenUsage.py +27 -27
  100. edsl/jobs/tokens/TokenUsage.py +34 -34
  101. edsl/language_models/KeyLookup.py +30 -30
  102. edsl/language_models/LanguageModel.py +708 -706
  103. edsl/language_models/ModelList.py +109 -102
  104. edsl/language_models/RegisterLanguageModelsMeta.py +184 -184
  105. edsl/language_models/__init__.py +3 -3
  106. edsl/language_models/fake_openai_call.py +15 -15
  107. edsl/language_models/fake_openai_service.py +61 -61
  108. edsl/language_models/registry.py +137 -137
  109. edsl/language_models/repair.py +156 -156
  110. edsl/language_models/unused/ReplicateBase.py +83 -83
  111. edsl/language_models/utilities.py +64 -64
  112. edsl/notebooks/Notebook.py +258 -259
  113. edsl/notebooks/__init__.py +1 -1
  114. edsl/prompts/Prompt.py +357 -357
  115. edsl/prompts/__init__.py +2 -2
  116. edsl/questions/AnswerValidatorMixin.py +289 -289
  117. edsl/questions/QuestionBase.py +660 -656
  118. edsl/questions/QuestionBaseGenMixin.py +161 -161
  119. edsl/questions/QuestionBasePromptsMixin.py +217 -234
  120. edsl/questions/QuestionBudget.py +227 -227
  121. edsl/questions/QuestionCheckBox.py +359 -359
  122. edsl/questions/QuestionExtract.py +183 -183
  123. edsl/questions/QuestionFreeText.py +114 -114
  124. edsl/questions/QuestionFunctional.py +166 -159
  125. edsl/questions/QuestionList.py +231 -231
  126. edsl/questions/QuestionMultipleChoice.py +286 -286
  127. edsl/questions/QuestionNumerical.py +153 -153
  128. edsl/questions/QuestionRank.py +324 -324
  129. edsl/questions/Quick.py +41 -41
  130. edsl/questions/RegisterQuestionsMeta.py +71 -71
  131. edsl/questions/ResponseValidatorABC.py +174 -174
  132. edsl/questions/SimpleAskMixin.py +73 -73
  133. edsl/questions/__init__.py +26 -26
  134. edsl/questions/compose_questions.py +98 -98
  135. edsl/questions/decorators.py +21 -21
  136. edsl/questions/derived/QuestionLikertFive.py +76 -76
  137. edsl/questions/derived/QuestionLinearScale.py +87 -87
  138. edsl/questions/derived/QuestionTopK.py +93 -91
  139. edsl/questions/derived/QuestionYesNo.py +82 -82
  140. edsl/questions/descriptors.py +413 -413
  141. edsl/questions/prompt_templates/question_budget.jinja +13 -13
  142. edsl/questions/prompt_templates/question_checkbox.jinja +32 -32
  143. edsl/questions/prompt_templates/question_extract.jinja +11 -11
  144. edsl/questions/prompt_templates/question_free_text.jinja +3 -3
  145. edsl/questions/prompt_templates/question_linear_scale.jinja +11 -11
  146. edsl/questions/prompt_templates/question_list.jinja +17 -17
  147. edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -33
  148. edsl/questions/prompt_templates/question_numerical.jinja +36 -36
  149. edsl/questions/question_registry.py +147 -147
  150. edsl/questions/settings.py +12 -12
  151. edsl/questions/templates/budget/answering_instructions.jinja +7 -7
  152. edsl/questions/templates/budget/question_presentation.jinja +7 -7
  153. edsl/questions/templates/checkbox/answering_instructions.jinja +10 -10
  154. edsl/questions/templates/checkbox/question_presentation.jinja +22 -22
  155. edsl/questions/templates/extract/answering_instructions.jinja +7 -7
  156. edsl/questions/templates/likert_five/answering_instructions.jinja +10 -10
  157. edsl/questions/templates/likert_five/question_presentation.jinja +11 -11
  158. edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -5
  159. edsl/questions/templates/linear_scale/question_presentation.jinja +5 -5
  160. edsl/questions/templates/list/answering_instructions.jinja +3 -3
  161. edsl/questions/templates/list/question_presentation.jinja +5 -5
  162. edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -9
  163. edsl/questions/templates/multiple_choice/question_presentation.jinja +11 -11
  164. edsl/questions/templates/numerical/answering_instructions.jinja +6 -6
  165. edsl/questions/templates/numerical/question_presentation.jinja +6 -6
  166. edsl/questions/templates/rank/answering_instructions.jinja +11 -11
  167. edsl/questions/templates/rank/question_presentation.jinja +15 -15
  168. edsl/questions/templates/top_k/answering_instructions.jinja +8 -8
  169. edsl/questions/templates/top_k/question_presentation.jinja +22 -22
  170. edsl/questions/templates/yes_no/answering_instructions.jinja +6 -6
  171. edsl/questions/templates/yes_no/question_presentation.jinja +11 -11
  172. edsl/results/Dataset.py +293 -293
  173. edsl/results/DatasetExportMixin.py +717 -717
  174. edsl/results/DatasetTree.py +145 -145
  175. edsl/results/Result.py +456 -450
  176. edsl/results/Results.py +1071 -1071
  177. edsl/results/ResultsDBMixin.py +238 -238
  178. edsl/results/ResultsExportMixin.py +43 -43
  179. edsl/results/ResultsFetchMixin.py +33 -33
  180. edsl/results/ResultsGGMixin.py +121 -121
  181. edsl/results/ResultsToolsMixin.py +98 -98
  182. edsl/results/Selector.py +135 -135
  183. edsl/results/__init__.py +2 -2
  184. edsl/results/tree_explore.py +115 -115
  185. edsl/scenarios/FileStore.py +458 -458
  186. edsl/scenarios/Scenario.py +544 -546
  187. edsl/scenarios/ScenarioHtmlMixin.py +64 -64
  188. edsl/scenarios/ScenarioList.py +1112 -1112
  189. edsl/scenarios/ScenarioListExportMixin.py +52 -52
  190. edsl/scenarios/ScenarioListPdfMixin.py +261 -261
  191. edsl/scenarios/__init__.py +4 -4
  192. edsl/shared.py +1 -1
  193. edsl/study/ObjectEntry.py +173 -173
  194. edsl/study/ProofOfWork.py +113 -113
  195. edsl/study/SnapShot.py +80 -80
  196. edsl/study/Study.py +528 -528
  197. edsl/study/__init__.py +4 -4
  198. edsl/surveys/DAG.py +148 -148
  199. edsl/surveys/Memory.py +31 -31
  200. edsl/surveys/MemoryPlan.py +244 -244
  201. edsl/surveys/Rule.py +326 -330
  202. edsl/surveys/RuleCollection.py +387 -387
  203. edsl/surveys/Survey.py +1787 -1795
  204. edsl/surveys/SurveyCSS.py +261 -261
  205. edsl/surveys/SurveyExportMixin.py +259 -259
  206. edsl/surveys/SurveyFlowVisualizationMixin.py +121 -121
  207. edsl/surveys/SurveyQualtricsImport.py +284 -284
  208. edsl/surveys/__init__.py +3 -3
  209. edsl/surveys/base.py +53 -53
  210. edsl/surveys/descriptors.py +56 -56
  211. edsl/surveys/instructions/ChangeInstruction.py +49 -47
  212. edsl/surveys/instructions/Instruction.py +53 -51
  213. edsl/surveys/instructions/InstructionCollection.py +77 -77
  214. edsl/templates/error_reporting/base.html +23 -23
  215. edsl/templates/error_reporting/exceptions_by_model.html +34 -34
  216. edsl/templates/error_reporting/exceptions_by_question_name.html +16 -16
  217. edsl/templates/error_reporting/exceptions_by_type.html +16 -16
  218. edsl/templates/error_reporting/interview_details.html +115 -115
  219. edsl/templates/error_reporting/interviews.html +9 -9
  220. edsl/templates/error_reporting/overview.html +4 -4
  221. edsl/templates/error_reporting/performance_plot.html +1 -1
  222. edsl/templates/error_reporting/report.css +73 -73
  223. edsl/templates/error_reporting/report.html +117 -117
  224. edsl/templates/error_reporting/report.js +25 -25
  225. edsl/tools/__init__.py +1 -1
  226. edsl/tools/clusters.py +192 -192
  227. edsl/tools/embeddings.py +27 -27
  228. edsl/tools/embeddings_plotting.py +118 -118
  229. edsl/tools/plotting.py +112 -112
  230. edsl/tools/summarize.py +18 -18
  231. edsl/utilities/SystemInfo.py +28 -28
  232. edsl/utilities/__init__.py +22 -22
  233. edsl/utilities/ast_utilities.py +25 -25
  234. edsl/utilities/data/Registry.py +6 -6
  235. edsl/utilities/data/__init__.py +1 -1
  236. edsl/utilities/data/scooter_results.json +1 -1
  237. edsl/utilities/decorators.py +77 -77
  238. edsl/utilities/gcp_bucket/cloud_storage.py +96 -96
  239. edsl/utilities/interface.py +627 -627
  240. edsl/{conjure → utilities}/naming_utilities.py +263 -263
  241. edsl/utilities/repair_functions.py +28 -28
  242. edsl/utilities/restricted_python.py +70 -70
  243. edsl/utilities/utilities.py +409 -409
  244. {edsl-0.1.38.dev1.dist-info → edsl-0.1.38.dev3.dist-info}/LICENSE +21 -21
  245. {edsl-0.1.38.dev1.dist-info → edsl-0.1.38.dev3.dist-info}/METADATA +1 -1
  246. edsl-0.1.38.dev3.dist-info/RECORD +269 -0
  247. edsl/conjure/AgentConstructionMixin.py +0 -160
  248. edsl/conjure/Conjure.py +0 -62
  249. edsl/conjure/InputData.py +0 -659
  250. edsl/conjure/InputDataCSV.py +0 -48
  251. edsl/conjure/InputDataMixinQuestionStats.py +0 -182
  252. edsl/conjure/InputDataPyRead.py +0 -91
  253. edsl/conjure/InputDataSPSS.py +0 -8
  254. edsl/conjure/InputDataStata.py +0 -8
  255. edsl/conjure/QuestionOptionMixin.py +0 -76
  256. edsl/conjure/QuestionTypeMixin.py +0 -23
  257. edsl/conjure/RawQuestion.py +0 -65
  258. edsl/conjure/SurveyResponses.py +0 -7
  259. edsl/conjure/__init__.py +0 -9
  260. edsl/conjure/examples/placeholder.txt +0 -0
  261. edsl/conjure/utilities.py +0 -201
  262. edsl-0.1.38.dev1.dist-info/RECORD +0 -283
  263. {edsl-0.1.38.dev1.dist-info → edsl-0.1.38.dev3.dist-info}/WHEEL +0 -0
edsl/jobs/Jobs.py CHANGED
@@ -1,1347 +1,1358 @@
1
- # """The Jobs class is a collection of agents, scenarios and models and one survey."""
2
- from __future__ import annotations
3
- import warnings
4
- import requests
5
- from itertools import product
6
- from typing import Literal, Optional, Union, Sequence, Generator
7
-
8
- from edsl.Base import Base
9
-
10
- from edsl.exceptions import MissingAPIKeyError
11
- from edsl.jobs.buckets.BucketCollection import BucketCollection
12
- from edsl.jobs.interviews.Interview import Interview
13
- from edsl.jobs.runners.JobsRunnerAsyncio import JobsRunnerAsyncio
14
- from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
15
-
16
- from edsl.data.RemoteCacheSync import RemoteCacheSync
17
- from edsl.exceptions.coop import CoopServerResponseError
18
-
19
-
20
- class Jobs(Base):
21
- """
22
- A collection of agents, scenarios and models and one survey.
23
- The actual running of a job is done by a `JobsRunner`, which is a subclass of `JobsRunner`.
24
- The `JobsRunner` is chosen by the user, and is stored in the `jobs_runner_name` attribute.
25
- """
26
-
27
- def __init__(
28
- self,
29
- survey: "Survey",
30
- agents: Optional[list["Agent"]] = None,
31
- models: Optional[list["LanguageModel"]] = None,
32
- scenarios: Optional[list["Scenario"]] = None,
33
- ):
34
- """Initialize a Jobs instance.
35
-
36
- :param survey: the survey to be used in the job
37
- :param agents: a list of agents
38
- :param models: a list of models
39
- :param scenarios: a list of scenarios
40
- """
41
- self.survey = survey
42
- self.agents: "AgentList" = agents
43
- self.scenarios: "ScenarioList" = scenarios
44
- self.models = models
45
-
46
- self.__bucket_collection = None
47
-
48
- # these setters and getters are used to ensure that the agents, models, and scenarios are stored as AgentList, ModelList, and ScenarioList objects
49
-
50
- @property
51
- def models(self):
52
- return self._models
53
-
54
- @models.setter
55
- def models(self, value):
56
- from edsl import ModelList
57
-
58
- if value:
59
- if not isinstance(value, ModelList):
60
- self._models = ModelList(value)
61
- else:
62
- self._models = value
63
- else:
64
- self._models = ModelList([])
65
-
66
- @property
67
- def agents(self):
68
- return self._agents
69
-
70
- @agents.setter
71
- def agents(self, value):
72
- from edsl import AgentList
73
-
74
- if value:
75
- if not isinstance(value, AgentList):
76
- self._agents = AgentList(value)
77
- else:
78
- self._agents = value
79
- else:
80
- self._agents = AgentList([])
81
-
82
- @property
83
- def scenarios(self):
84
- return self._scenarios
85
-
86
- @scenarios.setter
87
- def scenarios(self, value):
88
- from edsl import ScenarioList
89
-
90
- if value:
91
- if not isinstance(value, ScenarioList):
92
- self._scenarios = ScenarioList(value)
93
- else:
94
- self._scenarios = value
95
- else:
96
- self._scenarios = ScenarioList([])
97
-
98
- def by(
99
- self,
100
- *args: Union[
101
- "Agent",
102
- "Scenario",
103
- "LanguageModel",
104
- Sequence[Union["Agent", "Scenario", "LanguageModel"]],
105
- ],
106
- ) -> Jobs:
107
- """
108
- Add Agents, Scenarios and LanguageModels to a job. If no objects of this type exist in the Jobs instance, it stores the new objects as a list in the corresponding attribute. Otherwise, it combines the new objects with existing objects using the object's `__add__` method.
109
-
110
- This 'by' is intended to create a fluent interface.
111
-
112
- >>> from edsl import Survey
113
- >>> from edsl import QuestionFreeText
114
- >>> q = QuestionFreeText(question_name="name", question_text="What is your name?")
115
- >>> j = Jobs(survey = Survey(questions=[q]))
116
- >>> j
117
- Jobs(survey=Survey(...), agents=AgentList([]), models=ModelList([]), scenarios=ScenarioList([]))
118
- >>> from edsl import Agent; a = Agent(traits = {"status": "Sad"})
119
- >>> j.by(a).agents
120
- AgentList([Agent(traits = {'status': 'Sad'})])
121
-
122
- :param args: objects or a sequence (list, tuple, ...) of objects of the same type
123
-
124
- Notes:
125
- - all objects must implement the 'get_value', 'set_value', and `__add__` methods
126
- - agents: traits of new agents are combined with traits of existing agents. New and existing agents should not have overlapping traits, and do not increase the # agents in the instance
127
- - scenarios: traits of new scenarios are combined with traits of old existing. New scenarios will overwrite overlapping traits, and do not increase the number of scenarios in the instance
128
- - models: new models overwrite old models.
129
- """
130
- passed_objects = self._turn_args_to_list(
131
- args
132
- ) # objects can also be passed comma-separated
133
-
134
- current_objects, objects_key = self._get_current_objects_of_this_type(
135
- passed_objects[0]
136
- )
137
-
138
- if not current_objects:
139
- new_objects = passed_objects
140
- else:
141
- new_objects = self._merge_objects(passed_objects, current_objects)
142
-
143
- setattr(self, objects_key, new_objects) # update the job
144
- return self
145
-
146
- def prompts(self) -> "Dataset":
147
- """Return a Dataset of prompts that will be used.
148
-
149
-
150
- >>> from edsl.jobs import Jobs
151
- >>> Jobs.example().prompts()
152
- Dataset(...)
153
- """
154
- from edsl import Coop
155
-
156
- c = Coop()
157
- price_lookup = c.fetch_prices()
158
-
159
- interviews = self.interviews()
160
- # data = []
161
- interview_indices = []
162
- question_names = []
163
- user_prompts = []
164
- system_prompts = []
165
- scenario_indices = []
166
- agent_indices = []
167
- models = []
168
- costs = []
169
- from edsl.results.Dataset import Dataset
170
-
171
- for interview_index, interview in enumerate(interviews):
172
- invigilators = [
173
- interview._get_invigilator(question)
174
- for question in self.survey.questions
175
- ]
176
- for _, invigilator in enumerate(invigilators):
177
- prompts = invigilator.get_prompts()
178
- user_prompt = prompts["user_prompt"]
179
- system_prompt = prompts["system_prompt"]
180
- user_prompts.append(user_prompt)
181
- system_prompts.append(system_prompt)
182
- agent_index = self.agents.index(invigilator.agent)
183
- agent_indices.append(agent_index)
184
- interview_indices.append(interview_index)
185
- scenario_index = self.scenarios.index(invigilator.scenario)
186
- scenario_indices.append(scenario_index)
187
- models.append(invigilator.model.model)
188
- question_names.append(invigilator.question.question_name)
189
-
190
- prompt_cost = self.estimate_prompt_cost(
191
- system_prompt=system_prompt,
192
- user_prompt=user_prompt,
193
- price_lookup=price_lookup,
194
- inference_service=invigilator.model._inference_service_,
195
- model=invigilator.model.model,
196
- )
197
- costs.append(prompt_cost["cost_usd"])
198
-
199
- d = Dataset(
200
- [
201
- {"user_prompt": user_prompts},
202
- {"system_prompt": system_prompts},
203
- {"interview_index": interview_indices},
204
- {"question_name": question_names},
205
- {"scenario_index": scenario_indices},
206
- {"agent_index": agent_indices},
207
- {"model": models},
208
- {"estimated_cost": costs},
209
- ]
210
- )
211
- return d
212
-
213
- def show_prompts(self, all=False, max_rows: Optional[int] = None) -> None:
214
- """Print the prompts."""
215
- if all:
216
- self.prompts().to_scenario_list().print(format="rich", max_rows=max_rows)
217
- else:
218
- self.prompts().select(
219
- "user_prompt", "system_prompt"
220
- ).to_scenario_list().print(format="rich", max_rows=max_rows)
221
-
222
- @staticmethod
223
- def estimate_prompt_cost(
224
- system_prompt: str,
225
- user_prompt: str,
226
- price_lookup: dict,
227
- inference_service: str,
228
- model: str,
229
- ) -> dict:
230
- """Estimates the cost of a prompt. Takes piping into account."""
231
- import math
232
-
233
- def get_piping_multiplier(prompt: str):
234
- """Returns 2 if a prompt includes Jinja braces, and 1 otherwise."""
235
-
236
- if "{{" in prompt and "}}" in prompt:
237
- return 2
238
- return 1
239
-
240
- # Look up prices per token
241
- key = (inference_service, model)
242
-
243
- try:
244
- relevant_prices = price_lookup[key]
245
-
246
- service_input_token_price = float(
247
- relevant_prices["input"]["service_stated_token_price"]
248
- )
249
- service_input_token_qty = float(
250
- relevant_prices["input"]["service_stated_token_qty"]
251
- )
252
- input_price_per_token = service_input_token_price / service_input_token_qty
253
-
254
- service_output_token_price = float(
255
- relevant_prices["output"]["service_stated_token_price"]
256
- )
257
- service_output_token_qty = float(
258
- relevant_prices["output"]["service_stated_token_qty"]
259
- )
260
- output_price_per_token = (
261
- service_output_token_price / service_output_token_qty
262
- )
263
-
264
- except KeyError:
265
- # A KeyError is likely to occur if we cannot retrieve prices (the price_lookup dict is empty)
266
- # Use a sensible default
267
-
268
- import warnings
269
-
270
- warnings.warn(
271
- "Price data could not be retrieved. Using default estimates for input and output token prices. Input: $0.15 / 1M tokens; Output: $0.60 / 1M tokens"
272
- )
273
- input_price_per_token = 0.00000015 # $0.15 / 1M tokens
274
- output_price_per_token = 0.00000060 # $0.60 / 1M tokens
275
-
276
- # Compute the number of characters (double if the question involves piping)
277
- user_prompt_chars = len(str(user_prompt)) * get_piping_multiplier(
278
- str(user_prompt)
279
- )
280
- system_prompt_chars = len(str(system_prompt)) * get_piping_multiplier(
281
- str(system_prompt)
282
- )
283
-
284
- # Convert into tokens (1 token approx. equals 4 characters)
285
- input_tokens = (user_prompt_chars + system_prompt_chars) // 4
286
-
287
- output_tokens = math.ceil(0.75 * input_tokens)
288
-
289
- cost = (
290
- input_tokens * input_price_per_token
291
- + output_tokens * output_price_per_token
292
- )
293
-
294
- return {
295
- "input_tokens": input_tokens,
296
- "output_tokens": output_tokens,
297
- "cost_usd": cost,
298
- }
299
-
300
- def estimate_job_cost_from_external_prices(
301
- self, price_lookup: dict, iterations: int = 1
302
- ) -> dict:
303
- """
304
- Estimates the cost of a job according to the following assumptions:
305
-
306
- - 1 token = 4 characters.
307
- - For each prompt, output tokens = input tokens * 0.75, rounded up to the nearest integer.
308
-
309
- price_lookup is an external pricing dictionary.
310
- """
311
-
312
- import pandas as pd
313
-
314
- interviews = self.interviews()
315
- data = []
316
- for interview in interviews:
317
- invigilators = [
318
- interview._get_invigilator(question)
319
- for question in self.survey.questions
320
- ]
321
- for invigilator in invigilators:
322
- prompts = invigilator.get_prompts()
323
-
324
- # By this point, agent and scenario data has already been added to the prompts
325
- user_prompt = prompts["user_prompt"]
326
- system_prompt = prompts["system_prompt"]
327
- inference_service = invigilator.model._inference_service_
328
- model = invigilator.model.model
329
-
330
- prompt_cost = self.estimate_prompt_cost(
331
- system_prompt=system_prompt,
332
- user_prompt=user_prompt,
333
- price_lookup=price_lookup,
334
- inference_service=inference_service,
335
- model=model,
336
- )
337
-
338
- data.append(
339
- {
340
- "user_prompt": user_prompt,
341
- "system_prompt": system_prompt,
342
- "estimated_input_tokens": prompt_cost["input_tokens"],
343
- "estimated_output_tokens": prompt_cost["output_tokens"],
344
- "estimated_cost_usd": prompt_cost["cost_usd"],
345
- "inference_service": inference_service,
346
- "model": model,
347
- }
348
- )
349
-
350
- df = pd.DataFrame.from_records(data)
351
-
352
- df = (
353
- df.groupby(["inference_service", "model"])
354
- .agg(
355
- {
356
- "estimated_cost_usd": "sum",
357
- "estimated_input_tokens": "sum",
358
- "estimated_output_tokens": "sum",
359
- }
360
- )
361
- .reset_index()
362
- )
363
- df["estimated_cost_usd"] = df["estimated_cost_usd"] * iterations
364
- df["estimated_input_tokens"] = df["estimated_input_tokens"] * iterations
365
- df["estimated_output_tokens"] = df["estimated_output_tokens"] * iterations
366
-
367
- estimated_costs_by_model = df.to_dict("records")
368
-
369
- estimated_total_cost = sum(
370
- model["estimated_cost_usd"] for model in estimated_costs_by_model
371
- )
372
- estimated_total_input_tokens = sum(
373
- model["estimated_input_tokens"] for model in estimated_costs_by_model
374
- )
375
- estimated_total_output_tokens = sum(
376
- model["estimated_output_tokens"] for model in estimated_costs_by_model
377
- )
378
-
379
- output = {
380
- "estimated_total_cost_usd": estimated_total_cost,
381
- "estimated_total_input_tokens": estimated_total_input_tokens,
382
- "estimated_total_output_tokens": estimated_total_output_tokens,
383
- "model_costs": estimated_costs_by_model,
384
- }
385
-
386
- return output
387
-
388
- def estimate_job_cost(self, iterations: int = 1) -> dict:
389
- """
390
- Estimates the cost of a job according to the following assumptions:
391
-
392
- - 1 token = 4 characters.
393
- - For each prompt, output tokens = input tokens * 0.75, rounded up to the nearest integer.
394
-
395
- Fetches prices from Coop.
396
- """
397
- from edsl import Coop
398
-
399
- c = Coop()
400
- price_lookup = c.fetch_prices()
401
-
402
- return self.estimate_job_cost_from_external_prices(
403
- price_lookup=price_lookup, iterations=iterations
404
- )
405
-
406
- @staticmethod
407
- def compute_job_cost(job_results: "Results") -> float:
408
- """
409
- Computes the cost of a completed job in USD.
410
- """
411
- total_cost = 0
412
- for result in job_results:
413
- for key in result.raw_model_response:
414
- if key.endswith("_cost"):
415
- result_cost = result.raw_model_response[key]
416
-
417
- question_name = key.removesuffix("_cost")
418
- cache_used = result.cache_used_dict[question_name]
419
-
420
- if isinstance(result_cost, (int, float)) and not cache_used:
421
- total_cost += result_cost
422
-
423
- return total_cost
424
-
425
- @staticmethod
426
- def _get_container_class(object):
427
- from edsl.agents.AgentList import AgentList
428
- from edsl.agents.Agent import Agent
429
- from edsl.scenarios.Scenario import Scenario
430
- from edsl.scenarios.ScenarioList import ScenarioList
431
- from edsl.language_models.ModelList import ModelList
432
-
433
- if isinstance(object, Agent):
434
- return AgentList
435
- elif isinstance(object, Scenario):
436
- return ScenarioList
437
- elif isinstance(object, ModelList):
438
- return ModelList
439
- else:
440
- return list
441
-
442
- @staticmethod
443
- def _turn_args_to_list(args):
444
- """Return a list of the first argument if it is a sequence, otherwise returns a list of all the arguments.
445
-
446
- Example:
447
-
448
- >>> Jobs._turn_args_to_list([1,2,3])
449
- [1, 2, 3]
450
-
451
- """
452
-
453
- def did_user_pass_a_sequence(args):
454
- """Return True if the user passed a sequence, False otherwise.
455
-
456
- Example:
457
-
458
- >>> did_user_pass_a_sequence([1,2,3])
459
- True
460
-
461
- >>> did_user_pass_a_sequence(1)
462
- False
463
- """
464
- return len(args) == 1 and isinstance(args[0], Sequence)
465
-
466
- if did_user_pass_a_sequence(args):
467
- container_class = Jobs._get_container_class(args[0][0])
468
- return container_class(args[0])
469
- else:
470
- container_class = Jobs._get_container_class(args[0])
471
- return container_class(args)
472
-
473
- def _get_current_objects_of_this_type(
474
- self, object: Union["Agent", "Scenario", "LanguageModel"]
475
- ) -> tuple[list, str]:
476
- from edsl.agents.Agent import Agent
477
- from edsl.scenarios.Scenario import Scenario
478
- from edsl.language_models.LanguageModel import LanguageModel
479
-
480
- """Return the current objects of the same type as the first argument.
481
-
482
- >>> from edsl.jobs import Jobs
483
- >>> j = Jobs.example()
484
- >>> j._get_current_objects_of_this_type(j.agents[0])
485
- (AgentList([Agent(traits = {'status': 'Joyful'}), Agent(traits = {'status': 'Sad'})]), 'agents')
486
- """
487
- class_to_key = {
488
- Agent: "agents",
489
- Scenario: "scenarios",
490
- LanguageModel: "models",
491
- }
492
- for class_type in class_to_key:
493
- if isinstance(object, class_type) or issubclass(
494
- object.__class__, class_type
495
- ):
496
- key = class_to_key[class_type]
497
- break
498
- else:
499
- raise ValueError(
500
- f"First argument must be an Agent, Scenario, or LanguageModel, not {object}"
501
- )
502
- current_objects = getattr(self, key, None)
503
- return current_objects, key
504
-
505
- @staticmethod
506
- def _get_empty_container_object(object):
507
- from edsl import AgentList
508
- from edsl import Agent
509
- from edsl import Scenario
510
- from edsl import ScenarioList
511
-
512
- if isinstance(object, Agent):
513
- return AgentList([])
514
- elif isinstance(object, Scenario):
515
- return ScenarioList([])
516
- else:
517
- return []
518
-
519
- @staticmethod
520
- def _merge_objects(passed_objects, current_objects) -> list:
521
- """
522
- Combine all the existing objects with the new objects.
523
-
524
- For example, if the user passes in 3 agents,
525
- and there are 2 existing agents, this will create 6 new agents
526
-
527
- >>> Jobs(survey = [])._merge_objects([1,2,3], [4,5,6])
528
- [5, 6, 7, 6, 7, 8, 7, 8, 9]
529
- """
530
- new_objects = Jobs._get_empty_container_object(passed_objects[0])
531
- for current_object in current_objects:
532
- for new_object in passed_objects:
533
- new_objects.append(current_object + new_object)
534
- return new_objects
535
-
536
- def interviews(self) -> list[Interview]:
537
- """
538
- Return a list of :class:`edsl.jobs.interviews.Interview` objects.
539
-
540
- It returns one Interview for each combination of Agent, Scenario, and LanguageModel.
541
- If any of Agents, Scenarios, or LanguageModels are missing, it fills in with defaults.
542
-
543
- >>> from edsl.jobs import Jobs
544
- >>> j = Jobs.example()
545
- >>> len(j.interviews())
546
- 4
547
- >>> j.interviews()[0]
548
- Interview(agent = Agent(traits = {'status': 'Joyful'}), survey = Survey(...), scenario = Scenario({'period': 'morning'}), model = Model(...))
549
- """
550
- if hasattr(self, "_interviews"):
551
- return self._interviews
552
- else:
553
- return list(self._create_interviews())
554
-
555
- @classmethod
556
- def from_interviews(cls, interview_list):
557
- """Return a Jobs instance from a list of interviews.
558
-
559
- This is useful when you have, say, a list of failed interviews and you want to create
560
- a new job with only those interviews.
561
- """
562
- survey = interview_list[0].survey
563
- # get all the models
564
- models = list(set([interview.model for interview in interview_list]))
565
- jobs = cls(survey)
566
- jobs.models = models
567
- jobs._interviews = interview_list
568
- return jobs
569
-
570
- def _create_interviews(self) -> Generator[Interview, None, None]:
571
- """
572
- Generate interviews.
573
-
574
- Note that this sets the agents, model and scenarios if they have not been set. This is a side effect of the method.
575
- This is useful because a user can create a job without setting the agents, models, or scenarios, and the job will still run,
576
- with us filling in defaults.
577
-
578
-
579
- """
580
- # if no agents, models, or scenarios are set, set them to defaults
581
- from edsl.agents.Agent import Agent
582
- from edsl.language_models.registry import Model
583
- from edsl.scenarios.Scenario import Scenario
584
-
585
- self.agents = self.agents or [Agent()]
586
- self.models = self.models or [Model()]
587
- self.scenarios = self.scenarios or [Scenario()]
588
- for agent, scenario, model in product(self.agents, self.scenarios, self.models):
589
- yield Interview(
590
- survey=self.survey,
591
- agent=agent,
592
- scenario=scenario,
593
- model=model,
594
- skip_retry=self.skip_retry,
595
- raise_validation_errors=self.raise_validation_errors,
596
- )
597
-
598
- def create_bucket_collection(self) -> BucketCollection:
599
- """
600
- Create a collection of buckets for each model.
601
-
602
- These buckets are used to track API calls and token usage.
603
-
604
- >>> from edsl.jobs import Jobs
605
- >>> from edsl import Model
606
- >>> j = Jobs.example().by(Model(temperature = 1), Model(temperature = 0.5))
607
- >>> bc = j.create_bucket_collection()
608
- >>> bc
609
- BucketCollection(...)
610
- """
611
- bucket_collection = BucketCollection()
612
- for model in self.models:
613
- bucket_collection.add_model(model)
614
- return bucket_collection
615
-
616
- @property
617
- def bucket_collection(self) -> BucketCollection:
618
- """Return the bucket collection. If it does not exist, create it."""
619
- if self.__bucket_collection is None:
620
- self.__bucket_collection = self.create_bucket_collection()
621
- return self.__bucket_collection
622
-
623
- def html(self):
624
- """Return the HTML representations for each scenario"""
625
- links = []
626
- for index, scenario in enumerate(self.scenarios):
627
- links.append(
628
- self.survey.html(
629
- scenario=scenario, return_link=True, cta=f"Scenario {index}"
630
- )
631
- )
632
- return links
633
-
634
- def __hash__(self):
635
- """Allow the model to be used as a key in a dictionary.
636
-
637
- >>> from edsl.jobs import Jobs
638
- >>> hash(Jobs.example())
639
- 846655441787442972
640
-
641
- """
642
- from edsl.utilities.utilities import dict_hash
643
-
644
- return dict_hash(self._to_dict())
645
-
646
- def _output(self, message) -> None:
647
- """Check if a Job is verbose. If so, print the message."""
648
- if hasattr(self, "verbose") and self.verbose:
649
- print(message)
650
-
651
- def _check_parameters(self, strict=False, warn=False) -> None:
652
- """Check if the parameters in the survey and scenarios are consistent.
653
-
654
- >>> from edsl import QuestionFreeText
655
- >>> from edsl import Survey
656
- >>> from edsl import Scenario
657
- >>> q = QuestionFreeText(question_text = "{{poo}}", question_name = "ugly_question")
658
- >>> j = Jobs(survey = Survey(questions=[q]))
659
- >>> with warnings.catch_warnings(record=True) as w:
660
- ... j._check_parameters(warn = True)
661
- ... assert len(w) == 1
662
- ... assert issubclass(w[-1].category, UserWarning)
663
- ... assert "The following parameters are in the survey but not in the scenarios" in str(w[-1].message)
664
-
665
- >>> q = QuestionFreeText(question_text = "{{poo}}", question_name = "ugly_question")
666
- >>> s = Scenario({'plop': "A", 'poo': "B"})
667
- >>> j = Jobs(survey = Survey(questions=[q])).by(s)
668
- >>> j._check_parameters(strict = True)
669
- Traceback (most recent call last):
670
- ...
671
- ValueError: The following parameters are in the scenarios but not in the survey: {'plop'}
672
-
673
- >>> q = QuestionFreeText(question_text = "Hello", question_name = "ugly_question")
674
- >>> s = Scenario({'ugly_question': "B"})
675
- >>> j = Jobs(survey = Survey(questions=[q])).by(s)
676
- >>> j._check_parameters()
677
- Traceback (most recent call last):
678
- ...
679
- ValueError: The following names are in both the survey question_names and the scenario keys: {'ugly_question'}. This will create issues.
680
- """
681
- survey_parameters: set = self.survey.parameters
682
- scenario_parameters: set = self.scenarios.parameters
683
-
684
- msg0, msg1, msg2 = None, None, None
685
-
686
- # look for key issues
687
- if intersection := set(self.scenarios.parameters) & set(
688
- self.survey.question_names
689
- ):
690
- msg0 = f"The following names are in both the survey question_names and the scenario keys: {intersection}. This will create issues."
691
-
692
- raise ValueError(msg0)
693
-
694
- if in_survey_but_not_in_scenarios := survey_parameters - scenario_parameters:
695
- msg1 = f"The following parameters are in the survey but not in the scenarios: {in_survey_but_not_in_scenarios}"
696
- if in_scenarios_but_not_in_survey := scenario_parameters - survey_parameters:
697
- msg2 = f"The following parameters are in the scenarios but not in the survey: {in_scenarios_but_not_in_survey}"
698
-
699
- if msg1 or msg2:
700
- message = "\n".join(filter(None, [msg1, msg2]))
701
- if strict:
702
- raise ValueError(message)
703
- else:
704
- if warn:
705
- warnings.warn(message)
706
-
707
- if self.scenarios.has_jinja_braces:
708
- warnings.warn(
709
- "The scenarios have Jinja braces ({{ and }}). Converting to '<<' and '>>'. If you want a different conversion, use the convert_jinja_braces method first to modify the scenario."
710
- )
711
- self.scenarios = self.scenarios.convert_jinja_braces()
712
-
713
- @property
714
- def skip_retry(self):
715
- if not hasattr(self, "_skip_retry"):
716
- return False
717
- return self._skip_retry
718
-
719
- @property
720
- def raise_validation_errors(self):
721
- if not hasattr(self, "_raise_validation_errors"):
722
- return False
723
- return self._raise_validation_errors
724
-
725
- def create_remote_inference_job(
726
- self,
727
- iterations: int = 1,
728
- remote_inference_description: Optional[str] = None,
729
- remote_inference_results_visibility: Optional[VisibilityType] = "unlisted",
730
- verbose=False,
731
- ):
732
- """ """
733
- from edsl.coop.coop import Coop
734
-
735
- coop = Coop()
736
- self._output("Remote inference activated. Sending job to server...")
737
- remote_job_creation_data = coop.remote_inference_create(
738
- self,
739
- description=remote_inference_description,
740
- status="queued",
741
- iterations=iterations,
742
- initial_results_visibility=remote_inference_results_visibility,
743
- )
744
- job_uuid = remote_job_creation_data.get("uuid")
745
- if self.verbose:
746
- print(f"Job sent to server. (Job uuid={job_uuid}).")
747
- return remote_job_creation_data
748
-
749
- @staticmethod
750
- def check_status(job_uuid):
751
- from edsl.coop.coop import Coop
752
-
753
- coop = Coop()
754
- return coop.remote_inference_get(job_uuid)
755
-
756
- def poll_remote_inference_job(
757
- self, remote_job_creation_data: dict, verbose=False, poll_interval=5
758
- ) -> Union[Results, None]:
759
- from edsl.coop.coop import Coop
760
- import time
761
- from datetime import datetime
762
- from edsl.config import CONFIG
763
-
764
- expected_parrot_url = CONFIG.get("EXPECTED_PARROT_URL")
765
-
766
- job_uuid = remote_job_creation_data.get("uuid")
767
-
768
- coop = Coop()
769
- job_in_queue = True
770
- while job_in_queue:
771
- remote_job_data = coop.remote_inference_get(job_uuid)
772
- status = remote_job_data.get("status")
773
- if status == "cancelled":
774
- if self.verbose:
775
- print("\r" + " " * 80 + "\r", end="")
776
- print("Job cancelled by the user.")
777
- print(
778
- f"See {expected_parrot_url}/home/remote-inference for more details."
779
- )
780
- return None
781
- elif status == "failed":
782
- if self.verbose:
783
- print("\r" + " " * 80 + "\r", end="")
784
- print("Job failed.")
785
- print(
786
- f"See {expected_parrot_url}/home/remote-inference for more details."
787
- )
788
- return None
789
- elif status == "completed":
790
- results_uuid = remote_job_data.get("results_uuid")
791
- results = coop.get(results_uuid, expected_object_type="results")
792
- if self.verbose:
793
- print("\r" + " " * 80 + "\r", end="")
794
- url = f"{expected_parrot_url}/content/{results_uuid}"
795
- print(f"Job completed and Results stored on Coop: {url}.")
796
- return results
797
- else:
798
- duration = poll_interval
799
- time_checked = datetime.now().strftime("%Y-%m-%d %I:%M:%S %p")
800
- frames = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
801
- start_time = time.time()
802
- i = 0
803
- while time.time() - start_time < duration:
804
- if self.verbose:
805
- print(
806
- f"\r{frames[i % len(frames)]} Job status: {status} - last update: {time_checked}",
807
- end="",
808
- flush=True,
809
- )
810
- time.sleep(0.1)
811
- i += 1
812
-
813
- def use_remote_inference(self, disable_remote_inference: bool) -> bool:
814
- if disable_remote_inference:
815
- return False
816
- if not disable_remote_inference:
817
- try:
818
- from edsl import Coop
819
-
820
- user_edsl_settings = Coop().edsl_settings
821
- return user_edsl_settings.get("remote_inference", False)
822
- except requests.ConnectionError:
823
- pass
824
- except CoopServerResponseError as e:
825
- pass
826
-
827
- return False
828
-
829
- def use_remote_cache(self, disable_remote_cache: bool) -> bool:
830
- if disable_remote_cache:
831
- return False
832
- if not disable_remote_cache:
833
- try:
834
- from edsl import Coop
835
-
836
- user_edsl_settings = Coop().edsl_settings
837
- return user_edsl_settings.get("remote_caching", False)
838
- except requests.ConnectionError:
839
- pass
840
- except CoopServerResponseError as e:
841
- pass
842
-
843
- return False
844
-
845
- def check_api_keys(self) -> None:
846
- from edsl import Model
847
-
848
- for model in self.models + [Model()]:
849
- if not model.has_valid_api_key():
850
- raise MissingAPIKeyError(
851
- model_name=str(model.model),
852
- inference_service=model._inference_service_,
853
- )
854
-
855
- def get_missing_api_keys(self) -> set:
856
- """
857
- Returns a list of the api keys that a user needs to run this job, but does not currently have in their .env file.
858
- """
859
-
860
- missing_api_keys = set()
861
-
862
- from edsl import Model
863
- from edsl.enums import service_to_api_keyname
864
-
865
- for model in self.models + [Model()]:
866
- if not model.has_valid_api_key():
867
- key_name = service_to_api_keyname.get(
868
- model._inference_service_, "NOT FOUND"
869
- )
870
- missing_api_keys.add(key_name)
871
-
872
- return missing_api_keys
873
-
874
- def user_has_all_model_keys(self):
875
- """
876
- Returns True if the user has all model keys required to run their job.
877
-
878
- Otherwise, returns False.
879
- """
880
-
881
- try:
882
- self.check_api_keys()
883
- return True
884
- except MissingAPIKeyError:
885
- return False
886
- except Exception:
887
- raise
888
-
889
- def user_has_ep_api_key(self) -> bool:
890
- """
891
- Returns True if the user has an EXPECTED_PARROT_API_KEY in their env.
892
-
893
- Otherwise, returns False.
894
- """
895
-
896
- import os
897
-
898
- coop_api_key = os.getenv("EXPECTED_PARROT_API_KEY")
899
-
900
- if coop_api_key is not None:
901
- return True
902
- else:
903
- return False
904
-
905
- def needs_external_llms(self) -> bool:
906
- """
907
- Returns True if the job needs external LLMs to run.
908
-
909
- Otherwise, returns False.
910
- """
911
- # These cases are necessary to skip the API key check during doctests
912
-
913
- # Accounts for Results.example()
914
- all_agents_answer_questions_directly = len(self.agents) > 0 and all(
915
- [hasattr(a, "answer_question_directly") for a in self.agents]
916
- )
917
-
918
- # Accounts for InterviewExceptionEntry.example()
919
- only_model_is_test = set([m.model for m in self.models]) == set(["test"])
920
-
921
- # Accounts for Survey.__call__
922
- all_questions_are_functional = set(
923
- [q.question_type for q in self.survey.questions]
924
- ) == set(["functional"])
925
-
926
- if (
927
- all_agents_answer_questions_directly
928
- or only_model_is_test
929
- or all_questions_are_functional
930
- ):
931
- return False
932
- else:
933
- return True
934
-
935
- def run(
936
- self,
937
- n: int = 1,
938
- progress_bar: bool = False,
939
- stop_on_exception: bool = False,
940
- cache: Union[Cache, bool] = None,
941
- check_api_keys: bool = False,
942
- sidecar_model: Optional[LanguageModel] = None,
943
- verbose: bool = False,
944
- print_exceptions=True,
945
- remote_cache_description: Optional[str] = None,
946
- remote_inference_description: Optional[str] = None,
947
- remote_inference_results_visibility: Optional[
948
- Literal["private", "public", "unlisted"]
949
- ] = "unlisted",
950
- skip_retry: bool = False,
951
- raise_validation_errors: bool = False,
952
- disable_remote_cache: bool = False,
953
- disable_remote_inference: bool = False,
954
- ) -> Results:
955
- """
956
- Runs the Job: conducts Interviews and returns their results.
957
-
958
- :param n: How many times to run each interview
959
- :param progress_bar: Whether to show a progress bar
960
- :param stop_on_exception: Stops the job if an exception is raised
961
- :param cache: A Cache object to store results
962
- :param check_api_keys: Raises an error if API keys are invalid
963
- :param verbose: Prints extra messages
964
- :param remote_cache_description: Specifies a description for this group of entries in the remote cache
965
- :param remote_inference_description: Specifies a description for the remote inference job
966
- :param remote_inference_results_visibility: The initial visibility of the Results object on Coop. This will only be used for remote jobs!
967
- :param disable_remote_cache: If True, the job will not use remote cache. This only works for local jobs!
968
- :param disable_remote_inference: If True, the job will not use remote inference
969
- """
970
- from edsl.coop.coop import Coop
971
-
972
- self._check_parameters()
973
- self._skip_retry = skip_retry
974
- self._raise_validation_errors = raise_validation_errors
975
-
976
- self.verbose = verbose
977
-
978
- if (
979
- not self.user_has_all_model_keys()
980
- and not self.user_has_ep_api_key()
981
- and self.needs_external_llms()
982
- ):
983
- import secrets
984
- from dotenv import load_dotenv
985
- from edsl import CONFIG
986
- from edsl.coop.coop import Coop
987
- from edsl.utilities.utilities import write_api_key_to_env
988
-
989
- missing_api_keys = self.get_missing_api_keys()
990
-
991
- edsl_auth_token = secrets.token_urlsafe(16)
992
-
993
- print("You're missing some of the API keys needed to run this job:")
994
- for api_key in missing_api_keys:
995
- print(f" 🔑 {api_key}")
996
- print(
997
- "\nYou can either add the missing keys to your .env file, or use remote inference."
998
- )
999
- print("Remote inference allows you to run jobs on our server.")
1000
- print("\n🚀 To use remote inference, sign up at the following link:")
1001
-
1002
- coop = Coop()
1003
- coop._display_login_url(edsl_auth_token=edsl_auth_token)
1004
-
1005
- print(
1006
- "\nOnce you log in, we will automatically retrieve your Expected Parrot API key and continue your job remotely."
1007
- )
1008
-
1009
- api_key = coop._poll_for_api_key(edsl_auth_token)
1010
-
1011
- if api_key is None:
1012
- print("\nTimed out waiting for login. Please try again.")
1013
- return
1014
-
1015
- write_api_key_to_env(api_key)
1016
- print("✨ API key retrieved and written to .env file.\n")
1017
-
1018
- # Retrieve API key so we can continue running the job
1019
- load_dotenv()
1020
-
1021
- if remote_inference := self.use_remote_inference(disable_remote_inference):
1022
- remote_job_creation_data = self.create_remote_inference_job(
1023
- iterations=n,
1024
- remote_inference_description=remote_inference_description,
1025
- remote_inference_results_visibility=remote_inference_results_visibility,
1026
- )
1027
- results = self.poll_remote_inference_job(remote_job_creation_data)
1028
- if results is None:
1029
- self._output("Job failed.")
1030
- return results
1031
-
1032
- if check_api_keys:
1033
- self.check_api_keys()
1034
-
1035
- # handle cache
1036
- if cache is None or cache is True:
1037
- from edsl.data.CacheHandler import CacheHandler
1038
-
1039
- cache = CacheHandler().get_cache()
1040
- if cache is False:
1041
- from edsl.data.Cache import Cache
1042
-
1043
- cache = Cache()
1044
-
1045
- remote_cache = self.use_remote_cache(disable_remote_cache)
1046
- with RemoteCacheSync(
1047
- coop=Coop(),
1048
- cache=cache,
1049
- output_func=self._output,
1050
- remote_cache=remote_cache,
1051
- remote_cache_description=remote_cache_description,
1052
- ) as r:
1053
- results = self._run_local(
1054
- n=n,
1055
- progress_bar=progress_bar,
1056
- cache=cache,
1057
- stop_on_exception=stop_on_exception,
1058
- sidecar_model=sidecar_model,
1059
- print_exceptions=print_exceptions,
1060
- raise_validation_errors=raise_validation_errors,
1061
- )
1062
-
1063
- results.cache = cache.new_entries_cache()
1064
- return results
1065
-
1066
- async def create_and_poll_remote_job(
1067
- self,
1068
- iterations: int = 1,
1069
- remote_inference_description: Optional[str] = None,
1070
- remote_inference_results_visibility: Optional[
1071
- Literal["private", "public", "unlisted"]
1072
- ] = "unlisted",
1073
- ) -> Union[Results, None]:
1074
- """
1075
- Creates and polls a remote inference job asynchronously.
1076
- Reuses existing synchronous methods but runs them in an async context.
1077
-
1078
- :param iterations: Number of times to run each interview
1079
- :param remote_inference_description: Optional description for the remote job
1080
- :param remote_inference_results_visibility: Visibility setting for results
1081
- :return: Results object if successful, None if job fails or is cancelled
1082
- """
1083
- import asyncio
1084
- from functools import partial
1085
-
1086
- # Create job using existing method
1087
- loop = asyncio.get_event_loop()
1088
- remote_job_creation_data = await loop.run_in_executor(
1089
- None,
1090
- partial(
1091
- self.create_remote_inference_job,
1092
- iterations=iterations,
1093
- remote_inference_description=remote_inference_description,
1094
- remote_inference_results_visibility=remote_inference_results_visibility,
1095
- ),
1096
- )
1097
-
1098
- # Poll using existing method but with async sleep
1099
- return await loop.run_in_executor(
1100
- None, partial(self.poll_remote_inference_job, remote_job_creation_data)
1101
- )
1102
-
1103
- async def run_async(
1104
- self,
1105
- cache=None,
1106
- n=1,
1107
- disable_remote_inference: bool = False,
1108
- remote_inference_description: Optional[str] = None,
1109
- remote_inference_results_visibility: Optional[
1110
- Literal["private", "public", "unlisted"]
1111
- ] = "unlisted",
1112
- **kwargs,
1113
- ):
1114
- """Run the job asynchronously, either locally or remotely.
1115
-
1116
- :param cache: Cache object or boolean
1117
- :param n: Number of iterations
1118
- :param disable_remote_inference: If True, forces local execution
1119
- :param remote_inference_description: Description for remote jobs
1120
- :param remote_inference_results_visibility: Visibility setting for remote results
1121
- :param kwargs: Additional arguments passed to local execution
1122
- :return: Results object
1123
- """
1124
- # Check if we should use remote inference
1125
- if remote_inference := self.use_remote_inference(disable_remote_inference):
1126
- results = await self.create_and_poll_remote_job(
1127
- iterations=n,
1128
- remote_inference_description=remote_inference_description,
1129
- remote_inference_results_visibility=remote_inference_results_visibility,
1130
- )
1131
- if results is None:
1132
- self._output("Job failed.")
1133
- return results
1134
-
1135
- # If not using remote inference, run locally with async
1136
- return await JobsRunnerAsyncio(self).run_async(cache=cache, n=n, **kwargs)
1137
-
1138
- def _run_local(self, *args, **kwargs):
1139
- """Run the job locally."""
1140
-
1141
- results = JobsRunnerAsyncio(self).run(*args, **kwargs)
1142
- return results
1143
-
1144
- def all_question_parameters(self):
1145
- """Return all the fields in the questions in the survey.
1146
- >>> from edsl.jobs import Jobs
1147
- >>> Jobs.example().all_question_parameters()
1148
- {'period'}
1149
- """
1150
- return set.union(*[question.parameters for question in self.survey.questions])
1151
-
1152
- #######################
1153
- # Dunder methods
1154
- #######################
1155
- def print(self):
1156
- from rich import print_json
1157
- import json
1158
-
1159
- print_json(json.dumps(self.to_dict()))
1160
-
1161
- def __repr__(self) -> str:
1162
- """Return an eval-able string representation of the Jobs instance."""
1163
- return f"Jobs(survey={repr(self.survey)}, agents={repr(self.agents)}, models={repr(self.models)}, scenarios={repr(self.scenarios)})"
1164
-
1165
- def _repr_html_(self) -> str:
1166
- from rich import print_json
1167
- import json
1168
-
1169
- print_json(json.dumps(self.to_dict()))
1170
-
1171
- def __len__(self) -> int:
1172
- """Return the maximum number of questions that will be asked while running this job.
1173
- Note that this is the maximum number of questions, not the actual number of questions that will be asked, as some questions may be skipped.
1174
-
1175
- >>> from edsl.jobs import Jobs
1176
- >>> len(Jobs.example())
1177
- 8
1178
- """
1179
- number_of_questions = (
1180
- len(self.agents or [1])
1181
- * len(self.scenarios or [1])
1182
- * len(self.models or [1])
1183
- * len(self.survey)
1184
- )
1185
- return number_of_questions
1186
-
1187
- #######################
1188
- # Serialization methods
1189
- #######################
1190
-
1191
- def _to_dict(self):
1192
- return {
1193
- "survey": self.survey._to_dict(),
1194
- "agents": [agent._to_dict() for agent in self.agents],
1195
- "models": [model._to_dict() for model in self.models],
1196
- "scenarios": [scenario._to_dict() for scenario in self.scenarios],
1197
- }
1198
-
1199
- @add_edsl_version
1200
- def to_dict(self) -> dict:
1201
- """Convert the Jobs instance to a dictionary."""
1202
- return self._to_dict()
1203
-
1204
- @classmethod
1205
- @remove_edsl_version
1206
- def from_dict(cls, data: dict) -> Jobs:
1207
- """Creates a Jobs instance from a dictionary."""
1208
- from edsl import Survey
1209
- from edsl.agents.Agent import Agent
1210
- from edsl.language_models.LanguageModel import LanguageModel
1211
- from edsl.scenarios.Scenario import Scenario
1212
-
1213
- return cls(
1214
- survey=Survey.from_dict(data["survey"]),
1215
- agents=[Agent.from_dict(agent) for agent in data["agents"]],
1216
- models=[LanguageModel.from_dict(model) for model in data["models"]],
1217
- scenarios=[Scenario.from_dict(scenario) for scenario in data["scenarios"]],
1218
- )
1219
-
1220
- def __eq__(self, other: Jobs) -> bool:
1221
- """Return True if the Jobs instance is equal to another Jobs instance.
1222
-
1223
- >>> from edsl.jobs import Jobs
1224
- >>> Jobs.example() == Jobs.example()
1225
- True
1226
-
1227
- """
1228
- return self.to_dict() == other.to_dict()
1229
-
1230
- #######################
1231
- # Example methods
1232
- #######################
1233
- @classmethod
1234
- def example(
1235
- cls,
1236
- throw_exception_probability: float = 0.0,
1237
- randomize: bool = False,
1238
- test_model=False,
1239
- ) -> Jobs:
1240
- """Return an example Jobs instance.
1241
-
1242
- :param throw_exception_probability: the probability that an exception will be thrown when answering a question. This is useful for testing error handling.
1243
- :param randomize: whether to randomize the job by adding a random string to the period
1244
- :param test_model: whether to use a test model
1245
-
1246
- >>> Jobs.example()
1247
- Jobs(...)
1248
-
1249
- """
1250
- import random
1251
- from uuid import uuid4
1252
- from edsl.questions import QuestionMultipleChoice
1253
- from edsl.agents.Agent import Agent
1254
- from edsl.scenarios.Scenario import Scenario
1255
-
1256
- addition = "" if not randomize else str(uuid4())
1257
-
1258
- if test_model:
1259
- from edsl.language_models import LanguageModel
1260
-
1261
- m = LanguageModel.example(test_model=True)
1262
-
1263
- # (status, question, period)
1264
- agent_answers = {
1265
- ("Joyful", "how_feeling", "morning"): "OK",
1266
- ("Joyful", "how_feeling", "afternoon"): "Great",
1267
- ("Joyful", "how_feeling_yesterday", "morning"): "Great",
1268
- ("Joyful", "how_feeling_yesterday", "afternoon"): "Good",
1269
- ("Sad", "how_feeling", "morning"): "Terrible",
1270
- ("Sad", "how_feeling", "afternoon"): "OK",
1271
- ("Sad", "how_feeling_yesterday", "morning"): "OK",
1272
- ("Sad", "how_feeling_yesterday", "afternoon"): "Terrible",
1273
- }
1274
-
1275
- def answer_question_directly(self, question, scenario):
1276
- """Return the answer to a question. This is a method that can be added to an agent."""
1277
-
1278
- if random.random() < throw_exception_probability:
1279
- raise Exception("Error!")
1280
- return agent_answers[
1281
- (self.traits["status"], question.question_name, scenario["period"])
1282
- ]
1283
-
1284
- sad_agent = Agent(traits={"status": "Sad"})
1285
- joy_agent = Agent(traits={"status": "Joyful"})
1286
-
1287
- sad_agent.add_direct_question_answering_method(answer_question_directly)
1288
- joy_agent.add_direct_question_answering_method(answer_question_directly)
1289
-
1290
- q1 = QuestionMultipleChoice(
1291
- question_text="How are you this {{ period }}?",
1292
- question_options=["Good", "Great", "OK", "Terrible"],
1293
- question_name="how_feeling",
1294
- )
1295
- q2 = QuestionMultipleChoice(
1296
- question_text="How were you feeling yesterday {{ period }}?",
1297
- question_options=["Good", "Great", "OK", "Terrible"],
1298
- question_name="how_feeling_yesterday",
1299
- )
1300
- from edsl import Survey, ScenarioList
1301
-
1302
- base_survey = Survey(questions=[q1, q2])
1303
-
1304
- scenario_list = ScenarioList(
1305
- [
1306
- Scenario({"period": f"morning{addition}"}),
1307
- Scenario({"period": "afternoon"}),
1308
- ]
1309
- )
1310
- if test_model:
1311
- job = base_survey.by(m).by(scenario_list).by(joy_agent, sad_agent)
1312
- else:
1313
- job = base_survey.by(scenario_list).by(joy_agent, sad_agent)
1314
-
1315
- return job
1316
-
1317
- def rich_print(self):
1318
- """Print a rich representation of the Jobs instance."""
1319
- from rich.table import Table
1320
-
1321
- table = Table(title="Jobs")
1322
- table.add_column("Jobs")
1323
- table.add_row(self.survey.rich_print())
1324
- return table
1325
-
1326
- def code(self):
1327
- """Return the code to create this instance."""
1328
- raise NotImplementedError
1329
-
1330
-
1331
- def main():
1332
- """Run the module's doctests."""
1333
- from edsl.jobs import Jobs
1334
- from edsl.data.Cache import Cache
1335
-
1336
- job = Jobs.example()
1337
- len(job) == 8
1338
- results = job.run(cache=Cache())
1339
- len(results) == 8
1340
- results
1341
-
1342
-
1343
- if __name__ == "__main__":
1344
- """Run the module's doctests."""
1345
- import doctest
1346
-
1347
- doctest.testmod(optionflags=doctest.ELLIPSIS)
1
+ # """The Jobs class is a collection of agents, scenarios and models and one survey."""
2
+ from __future__ import annotations
3
+ import warnings
4
+ import requests
5
+ from itertools import product
6
+ from typing import Literal, Optional, Union, Sequence, Generator
7
+
8
+ from edsl.Base import Base
9
+
10
+ from edsl.exceptions import MissingAPIKeyError
11
+ from edsl.jobs.buckets.BucketCollection import BucketCollection
12
+ from edsl.jobs.interviews.Interview import Interview
13
+ from edsl.jobs.runners.JobsRunnerAsyncio import JobsRunnerAsyncio
14
+ from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
15
+
16
+ from edsl.data.RemoteCacheSync import RemoteCacheSync
17
+ from edsl.exceptions.coop import CoopServerResponseError
18
+
19
+
20
+ class Jobs(Base):
21
+ """
22
+ A collection of agents, scenarios and models and one survey.
23
+ The actual running of a job is done by a `JobsRunner`, which is a subclass of `JobsRunner`.
24
+ The `JobsRunner` is chosen by the user, and is stored in the `jobs_runner_name` attribute.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ survey: "Survey",
30
+ agents: Optional[list["Agent"]] = None,
31
+ models: Optional[list["LanguageModel"]] = None,
32
+ scenarios: Optional[list["Scenario"]] = None,
33
+ ):
34
+ """Initialize a Jobs instance.
35
+
36
+ :param survey: the survey to be used in the job
37
+ :param agents: a list of agents
38
+ :param models: a list of models
39
+ :param scenarios: a list of scenarios
40
+ """
41
+ self.survey = survey
42
+ self.agents: "AgentList" = agents
43
+ self.scenarios: "ScenarioList" = scenarios
44
+ self.models = models
45
+
46
+ self.__bucket_collection = None
47
+
48
+ # these setters and getters are used to ensure that the agents, models, and scenarios are stored as AgentList, ModelList, and ScenarioList objects
49
+
50
+ @property
51
+ def models(self):
52
+ return self._models
53
+
54
+ @models.setter
55
+ def models(self, value):
56
+ from edsl import ModelList
57
+
58
+ if value:
59
+ if not isinstance(value, ModelList):
60
+ self._models = ModelList(value)
61
+ else:
62
+ self._models = value
63
+ else:
64
+ self._models = ModelList([])
65
+
66
+ @property
67
+ def agents(self):
68
+ return self._agents
69
+
70
+ @agents.setter
71
+ def agents(self, value):
72
+ from edsl import AgentList
73
+
74
+ if value:
75
+ if not isinstance(value, AgentList):
76
+ self._agents = AgentList(value)
77
+ else:
78
+ self._agents = value
79
+ else:
80
+ self._agents = AgentList([])
81
+
82
+ @property
83
+ def scenarios(self):
84
+ return self._scenarios
85
+
86
+ @scenarios.setter
87
+ def scenarios(self, value):
88
+ from edsl import ScenarioList
89
+
90
+ if value:
91
+ if not isinstance(value, ScenarioList):
92
+ self._scenarios = ScenarioList(value)
93
+ else:
94
+ self._scenarios = value
95
+ else:
96
+ self._scenarios = ScenarioList([])
97
+
98
+ def by(
99
+ self,
100
+ *args: Union[
101
+ "Agent",
102
+ "Scenario",
103
+ "LanguageModel",
104
+ Sequence[Union["Agent", "Scenario", "LanguageModel"]],
105
+ ],
106
+ ) -> Jobs:
107
+ """
108
+ Add Agents, Scenarios and LanguageModels to a job. If no objects of this type exist in the Jobs instance, it stores the new objects as a list in the corresponding attribute. Otherwise, it combines the new objects with existing objects using the object's `__add__` method.
109
+
110
+ This 'by' is intended to create a fluent interface.
111
+
112
+ >>> from edsl import Survey
113
+ >>> from edsl import QuestionFreeText
114
+ >>> q = QuestionFreeText(question_name="name", question_text="What is your name?")
115
+ >>> j = Jobs(survey = Survey(questions=[q]))
116
+ >>> j
117
+ Jobs(survey=Survey(...), agents=AgentList([]), models=ModelList([]), scenarios=ScenarioList([]))
118
+ >>> from edsl import Agent; a = Agent(traits = {"status": "Sad"})
119
+ >>> j.by(a).agents
120
+ AgentList([Agent(traits = {'status': 'Sad'})])
121
+
122
+ :param args: objects or a sequence (list, tuple, ...) of objects of the same type
123
+
124
+ Notes:
125
+ - all objects must implement the 'get_value', 'set_value', and `__add__` methods
126
+ - agents: traits of new agents are combined with traits of existing agents. New and existing agents should not have overlapping traits, and do not increase the # agents in the instance
127
+ - scenarios: traits of new scenarios are combined with traits of old existing. New scenarios will overwrite overlapping traits, and do not increase the number of scenarios in the instance
128
+ - models: new models overwrite old models.
129
+ """
130
+ passed_objects = self._turn_args_to_list(
131
+ args
132
+ ) # objects can also be passed comma-separated
133
+
134
+ current_objects, objects_key = self._get_current_objects_of_this_type(
135
+ passed_objects[0]
136
+ )
137
+
138
+ if not current_objects:
139
+ new_objects = passed_objects
140
+ else:
141
+ new_objects = self._merge_objects(passed_objects, current_objects)
142
+
143
+ setattr(self, objects_key, new_objects) # update the job
144
+ return self
145
+
146
+ def prompts(self) -> "Dataset":
147
+ """Return a Dataset of prompts that will be used.
148
+
149
+
150
+ >>> from edsl.jobs import Jobs
151
+ >>> Jobs.example().prompts()
152
+ Dataset(...)
153
+ """
154
+ from edsl import Coop
155
+
156
+ c = Coop()
157
+ price_lookup = c.fetch_prices()
158
+
159
+ interviews = self.interviews()
160
+ # data = []
161
+ interview_indices = []
162
+ question_names = []
163
+ user_prompts = []
164
+ system_prompts = []
165
+ scenario_indices = []
166
+ agent_indices = []
167
+ models = []
168
+ costs = []
169
+ from edsl.results.Dataset import Dataset
170
+
171
+ for interview_index, interview in enumerate(interviews):
172
+ invigilators = [
173
+ interview._get_invigilator(question)
174
+ for question in self.survey.questions
175
+ ]
176
+ for _, invigilator in enumerate(invigilators):
177
+ prompts = invigilator.get_prompts()
178
+ user_prompt = prompts["user_prompt"]
179
+ system_prompt = prompts["system_prompt"]
180
+ user_prompts.append(user_prompt)
181
+ system_prompts.append(system_prompt)
182
+ agent_index = self.agents.index(invigilator.agent)
183
+ agent_indices.append(agent_index)
184
+ interview_indices.append(interview_index)
185
+ scenario_index = self.scenarios.index(invigilator.scenario)
186
+ scenario_indices.append(scenario_index)
187
+ models.append(invigilator.model.model)
188
+ question_names.append(invigilator.question.question_name)
189
+
190
+ prompt_cost = self.estimate_prompt_cost(
191
+ system_prompt=system_prompt,
192
+ user_prompt=user_prompt,
193
+ price_lookup=price_lookup,
194
+ inference_service=invigilator.model._inference_service_,
195
+ model=invigilator.model.model,
196
+ )
197
+ costs.append(prompt_cost["cost_usd"])
198
+
199
+ d = Dataset(
200
+ [
201
+ {"user_prompt": user_prompts},
202
+ {"system_prompt": system_prompts},
203
+ {"interview_index": interview_indices},
204
+ {"question_name": question_names},
205
+ {"scenario_index": scenario_indices},
206
+ {"agent_index": agent_indices},
207
+ {"model": models},
208
+ {"estimated_cost": costs},
209
+ ]
210
+ )
211
+ return d
212
+
213
+ def show_prompts(self, all=False, max_rows: Optional[int] = None) -> None:
214
+ """Print the prompts."""
215
+ if all:
216
+ self.prompts().to_scenario_list().print(format="rich", max_rows=max_rows)
217
+ else:
218
+ self.prompts().select(
219
+ "user_prompt", "system_prompt"
220
+ ).to_scenario_list().print(format="rich", max_rows=max_rows)
221
+
222
+ @staticmethod
223
+ def estimate_prompt_cost(
224
+ system_prompt: str,
225
+ user_prompt: str,
226
+ price_lookup: dict,
227
+ inference_service: str,
228
+ model: str,
229
+ ) -> dict:
230
+ """Estimates the cost of a prompt. Takes piping into account."""
231
+ import math
232
+
233
+ def get_piping_multiplier(prompt: str):
234
+ """Returns 2 if a prompt includes Jinja braces, and 1 otherwise."""
235
+
236
+ if "{{" in prompt and "}}" in prompt:
237
+ return 2
238
+ return 1
239
+
240
+ # Look up prices per token
241
+ key = (inference_service, model)
242
+
243
+ try:
244
+ relevant_prices = price_lookup[key]
245
+
246
+ service_input_token_price = float(
247
+ relevant_prices["input"]["service_stated_token_price"]
248
+ )
249
+ service_input_token_qty = float(
250
+ relevant_prices["input"]["service_stated_token_qty"]
251
+ )
252
+ input_price_per_token = service_input_token_price / service_input_token_qty
253
+
254
+ service_output_token_price = float(
255
+ relevant_prices["output"]["service_stated_token_price"]
256
+ )
257
+ service_output_token_qty = float(
258
+ relevant_prices["output"]["service_stated_token_qty"]
259
+ )
260
+ output_price_per_token = (
261
+ service_output_token_price / service_output_token_qty
262
+ )
263
+
264
+ except KeyError:
265
+ # A KeyError is likely to occur if we cannot retrieve prices (the price_lookup dict is empty)
266
+ # Use a sensible default
267
+
268
+ import warnings
269
+
270
+ warnings.warn(
271
+ "Price data could not be retrieved. Using default estimates for input and output token prices. Input: $0.15 / 1M tokens; Output: $0.60 / 1M tokens"
272
+ )
273
+ input_price_per_token = 0.00000015 # $0.15 / 1M tokens
274
+ output_price_per_token = 0.00000060 # $0.60 / 1M tokens
275
+
276
+ # Compute the number of characters (double if the question involves piping)
277
+ user_prompt_chars = len(str(user_prompt)) * get_piping_multiplier(
278
+ str(user_prompt)
279
+ )
280
+ system_prompt_chars = len(str(system_prompt)) * get_piping_multiplier(
281
+ str(system_prompt)
282
+ )
283
+
284
+ # Convert into tokens (1 token approx. equals 4 characters)
285
+ input_tokens = (user_prompt_chars + system_prompt_chars) // 4
286
+
287
+ output_tokens = math.ceil(0.75 * input_tokens)
288
+
289
+ cost = (
290
+ input_tokens * input_price_per_token
291
+ + output_tokens * output_price_per_token
292
+ )
293
+
294
+ return {
295
+ "input_tokens": input_tokens,
296
+ "output_tokens": output_tokens,
297
+ "cost_usd": cost,
298
+ }
299
+
300
+ def estimate_job_cost_from_external_prices(
301
+ self, price_lookup: dict, iterations: int = 1
302
+ ) -> dict:
303
+ """
304
+ Estimates the cost of a job according to the following assumptions:
305
+
306
+ - 1 token = 4 characters.
307
+ - For each prompt, output tokens = input tokens * 0.75, rounded up to the nearest integer.
308
+
309
+ price_lookup is an external pricing dictionary.
310
+ """
311
+
312
+ import pandas as pd
313
+
314
+ interviews = self.interviews()
315
+ data = []
316
+ for interview in interviews:
317
+ invigilators = [
318
+ interview._get_invigilator(question)
319
+ for question in self.survey.questions
320
+ ]
321
+ for invigilator in invigilators:
322
+ prompts = invigilator.get_prompts()
323
+
324
+ # By this point, agent and scenario data has already been added to the prompts
325
+ user_prompt = prompts["user_prompt"]
326
+ system_prompt = prompts["system_prompt"]
327
+ inference_service = invigilator.model._inference_service_
328
+ model = invigilator.model.model
329
+
330
+ prompt_cost = self.estimate_prompt_cost(
331
+ system_prompt=system_prompt,
332
+ user_prompt=user_prompt,
333
+ price_lookup=price_lookup,
334
+ inference_service=inference_service,
335
+ model=model,
336
+ )
337
+
338
+ data.append(
339
+ {
340
+ "user_prompt": user_prompt,
341
+ "system_prompt": system_prompt,
342
+ "estimated_input_tokens": prompt_cost["input_tokens"],
343
+ "estimated_output_tokens": prompt_cost["output_tokens"],
344
+ "estimated_cost_usd": prompt_cost["cost_usd"],
345
+ "inference_service": inference_service,
346
+ "model": model,
347
+ }
348
+ )
349
+
350
+ df = pd.DataFrame.from_records(data)
351
+
352
+ df = (
353
+ df.groupby(["inference_service", "model"])
354
+ .agg(
355
+ {
356
+ "estimated_cost_usd": "sum",
357
+ "estimated_input_tokens": "sum",
358
+ "estimated_output_tokens": "sum",
359
+ }
360
+ )
361
+ .reset_index()
362
+ )
363
+ df["estimated_cost_usd"] = df["estimated_cost_usd"] * iterations
364
+ df["estimated_input_tokens"] = df["estimated_input_tokens"] * iterations
365
+ df["estimated_output_tokens"] = df["estimated_output_tokens"] * iterations
366
+
367
+ estimated_costs_by_model = df.to_dict("records")
368
+
369
+ estimated_total_cost = sum(
370
+ model["estimated_cost_usd"] for model in estimated_costs_by_model
371
+ )
372
+ estimated_total_input_tokens = sum(
373
+ model["estimated_input_tokens"] for model in estimated_costs_by_model
374
+ )
375
+ estimated_total_output_tokens = sum(
376
+ model["estimated_output_tokens"] for model in estimated_costs_by_model
377
+ )
378
+
379
+ output = {
380
+ "estimated_total_cost_usd": estimated_total_cost,
381
+ "estimated_total_input_tokens": estimated_total_input_tokens,
382
+ "estimated_total_output_tokens": estimated_total_output_tokens,
383
+ "model_costs": estimated_costs_by_model,
384
+ }
385
+
386
+ return output
387
+
388
+ def estimate_job_cost(self, iterations: int = 1) -> dict:
389
+ """
390
+ Estimates the cost of a job according to the following assumptions:
391
+
392
+ - 1 token = 4 characters.
393
+ - For each prompt, output tokens = input tokens * 0.75, rounded up to the nearest integer.
394
+
395
+ Fetches prices from Coop.
396
+ """
397
+ from edsl import Coop
398
+
399
+ c = Coop()
400
+ price_lookup = c.fetch_prices()
401
+
402
+ return self.estimate_job_cost_from_external_prices(
403
+ price_lookup=price_lookup, iterations=iterations
404
+ )
405
+
406
+ @staticmethod
407
+ def compute_job_cost(job_results: "Results") -> float:
408
+ """
409
+ Computes the cost of a completed job in USD.
410
+ """
411
+ total_cost = 0
412
+ for result in job_results:
413
+ for key in result.raw_model_response:
414
+ if key.endswith("_cost"):
415
+ result_cost = result.raw_model_response[key]
416
+
417
+ question_name = key.removesuffix("_cost")
418
+ cache_used = result.cache_used_dict[question_name]
419
+
420
+ if isinstance(result_cost, (int, float)) and not cache_used:
421
+ total_cost += result_cost
422
+
423
+ return total_cost
424
+
425
+ @staticmethod
426
+ def _get_container_class(object):
427
+ from edsl.agents.AgentList import AgentList
428
+ from edsl.agents.Agent import Agent
429
+ from edsl.scenarios.Scenario import Scenario
430
+ from edsl.scenarios.ScenarioList import ScenarioList
431
+ from edsl.language_models.ModelList import ModelList
432
+
433
+ if isinstance(object, Agent):
434
+ return AgentList
435
+ elif isinstance(object, Scenario):
436
+ return ScenarioList
437
+ elif isinstance(object, ModelList):
438
+ return ModelList
439
+ else:
440
+ return list
441
+
442
+ @staticmethod
443
+ def _turn_args_to_list(args):
444
+ """Return a list of the first argument if it is a sequence, otherwise returns a list of all the arguments.
445
+
446
+ Example:
447
+
448
+ >>> Jobs._turn_args_to_list([1,2,3])
449
+ [1, 2, 3]
450
+
451
+ """
452
+
453
+ def did_user_pass_a_sequence(args):
454
+ """Return True if the user passed a sequence, False otherwise.
455
+
456
+ Example:
457
+
458
+ >>> did_user_pass_a_sequence([1,2,3])
459
+ True
460
+
461
+ >>> did_user_pass_a_sequence(1)
462
+ False
463
+ """
464
+ return len(args) == 1 and isinstance(args[0], Sequence)
465
+
466
+ if did_user_pass_a_sequence(args):
467
+ container_class = Jobs._get_container_class(args[0][0])
468
+ return container_class(args[0])
469
+ else:
470
+ container_class = Jobs._get_container_class(args[0])
471
+ return container_class(args)
472
+
473
+ def _get_current_objects_of_this_type(
474
+ self, object: Union["Agent", "Scenario", "LanguageModel"]
475
+ ) -> tuple[list, str]:
476
+ from edsl.agents.Agent import Agent
477
+ from edsl.scenarios.Scenario import Scenario
478
+ from edsl.language_models.LanguageModel import LanguageModel
479
+
480
+ """Return the current objects of the same type as the first argument.
481
+
482
+ >>> from edsl.jobs import Jobs
483
+ >>> j = Jobs.example()
484
+ >>> j._get_current_objects_of_this_type(j.agents[0])
485
+ (AgentList([Agent(traits = {'status': 'Joyful'}), Agent(traits = {'status': 'Sad'})]), 'agents')
486
+ """
487
+ class_to_key = {
488
+ Agent: "agents",
489
+ Scenario: "scenarios",
490
+ LanguageModel: "models",
491
+ }
492
+ for class_type in class_to_key:
493
+ if isinstance(object, class_type) or issubclass(
494
+ object.__class__, class_type
495
+ ):
496
+ key = class_to_key[class_type]
497
+ break
498
+ else:
499
+ raise ValueError(
500
+ f"First argument must be an Agent, Scenario, or LanguageModel, not {object}"
501
+ )
502
+ current_objects = getattr(self, key, None)
503
+ return current_objects, key
504
+
505
+ @staticmethod
506
+ def _get_empty_container_object(object):
507
+ from edsl import AgentList
508
+ from edsl import Agent
509
+ from edsl import Scenario
510
+ from edsl import ScenarioList
511
+
512
+ if isinstance(object, Agent):
513
+ return AgentList([])
514
+ elif isinstance(object, Scenario):
515
+ return ScenarioList([])
516
+ else:
517
+ return []
518
+
519
+ @staticmethod
520
+ def _merge_objects(passed_objects, current_objects) -> list:
521
+ """
522
+ Combine all the existing objects with the new objects.
523
+
524
+ For example, if the user passes in 3 agents,
525
+ and there are 2 existing agents, this will create 6 new agents
526
+
527
+ >>> Jobs(survey = [])._merge_objects([1,2,3], [4,5,6])
528
+ [5, 6, 7, 6, 7, 8, 7, 8, 9]
529
+ """
530
+ new_objects = Jobs._get_empty_container_object(passed_objects[0])
531
+ for current_object in current_objects:
532
+ for new_object in passed_objects:
533
+ new_objects.append(current_object + new_object)
534
+ return new_objects
535
+
536
+ def interviews(self) -> list[Interview]:
537
+ """
538
+ Return a list of :class:`edsl.jobs.interviews.Interview` objects.
539
+
540
+ It returns one Interview for each combination of Agent, Scenario, and LanguageModel.
541
+ If any of Agents, Scenarios, or LanguageModels are missing, it fills in with defaults.
542
+
543
+ >>> from edsl.jobs import Jobs
544
+ >>> j = Jobs.example()
545
+ >>> len(j.interviews())
546
+ 4
547
+ >>> j.interviews()[0]
548
+ Interview(agent = Agent(traits = {'status': 'Joyful'}), survey = Survey(...), scenario = Scenario({'period': 'morning'}), model = Model(...))
549
+ """
550
+ if hasattr(self, "_interviews"):
551
+ return self._interviews
552
+ else:
553
+ return list(self._create_interviews())
554
+
555
+ @classmethod
556
+ def from_interviews(cls, interview_list):
557
+ """Return a Jobs instance from a list of interviews.
558
+
559
+ This is useful when you have, say, a list of failed interviews and you want to create
560
+ a new job with only those interviews.
561
+ """
562
+ survey = interview_list[0].survey
563
+ # get all the models
564
+ models = list(set([interview.model for interview in interview_list]))
565
+ jobs = cls(survey)
566
+ jobs.models = models
567
+ jobs._interviews = interview_list
568
+ return jobs
569
+
570
+ def _create_interviews(self) -> Generator[Interview, None, None]:
571
+ """
572
+ Generate interviews.
573
+
574
+ Note that this sets the agents, model and scenarios if they have not been set. This is a side effect of the method.
575
+ This is useful because a user can create a job without setting the agents, models, or scenarios, and the job will still run,
576
+ with us filling in defaults.
577
+
578
+
579
+ """
580
+ # if no agents, models, or scenarios are set, set them to defaults
581
+ from edsl.agents.Agent import Agent
582
+ from edsl.language_models.registry import Model
583
+ from edsl.scenarios.Scenario import Scenario
584
+
585
+ self.agents = self.agents or [Agent()]
586
+ self.models = self.models or [Model()]
587
+ self.scenarios = self.scenarios or [Scenario()]
588
+ for agent, scenario, model in product(self.agents, self.scenarios, self.models):
589
+ yield Interview(
590
+ survey=self.survey,
591
+ agent=agent,
592
+ scenario=scenario,
593
+ model=model,
594
+ skip_retry=self.skip_retry,
595
+ raise_validation_errors=self.raise_validation_errors,
596
+ )
597
+
598
+ def create_bucket_collection(self) -> BucketCollection:
599
+ """
600
+ Create a collection of buckets for each model.
601
+
602
+ These buckets are used to track API calls and token usage.
603
+
604
+ >>> from edsl.jobs import Jobs
605
+ >>> from edsl import Model
606
+ >>> j = Jobs.example().by(Model(temperature = 1), Model(temperature = 0.5))
607
+ >>> bc = j.create_bucket_collection()
608
+ >>> bc
609
+ BucketCollection(...)
610
+ """
611
+ bucket_collection = BucketCollection()
612
+ for model in self.models:
613
+ bucket_collection.add_model(model)
614
+ return bucket_collection
615
+
616
+ @property
617
+ def bucket_collection(self) -> BucketCollection:
618
+ """Return the bucket collection. If it does not exist, create it."""
619
+ if self.__bucket_collection is None:
620
+ self.__bucket_collection = self.create_bucket_collection()
621
+ return self.__bucket_collection
622
+
623
+ def html(self):
624
+ """Return the HTML representations for each scenario"""
625
+ links = []
626
+ for index, scenario in enumerate(self.scenarios):
627
+ links.append(
628
+ self.survey.html(
629
+ scenario=scenario, return_link=True, cta=f"Scenario {index}"
630
+ )
631
+ )
632
+ return links
633
+
634
+ def __hash__(self):
635
+ """Allow the model to be used as a key in a dictionary.
636
+
637
+ >>> from edsl.jobs import Jobs
638
+ >>> hash(Jobs.example())
639
+ 846655441787442972
640
+
641
+ """
642
+ from edsl.utilities.utilities import dict_hash
643
+
644
+ return dict_hash(self.to_dict(add_edsl_version=False))
645
+
646
+ def _output(self, message) -> None:
647
+ """Check if a Job is verbose. If so, print the message."""
648
+ if hasattr(self, "verbose") and self.verbose:
649
+ print(message)
650
+
651
+ def _check_parameters(self, strict=False, warn=False) -> None:
652
+ """Check if the parameters in the survey and scenarios are consistent.
653
+
654
+ >>> from edsl import QuestionFreeText
655
+ >>> from edsl import Survey
656
+ >>> from edsl import Scenario
657
+ >>> q = QuestionFreeText(question_text = "{{poo}}", question_name = "ugly_question")
658
+ >>> j = Jobs(survey = Survey(questions=[q]))
659
+ >>> with warnings.catch_warnings(record=True) as w:
660
+ ... j._check_parameters(warn = True)
661
+ ... assert len(w) == 1
662
+ ... assert issubclass(w[-1].category, UserWarning)
663
+ ... assert "The following parameters are in the survey but not in the scenarios" in str(w[-1].message)
664
+
665
+ >>> q = QuestionFreeText(question_text = "{{poo}}", question_name = "ugly_question")
666
+ >>> s = Scenario({'plop': "A", 'poo': "B"})
667
+ >>> j = Jobs(survey = Survey(questions=[q])).by(s)
668
+ >>> j._check_parameters(strict = True)
669
+ Traceback (most recent call last):
670
+ ...
671
+ ValueError: The following parameters are in the scenarios but not in the survey: {'plop'}
672
+
673
+ >>> q = QuestionFreeText(question_text = "Hello", question_name = "ugly_question")
674
+ >>> s = Scenario({'ugly_question': "B"})
675
+ >>> j = Jobs(survey = Survey(questions=[q])).by(s)
676
+ >>> j._check_parameters()
677
+ Traceback (most recent call last):
678
+ ...
679
+ ValueError: The following names are in both the survey question_names and the scenario keys: {'ugly_question'}. This will create issues.
680
+ """
681
+ survey_parameters: set = self.survey.parameters
682
+ scenario_parameters: set = self.scenarios.parameters
683
+
684
+ msg0, msg1, msg2 = None, None, None
685
+
686
+ # look for key issues
687
+ if intersection := set(self.scenarios.parameters) & set(
688
+ self.survey.question_names
689
+ ):
690
+ msg0 = f"The following names are in both the survey question_names and the scenario keys: {intersection}. This will create issues."
691
+
692
+ raise ValueError(msg0)
693
+
694
+ if in_survey_but_not_in_scenarios := survey_parameters - scenario_parameters:
695
+ msg1 = f"The following parameters are in the survey but not in the scenarios: {in_survey_but_not_in_scenarios}"
696
+ if in_scenarios_but_not_in_survey := scenario_parameters - survey_parameters:
697
+ msg2 = f"The following parameters are in the scenarios but not in the survey: {in_scenarios_but_not_in_survey}"
698
+
699
+ if msg1 or msg2:
700
+ message = "\n".join(filter(None, [msg1, msg2]))
701
+ if strict:
702
+ raise ValueError(message)
703
+ else:
704
+ if warn:
705
+ warnings.warn(message)
706
+
707
+ if self.scenarios.has_jinja_braces:
708
+ warnings.warn(
709
+ "The scenarios have Jinja braces ({{ and }}). Converting to '<<' and '>>'. If you want a different conversion, use the convert_jinja_braces method first to modify the scenario."
710
+ )
711
+ self.scenarios = self.scenarios.convert_jinja_braces()
712
+
713
+ @property
714
+ def skip_retry(self):
715
+ if not hasattr(self, "_skip_retry"):
716
+ return False
717
+ return self._skip_retry
718
+
719
+ @property
720
+ def raise_validation_errors(self):
721
+ if not hasattr(self, "_raise_validation_errors"):
722
+ return False
723
+ return self._raise_validation_errors
724
+
725
+ def create_remote_inference_job(
726
+ self,
727
+ iterations: int = 1,
728
+ remote_inference_description: Optional[str] = None,
729
+ remote_inference_results_visibility: Optional[VisibilityType] = "unlisted",
730
+ verbose=False,
731
+ ):
732
+ """ """
733
+ from edsl.coop.coop import Coop
734
+
735
+ coop = Coop()
736
+ self._output("Remote inference activated. Sending job to server...")
737
+ remote_job_creation_data = coop.remote_inference_create(
738
+ self,
739
+ description=remote_inference_description,
740
+ status="queued",
741
+ iterations=iterations,
742
+ initial_results_visibility=remote_inference_results_visibility,
743
+ )
744
+ job_uuid = remote_job_creation_data.get("uuid")
745
+ if self.verbose:
746
+ print(f"Job sent to server. (Job uuid={job_uuid}).")
747
+ return remote_job_creation_data
748
+
749
+ @staticmethod
750
+ def check_status(job_uuid):
751
+ from edsl.coop.coop import Coop
752
+
753
+ coop = Coop()
754
+ return coop.remote_inference_get(job_uuid)
755
+
756
+ def poll_remote_inference_job(
757
+ self, remote_job_creation_data: dict, verbose=False, poll_interval=5
758
+ ) -> Union[Results, None]:
759
+ from edsl.coop.coop import Coop
760
+ import time
761
+ from datetime import datetime
762
+ from edsl.config import CONFIG
763
+
764
+ expected_parrot_url = CONFIG.get("EXPECTED_PARROT_URL")
765
+
766
+ job_uuid = remote_job_creation_data.get("uuid")
767
+
768
+ coop = Coop()
769
+ job_in_queue = True
770
+ while job_in_queue:
771
+ remote_job_data = coop.remote_inference_get(job_uuid)
772
+ status = remote_job_data.get("status")
773
+ if status == "cancelled":
774
+ if self.verbose:
775
+ print("\r" + " " * 80 + "\r", end="")
776
+ print("Job cancelled by the user.")
777
+ print(
778
+ f"See {expected_parrot_url}/home/remote-inference for more details."
779
+ )
780
+ return None
781
+ elif status == "failed":
782
+ if self.verbose:
783
+ print("\r" + " " * 80 + "\r", end="")
784
+ print("Job failed.")
785
+ print(
786
+ f"See {expected_parrot_url}/home/remote-inference for more details."
787
+ )
788
+ return None
789
+ elif status == "completed":
790
+ results_uuid = remote_job_data.get("results_uuid")
791
+ results = coop.get(results_uuid, expected_object_type="results")
792
+ if self.verbose:
793
+ print("\r" + " " * 80 + "\r", end="")
794
+ url = f"{expected_parrot_url}/content/{results_uuid}"
795
+ print(f"Job completed and Results stored on Coop: {url}.")
796
+ return results
797
+ else:
798
+ duration = poll_interval
799
+ time_checked = datetime.now().strftime("%Y-%m-%d %I:%M:%S %p")
800
+ frames = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
801
+ start_time = time.time()
802
+ i = 0
803
+ while time.time() - start_time < duration:
804
+ if self.verbose:
805
+ print(
806
+ f"\r{frames[i % len(frames)]} Job status: {status} - last update: {time_checked}",
807
+ end="",
808
+ flush=True,
809
+ )
810
+ time.sleep(0.1)
811
+ i += 1
812
+
813
+ def use_remote_inference(self, disable_remote_inference: bool) -> bool:
814
+ if disable_remote_inference:
815
+ return False
816
+ if not disable_remote_inference:
817
+ try:
818
+ from edsl import Coop
819
+
820
+ user_edsl_settings = Coop().edsl_settings
821
+ return user_edsl_settings.get("remote_inference", False)
822
+ except requests.ConnectionError:
823
+ pass
824
+ except CoopServerResponseError as e:
825
+ pass
826
+
827
+ return False
828
+
829
+ def use_remote_cache(self, disable_remote_cache: bool) -> bool:
830
+ if disable_remote_cache:
831
+ return False
832
+ if not disable_remote_cache:
833
+ try:
834
+ from edsl import Coop
835
+
836
+ user_edsl_settings = Coop().edsl_settings
837
+ return user_edsl_settings.get("remote_caching", False)
838
+ except requests.ConnectionError:
839
+ pass
840
+ except CoopServerResponseError as e:
841
+ pass
842
+
843
+ return False
844
+
845
+ def check_api_keys(self) -> None:
846
+ from edsl import Model
847
+
848
+ for model in self.models + [Model()]:
849
+ if not model.has_valid_api_key():
850
+ raise MissingAPIKeyError(
851
+ model_name=str(model.model),
852
+ inference_service=model._inference_service_,
853
+ )
854
+
855
+ def get_missing_api_keys(self) -> set:
856
+ """
857
+ Returns a list of the api keys that a user needs to run this job, but does not currently have in their .env file.
858
+ """
859
+
860
+ missing_api_keys = set()
861
+
862
+ from edsl import Model
863
+ from edsl.enums import service_to_api_keyname
864
+
865
+ for model in self.models + [Model()]:
866
+ if not model.has_valid_api_key():
867
+ key_name = service_to_api_keyname.get(
868
+ model._inference_service_, "NOT FOUND"
869
+ )
870
+ missing_api_keys.add(key_name)
871
+
872
+ return missing_api_keys
873
+
874
+ def user_has_all_model_keys(self):
875
+ """
876
+ Returns True if the user has all model keys required to run their job.
877
+
878
+ Otherwise, returns False.
879
+ """
880
+
881
+ try:
882
+ self.check_api_keys()
883
+ return True
884
+ except MissingAPIKeyError:
885
+ return False
886
+ except Exception:
887
+ raise
888
+
889
+ def user_has_ep_api_key(self) -> bool:
890
+ """
891
+ Returns True if the user has an EXPECTED_PARROT_API_KEY in their env.
892
+
893
+ Otherwise, returns False.
894
+ """
895
+
896
+ import os
897
+
898
+ coop_api_key = os.getenv("EXPECTED_PARROT_API_KEY")
899
+
900
+ if coop_api_key is not None:
901
+ return True
902
+ else:
903
+ return False
904
+
905
+ def needs_external_llms(self) -> bool:
906
+ """
907
+ Returns True if the job needs external LLMs to run.
908
+
909
+ Otherwise, returns False.
910
+ """
911
+ # These cases are necessary to skip the API key check during doctests
912
+
913
+ # Accounts for Results.example()
914
+ all_agents_answer_questions_directly = len(self.agents) > 0 and all(
915
+ [hasattr(a, "answer_question_directly") for a in self.agents]
916
+ )
917
+
918
+ # Accounts for InterviewExceptionEntry.example()
919
+ only_model_is_test = set([m.model for m in self.models]) == set(["test"])
920
+
921
+ # Accounts for Survey.__call__
922
+ all_questions_are_functional = set(
923
+ [q.question_type for q in self.survey.questions]
924
+ ) == set(["functional"])
925
+
926
+ if (
927
+ all_agents_answer_questions_directly
928
+ or only_model_is_test
929
+ or all_questions_are_functional
930
+ ):
931
+ return False
932
+ else:
933
+ return True
934
+
935
+ def run(
936
+ self,
937
+ n: int = 1,
938
+ progress_bar: bool = False,
939
+ stop_on_exception: bool = False,
940
+ cache: Union[Cache, bool] = None,
941
+ check_api_keys: bool = False,
942
+ sidecar_model: Optional[LanguageModel] = None,
943
+ verbose: bool = False,
944
+ print_exceptions=True,
945
+ remote_cache_description: Optional[str] = None,
946
+ remote_inference_description: Optional[str] = None,
947
+ remote_inference_results_visibility: Optional[
948
+ Literal["private", "public", "unlisted"]
949
+ ] = "unlisted",
950
+ skip_retry: bool = False,
951
+ raise_validation_errors: bool = False,
952
+ disable_remote_cache: bool = False,
953
+ disable_remote_inference: bool = False,
954
+ ) -> Results:
955
+ """
956
+ Runs the Job: conducts Interviews and returns their results.
957
+
958
+ :param n: How many times to run each interview
959
+ :param progress_bar: Whether to show a progress bar
960
+ :param stop_on_exception: Stops the job if an exception is raised
961
+ :param cache: A Cache object to store results
962
+ :param check_api_keys: Raises an error if API keys are invalid
963
+ :param verbose: Prints extra messages
964
+ :param remote_cache_description: Specifies a description for this group of entries in the remote cache
965
+ :param remote_inference_description: Specifies a description for the remote inference job
966
+ :param remote_inference_results_visibility: The initial visibility of the Results object on Coop. This will only be used for remote jobs!
967
+ :param disable_remote_cache: If True, the job will not use remote cache. This only works for local jobs!
968
+ :param disable_remote_inference: If True, the job will not use remote inference
969
+ """
970
+ from edsl.coop.coop import Coop
971
+
972
+ self._check_parameters()
973
+ self._skip_retry = skip_retry
974
+ self._raise_validation_errors = raise_validation_errors
975
+
976
+ self.verbose = verbose
977
+
978
+ if (
979
+ not self.user_has_all_model_keys()
980
+ and not self.user_has_ep_api_key()
981
+ and self.needs_external_llms()
982
+ ):
983
+ import secrets
984
+ from dotenv import load_dotenv
985
+ from edsl import CONFIG
986
+ from edsl.coop.coop import Coop
987
+ from edsl.utilities.utilities import write_api_key_to_env
988
+
989
+ missing_api_keys = self.get_missing_api_keys()
990
+
991
+ edsl_auth_token = secrets.token_urlsafe(16)
992
+
993
+ print("You're missing some of the API keys needed to run this job:")
994
+ for api_key in missing_api_keys:
995
+ print(f" 🔑 {api_key}")
996
+ print(
997
+ "\nYou can either add the missing keys to your .env file, or use remote inference."
998
+ )
999
+ print("Remote inference allows you to run jobs on our server.")
1000
+ print("\n🚀 To use remote inference, sign up at the following link:")
1001
+
1002
+ coop = Coop()
1003
+ coop._display_login_url(edsl_auth_token=edsl_auth_token)
1004
+
1005
+ print(
1006
+ "\nOnce you log in, we will automatically retrieve your Expected Parrot API key and continue your job remotely."
1007
+ )
1008
+
1009
+ api_key = coop._poll_for_api_key(edsl_auth_token)
1010
+
1011
+ if api_key is None:
1012
+ print("\nTimed out waiting for login. Please try again.")
1013
+ return
1014
+
1015
+ write_api_key_to_env(api_key)
1016
+ print("✨ API key retrieved and written to .env file.\n")
1017
+
1018
+ # Retrieve API key so we can continue running the job
1019
+ load_dotenv()
1020
+
1021
+ if remote_inference := self.use_remote_inference(disable_remote_inference):
1022
+ remote_job_creation_data = self.create_remote_inference_job(
1023
+ iterations=n,
1024
+ remote_inference_description=remote_inference_description,
1025
+ remote_inference_results_visibility=remote_inference_results_visibility,
1026
+ )
1027
+ results = self.poll_remote_inference_job(remote_job_creation_data)
1028
+ if results is None:
1029
+ self._output("Job failed.")
1030
+ return results
1031
+
1032
+ if check_api_keys:
1033
+ self.check_api_keys()
1034
+
1035
+ # handle cache
1036
+ if cache is None or cache is True:
1037
+ from edsl.data.CacheHandler import CacheHandler
1038
+
1039
+ cache = CacheHandler().get_cache()
1040
+ if cache is False:
1041
+ from edsl.data.Cache import Cache
1042
+
1043
+ cache = Cache()
1044
+
1045
+ remote_cache = self.use_remote_cache(disable_remote_cache)
1046
+ with RemoteCacheSync(
1047
+ coop=Coop(),
1048
+ cache=cache,
1049
+ output_func=self._output,
1050
+ remote_cache=remote_cache,
1051
+ remote_cache_description=remote_cache_description,
1052
+ ) as r:
1053
+ results = self._run_local(
1054
+ n=n,
1055
+ progress_bar=progress_bar,
1056
+ cache=cache,
1057
+ stop_on_exception=stop_on_exception,
1058
+ sidecar_model=sidecar_model,
1059
+ print_exceptions=print_exceptions,
1060
+ raise_validation_errors=raise_validation_errors,
1061
+ )
1062
+
1063
+ results.cache = cache.new_entries_cache()
1064
+ return results
1065
+
1066
+ async def create_and_poll_remote_job(
1067
+ self,
1068
+ iterations: int = 1,
1069
+ remote_inference_description: Optional[str] = None,
1070
+ remote_inference_results_visibility: Optional[
1071
+ Literal["private", "public", "unlisted"]
1072
+ ] = "unlisted",
1073
+ ) -> Union[Results, None]:
1074
+ """
1075
+ Creates and polls a remote inference job asynchronously.
1076
+ Reuses existing synchronous methods but runs them in an async context.
1077
+
1078
+ :param iterations: Number of times to run each interview
1079
+ :param remote_inference_description: Optional description for the remote job
1080
+ :param remote_inference_results_visibility: Visibility setting for results
1081
+ :return: Results object if successful, None if job fails or is cancelled
1082
+ """
1083
+ import asyncio
1084
+ from functools import partial
1085
+
1086
+ # Create job using existing method
1087
+ loop = asyncio.get_event_loop()
1088
+ remote_job_creation_data = await loop.run_in_executor(
1089
+ None,
1090
+ partial(
1091
+ self.create_remote_inference_job,
1092
+ iterations=iterations,
1093
+ remote_inference_description=remote_inference_description,
1094
+ remote_inference_results_visibility=remote_inference_results_visibility,
1095
+ ),
1096
+ )
1097
+
1098
+ # Poll using existing method but with async sleep
1099
+ return await loop.run_in_executor(
1100
+ None, partial(self.poll_remote_inference_job, remote_job_creation_data)
1101
+ )
1102
+
1103
+ async def run_async(
1104
+ self,
1105
+ cache=None,
1106
+ n=1,
1107
+ disable_remote_inference: bool = False,
1108
+ remote_inference_description: Optional[str] = None,
1109
+ remote_inference_results_visibility: Optional[
1110
+ Literal["private", "public", "unlisted"]
1111
+ ] = "unlisted",
1112
+ **kwargs,
1113
+ ):
1114
+ """Run the job asynchronously, either locally or remotely.
1115
+
1116
+ :param cache: Cache object or boolean
1117
+ :param n: Number of iterations
1118
+ :param disable_remote_inference: If True, forces local execution
1119
+ :param remote_inference_description: Description for remote jobs
1120
+ :param remote_inference_results_visibility: Visibility setting for remote results
1121
+ :param kwargs: Additional arguments passed to local execution
1122
+ :return: Results object
1123
+ """
1124
+ # Check if we should use remote inference
1125
+ if remote_inference := self.use_remote_inference(disable_remote_inference):
1126
+ results = await self.create_and_poll_remote_job(
1127
+ iterations=n,
1128
+ remote_inference_description=remote_inference_description,
1129
+ remote_inference_results_visibility=remote_inference_results_visibility,
1130
+ )
1131
+ if results is None:
1132
+ self._output("Job failed.")
1133
+ return results
1134
+
1135
+ # If not using remote inference, run locally with async
1136
+ return await JobsRunnerAsyncio(self).run_async(cache=cache, n=n, **kwargs)
1137
+
1138
+ def _run_local(self, *args, **kwargs):
1139
+ """Run the job locally."""
1140
+
1141
+ results = JobsRunnerAsyncio(self).run(*args, **kwargs)
1142
+ return results
1143
+
1144
+ def all_question_parameters(self):
1145
+ """Return all the fields in the questions in the survey.
1146
+ >>> from edsl.jobs import Jobs
1147
+ >>> Jobs.example().all_question_parameters()
1148
+ {'period'}
1149
+ """
1150
+ return set.union(*[question.parameters for question in self.survey.questions])
1151
+
1152
+ #######################
1153
+ # Dunder methods
1154
+ #######################
1155
+ def print(self):
1156
+ from rich import print_json
1157
+ import json
1158
+
1159
+ print_json(json.dumps(self.to_dict()))
1160
+
1161
+ def __repr__(self) -> str:
1162
+ """Return an eval-able string representation of the Jobs instance."""
1163
+ return f"Jobs(survey={repr(self.survey)}, agents={repr(self.agents)}, models={repr(self.models)}, scenarios={repr(self.scenarios)})"
1164
+
1165
+ def _repr_html_(self) -> str:
1166
+ from rich import print_json
1167
+ import json
1168
+
1169
+ print_json(json.dumps(self.to_dict()))
1170
+
1171
+ def __len__(self) -> int:
1172
+ """Return the maximum number of questions that will be asked while running this job.
1173
+ Note that this is the maximum number of questions, not the actual number of questions that will be asked, as some questions may be skipped.
1174
+
1175
+ >>> from edsl.jobs import Jobs
1176
+ >>> len(Jobs.example())
1177
+ 8
1178
+ """
1179
+ number_of_questions = (
1180
+ len(self.agents or [1])
1181
+ * len(self.scenarios or [1])
1182
+ * len(self.models or [1])
1183
+ * len(self.survey)
1184
+ )
1185
+ return number_of_questions
1186
+
1187
+ #######################
1188
+ # Serialization methods
1189
+ #######################
1190
+
1191
+ def to_dict(self, add_edsl_version=True):
1192
+ d = {
1193
+ "survey": self.survey.to_dict(add_edsl_version=add_edsl_version),
1194
+ "agents": [
1195
+ agent.to_dict(add_edsl_version=add_edsl_version)
1196
+ for agent in self.agents
1197
+ ],
1198
+ "models": [
1199
+ model.to_dict(add_edsl_version=add_edsl_version)
1200
+ for model in self.models
1201
+ ],
1202
+ "scenarios": [
1203
+ scenario.to_dict(add_edsl_version=add_edsl_version)
1204
+ for scenario in self.scenarios
1205
+ ],
1206
+ }
1207
+ if add_edsl_version:
1208
+ from edsl import __version__
1209
+
1210
+ d["edsl_version"] = __version__
1211
+ d["edsl_class_name"] = "Jobs"
1212
+
1213
+ return d
1214
+
1215
+ @classmethod
1216
+ @remove_edsl_version
1217
+ def from_dict(cls, data: dict) -> Jobs:
1218
+ """Creates a Jobs instance from a dictionary."""
1219
+ from edsl import Survey
1220
+ from edsl.agents.Agent import Agent
1221
+ from edsl.language_models.LanguageModel import LanguageModel
1222
+ from edsl.scenarios.Scenario import Scenario
1223
+
1224
+ return cls(
1225
+ survey=Survey.from_dict(data["survey"]),
1226
+ agents=[Agent.from_dict(agent) for agent in data["agents"]],
1227
+ models=[LanguageModel.from_dict(model) for model in data["models"]],
1228
+ scenarios=[Scenario.from_dict(scenario) for scenario in data["scenarios"]],
1229
+ )
1230
+
1231
+ def __eq__(self, other: Jobs) -> bool:
1232
+ """Return True if the Jobs instance is equal to another Jobs instance.
1233
+
1234
+ >>> from edsl.jobs import Jobs
1235
+ >>> Jobs.example() == Jobs.example()
1236
+ True
1237
+
1238
+ """
1239
+ return self.to_dict() == other.to_dict()
1240
+
1241
+ #######################
1242
+ # Example methods
1243
+ #######################
1244
+ @classmethod
1245
+ def example(
1246
+ cls,
1247
+ throw_exception_probability: float = 0.0,
1248
+ randomize: bool = False,
1249
+ test_model=False,
1250
+ ) -> Jobs:
1251
+ """Return an example Jobs instance.
1252
+
1253
+ :param throw_exception_probability: the probability that an exception will be thrown when answering a question. This is useful for testing error handling.
1254
+ :param randomize: whether to randomize the job by adding a random string to the period
1255
+ :param test_model: whether to use a test model
1256
+
1257
+ >>> Jobs.example()
1258
+ Jobs(...)
1259
+
1260
+ """
1261
+ import random
1262
+ from uuid import uuid4
1263
+ from edsl.questions import QuestionMultipleChoice
1264
+ from edsl.agents.Agent import Agent
1265
+ from edsl.scenarios.Scenario import Scenario
1266
+
1267
+ addition = "" if not randomize else str(uuid4())
1268
+
1269
+ if test_model:
1270
+ from edsl.language_models import LanguageModel
1271
+
1272
+ m = LanguageModel.example(test_model=True)
1273
+
1274
+ # (status, question, period)
1275
+ agent_answers = {
1276
+ ("Joyful", "how_feeling", "morning"): "OK",
1277
+ ("Joyful", "how_feeling", "afternoon"): "Great",
1278
+ ("Joyful", "how_feeling_yesterday", "morning"): "Great",
1279
+ ("Joyful", "how_feeling_yesterday", "afternoon"): "Good",
1280
+ ("Sad", "how_feeling", "morning"): "Terrible",
1281
+ ("Sad", "how_feeling", "afternoon"): "OK",
1282
+ ("Sad", "how_feeling_yesterday", "morning"): "OK",
1283
+ ("Sad", "how_feeling_yesterday", "afternoon"): "Terrible",
1284
+ }
1285
+
1286
+ def answer_question_directly(self, question, scenario):
1287
+ """Return the answer to a question. This is a method that can be added to an agent."""
1288
+
1289
+ if random.random() < throw_exception_probability:
1290
+ raise Exception("Error!")
1291
+ return agent_answers[
1292
+ (self.traits["status"], question.question_name, scenario["period"])
1293
+ ]
1294
+
1295
+ sad_agent = Agent(traits={"status": "Sad"})
1296
+ joy_agent = Agent(traits={"status": "Joyful"})
1297
+
1298
+ sad_agent.add_direct_question_answering_method(answer_question_directly)
1299
+ joy_agent.add_direct_question_answering_method(answer_question_directly)
1300
+
1301
+ q1 = QuestionMultipleChoice(
1302
+ question_text="How are you this {{ period }}?",
1303
+ question_options=["Good", "Great", "OK", "Terrible"],
1304
+ question_name="how_feeling",
1305
+ )
1306
+ q2 = QuestionMultipleChoice(
1307
+ question_text="How were you feeling yesterday {{ period }}?",
1308
+ question_options=["Good", "Great", "OK", "Terrible"],
1309
+ question_name="how_feeling_yesterday",
1310
+ )
1311
+ from edsl import Survey, ScenarioList
1312
+
1313
+ base_survey = Survey(questions=[q1, q2])
1314
+
1315
+ scenario_list = ScenarioList(
1316
+ [
1317
+ Scenario({"period": f"morning{addition}"}),
1318
+ Scenario({"period": "afternoon"}),
1319
+ ]
1320
+ )
1321
+ if test_model:
1322
+ job = base_survey.by(m).by(scenario_list).by(joy_agent, sad_agent)
1323
+ else:
1324
+ job = base_survey.by(scenario_list).by(joy_agent, sad_agent)
1325
+
1326
+ return job
1327
+
1328
+ def rich_print(self):
1329
+ """Print a rich representation of the Jobs instance."""
1330
+ from rich.table import Table
1331
+
1332
+ table = Table(title="Jobs")
1333
+ table.add_column("Jobs")
1334
+ table.add_row(self.survey.rich_print())
1335
+ return table
1336
+
1337
+ def code(self):
1338
+ """Return the code to create this instance."""
1339
+ raise NotImplementedError
1340
+
1341
+
1342
+ def main():
1343
+ """Run the module's doctests."""
1344
+ from edsl.jobs import Jobs
1345
+ from edsl.data.Cache import Cache
1346
+
1347
+ job = Jobs.example()
1348
+ len(job) == 8
1349
+ results = job.run(cache=Cache())
1350
+ len(results) == 8
1351
+ results
1352
+
1353
+
1354
+ if __name__ == "__main__":
1355
+ """Run the module's doctests."""
1356
+ import doctest
1357
+
1358
+ doctest.testmod(optionflags=doctest.ELLIPSIS)