edsl 0.1.36.dev5__py3-none-any.whl → 0.1.36.dev7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +303 -303
- edsl/BaseDiff.py +260 -260
- edsl/TemplateLoader.py +24 -24
- edsl/__init__.py +48 -47
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +804 -804
- edsl/agents/AgentList.py +337 -337
- edsl/agents/Invigilator.py +222 -222
- edsl/agents/InvigilatorBase.py +298 -294
- edsl/agents/PromptConstructor.py +320 -312
- edsl/agents/__init__.py +3 -3
- edsl/agents/descriptors.py +86 -86
- edsl/agents/prompt_helpers.py +129 -129
- edsl/auto/AutoStudy.py +117 -117
- edsl/auto/StageBase.py +230 -230
- edsl/auto/StageGenerateSurvey.py +178 -178
- edsl/auto/StageLabelQuestions.py +125 -125
- edsl/auto/StagePersona.py +61 -61
- edsl/auto/StagePersonaDimensionValueRanges.py +88 -88
- edsl/auto/StagePersonaDimensionValues.py +74 -74
- edsl/auto/StagePersonaDimensions.py +69 -69
- edsl/auto/StageQuestions.py +73 -73
- edsl/auto/SurveyCreatorPipeline.py +21 -21
- edsl/auto/utilities.py +224 -224
- edsl/base/Base.py +289 -289
- edsl/config.py +149 -149
- edsl/conjure/AgentConstructionMixin.py +152 -152
- edsl/conjure/Conjure.py +62 -62
- edsl/conjure/InputData.py +659 -659
- edsl/conjure/InputDataCSV.py +48 -48
- edsl/conjure/InputDataMixinQuestionStats.py +182 -182
- edsl/conjure/InputDataPyRead.py +91 -91
- edsl/conjure/InputDataSPSS.py +8 -8
- edsl/conjure/InputDataStata.py +8 -8
- edsl/conjure/QuestionOptionMixin.py +76 -76
- edsl/conjure/QuestionTypeMixin.py +23 -23
- edsl/conjure/RawQuestion.py +65 -65
- edsl/conjure/SurveyResponses.py +7 -7
- edsl/conjure/__init__.py +9 -9
- edsl/conjure/naming_utilities.py +263 -263
- edsl/conjure/utilities.py +201 -201
- edsl/conversation/Conversation.py +238 -238
- edsl/conversation/car_buying.py +58 -58
- edsl/conversation/mug_negotiation.py +81 -81
- edsl/conversation/next_speaker_utilities.py +93 -93
- edsl/coop/PriceFetcher.py +54 -54
- edsl/coop/__init__.py +2 -2
- edsl/coop/coop.py +849 -849
- edsl/coop/utils.py +131 -131
- edsl/data/Cache.py +527 -527
- edsl/data/CacheEntry.py +228 -228
- edsl/data/CacheHandler.py +149 -149
- edsl/data/RemoteCacheSync.py +83 -83
- edsl/data/SQLiteDict.py +292 -292
- edsl/data/__init__.py +4 -4
- edsl/data/orm.py +10 -10
- edsl/data_transfer_models.py +73 -73
- edsl/enums.py +173 -173
- edsl/exceptions/__init__.py +50 -50
- edsl/exceptions/agents.py +40 -40
- edsl/exceptions/configuration.py +16 -16
- edsl/exceptions/coop.py +10 -10
- edsl/exceptions/data.py +14 -14
- edsl/exceptions/general.py +34 -34
- edsl/exceptions/jobs.py +33 -33
- edsl/exceptions/language_models.py +63 -63
- edsl/exceptions/prompts.py +15 -15
- edsl/exceptions/questions.py +91 -91
- edsl/exceptions/results.py +26 -26
- edsl/exceptions/surveys.py +34 -34
- edsl/inference_services/AnthropicService.py +87 -87
- edsl/inference_services/AwsBedrock.py +115 -115
- edsl/inference_services/AzureAI.py +217 -217
- edsl/inference_services/DeepInfraService.py +18 -18
- edsl/inference_services/GoogleService.py +156 -156
- edsl/inference_services/GroqService.py +20 -20
- edsl/inference_services/InferenceServiceABC.py +147 -147
- edsl/inference_services/InferenceServicesCollection.py +74 -68
- edsl/inference_services/MistralAIService.py +123 -123
- edsl/inference_services/OllamaService.py +18 -18
- edsl/inference_services/OpenAIService.py +224 -224
- edsl/inference_services/TestService.py +89 -89
- edsl/inference_services/TogetherAIService.py +170 -170
- edsl/inference_services/models_available_cache.py +118 -94
- edsl/inference_services/rate_limits_cache.py +25 -25
- edsl/inference_services/registry.py +39 -39
- edsl/inference_services/write_available.py +10 -10
- edsl/jobs/Answers.py +56 -56
- edsl/jobs/Jobs.py +1112 -1112
- edsl/jobs/__init__.py +1 -1
- edsl/jobs/buckets/BucketCollection.py +63 -63
- edsl/jobs/buckets/ModelBuckets.py +65 -65
- edsl/jobs/buckets/TokenBucket.py +248 -248
- edsl/jobs/interviews/Interview.py +661 -651
- edsl/jobs/interviews/InterviewExceptionCollection.py +99 -99
- edsl/jobs/interviews/InterviewExceptionEntry.py +189 -182
- edsl/jobs/interviews/InterviewStatistic.py +63 -63
- edsl/jobs/interviews/InterviewStatisticsCollection.py +25 -25
- edsl/jobs/interviews/InterviewStatusDictionary.py +78 -78
- edsl/jobs/interviews/InterviewStatusLog.py +92 -92
- edsl/jobs/interviews/ReportErrors.py +66 -66
- edsl/jobs/interviews/interview_status_enum.py +9 -9
- edsl/jobs/runners/JobsRunnerAsyncio.py +337 -337
- edsl/jobs/runners/JobsRunnerStatus.py +332 -332
- edsl/jobs/tasks/QuestionTaskCreator.py +242 -242
- edsl/jobs/tasks/TaskCreators.py +64 -64
- edsl/jobs/tasks/TaskHistory.py +441 -441
- edsl/jobs/tasks/TaskStatusLog.py +23 -23
- edsl/jobs/tasks/task_status_enum.py +163 -163
- edsl/jobs/tokens/InterviewTokenUsage.py +27 -27
- edsl/jobs/tokens/TokenUsage.py +34 -34
- edsl/language_models/LanguageModel.py +718 -718
- edsl/language_models/ModelList.py +102 -102
- edsl/language_models/RegisterLanguageModelsMeta.py +184 -184
- edsl/language_models/__init__.py +2 -2
- edsl/language_models/fake_openai_call.py +15 -15
- edsl/language_models/fake_openai_service.py +61 -61
- edsl/language_models/registry.py +137 -137
- edsl/language_models/repair.py +156 -156
- edsl/language_models/unused/ReplicateBase.py +83 -83
- edsl/language_models/utilities.py +64 -64
- edsl/notebooks/Notebook.py +259 -259
- edsl/notebooks/__init__.py +1 -1
- edsl/prompts/Prompt.py +358 -358
- edsl/prompts/__init__.py +2 -2
- edsl/questions/AnswerValidatorMixin.py +289 -289
- edsl/questions/QuestionBase.py +616 -616
- edsl/questions/QuestionBaseGenMixin.py +161 -161
- edsl/questions/QuestionBasePromptsMixin.py +266 -266
- edsl/questions/QuestionBudget.py +227 -227
- edsl/questions/QuestionCheckBox.py +359 -359
- edsl/questions/QuestionExtract.py +183 -183
- edsl/questions/QuestionFreeText.py +113 -113
- edsl/questions/QuestionFunctional.py +159 -159
- edsl/questions/QuestionList.py +231 -231
- edsl/questions/QuestionMultipleChoice.py +286 -286
- edsl/questions/QuestionNumerical.py +153 -153
- edsl/questions/QuestionRank.py +324 -324
- edsl/questions/Quick.py +41 -41
- edsl/questions/RegisterQuestionsMeta.py +71 -71
- edsl/questions/ResponseValidatorABC.py +174 -174
- edsl/questions/SimpleAskMixin.py +73 -73
- edsl/questions/__init__.py +26 -26
- edsl/questions/compose_questions.py +98 -98
- edsl/questions/decorators.py +21 -21
- edsl/questions/derived/QuestionLikertFive.py +76 -76
- edsl/questions/derived/QuestionLinearScale.py +87 -87
- edsl/questions/derived/QuestionTopK.py +91 -91
- edsl/questions/derived/QuestionYesNo.py +82 -82
- edsl/questions/descriptors.py +418 -418
- edsl/questions/prompt_templates/question_budget.jinja +13 -13
- edsl/questions/prompt_templates/question_checkbox.jinja +32 -32
- edsl/questions/prompt_templates/question_extract.jinja +11 -11
- edsl/questions/prompt_templates/question_free_text.jinja +3 -3
- edsl/questions/prompt_templates/question_linear_scale.jinja +11 -11
- edsl/questions/prompt_templates/question_list.jinja +17 -17
- edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -33
- edsl/questions/prompt_templates/question_numerical.jinja +36 -36
- edsl/questions/question_registry.py +147 -147
- edsl/questions/settings.py +12 -12
- edsl/questions/templates/budget/answering_instructions.jinja +7 -7
- edsl/questions/templates/budget/question_presentation.jinja +7 -7
- edsl/questions/templates/checkbox/answering_instructions.jinja +10 -10
- edsl/questions/templates/checkbox/question_presentation.jinja +22 -22
- edsl/questions/templates/extract/answering_instructions.jinja +7 -7
- edsl/questions/templates/likert_five/answering_instructions.jinja +10 -10
- edsl/questions/templates/likert_five/question_presentation.jinja +11 -11
- edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -5
- edsl/questions/templates/linear_scale/question_presentation.jinja +5 -5
- edsl/questions/templates/list/answering_instructions.jinja +3 -3
- edsl/questions/templates/list/question_presentation.jinja +5 -5
- edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -9
- edsl/questions/templates/multiple_choice/question_presentation.jinja +11 -11
- edsl/questions/templates/numerical/answering_instructions.jinja +6 -6
- edsl/questions/templates/numerical/question_presentation.jinja +6 -6
- edsl/questions/templates/rank/answering_instructions.jinja +11 -11
- edsl/questions/templates/rank/question_presentation.jinja +15 -15
- edsl/questions/templates/top_k/answering_instructions.jinja +8 -8
- edsl/questions/templates/top_k/question_presentation.jinja +22 -22
- edsl/questions/templates/yes_no/answering_instructions.jinja +6 -6
- edsl/questions/templates/yes_no/question_presentation.jinja +11 -11
- edsl/results/Dataset.py +293 -293
- edsl/results/DatasetExportMixin.py +693 -693
- edsl/results/DatasetTree.py +145 -145
- edsl/results/Result.py +433 -433
- edsl/results/Results.py +1158 -1158
- edsl/results/ResultsDBMixin.py +238 -238
- edsl/results/ResultsExportMixin.py +43 -43
- edsl/results/ResultsFetchMixin.py +33 -33
- edsl/results/ResultsGGMixin.py +121 -121
- edsl/results/ResultsToolsMixin.py +98 -98
- edsl/results/Selector.py +118 -118
- edsl/results/__init__.py +2 -2
- edsl/results/tree_explore.py +115 -115
- edsl/scenarios/FileStore.py +458 -443
- edsl/scenarios/Scenario.py +510 -507
- edsl/scenarios/ScenarioHtmlMixin.py +59 -59
- edsl/scenarios/ScenarioList.py +1101 -1101
- edsl/scenarios/ScenarioListExportMixin.py +52 -52
- edsl/scenarios/ScenarioListPdfMixin.py +261 -261
- edsl/scenarios/__init__.py +4 -2
- edsl/shared.py +1 -1
- edsl/study/ObjectEntry.py +173 -173
- edsl/study/ProofOfWork.py +113 -113
- edsl/study/SnapShot.py +80 -80
- edsl/study/Study.py +528 -528
- edsl/study/__init__.py +4 -4
- edsl/surveys/DAG.py +148 -148
- edsl/surveys/Memory.py +31 -31
- edsl/surveys/MemoryPlan.py +244 -244
- edsl/surveys/Rule.py +324 -324
- edsl/surveys/RuleCollection.py +387 -387
- edsl/surveys/Survey.py +1772 -1772
- edsl/surveys/SurveyCSS.py +261 -261
- edsl/surveys/SurveyExportMixin.py +259 -259
- edsl/surveys/SurveyFlowVisualizationMixin.py +121 -121
- edsl/surveys/SurveyQualtricsImport.py +284 -284
- edsl/surveys/__init__.py +3 -3
- edsl/surveys/base.py +53 -53
- edsl/surveys/descriptors.py +56 -56
- edsl/surveys/instructions/ChangeInstruction.py +47 -47
- edsl/surveys/instructions/Instruction.py +51 -51
- edsl/surveys/instructions/InstructionCollection.py +77 -77
- edsl/templates/error_reporting/base.html +23 -23
- edsl/templates/error_reporting/exceptions_by_model.html +34 -34
- edsl/templates/error_reporting/exceptions_by_question_name.html +16 -16
- edsl/templates/error_reporting/exceptions_by_type.html +16 -16
- edsl/templates/error_reporting/interview_details.html +115 -115
- edsl/templates/error_reporting/interviews.html +9 -9
- edsl/templates/error_reporting/overview.html +4 -4
- edsl/templates/error_reporting/performance_plot.html +1 -1
- edsl/templates/error_reporting/report.css +73 -73
- edsl/templates/error_reporting/report.html +117 -117
- edsl/templates/error_reporting/report.js +25 -25
- edsl/tools/__init__.py +1 -1
- edsl/tools/clusters.py +192 -192
- edsl/tools/embeddings.py +27 -27
- edsl/tools/embeddings_plotting.py +118 -118
- edsl/tools/plotting.py +112 -112
- edsl/tools/summarize.py +18 -18
- edsl/utilities/SystemInfo.py +28 -28
- edsl/utilities/__init__.py +22 -22
- edsl/utilities/ast_utilities.py +25 -25
- edsl/utilities/data/Registry.py +6 -6
- edsl/utilities/data/__init__.py +1 -1
- edsl/utilities/data/scooter_results.json +1 -1
- edsl/utilities/decorators.py +77 -77
- edsl/utilities/gcp_bucket/cloud_storage.py +96 -96
- edsl/utilities/interface.py +627 -627
- edsl/utilities/repair_functions.py +28 -28
- edsl/utilities/restricted_python.py +70 -70
- edsl/utilities/utilities.py +391 -391
- {edsl-0.1.36.dev5.dist-info → edsl-0.1.36.dev7.dist-info}/LICENSE +21 -21
- {edsl-0.1.36.dev5.dist-info → edsl-0.1.36.dev7.dist-info}/METADATA +1 -1
- edsl-0.1.36.dev7.dist-info/RECORD +279 -0
- edsl-0.1.36.dev5.dist-info/RECORD +0 -279
- {edsl-0.1.36.dev5.dist-info → edsl-0.1.36.dev7.dist-info}/WHEEL +0 -0
@@ -1,337 +1,337 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
import time
|
3
|
-
import asyncio
|
4
|
-
import threading
|
5
|
-
from typing import Coroutine, List, AsyncGenerator, Optional, Union, Generator
|
6
|
-
from contextlib import contextmanager
|
7
|
-
from collections import UserList
|
8
|
-
|
9
|
-
from edsl.results.Results import Results
|
10
|
-
from edsl.jobs.interviews.Interview import Interview
|
11
|
-
from edsl.jobs.runners.JobsRunnerStatus import JobsRunnerStatus
|
12
|
-
|
13
|
-
from edsl.jobs.tasks.TaskHistory import TaskHistory
|
14
|
-
from edsl.jobs.buckets.BucketCollection import BucketCollection
|
15
|
-
from edsl.utilities.decorators import jupyter_nb_handler
|
16
|
-
from edsl.data.Cache import Cache
|
17
|
-
from edsl.results.Result import Result
|
18
|
-
from edsl.results.Results import Results
|
19
|
-
from edsl.language_models.LanguageModel import LanguageModel
|
20
|
-
from edsl.data.Cache import Cache
|
21
|
-
|
22
|
-
class StatusTracker(UserList):
|
23
|
-
def __init__(self, total_tasks: int):
|
24
|
-
self.total_tasks = total_tasks
|
25
|
-
super().__init__()
|
26
|
-
|
27
|
-
def current_status(self):
|
28
|
-
return print(f"Completed: {len(self.data)} of {self.total_tasks}", end="\r")
|
29
|
-
|
30
|
-
|
31
|
-
class JobsRunnerAsyncio:
|
32
|
-
"""A class for running a collection of interviews asynchronously.
|
33
|
-
|
34
|
-
It gets instaniated from a Jobs object.
|
35
|
-
The Jobs object is a collection of interviews that are to be run.
|
36
|
-
"""
|
37
|
-
|
38
|
-
def __init__(self, jobs: "Jobs"):
|
39
|
-
self.jobs = jobs
|
40
|
-
self.interviews: List["Interview"] = jobs.interviews()
|
41
|
-
self.bucket_collection: "BucketCollection" = jobs.bucket_collection
|
42
|
-
self.total_interviews: List["Interview"] = []
|
43
|
-
|
44
|
-
async def run_async_generator(
|
45
|
-
self,
|
46
|
-
cache: Cache,
|
47
|
-
n: int = 1,
|
48
|
-
stop_on_exception: bool = False,
|
49
|
-
sidecar_model: Optional[LanguageModel] = None,
|
50
|
-
total_interviews: Optional[List["Interview"]] = None,
|
51
|
-
raise_validation_errors: bool = False,
|
52
|
-
) -> AsyncGenerator["Result", None]:
|
53
|
-
"""Creates the tasks, runs them asynchronously, and returns the results as a Results object.
|
54
|
-
|
55
|
-
Completed tasks are yielded as they are completed.
|
56
|
-
|
57
|
-
:param n: how many times to run each interview
|
58
|
-
:param stop_on_exception: Whether to stop the interview if an exception is raised
|
59
|
-
:param sidecar_model: a language model to use in addition to the interview's model
|
60
|
-
:param total_interviews: A list of interviews to run can be provided instead.
|
61
|
-
:param raise_validation_errors: Whether to raise validation errors
|
62
|
-
"""
|
63
|
-
tasks = []
|
64
|
-
if total_interviews: # was already passed in total interviews
|
65
|
-
self.total_interviews = total_interviews
|
66
|
-
else:
|
67
|
-
self.total_interviews = list(
|
68
|
-
self._populate_total_interviews(n=n)
|
69
|
-
) # Populate self.total_interviews before creating tasks
|
70
|
-
|
71
|
-
for interview in self.total_interviews:
|
72
|
-
interviewing_task = self._build_interview_task(
|
73
|
-
interview=interview,
|
74
|
-
stop_on_exception=stop_on_exception,
|
75
|
-
sidecar_model=sidecar_model,
|
76
|
-
raise_validation_errors=raise_validation_errors,
|
77
|
-
)
|
78
|
-
tasks.append(asyncio.create_task(interviewing_task))
|
79
|
-
|
80
|
-
for task in asyncio.as_completed(tasks):
|
81
|
-
result = await task
|
82
|
-
self.jobs_runner_status.add_completed_interview(result)
|
83
|
-
yield result
|
84
|
-
|
85
|
-
def _populate_total_interviews(
|
86
|
-
self, n: int = 1
|
87
|
-
) -> Generator["Interview", None, None]:
|
88
|
-
"""Populates self.total_interviews with n copies of each interview.
|
89
|
-
|
90
|
-
:param n: how many times to run each interview.
|
91
|
-
"""
|
92
|
-
for interview in self.interviews:
|
93
|
-
for iteration in range(n):
|
94
|
-
if iteration > 0:
|
95
|
-
yield interview.duplicate(iteration=iteration, cache=self.cache)
|
96
|
-
else:
|
97
|
-
interview.cache = self.cache
|
98
|
-
yield interview
|
99
|
-
|
100
|
-
async def run_async(self, cache: Optional[Cache] = None, n: int = 1) -> Results:
|
101
|
-
"""Used for some other modules that have a non-standard way of running interviews."""
|
102
|
-
self.jobs_runner_status = JobsRunnerStatus(self, n=n)
|
103
|
-
self.cache = Cache() if cache is None else cache
|
104
|
-
data = []
|
105
|
-
async for result in self.run_async_generator(cache=self.cache, n=n):
|
106
|
-
data.append(result)
|
107
|
-
return Results(survey=self.jobs.survey, data=data)
|
108
|
-
|
109
|
-
def simple_run(self):
|
110
|
-
data = asyncio.run(self.run_async())
|
111
|
-
return Results(survey=self.jobs.survey, data=data)
|
112
|
-
|
113
|
-
async def _build_interview_task(
|
114
|
-
self,
|
115
|
-
*,
|
116
|
-
interview: Interview,
|
117
|
-
stop_on_exception: bool = False,
|
118
|
-
sidecar_model: Optional["LanguageModel"] = None,
|
119
|
-
raise_validation_errors: bool = False,
|
120
|
-
) -> "Result":
|
121
|
-
"""Conducts an interview and returns the result.
|
122
|
-
|
123
|
-
:param interview: the interview to conduct
|
124
|
-
:param stop_on_exception: stops the interview if an exception is raised
|
125
|
-
:param sidecar_model: a language model to use in addition to the interview's model
|
126
|
-
"""
|
127
|
-
# the model buckets are used to track usage rates
|
128
|
-
model_buckets = self.bucket_collection[interview.model]
|
129
|
-
|
130
|
-
# get the results of the interview
|
131
|
-
answer, valid_results = await interview.async_conduct_interview(
|
132
|
-
model_buckets=model_buckets,
|
133
|
-
stop_on_exception=stop_on_exception,
|
134
|
-
sidecar_model=sidecar_model,
|
135
|
-
raise_validation_errors=raise_validation_errors,
|
136
|
-
)
|
137
|
-
|
138
|
-
question_results = {}
|
139
|
-
for result in valid_results:
|
140
|
-
question_results[result.question_name] = result
|
141
|
-
|
142
|
-
answer_key_names = list(question_results.keys())
|
143
|
-
|
144
|
-
generated_tokens_dict = {
|
145
|
-
k + "_generated_tokens": question_results[k].generated_tokens
|
146
|
-
for k in answer_key_names
|
147
|
-
}
|
148
|
-
comments_dict = {
|
149
|
-
k + "_comment": question_results[k].comment for k in answer_key_names
|
150
|
-
}
|
151
|
-
|
152
|
-
# we should have a valid result for each question
|
153
|
-
answer_dict = {k: answer[k] for k in answer_key_names}
|
154
|
-
assert len(valid_results) == len(answer_key_names)
|
155
|
-
|
156
|
-
# TODO: move this down into Interview
|
157
|
-
question_name_to_prompts = dict({})
|
158
|
-
for result in valid_results:
|
159
|
-
question_name = result.question_name
|
160
|
-
question_name_to_prompts[question_name] = {
|
161
|
-
"user_prompt": result.prompts["user_prompt"],
|
162
|
-
"system_prompt": result.prompts["system_prompt"],
|
163
|
-
}
|
164
|
-
|
165
|
-
prompt_dictionary = {}
|
166
|
-
for answer_key_name in answer_key_names:
|
167
|
-
prompt_dictionary[answer_key_name + "_user_prompt"] = (
|
168
|
-
question_name_to_prompts[answer_key_name]["user_prompt"]
|
169
|
-
)
|
170
|
-
prompt_dictionary[answer_key_name + "_system_prompt"] = (
|
171
|
-
question_name_to_prompts[answer_key_name]["system_prompt"]
|
172
|
-
)
|
173
|
-
|
174
|
-
raw_model_results_dictionary = {}
|
175
|
-
cache_used_dictionary = {}
|
176
|
-
for result in valid_results:
|
177
|
-
question_name = result.question_name
|
178
|
-
raw_model_results_dictionary[question_name + "_raw_model_response"] = (
|
179
|
-
result.raw_model_response
|
180
|
-
)
|
181
|
-
raw_model_results_dictionary[question_name + "_cost"] = result.cost
|
182
|
-
one_use_buys = (
|
183
|
-
"NA"
|
184
|
-
if isinstance(result.cost, str)
|
185
|
-
or result.cost == 0
|
186
|
-
or result.cost is None
|
187
|
-
else 1.0 / result.cost
|
188
|
-
)
|
189
|
-
raw_model_results_dictionary[question_name + "_one_usd_buys"] = one_use_buys
|
190
|
-
cache_used_dictionary[question_name] = result.cache_used
|
191
|
-
|
192
|
-
result = Result(
|
193
|
-
agent=interview.agent,
|
194
|
-
scenario=interview.scenario,
|
195
|
-
model=interview.model,
|
196
|
-
iteration=interview.iteration,
|
197
|
-
answer=answer_dict,
|
198
|
-
prompt=prompt_dictionary,
|
199
|
-
raw_model_response=raw_model_results_dictionary,
|
200
|
-
survey=interview.survey,
|
201
|
-
generated_tokens=generated_tokens_dict,
|
202
|
-
comments_dict=comments_dict,
|
203
|
-
cache_used_dict=cache_used_dictionary,
|
204
|
-
)
|
205
|
-
result.interview_hash = hash(interview)
|
206
|
-
|
207
|
-
return result
|
208
|
-
|
209
|
-
@property
|
210
|
-
def elapsed_time(self):
|
211
|
-
return time.monotonic() - self.start_time
|
212
|
-
|
213
|
-
def process_results(
|
214
|
-
self, raw_results: Results, cache: Cache, print_exceptions: bool
|
215
|
-
):
|
216
|
-
interview_lookup = {
|
217
|
-
hash(interview): index
|
218
|
-
for index, interview in enumerate(self.total_interviews)
|
219
|
-
}
|
220
|
-
interview_hashes = list(interview_lookup.keys())
|
221
|
-
|
222
|
-
task_history = TaskHistory(self.total_interviews, include_traceback=False)
|
223
|
-
|
224
|
-
results = Results(
|
225
|
-
survey=self.jobs.survey,
|
226
|
-
data=sorted(
|
227
|
-
raw_results, key=lambda x: interview_hashes.index(x.interview_hash)
|
228
|
-
),
|
229
|
-
task_history=task_history,
|
230
|
-
cache=cache,
|
231
|
-
)
|
232
|
-
results.bucket_collection = self.bucket_collection
|
233
|
-
|
234
|
-
if results.has_unfixed_exceptions and print_exceptions:
|
235
|
-
from edsl.scenarios.FileStore import HTMLFileStore
|
236
|
-
from edsl.config import CONFIG
|
237
|
-
from edsl.coop.coop import Coop
|
238
|
-
|
239
|
-
msg = f"Exceptions were raised in {len(results.task_history.indices)} out of {len(self.total_interviews)} interviews.\n"
|
240
|
-
|
241
|
-
if len(results.task_history.indices) > 5:
|
242
|
-
msg += f"Exceptions were raised in the following interviews: {results.task_history.indices}.\n"
|
243
|
-
|
244
|
-
print(msg)
|
245
|
-
# this is where exceptions are opening up
|
246
|
-
filepath = results.task_history.html(
|
247
|
-
cta="Open report to see details.",
|
248
|
-
open_in_browser=True,
|
249
|
-
return_link=True,
|
250
|
-
)
|
251
|
-
|
252
|
-
try:
|
253
|
-
coop = Coop()
|
254
|
-
user_edsl_settings = coop.edsl_settings
|
255
|
-
remote_logging = user_edsl_settings["remote_logging"]
|
256
|
-
except Exception as e:
|
257
|
-
print(e)
|
258
|
-
remote_logging = False
|
259
|
-
|
260
|
-
if remote_logging:
|
261
|
-
filestore = HTMLFileStore(filepath)
|
262
|
-
coop_details = filestore.push(description="Error report")
|
263
|
-
print(coop_details)
|
264
|
-
|
265
|
-
print("Also see: https://docs.expectedparrot.com/en/latest/exceptions.html")
|
266
|
-
|
267
|
-
return results
|
268
|
-
|
269
|
-
@jupyter_nb_handler
|
270
|
-
async def run(
|
271
|
-
self,
|
272
|
-
cache: Union[Cache, False, None],
|
273
|
-
n: int = 1,
|
274
|
-
stop_on_exception: bool = False,
|
275
|
-
progress_bar: bool = False,
|
276
|
-
sidecar_model: Optional[LanguageModel] = None,
|
277
|
-
print_exceptions: bool = True,
|
278
|
-
raise_validation_errors: bool = False,
|
279
|
-
) -> "Coroutine":
|
280
|
-
"""Runs a collection of interviews, handling both async and sync contexts."""
|
281
|
-
|
282
|
-
self.results = []
|
283
|
-
self.start_time = time.monotonic()
|
284
|
-
self.completed = False
|
285
|
-
self.cache = cache
|
286
|
-
self.sidecar_model = sidecar_model
|
287
|
-
|
288
|
-
self.jobs_runner_status = JobsRunnerStatus(self, n=n)
|
289
|
-
|
290
|
-
stop_event = threading.Event()
|
291
|
-
|
292
|
-
async def process_results(cache):
|
293
|
-
"""Processes results from interviews."""
|
294
|
-
async for result in self.run_async_generator(
|
295
|
-
n=n,
|
296
|
-
stop_on_exception=stop_on_exception,
|
297
|
-
cache=cache,
|
298
|
-
sidecar_model=sidecar_model,
|
299
|
-
raise_validation_errors=raise_validation_errors,
|
300
|
-
):
|
301
|
-
self.results.append(result)
|
302
|
-
self.completed = True
|
303
|
-
|
304
|
-
def run_progress_bar(stop_event):
|
305
|
-
"""Runs the progress bar in a separate thread."""
|
306
|
-
self.jobs_runner_status.update_progress(stop_event)
|
307
|
-
|
308
|
-
if progress_bar:
|
309
|
-
progress_thread = threading.Thread(
|
310
|
-
target=run_progress_bar, args=(stop_event,)
|
311
|
-
)
|
312
|
-
progress_thread.start()
|
313
|
-
|
314
|
-
exception_to_raise = None
|
315
|
-
try:
|
316
|
-
with cache as c:
|
317
|
-
await process_results(cache=c)
|
318
|
-
except KeyboardInterrupt:
|
319
|
-
print("Keyboard interrupt received. Stopping gracefully...")
|
320
|
-
stop_event.set()
|
321
|
-
except Exception as e:
|
322
|
-
if stop_on_exception:
|
323
|
-
exception_to_raise = e
|
324
|
-
stop_event.set()
|
325
|
-
finally:
|
326
|
-
stop_event.set()
|
327
|
-
if progress_bar:
|
328
|
-
# self.jobs_runner_status.stop_event.set()
|
329
|
-
if progress_thread:
|
330
|
-
progress_thread.join()
|
331
|
-
|
332
|
-
if exception_to_raise:
|
333
|
-
raise exception_to_raise
|
334
|
-
|
335
|
-
return self.process_results(
|
336
|
-
raw_results=self.results, cache=cache, print_exceptions=print_exceptions
|
337
|
-
)
|
1
|
+
from __future__ import annotations
|
2
|
+
import time
|
3
|
+
import asyncio
|
4
|
+
import threading
|
5
|
+
from typing import Coroutine, List, AsyncGenerator, Optional, Union, Generator
|
6
|
+
from contextlib import contextmanager
|
7
|
+
from collections import UserList
|
8
|
+
|
9
|
+
from edsl.results.Results import Results
|
10
|
+
from edsl.jobs.interviews.Interview import Interview
|
11
|
+
from edsl.jobs.runners.JobsRunnerStatus import JobsRunnerStatus
|
12
|
+
|
13
|
+
from edsl.jobs.tasks.TaskHistory import TaskHistory
|
14
|
+
from edsl.jobs.buckets.BucketCollection import BucketCollection
|
15
|
+
from edsl.utilities.decorators import jupyter_nb_handler
|
16
|
+
from edsl.data.Cache import Cache
|
17
|
+
from edsl.results.Result import Result
|
18
|
+
from edsl.results.Results import Results
|
19
|
+
from edsl.language_models.LanguageModel import LanguageModel
|
20
|
+
from edsl.data.Cache import Cache
|
21
|
+
|
22
|
+
class StatusTracker(UserList):
|
23
|
+
def __init__(self, total_tasks: int):
|
24
|
+
self.total_tasks = total_tasks
|
25
|
+
super().__init__()
|
26
|
+
|
27
|
+
def current_status(self):
|
28
|
+
return print(f"Completed: {len(self.data)} of {self.total_tasks}", end="\r")
|
29
|
+
|
30
|
+
|
31
|
+
class JobsRunnerAsyncio:
|
32
|
+
"""A class for running a collection of interviews asynchronously.
|
33
|
+
|
34
|
+
It gets instaniated from a Jobs object.
|
35
|
+
The Jobs object is a collection of interviews that are to be run.
|
36
|
+
"""
|
37
|
+
|
38
|
+
def __init__(self, jobs: "Jobs"):
|
39
|
+
self.jobs = jobs
|
40
|
+
self.interviews: List["Interview"] = jobs.interviews()
|
41
|
+
self.bucket_collection: "BucketCollection" = jobs.bucket_collection
|
42
|
+
self.total_interviews: List["Interview"] = []
|
43
|
+
|
44
|
+
async def run_async_generator(
|
45
|
+
self,
|
46
|
+
cache: Cache,
|
47
|
+
n: int = 1,
|
48
|
+
stop_on_exception: bool = False,
|
49
|
+
sidecar_model: Optional[LanguageModel] = None,
|
50
|
+
total_interviews: Optional[List["Interview"]] = None,
|
51
|
+
raise_validation_errors: bool = False,
|
52
|
+
) -> AsyncGenerator["Result", None]:
|
53
|
+
"""Creates the tasks, runs them asynchronously, and returns the results as a Results object.
|
54
|
+
|
55
|
+
Completed tasks are yielded as they are completed.
|
56
|
+
|
57
|
+
:param n: how many times to run each interview
|
58
|
+
:param stop_on_exception: Whether to stop the interview if an exception is raised
|
59
|
+
:param sidecar_model: a language model to use in addition to the interview's model
|
60
|
+
:param total_interviews: A list of interviews to run can be provided instead.
|
61
|
+
:param raise_validation_errors: Whether to raise validation errors
|
62
|
+
"""
|
63
|
+
tasks = []
|
64
|
+
if total_interviews: # was already passed in total interviews
|
65
|
+
self.total_interviews = total_interviews
|
66
|
+
else:
|
67
|
+
self.total_interviews = list(
|
68
|
+
self._populate_total_interviews(n=n)
|
69
|
+
) # Populate self.total_interviews before creating tasks
|
70
|
+
|
71
|
+
for interview in self.total_interviews:
|
72
|
+
interviewing_task = self._build_interview_task(
|
73
|
+
interview=interview,
|
74
|
+
stop_on_exception=stop_on_exception,
|
75
|
+
sidecar_model=sidecar_model,
|
76
|
+
raise_validation_errors=raise_validation_errors,
|
77
|
+
)
|
78
|
+
tasks.append(asyncio.create_task(interviewing_task))
|
79
|
+
|
80
|
+
for task in asyncio.as_completed(tasks):
|
81
|
+
result = await task
|
82
|
+
self.jobs_runner_status.add_completed_interview(result)
|
83
|
+
yield result
|
84
|
+
|
85
|
+
def _populate_total_interviews(
|
86
|
+
self, n: int = 1
|
87
|
+
) -> Generator["Interview", None, None]:
|
88
|
+
"""Populates self.total_interviews with n copies of each interview.
|
89
|
+
|
90
|
+
:param n: how many times to run each interview.
|
91
|
+
"""
|
92
|
+
for interview in self.interviews:
|
93
|
+
for iteration in range(n):
|
94
|
+
if iteration > 0:
|
95
|
+
yield interview.duplicate(iteration=iteration, cache=self.cache)
|
96
|
+
else:
|
97
|
+
interview.cache = self.cache
|
98
|
+
yield interview
|
99
|
+
|
100
|
+
async def run_async(self, cache: Optional[Cache] = None, n: int = 1) -> Results:
|
101
|
+
"""Used for some other modules that have a non-standard way of running interviews."""
|
102
|
+
self.jobs_runner_status = JobsRunnerStatus(self, n=n)
|
103
|
+
self.cache = Cache() if cache is None else cache
|
104
|
+
data = []
|
105
|
+
async for result in self.run_async_generator(cache=self.cache, n=n):
|
106
|
+
data.append(result)
|
107
|
+
return Results(survey=self.jobs.survey, data=data)
|
108
|
+
|
109
|
+
def simple_run(self):
|
110
|
+
data = asyncio.run(self.run_async())
|
111
|
+
return Results(survey=self.jobs.survey, data=data)
|
112
|
+
|
113
|
+
async def _build_interview_task(
|
114
|
+
self,
|
115
|
+
*,
|
116
|
+
interview: Interview,
|
117
|
+
stop_on_exception: bool = False,
|
118
|
+
sidecar_model: Optional["LanguageModel"] = None,
|
119
|
+
raise_validation_errors: bool = False,
|
120
|
+
) -> "Result":
|
121
|
+
"""Conducts an interview and returns the result.
|
122
|
+
|
123
|
+
:param interview: the interview to conduct
|
124
|
+
:param stop_on_exception: stops the interview if an exception is raised
|
125
|
+
:param sidecar_model: a language model to use in addition to the interview's model
|
126
|
+
"""
|
127
|
+
# the model buckets are used to track usage rates
|
128
|
+
model_buckets = self.bucket_collection[interview.model]
|
129
|
+
|
130
|
+
# get the results of the interview
|
131
|
+
answer, valid_results = await interview.async_conduct_interview(
|
132
|
+
model_buckets=model_buckets,
|
133
|
+
stop_on_exception=stop_on_exception,
|
134
|
+
sidecar_model=sidecar_model,
|
135
|
+
raise_validation_errors=raise_validation_errors,
|
136
|
+
)
|
137
|
+
|
138
|
+
question_results = {}
|
139
|
+
for result in valid_results:
|
140
|
+
question_results[result.question_name] = result
|
141
|
+
|
142
|
+
answer_key_names = list(question_results.keys())
|
143
|
+
|
144
|
+
generated_tokens_dict = {
|
145
|
+
k + "_generated_tokens": question_results[k].generated_tokens
|
146
|
+
for k in answer_key_names
|
147
|
+
}
|
148
|
+
comments_dict = {
|
149
|
+
k + "_comment": question_results[k].comment for k in answer_key_names
|
150
|
+
}
|
151
|
+
|
152
|
+
# we should have a valid result for each question
|
153
|
+
answer_dict = {k: answer[k] for k in answer_key_names}
|
154
|
+
assert len(valid_results) == len(answer_key_names)
|
155
|
+
|
156
|
+
# TODO: move this down into Interview
|
157
|
+
question_name_to_prompts = dict({})
|
158
|
+
for result in valid_results:
|
159
|
+
question_name = result.question_name
|
160
|
+
question_name_to_prompts[question_name] = {
|
161
|
+
"user_prompt": result.prompts["user_prompt"],
|
162
|
+
"system_prompt": result.prompts["system_prompt"],
|
163
|
+
}
|
164
|
+
|
165
|
+
prompt_dictionary = {}
|
166
|
+
for answer_key_name in answer_key_names:
|
167
|
+
prompt_dictionary[answer_key_name + "_user_prompt"] = (
|
168
|
+
question_name_to_prompts[answer_key_name]["user_prompt"]
|
169
|
+
)
|
170
|
+
prompt_dictionary[answer_key_name + "_system_prompt"] = (
|
171
|
+
question_name_to_prompts[answer_key_name]["system_prompt"]
|
172
|
+
)
|
173
|
+
|
174
|
+
raw_model_results_dictionary = {}
|
175
|
+
cache_used_dictionary = {}
|
176
|
+
for result in valid_results:
|
177
|
+
question_name = result.question_name
|
178
|
+
raw_model_results_dictionary[question_name + "_raw_model_response"] = (
|
179
|
+
result.raw_model_response
|
180
|
+
)
|
181
|
+
raw_model_results_dictionary[question_name + "_cost"] = result.cost
|
182
|
+
one_use_buys = (
|
183
|
+
"NA"
|
184
|
+
if isinstance(result.cost, str)
|
185
|
+
or result.cost == 0
|
186
|
+
or result.cost is None
|
187
|
+
else 1.0 / result.cost
|
188
|
+
)
|
189
|
+
raw_model_results_dictionary[question_name + "_one_usd_buys"] = one_use_buys
|
190
|
+
cache_used_dictionary[question_name] = result.cache_used
|
191
|
+
|
192
|
+
result = Result(
|
193
|
+
agent=interview.agent,
|
194
|
+
scenario=interview.scenario,
|
195
|
+
model=interview.model,
|
196
|
+
iteration=interview.iteration,
|
197
|
+
answer=answer_dict,
|
198
|
+
prompt=prompt_dictionary,
|
199
|
+
raw_model_response=raw_model_results_dictionary,
|
200
|
+
survey=interview.survey,
|
201
|
+
generated_tokens=generated_tokens_dict,
|
202
|
+
comments_dict=comments_dict,
|
203
|
+
cache_used_dict=cache_used_dictionary,
|
204
|
+
)
|
205
|
+
result.interview_hash = hash(interview)
|
206
|
+
|
207
|
+
return result
|
208
|
+
|
209
|
+
@property
|
210
|
+
def elapsed_time(self):
|
211
|
+
return time.monotonic() - self.start_time
|
212
|
+
|
213
|
+
def process_results(
|
214
|
+
self, raw_results: Results, cache: Cache, print_exceptions: bool
|
215
|
+
):
|
216
|
+
interview_lookup = {
|
217
|
+
hash(interview): index
|
218
|
+
for index, interview in enumerate(self.total_interviews)
|
219
|
+
}
|
220
|
+
interview_hashes = list(interview_lookup.keys())
|
221
|
+
|
222
|
+
task_history = TaskHistory(self.total_interviews, include_traceback=False)
|
223
|
+
|
224
|
+
results = Results(
|
225
|
+
survey=self.jobs.survey,
|
226
|
+
data=sorted(
|
227
|
+
raw_results, key=lambda x: interview_hashes.index(x.interview_hash)
|
228
|
+
),
|
229
|
+
task_history=task_history,
|
230
|
+
cache=cache,
|
231
|
+
)
|
232
|
+
results.bucket_collection = self.bucket_collection
|
233
|
+
|
234
|
+
if results.has_unfixed_exceptions and print_exceptions:
|
235
|
+
from edsl.scenarios.FileStore import HTMLFileStore
|
236
|
+
from edsl.config import CONFIG
|
237
|
+
from edsl.coop.coop import Coop
|
238
|
+
|
239
|
+
msg = f"Exceptions were raised in {len(results.task_history.indices)} out of {len(self.total_interviews)} interviews.\n"
|
240
|
+
|
241
|
+
if len(results.task_history.indices) > 5:
|
242
|
+
msg += f"Exceptions were raised in the following interviews: {results.task_history.indices}.\n"
|
243
|
+
|
244
|
+
print(msg)
|
245
|
+
# this is where exceptions are opening up
|
246
|
+
filepath = results.task_history.html(
|
247
|
+
cta="Open report to see details.",
|
248
|
+
open_in_browser=True,
|
249
|
+
return_link=True,
|
250
|
+
)
|
251
|
+
|
252
|
+
try:
|
253
|
+
coop = Coop()
|
254
|
+
user_edsl_settings = coop.edsl_settings
|
255
|
+
remote_logging = user_edsl_settings["remote_logging"]
|
256
|
+
except Exception as e:
|
257
|
+
print(e)
|
258
|
+
remote_logging = False
|
259
|
+
|
260
|
+
if remote_logging:
|
261
|
+
filestore = HTMLFileStore(filepath)
|
262
|
+
coop_details = filestore.push(description="Error report")
|
263
|
+
print(coop_details)
|
264
|
+
|
265
|
+
print("Also see: https://docs.expectedparrot.com/en/latest/exceptions.html")
|
266
|
+
|
267
|
+
return results
|
268
|
+
|
269
|
+
@jupyter_nb_handler
|
270
|
+
async def run(
|
271
|
+
self,
|
272
|
+
cache: Union[Cache, False, None],
|
273
|
+
n: int = 1,
|
274
|
+
stop_on_exception: bool = False,
|
275
|
+
progress_bar: bool = False,
|
276
|
+
sidecar_model: Optional[LanguageModel] = None,
|
277
|
+
print_exceptions: bool = True,
|
278
|
+
raise_validation_errors: bool = False,
|
279
|
+
) -> "Coroutine":
|
280
|
+
"""Runs a collection of interviews, handling both async and sync contexts."""
|
281
|
+
|
282
|
+
self.results = []
|
283
|
+
self.start_time = time.monotonic()
|
284
|
+
self.completed = False
|
285
|
+
self.cache = cache
|
286
|
+
self.sidecar_model = sidecar_model
|
287
|
+
|
288
|
+
self.jobs_runner_status = JobsRunnerStatus(self, n=n)
|
289
|
+
|
290
|
+
stop_event = threading.Event()
|
291
|
+
|
292
|
+
async def process_results(cache):
|
293
|
+
"""Processes results from interviews."""
|
294
|
+
async for result in self.run_async_generator(
|
295
|
+
n=n,
|
296
|
+
stop_on_exception=stop_on_exception,
|
297
|
+
cache=cache,
|
298
|
+
sidecar_model=sidecar_model,
|
299
|
+
raise_validation_errors=raise_validation_errors,
|
300
|
+
):
|
301
|
+
self.results.append(result)
|
302
|
+
self.completed = True
|
303
|
+
|
304
|
+
def run_progress_bar(stop_event):
|
305
|
+
"""Runs the progress bar in a separate thread."""
|
306
|
+
self.jobs_runner_status.update_progress(stop_event)
|
307
|
+
|
308
|
+
if progress_bar:
|
309
|
+
progress_thread = threading.Thread(
|
310
|
+
target=run_progress_bar, args=(stop_event,)
|
311
|
+
)
|
312
|
+
progress_thread.start()
|
313
|
+
|
314
|
+
exception_to_raise = None
|
315
|
+
try:
|
316
|
+
with cache as c:
|
317
|
+
await process_results(cache=c)
|
318
|
+
except KeyboardInterrupt:
|
319
|
+
print("Keyboard interrupt received. Stopping gracefully...")
|
320
|
+
stop_event.set()
|
321
|
+
except Exception as e:
|
322
|
+
if stop_on_exception:
|
323
|
+
exception_to_raise = e
|
324
|
+
stop_event.set()
|
325
|
+
finally:
|
326
|
+
stop_event.set()
|
327
|
+
if progress_bar:
|
328
|
+
# self.jobs_runner_status.stop_event.set()
|
329
|
+
if progress_thread:
|
330
|
+
progress_thread.join()
|
331
|
+
|
332
|
+
if exception_to_raise:
|
333
|
+
raise exception_to_raise
|
334
|
+
|
335
|
+
return self.process_results(
|
336
|
+
raw_results=self.results, cache=cache, print_exceptions=print_exceptions
|
337
|
+
)
|