edsl 0.1.39.dev3__py3-none-any.whl → 0.1.39.dev4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +413 -332
- edsl/BaseDiff.py +260 -260
- edsl/TemplateLoader.py +24 -24
- edsl/__init__.py +57 -49
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +1071 -867
- edsl/agents/AgentList.py +551 -413
- edsl/agents/Invigilator.py +284 -233
- edsl/agents/InvigilatorBase.py +257 -270
- edsl/agents/PromptConstructor.py +272 -354
- edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
- edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
- edsl/agents/__init__.py +2 -3
- edsl/agents/descriptors.py +99 -99
- edsl/agents/prompt_helpers.py +129 -129
- edsl/agents/question_option_processor.py +172 -0
- edsl/auto/AutoStudy.py +130 -117
- edsl/auto/StageBase.py +243 -230
- edsl/auto/StageGenerateSurvey.py +178 -178
- edsl/auto/StageLabelQuestions.py +125 -125
- edsl/auto/StagePersona.py +61 -61
- edsl/auto/StagePersonaDimensionValueRanges.py +88 -88
- edsl/auto/StagePersonaDimensionValues.py +74 -74
- edsl/auto/StagePersonaDimensions.py +69 -69
- edsl/auto/StageQuestions.py +74 -73
- edsl/auto/SurveyCreatorPipeline.py +21 -21
- edsl/auto/utilities.py +218 -224
- edsl/base/Base.py +279 -279
- edsl/config.py +177 -157
- edsl/conversation/Conversation.py +290 -290
- edsl/conversation/car_buying.py +59 -58
- edsl/conversation/chips.py +95 -95
- edsl/conversation/mug_negotiation.py +81 -81
- edsl/conversation/next_speaker_utilities.py +93 -93
- edsl/coop/CoopFunctionsMixin.py +15 -0
- edsl/coop/ExpectedParrotKeyHandler.py +125 -0
- edsl/coop/PriceFetcher.py +54 -54
- edsl/coop/__init__.py +2 -2
- edsl/coop/coop.py +1106 -1028
- edsl/coop/utils.py +131 -131
- edsl/data/Cache.py +573 -555
- edsl/data/CacheEntry.py +230 -233
- edsl/data/CacheHandler.py +168 -149
- edsl/data/RemoteCacheSync.py +186 -78
- edsl/data/SQLiteDict.py +292 -292
- edsl/data/__init__.py +5 -4
- edsl/data/hack.py +10 -0
- edsl/data/orm.py +10 -10
- edsl/data_transfer_models.py +74 -73
- edsl/enums.py +202 -175
- edsl/exceptions/BaseException.py +21 -21
- edsl/exceptions/__init__.py +54 -54
- edsl/exceptions/agents.py +54 -42
- edsl/exceptions/cache.py +5 -5
- edsl/exceptions/configuration.py +16 -16
- edsl/exceptions/coop.py +10 -10
- edsl/exceptions/data.py +14 -14
- edsl/exceptions/general.py +34 -34
- edsl/exceptions/inference_services.py +5 -0
- edsl/exceptions/jobs.py +33 -33
- edsl/exceptions/language_models.py +63 -63
- edsl/exceptions/prompts.py +15 -15
- edsl/exceptions/questions.py +109 -91
- edsl/exceptions/results.py +29 -29
- edsl/exceptions/scenarios.py +29 -22
- edsl/exceptions/surveys.py +37 -37
- edsl/inference_services/AnthropicService.py +106 -87
- edsl/inference_services/AvailableModelCacheHandler.py +184 -0
- edsl/inference_services/AvailableModelFetcher.py +215 -0
- edsl/inference_services/AwsBedrock.py +118 -120
- edsl/inference_services/AzureAI.py +215 -217
- edsl/inference_services/DeepInfraService.py +18 -18
- edsl/inference_services/GoogleService.py +143 -148
- edsl/inference_services/GroqService.py +20 -20
- edsl/inference_services/InferenceServiceABC.py +80 -147
- edsl/inference_services/InferenceServicesCollection.py +138 -97
- edsl/inference_services/MistralAIService.py +120 -123
- edsl/inference_services/OllamaService.py +18 -18
- edsl/inference_services/OpenAIService.py +236 -224
- edsl/inference_services/PerplexityService.py +160 -163
- edsl/inference_services/ServiceAvailability.py +135 -0
- edsl/inference_services/TestService.py +90 -89
- edsl/inference_services/TogetherAIService.py +172 -170
- edsl/inference_services/data_structures.py +134 -0
- edsl/inference_services/models_available_cache.py +118 -118
- edsl/inference_services/rate_limits_cache.py +25 -25
- edsl/inference_services/registry.py +41 -41
- edsl/inference_services/write_available.py +10 -10
- edsl/jobs/AnswerQuestionFunctionConstructor.py +223 -0
- edsl/jobs/Answers.py +43 -56
- edsl/jobs/FetchInvigilator.py +47 -0
- edsl/jobs/InterviewTaskManager.py +98 -0
- edsl/jobs/InterviewsConstructor.py +50 -0
- edsl/jobs/Jobs.py +823 -898
- edsl/jobs/JobsChecks.py +172 -147
- edsl/jobs/JobsComponentConstructor.py +189 -0
- edsl/jobs/JobsPrompts.py +270 -268
- edsl/jobs/JobsRemoteInferenceHandler.py +311 -239
- edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
- edsl/jobs/RequestTokenEstimator.py +30 -0
- edsl/jobs/__init__.py +1 -1
- edsl/jobs/async_interview_runner.py +138 -0
- edsl/jobs/buckets/BucketCollection.py +104 -63
- edsl/jobs/buckets/ModelBuckets.py +65 -65
- edsl/jobs/buckets/TokenBucket.py +283 -251
- edsl/jobs/buckets/TokenBucketAPI.py +211 -0
- edsl/jobs/buckets/TokenBucketClient.py +191 -0
- edsl/jobs/check_survey_scenario_compatibility.py +85 -0
- edsl/jobs/data_structures.py +120 -0
- edsl/jobs/decorators.py +35 -0
- edsl/jobs/interviews/Interview.py +396 -661
- edsl/jobs/interviews/InterviewExceptionCollection.py +99 -99
- edsl/jobs/interviews/InterviewExceptionEntry.py +186 -186
- edsl/jobs/interviews/InterviewStatistic.py +63 -63
- edsl/jobs/interviews/InterviewStatisticsCollection.py +25 -25
- edsl/jobs/interviews/InterviewStatusDictionary.py +78 -78
- edsl/jobs/interviews/InterviewStatusLog.py +92 -92
- edsl/jobs/interviews/ReportErrors.py +66 -66
- edsl/jobs/interviews/interview_status_enum.py +9 -9
- edsl/jobs/jobs_status_enums.py +9 -0
- edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
- edsl/jobs/results_exceptions_handler.py +98 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +151 -466
- edsl/jobs/runners/JobsRunnerStatus.py +297 -330
- edsl/jobs/tasks/QuestionTaskCreator.py +244 -242
- edsl/jobs/tasks/TaskCreators.py +64 -64
- edsl/jobs/tasks/TaskHistory.py +470 -450
- edsl/jobs/tasks/TaskStatusLog.py +23 -23
- edsl/jobs/tasks/task_status_enum.py +161 -163
- edsl/jobs/tokens/InterviewTokenUsage.py +27 -27
- edsl/jobs/tokens/TokenUsage.py +34 -34
- edsl/language_models/ComputeCost.py +63 -0
- edsl/language_models/LanguageModel.py +626 -668
- edsl/language_models/ModelList.py +164 -155
- edsl/language_models/PriceManager.py +127 -0
- edsl/language_models/RawResponseHandler.py +106 -0
- edsl/language_models/RegisterLanguageModelsMeta.py +184 -184
- edsl/language_models/ServiceDataSources.py +0 -0
- edsl/language_models/__init__.py +2 -3
- edsl/language_models/fake_openai_call.py +15 -15
- edsl/language_models/fake_openai_service.py +61 -61
- edsl/language_models/key_management/KeyLookup.py +63 -0
- edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
- edsl/language_models/key_management/KeyLookupCollection.py +38 -0
- edsl/language_models/key_management/__init__.py +0 -0
- edsl/language_models/key_management/models.py +131 -0
- edsl/language_models/model.py +256 -0
- edsl/language_models/repair.py +156 -156
- edsl/language_models/utilities.py +65 -64
- edsl/notebooks/Notebook.py +263 -258
- edsl/notebooks/NotebookToLaTeX.py +142 -0
- edsl/notebooks/__init__.py +1 -1
- edsl/prompts/Prompt.py +352 -362
- edsl/prompts/__init__.py +2 -2
- edsl/questions/ExceptionExplainer.py +77 -0
- edsl/questions/HTMLQuestion.py +103 -0
- edsl/questions/QuestionBase.py +518 -664
- edsl/questions/QuestionBasePromptsMixin.py +221 -217
- edsl/questions/QuestionBudget.py +227 -227
- edsl/questions/QuestionCheckBox.py +359 -359
- edsl/questions/QuestionExtract.py +180 -182
- edsl/questions/QuestionFreeText.py +113 -114
- edsl/questions/QuestionFunctional.py +166 -166
- edsl/questions/QuestionList.py +223 -231
- edsl/questions/QuestionMatrix.py +265 -0
- edsl/questions/QuestionMultipleChoice.py +330 -286
- edsl/questions/QuestionNumerical.py +151 -153
- edsl/questions/QuestionRank.py +314 -324
- edsl/questions/Quick.py +41 -41
- edsl/questions/SimpleAskMixin.py +74 -73
- edsl/questions/__init__.py +27 -26
- edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +334 -289
- edsl/questions/compose_questions.py +98 -98
- edsl/questions/data_structures.py +20 -0
- edsl/questions/decorators.py +21 -21
- edsl/questions/derived/QuestionLikertFive.py +76 -76
- edsl/questions/derived/QuestionLinearScale.py +90 -87
- edsl/questions/derived/QuestionTopK.py +93 -93
- edsl/questions/derived/QuestionYesNo.py +82 -82
- edsl/questions/descriptors.py +427 -413
- edsl/questions/loop_processor.py +149 -0
- edsl/questions/prompt_templates/question_budget.jinja +13 -13
- edsl/questions/prompt_templates/question_checkbox.jinja +32 -32
- edsl/questions/prompt_templates/question_extract.jinja +11 -11
- edsl/questions/prompt_templates/question_free_text.jinja +3 -3
- edsl/questions/prompt_templates/question_linear_scale.jinja +11 -11
- edsl/questions/prompt_templates/question_list.jinja +17 -17
- edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -33
- edsl/questions/prompt_templates/question_numerical.jinja +36 -36
- edsl/questions/{QuestionBaseGenMixin.py → question_base_gen_mixin.py} +168 -161
- edsl/questions/question_registry.py +177 -177
- edsl/questions/{RegisterQuestionsMeta.py → register_questions_meta.py} +71 -71
- edsl/questions/{ResponseValidatorABC.py → response_validator_abc.py} +188 -174
- edsl/questions/response_validator_factory.py +34 -0
- edsl/questions/settings.py +12 -12
- edsl/questions/templates/budget/answering_instructions.jinja +7 -7
- edsl/questions/templates/budget/question_presentation.jinja +7 -7
- edsl/questions/templates/checkbox/answering_instructions.jinja +10 -10
- edsl/questions/templates/checkbox/question_presentation.jinja +22 -22
- edsl/questions/templates/extract/answering_instructions.jinja +7 -7
- edsl/questions/templates/likert_five/answering_instructions.jinja +10 -10
- edsl/questions/templates/likert_five/question_presentation.jinja +11 -11
- edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -5
- edsl/questions/templates/linear_scale/question_presentation.jinja +5 -5
- edsl/questions/templates/list/answering_instructions.jinja +3 -3
- edsl/questions/templates/list/question_presentation.jinja +5 -5
- edsl/questions/templates/matrix/__init__.py +1 -0
- edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
- edsl/questions/templates/matrix/question_presentation.jinja +20 -0
- edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -9
- edsl/questions/templates/multiple_choice/question_presentation.jinja +11 -11
- edsl/questions/templates/numerical/answering_instructions.jinja +6 -6
- edsl/questions/templates/numerical/question_presentation.jinja +6 -6
- edsl/questions/templates/rank/answering_instructions.jinja +11 -11
- edsl/questions/templates/rank/question_presentation.jinja +15 -15
- edsl/questions/templates/top_k/answering_instructions.jinja +8 -8
- edsl/questions/templates/top_k/question_presentation.jinja +22 -22
- edsl/questions/templates/yes_no/answering_instructions.jinja +6 -6
- edsl/questions/templates/yes_no/question_presentation.jinja +11 -11
- edsl/results/CSSParameterizer.py +108 -108
- edsl/results/Dataset.py +587 -424
- edsl/results/DatasetExportMixin.py +594 -731
- edsl/results/DatasetTree.py +295 -275
- edsl/results/MarkdownToDocx.py +122 -0
- edsl/results/MarkdownToPDF.py +111 -0
- edsl/results/Result.py +557 -465
- edsl/results/Results.py +1183 -1165
- edsl/results/ResultsExportMixin.py +45 -43
- edsl/results/ResultsGGMixin.py +121 -121
- edsl/results/TableDisplay.py +125 -198
- edsl/results/TextEditor.py +50 -0
- edsl/results/__init__.py +2 -2
- edsl/results/file_exports.py +252 -0
- edsl/results/{ResultsFetchMixin.py → results_fetch_mixin.py} +33 -33
- edsl/results/{Selector.py → results_selector.py} +145 -135
- edsl/results/{ResultsToolsMixin.py → results_tools_mixin.py} +98 -98
- edsl/results/smart_objects.py +96 -0
- edsl/results/table_data_class.py +12 -0
- edsl/results/table_display.css +77 -77
- edsl/results/table_renderers.py +118 -0
- edsl/results/tree_explore.py +115 -115
- edsl/scenarios/ConstructDownloadLink.py +109 -0
- edsl/scenarios/DocumentChunker.py +102 -0
- edsl/scenarios/DocxScenario.py +16 -0
- edsl/scenarios/FileStore.py +511 -632
- edsl/scenarios/PdfExtractor.py +40 -0
- edsl/scenarios/Scenario.py +498 -601
- edsl/scenarios/ScenarioHtmlMixin.py +65 -64
- edsl/scenarios/ScenarioList.py +1458 -1287
- edsl/scenarios/ScenarioListExportMixin.py +45 -52
- edsl/scenarios/ScenarioListPdfMixin.py +239 -261
- edsl/scenarios/__init__.py +3 -4
- edsl/scenarios/directory_scanner.py +96 -0
- edsl/scenarios/file_methods.py +85 -0
- edsl/scenarios/handlers/__init__.py +13 -0
- edsl/scenarios/handlers/csv.py +38 -0
- edsl/scenarios/handlers/docx.py +76 -0
- edsl/scenarios/handlers/html.py +37 -0
- edsl/scenarios/handlers/json.py +111 -0
- edsl/scenarios/handlers/latex.py +5 -0
- edsl/scenarios/handlers/md.py +51 -0
- edsl/scenarios/handlers/pdf.py +68 -0
- edsl/scenarios/handlers/png.py +39 -0
- edsl/scenarios/handlers/pptx.py +105 -0
- edsl/scenarios/handlers/py.py +294 -0
- edsl/scenarios/handlers/sql.py +313 -0
- edsl/scenarios/handlers/sqlite.py +149 -0
- edsl/scenarios/handlers/txt.py +33 -0
- edsl/scenarios/{ScenarioJoin.py → scenario_join.py} +131 -127
- edsl/scenarios/scenario_selector.py +156 -0
- edsl/shared.py +1 -1
- edsl/study/ObjectEntry.py +173 -173
- edsl/study/ProofOfWork.py +113 -113
- edsl/study/SnapShot.py +80 -80
- edsl/study/Study.py +521 -528
- edsl/study/__init__.py +4 -4
- edsl/surveys/ConstructDAG.py +92 -0
- edsl/surveys/DAG.py +148 -148
- edsl/surveys/EditSurvey.py +221 -0
- edsl/surveys/InstructionHandler.py +100 -0
- edsl/surveys/Memory.py +31 -31
- edsl/surveys/MemoryManagement.py +72 -0
- edsl/surveys/MemoryPlan.py +244 -244
- edsl/surveys/Rule.py +327 -326
- edsl/surveys/RuleCollection.py +385 -387
- edsl/surveys/RuleManager.py +172 -0
- edsl/surveys/Simulator.py +75 -0
- edsl/surveys/Survey.py +1280 -1801
- edsl/surveys/SurveyCSS.py +273 -261
- edsl/surveys/SurveyExportMixin.py +259 -259
- edsl/surveys/{SurveyFlowVisualizationMixin.py → SurveyFlowVisualization.py} +181 -179
- edsl/surveys/SurveyQualtricsImport.py +284 -284
- edsl/surveys/SurveyToApp.py +141 -0
- edsl/surveys/__init__.py +5 -3
- edsl/surveys/base.py +53 -53
- edsl/surveys/descriptors.py +60 -56
- edsl/surveys/instructions/ChangeInstruction.py +48 -49
- edsl/surveys/instructions/Instruction.py +56 -65
- edsl/surveys/instructions/InstructionCollection.py +82 -77
- edsl/templates/error_reporting/base.html +23 -23
- edsl/templates/error_reporting/exceptions_by_model.html +34 -34
- edsl/templates/error_reporting/exceptions_by_question_name.html +16 -16
- edsl/templates/error_reporting/exceptions_by_type.html +16 -16
- edsl/templates/error_reporting/interview_details.html +115 -115
- edsl/templates/error_reporting/interviews.html +19 -19
- edsl/templates/error_reporting/overview.html +4 -4
- edsl/templates/error_reporting/performance_plot.html +1 -1
- edsl/templates/error_reporting/report.css +73 -73
- edsl/templates/error_reporting/report.html +117 -117
- edsl/templates/error_reporting/report.js +25 -25
- edsl/test_h +1 -0
- edsl/tools/__init__.py +1 -1
- edsl/tools/clusters.py +192 -192
- edsl/tools/embeddings.py +27 -27
- edsl/tools/embeddings_plotting.py +118 -118
- edsl/tools/plotting.py +112 -112
- edsl/tools/summarize.py +18 -18
- edsl/utilities/PrettyList.py +56 -0
- edsl/utilities/SystemInfo.py +28 -28
- edsl/utilities/__init__.py +22 -22
- edsl/utilities/ast_utilities.py +25 -25
- edsl/utilities/data/Registry.py +6 -6
- edsl/utilities/data/__init__.py +1 -1
- edsl/utilities/data/scooter_results.json +1 -1
- edsl/utilities/decorators.py +77 -77
- edsl/utilities/gcp_bucket/cloud_storage.py +96 -96
- edsl/utilities/gcp_bucket/example.py +50 -0
- edsl/utilities/interface.py +627 -627
- edsl/utilities/is_notebook.py +18 -0
- edsl/utilities/is_valid_variable_name.py +11 -0
- edsl/utilities/naming_utilities.py +263 -263
- edsl/utilities/remove_edsl_version.py +24 -0
- edsl/utilities/repair_functions.py +28 -28
- edsl/utilities/restricted_python.py +70 -70
- edsl/utilities/utilities.py +436 -424
- {edsl-0.1.39.dev3.dist-info → edsl-0.1.39.dev4.dist-info}/LICENSE +21 -21
- {edsl-0.1.39.dev3.dist-info → edsl-0.1.39.dev4.dist-info}/METADATA +13 -11
- edsl-0.1.39.dev4.dist-info/RECORD +361 -0
- edsl/language_models/KeyLookup.py +0 -30
- edsl/language_models/registry.py +0 -190
- edsl/language_models/unused/ReplicateBase.py +0 -83
- edsl/results/ResultsDBMixin.py +0 -238
- edsl-0.1.39.dev3.dist-info/RECORD +0 -277
- {edsl-0.1.39.dev3.dist-info → edsl-0.1.39.dev4.dist-info}/WHEEL +0 -0
@@ -1,290 +1,290 @@
|
|
1
|
-
from collections import UserList
|
2
|
-
import asyncio
|
3
|
-
import inspect
|
4
|
-
from typing import Optional, Callable
|
5
|
-
from edsl import Agent, QuestionFreeText, Results, AgentList, ScenarioList, Scenario
|
6
|
-
from edsl.questions import QuestionBase
|
7
|
-
from edsl.results.Result import Result
|
8
|
-
from jinja2 import Template
|
9
|
-
from edsl.data import Cache
|
10
|
-
|
11
|
-
from edsl.conversation.next_speaker_utilities import (
|
12
|
-
default_turn_taking_generator,
|
13
|
-
speaker_closure,
|
14
|
-
)
|
15
|
-
|
16
|
-
|
17
|
-
class AgentStatement:
|
18
|
-
def __init__(self, statement: Result):
|
19
|
-
self.statement = statement
|
20
|
-
|
21
|
-
@property
|
22
|
-
def agent_name(self):
|
23
|
-
return self.statement["agent"]["name"]
|
24
|
-
|
25
|
-
def to_dict(self):
|
26
|
-
return self.statement.to_dict()
|
27
|
-
|
28
|
-
@classmethod
|
29
|
-
def from_dict(cls, data):
|
30
|
-
return cls(Result.from_dict(data))
|
31
|
-
|
32
|
-
@property
|
33
|
-
def text(self):
|
34
|
-
return self.statement["answer"]["dialogue"]
|
35
|
-
|
36
|
-
|
37
|
-
class AgentStatements(UserList):
|
38
|
-
def __init__(self, data=None):
|
39
|
-
super().__init__(data)
|
40
|
-
|
41
|
-
@property
|
42
|
-
def transcript(self):
|
43
|
-
return [{s.agent_name: s.text} for s in self.data]
|
44
|
-
|
45
|
-
def to_dict(self):
|
46
|
-
return [d.to_dict() for d in self.data]
|
47
|
-
|
48
|
-
@classmethod
|
49
|
-
def from_dict(cls, data):
|
50
|
-
return cls([AgentStatement.from_dict(d) for d in data])
|
51
|
-
|
52
|
-
|
53
|
-
class Conversation:
|
54
|
-
"""A conversation between a list of agents. The first agent in the list is the first speaker.
|
55
|
-
After that, order is determined by the next_speaker function.
|
56
|
-
The question asked to each agent is determined by the next_statement_question.
|
57
|
-
|
58
|
-
If the user has passed in a "per_round_message_template", this will be displayed at the beginning of each round.
|
59
|
-
{{ round_message }} must be in the question_text.
|
60
|
-
"""
|
61
|
-
|
62
|
-
def __init__(
|
63
|
-
self,
|
64
|
-
agent_list: AgentList,
|
65
|
-
max_turns: int = 20,
|
66
|
-
stopping_function: Optional[Callable] = None,
|
67
|
-
next_statement_question: Optional[QuestionBase] = None,
|
68
|
-
next_speaker_generator: Optional[Callable] = None,
|
69
|
-
verbose: bool = False,
|
70
|
-
per_round_message_template: Optional[str] = None,
|
71
|
-
conversation_index: Optional[int] = None,
|
72
|
-
cache=None,
|
73
|
-
disable_remote_inference=False,
|
74
|
-
default_model: Optional["LanguageModel"] = None,
|
75
|
-
):
|
76
|
-
self.disable_remote_inference = disable_remote_inference
|
77
|
-
self.per_round_message_template = per_round_message_template
|
78
|
-
|
79
|
-
if cache is None:
|
80
|
-
self.cache = Cache()
|
81
|
-
else:
|
82
|
-
self.cache = cache
|
83
|
-
|
84
|
-
self.agent_list = agent_list
|
85
|
-
|
86
|
-
from edsl import Model
|
87
|
-
|
88
|
-
for agent in self.agent_list:
|
89
|
-
if not hasattr(agent, "model"):
|
90
|
-
if default_model is not None:
|
91
|
-
agent.model = default_model
|
92
|
-
else:
|
93
|
-
agent.model = Model()
|
94
|
-
|
95
|
-
self.verbose = verbose
|
96
|
-
self.agent_statements = []
|
97
|
-
self._conversation_index = conversation_index
|
98
|
-
self.agent_statements = AgentStatements()
|
99
|
-
|
100
|
-
self.max_turns = max_turns
|
101
|
-
|
102
|
-
if next_statement_question is None:
|
103
|
-
import textwrap
|
104
|
-
|
105
|
-
base_question = textwrap.dedent(
|
106
|
-
"""\
|
107
|
-
You are {{ agent_name }}. This is the conversation so far: {{ conversation }}
|
108
|
-
{% if round_message is not none %}
|
109
|
-
{{ round_message }}
|
110
|
-
{% endif %}
|
111
|
-
What do you say next?"""
|
112
|
-
)
|
113
|
-
self.next_statement_question = QuestionFreeText(
|
114
|
-
question_text=base_question,
|
115
|
-
question_name="dialogue",
|
116
|
-
)
|
117
|
-
else:
|
118
|
-
self.next_statement_question = next_statement_question
|
119
|
-
if (
|
120
|
-
per_round_message_template
|
121
|
-
and "{{ round_message }}" not in next_statement_question.question_text
|
122
|
-
):
|
123
|
-
raise ValueError(
|
124
|
-
"If you pass in a per_round_message_template, you must include {{ round_message }} in the question_text."
|
125
|
-
)
|
126
|
-
|
127
|
-
# Determine how the next speaker is chosen
|
128
|
-
if next_speaker_generator is None:
|
129
|
-
func = default_turn_taking_generator
|
130
|
-
else:
|
131
|
-
func = next_speaker_generator
|
132
|
-
|
133
|
-
# Choose the next speaker
|
134
|
-
self.next_speaker = speaker_closure(
|
135
|
-
agent_list=self.agent_list, generator_function=func
|
136
|
-
)
|
137
|
-
|
138
|
-
# Determine when the conversation ends
|
139
|
-
if stopping_function is None:
|
140
|
-
self.stopping_function = lambda agent_statements: False
|
141
|
-
else:
|
142
|
-
self.stopping_function = stopping_function
|
143
|
-
|
144
|
-
async def continue_conversation(self, **kwargs) -> bool:
|
145
|
-
if len(self.agent_statements) >= self.max_turns:
|
146
|
-
return False
|
147
|
-
|
148
|
-
if inspect.iscoroutinefunction(self.stopping_function):
|
149
|
-
should_stop = await self.stopping_function(self.agent_statements, **kwargs)
|
150
|
-
else:
|
151
|
-
should_stop = self.stopping_function(self.agent_statements, **kwargs)
|
152
|
-
|
153
|
-
return not should_stop
|
154
|
-
|
155
|
-
def add_index(self, index) -> None:
|
156
|
-
self._conversation_index = index
|
157
|
-
|
158
|
-
@property
|
159
|
-
def conversation_index(self):
|
160
|
-
return self._conversation_index
|
161
|
-
|
162
|
-
def to_dict(self):
|
163
|
-
return {
|
164
|
-
"agent_list": self.agent_list.to_dict(),
|
165
|
-
"max_turns": self.max_turns,
|
166
|
-
"verbose": self.verbose,
|
167
|
-
"agent_statements": [d.to_dict() for d in self.agent_statements],
|
168
|
-
"conversation_index": self.conversation_index,
|
169
|
-
}
|
170
|
-
|
171
|
-
@classmethod
|
172
|
-
def from_dict(cls, data):
|
173
|
-
agent_list = AgentList.from_dict(data["agent_list"])
|
174
|
-
max_turns = data["max_turns"]
|
175
|
-
verbose = data["verbose"]
|
176
|
-
agent_statements = (AgentStatements.from_dict(data["agent_statements"]),)
|
177
|
-
conversation_index = data["conversation_index"]
|
178
|
-
return cls(
|
179
|
-
agent_list=agent_list,
|
180
|
-
max_turns=max_turns,
|
181
|
-
verbose=verbose,
|
182
|
-
results_data=agent_statements,
|
183
|
-
conversation_index=conversation_index,
|
184
|
-
)
|
185
|
-
|
186
|
-
def to_results(self):
|
187
|
-
return Results(data=[s.statement for s in self.agent_statements])
|
188
|
-
|
189
|
-
def summarize(self):
|
190
|
-
d = {
|
191
|
-
"num_agents": len(self.agent_list),
|
192
|
-
"max_turns": self.max_turns,
|
193
|
-
"conversation_index": self.conversation_index,
|
194
|
-
"transcript": self.to_results().select("agent_name", "dialogue").to_list(),
|
195
|
-
"number_of_agent_statements": len(self.agent_statements),
|
196
|
-
}
|
197
|
-
return Scenario(d)
|
198
|
-
|
199
|
-
async def get_next_statement(self, *, index, speaker, conversation) -> "Result":
|
200
|
-
"""Get the next statement from the speaker."""
|
201
|
-
q = self.next_statement_question
|
202
|
-
# assert q.parameters == {"agent_name", "conversation"}, q.parameters
|
203
|
-
from edsl import Scenario
|
204
|
-
|
205
|
-
if self.per_round_message_template is None:
|
206
|
-
round_message = None
|
207
|
-
else:
|
208
|
-
round_message = Template(self.per_round_message_template).render(
|
209
|
-
{"max_turns": self.max_turns, "current_turn": index}
|
210
|
-
)
|
211
|
-
|
212
|
-
s = Scenario(
|
213
|
-
{
|
214
|
-
"agent_name": speaker.name,
|
215
|
-
"conversation": conversation,
|
216
|
-
"conversation_index": self.conversation_index,
|
217
|
-
"index": index,
|
218
|
-
"round_message": round_message,
|
219
|
-
}
|
220
|
-
)
|
221
|
-
jobs = q.by(s).by(speaker).by(speaker.model)
|
222
|
-
jobs.show_prompts()
|
223
|
-
results = await jobs.run_async(
|
224
|
-
cache=self.cache, disable_remote_inference=self.disable_remote_inference
|
225
|
-
)
|
226
|
-
return results[0]
|
227
|
-
|
228
|
-
def converse(self):
|
229
|
-
return asyncio.run(self._converse())
|
230
|
-
|
231
|
-
async def _converse(self):
|
232
|
-
i = 0
|
233
|
-
while await self.continue_conversation():
|
234
|
-
speaker = self.next_speaker()
|
235
|
-
|
236
|
-
next_statement = AgentStatement(
|
237
|
-
statement=await self.get_next_statement(
|
238
|
-
index=i,
|
239
|
-
speaker=speaker,
|
240
|
-
conversation=self.agent_statements.transcript,
|
241
|
-
)
|
242
|
-
)
|
243
|
-
self.agent_statements.append(next_statement)
|
244
|
-
if self.verbose:
|
245
|
-
print(f"'{speaker.name}':{next_statement.text}")
|
246
|
-
print("\n")
|
247
|
-
i += 1
|
248
|
-
|
249
|
-
|
250
|
-
class ConversationList:
|
251
|
-
"""A collection of conversations to be run in parallel."""
|
252
|
-
|
253
|
-
def __init__(self, conversations: list[Conversation], cache=None):
|
254
|
-
self.conversations = conversations
|
255
|
-
for i, conversation in enumerate(self.conversations):
|
256
|
-
conversation.add_index(i)
|
257
|
-
|
258
|
-
if cache is None:
|
259
|
-
self.cache = Cache()
|
260
|
-
else:
|
261
|
-
self.cache = cache
|
262
|
-
|
263
|
-
for c in self.conversations:
|
264
|
-
c.cache = self.cache
|
265
|
-
|
266
|
-
async def run_conversations(self):
|
267
|
-
await asyncio.gather(*[c._converse() for c in self.conversations])
|
268
|
-
|
269
|
-
def run(self) -> None:
|
270
|
-
"""Run all conversations in parallel"""
|
271
|
-
asyncio.run(self.run_conversations())
|
272
|
-
|
273
|
-
def to_dict(self) -> dict:
|
274
|
-
return {"conversations": c.to_dict() for c in self.conversations}
|
275
|
-
|
276
|
-
@classmethod
|
277
|
-
def from_dict(cls, data):
|
278
|
-
conversations = [Conversation.from_dict(d) for d in data["conversations"]]
|
279
|
-
return cls(conversations)
|
280
|
-
|
281
|
-
def to_results(self) -> Results:
|
282
|
-
"""Return the results of all conversations as a single Results"""
|
283
|
-
first_convo = self.conversations[0]
|
284
|
-
results = first_convo.to_results()
|
285
|
-
for conv in self.conversations[1:]:
|
286
|
-
results += conv.to_results()
|
287
|
-
return results
|
288
|
-
|
289
|
-
def summarize(self) -> ScenarioList:
|
290
|
-
return ScenarioList([c.summarize() for c in self.conversations])
|
1
|
+
from collections import UserList
|
2
|
+
import asyncio
|
3
|
+
import inspect
|
4
|
+
from typing import Optional, Callable
|
5
|
+
from edsl import Agent, QuestionFreeText, Results, AgentList, ScenarioList, Scenario
|
6
|
+
from edsl.questions import QuestionBase
|
7
|
+
from edsl.results.Result import Result
|
8
|
+
from jinja2 import Template
|
9
|
+
from edsl.data import Cache
|
10
|
+
|
11
|
+
from edsl.conversation.next_speaker_utilities import (
|
12
|
+
default_turn_taking_generator,
|
13
|
+
speaker_closure,
|
14
|
+
)
|
15
|
+
|
16
|
+
|
17
|
+
class AgentStatement:
|
18
|
+
def __init__(self, statement: Result):
|
19
|
+
self.statement = statement
|
20
|
+
|
21
|
+
@property
|
22
|
+
def agent_name(self):
|
23
|
+
return self.statement["agent"]["name"]
|
24
|
+
|
25
|
+
def to_dict(self):
|
26
|
+
return self.statement.to_dict()
|
27
|
+
|
28
|
+
@classmethod
|
29
|
+
def from_dict(cls, data):
|
30
|
+
return cls(Result.from_dict(data))
|
31
|
+
|
32
|
+
@property
|
33
|
+
def text(self):
|
34
|
+
return self.statement["answer"]["dialogue"]
|
35
|
+
|
36
|
+
|
37
|
+
class AgentStatements(UserList):
|
38
|
+
def __init__(self, data=None):
|
39
|
+
super().__init__(data)
|
40
|
+
|
41
|
+
@property
|
42
|
+
def transcript(self):
|
43
|
+
return [{s.agent_name: s.text} for s in self.data]
|
44
|
+
|
45
|
+
def to_dict(self):
|
46
|
+
return [d.to_dict() for d in self.data]
|
47
|
+
|
48
|
+
@classmethod
|
49
|
+
def from_dict(cls, data):
|
50
|
+
return cls([AgentStatement.from_dict(d) for d in data])
|
51
|
+
|
52
|
+
|
53
|
+
class Conversation:
|
54
|
+
"""A conversation between a list of agents. The first agent in the list is the first speaker.
|
55
|
+
After that, order is determined by the next_speaker function.
|
56
|
+
The question asked to each agent is determined by the next_statement_question.
|
57
|
+
|
58
|
+
If the user has passed in a "per_round_message_template", this will be displayed at the beginning of each round.
|
59
|
+
{{ round_message }} must be in the question_text.
|
60
|
+
"""
|
61
|
+
|
62
|
+
def __init__(
|
63
|
+
self,
|
64
|
+
agent_list: AgentList,
|
65
|
+
max_turns: int = 20,
|
66
|
+
stopping_function: Optional[Callable] = None,
|
67
|
+
next_statement_question: Optional[QuestionBase] = None,
|
68
|
+
next_speaker_generator: Optional[Callable] = None,
|
69
|
+
verbose: bool = False,
|
70
|
+
per_round_message_template: Optional[str] = None,
|
71
|
+
conversation_index: Optional[int] = None,
|
72
|
+
cache=None,
|
73
|
+
disable_remote_inference=False,
|
74
|
+
default_model: Optional["LanguageModel"] = None,
|
75
|
+
):
|
76
|
+
self.disable_remote_inference = disable_remote_inference
|
77
|
+
self.per_round_message_template = per_round_message_template
|
78
|
+
|
79
|
+
if cache is None:
|
80
|
+
self.cache = Cache()
|
81
|
+
else:
|
82
|
+
self.cache = cache
|
83
|
+
|
84
|
+
self.agent_list = agent_list
|
85
|
+
|
86
|
+
from edsl import Model
|
87
|
+
|
88
|
+
for agent in self.agent_list:
|
89
|
+
if not hasattr(agent, "model"):
|
90
|
+
if default_model is not None:
|
91
|
+
agent.model = default_model
|
92
|
+
else:
|
93
|
+
agent.model = Model()
|
94
|
+
|
95
|
+
self.verbose = verbose
|
96
|
+
self.agent_statements = []
|
97
|
+
self._conversation_index = conversation_index
|
98
|
+
self.agent_statements = AgentStatements()
|
99
|
+
|
100
|
+
self.max_turns = max_turns
|
101
|
+
|
102
|
+
if next_statement_question is None:
|
103
|
+
import textwrap
|
104
|
+
|
105
|
+
base_question = textwrap.dedent(
|
106
|
+
"""\
|
107
|
+
You are {{ agent_name }}. This is the conversation so far: {{ conversation }}
|
108
|
+
{% if round_message is not none %}
|
109
|
+
{{ round_message }}
|
110
|
+
{% endif %}
|
111
|
+
What do you say next?"""
|
112
|
+
)
|
113
|
+
self.next_statement_question = QuestionFreeText(
|
114
|
+
question_text=base_question,
|
115
|
+
question_name="dialogue",
|
116
|
+
)
|
117
|
+
else:
|
118
|
+
self.next_statement_question = next_statement_question
|
119
|
+
if (
|
120
|
+
per_round_message_template
|
121
|
+
and "{{ round_message }}" not in next_statement_question.question_text
|
122
|
+
):
|
123
|
+
raise ValueError(
|
124
|
+
"If you pass in a per_round_message_template, you must include {{ round_message }} in the question_text."
|
125
|
+
)
|
126
|
+
|
127
|
+
# Determine how the next speaker is chosen
|
128
|
+
if next_speaker_generator is None:
|
129
|
+
func = default_turn_taking_generator
|
130
|
+
else:
|
131
|
+
func = next_speaker_generator
|
132
|
+
|
133
|
+
# Choose the next speaker
|
134
|
+
self.next_speaker = speaker_closure(
|
135
|
+
agent_list=self.agent_list, generator_function=func
|
136
|
+
)
|
137
|
+
|
138
|
+
# Determine when the conversation ends
|
139
|
+
if stopping_function is None:
|
140
|
+
self.stopping_function = lambda agent_statements: False
|
141
|
+
else:
|
142
|
+
self.stopping_function = stopping_function
|
143
|
+
|
144
|
+
async def continue_conversation(self, **kwargs) -> bool:
|
145
|
+
if len(self.agent_statements) >= self.max_turns:
|
146
|
+
return False
|
147
|
+
|
148
|
+
if inspect.iscoroutinefunction(self.stopping_function):
|
149
|
+
should_stop = await self.stopping_function(self.agent_statements, **kwargs)
|
150
|
+
else:
|
151
|
+
should_stop = self.stopping_function(self.agent_statements, **kwargs)
|
152
|
+
|
153
|
+
return not should_stop
|
154
|
+
|
155
|
+
def add_index(self, index) -> None:
|
156
|
+
self._conversation_index = index
|
157
|
+
|
158
|
+
@property
|
159
|
+
def conversation_index(self):
|
160
|
+
return self._conversation_index
|
161
|
+
|
162
|
+
def to_dict(self):
|
163
|
+
return {
|
164
|
+
"agent_list": self.agent_list.to_dict(),
|
165
|
+
"max_turns": self.max_turns,
|
166
|
+
"verbose": self.verbose,
|
167
|
+
"agent_statements": [d.to_dict() for d in self.agent_statements],
|
168
|
+
"conversation_index": self.conversation_index,
|
169
|
+
}
|
170
|
+
|
171
|
+
@classmethod
|
172
|
+
def from_dict(cls, data):
|
173
|
+
agent_list = AgentList.from_dict(data["agent_list"])
|
174
|
+
max_turns = data["max_turns"]
|
175
|
+
verbose = data["verbose"]
|
176
|
+
agent_statements = (AgentStatements.from_dict(data["agent_statements"]),)
|
177
|
+
conversation_index = data["conversation_index"]
|
178
|
+
return cls(
|
179
|
+
agent_list=agent_list,
|
180
|
+
max_turns=max_turns,
|
181
|
+
verbose=verbose,
|
182
|
+
results_data=agent_statements,
|
183
|
+
conversation_index=conversation_index,
|
184
|
+
)
|
185
|
+
|
186
|
+
def to_results(self):
|
187
|
+
return Results(data=[s.statement for s in self.agent_statements])
|
188
|
+
|
189
|
+
def summarize(self):
|
190
|
+
d = {
|
191
|
+
"num_agents": len(self.agent_list),
|
192
|
+
"max_turns": self.max_turns,
|
193
|
+
"conversation_index": self.conversation_index,
|
194
|
+
"transcript": self.to_results().select("agent_name", "dialogue").to_list(),
|
195
|
+
"number_of_agent_statements": len(self.agent_statements),
|
196
|
+
}
|
197
|
+
return Scenario(d)
|
198
|
+
|
199
|
+
async def get_next_statement(self, *, index, speaker, conversation) -> "Result":
|
200
|
+
"""Get the next statement from the speaker."""
|
201
|
+
q = self.next_statement_question
|
202
|
+
# assert q.parameters == {"agent_name", "conversation"}, q.parameters
|
203
|
+
from edsl import Scenario
|
204
|
+
|
205
|
+
if self.per_round_message_template is None:
|
206
|
+
round_message = None
|
207
|
+
else:
|
208
|
+
round_message = Template(self.per_round_message_template).render(
|
209
|
+
{"max_turns": self.max_turns, "current_turn": index}
|
210
|
+
)
|
211
|
+
|
212
|
+
s = Scenario(
|
213
|
+
{
|
214
|
+
"agent_name": speaker.name,
|
215
|
+
"conversation": conversation,
|
216
|
+
"conversation_index": self.conversation_index,
|
217
|
+
"index": index,
|
218
|
+
"round_message": round_message,
|
219
|
+
}
|
220
|
+
)
|
221
|
+
jobs = q.by(s).by(speaker).by(speaker.model)
|
222
|
+
jobs.show_prompts()
|
223
|
+
results = await jobs.run_async(
|
224
|
+
cache=self.cache, disable_remote_inference=self.disable_remote_inference
|
225
|
+
)
|
226
|
+
return results[0]
|
227
|
+
|
228
|
+
def converse(self):
|
229
|
+
return asyncio.run(self._converse())
|
230
|
+
|
231
|
+
async def _converse(self):
|
232
|
+
i = 0
|
233
|
+
while await self.continue_conversation():
|
234
|
+
speaker = self.next_speaker()
|
235
|
+
|
236
|
+
next_statement = AgentStatement(
|
237
|
+
statement=await self.get_next_statement(
|
238
|
+
index=i,
|
239
|
+
speaker=speaker,
|
240
|
+
conversation=self.agent_statements.transcript,
|
241
|
+
)
|
242
|
+
)
|
243
|
+
self.agent_statements.append(next_statement)
|
244
|
+
if self.verbose:
|
245
|
+
print(f"'{speaker.name}':{next_statement.text}")
|
246
|
+
print("\n")
|
247
|
+
i += 1
|
248
|
+
|
249
|
+
|
250
|
+
class ConversationList:
|
251
|
+
"""A collection of conversations to be run in parallel."""
|
252
|
+
|
253
|
+
def __init__(self, conversations: list[Conversation], cache=None):
|
254
|
+
self.conversations = conversations
|
255
|
+
for i, conversation in enumerate(self.conversations):
|
256
|
+
conversation.add_index(i)
|
257
|
+
|
258
|
+
if cache is None:
|
259
|
+
self.cache = Cache()
|
260
|
+
else:
|
261
|
+
self.cache = cache
|
262
|
+
|
263
|
+
for c in self.conversations:
|
264
|
+
c.cache = self.cache
|
265
|
+
|
266
|
+
async def run_conversations(self):
|
267
|
+
await asyncio.gather(*[c._converse() for c in self.conversations])
|
268
|
+
|
269
|
+
def run(self) -> None:
|
270
|
+
"""Run all conversations in parallel"""
|
271
|
+
asyncio.run(self.run_conversations())
|
272
|
+
|
273
|
+
def to_dict(self) -> dict:
|
274
|
+
return {"conversations": c.to_dict() for c in self.conversations}
|
275
|
+
|
276
|
+
@classmethod
|
277
|
+
def from_dict(cls, data):
|
278
|
+
conversations = [Conversation.from_dict(d) for d in data["conversations"]]
|
279
|
+
return cls(conversations)
|
280
|
+
|
281
|
+
def to_results(self) -> Results:
|
282
|
+
"""Return the results of all conversations as a single Results"""
|
283
|
+
first_convo = self.conversations[0]
|
284
|
+
results = first_convo.to_results()
|
285
|
+
for conv in self.conversations[1:]:
|
286
|
+
results += conv.to_results()
|
287
|
+
return results
|
288
|
+
|
289
|
+
def summarize(self) -> ScenarioList:
|
290
|
+
return ScenarioList([c.summarize() for c in self.conversations])
|