edsl 0.1.36.dev6__py3-none-any.whl → 0.1.37__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +303 -303
- edsl/BaseDiff.py +260 -260
- edsl/TemplateLoader.py +24 -24
- edsl/__init__.py +48 -47
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +855 -804
- edsl/agents/AgentList.py +350 -337
- edsl/agents/Invigilator.py +222 -222
- edsl/agents/InvigilatorBase.py +284 -294
- edsl/agents/PromptConstructor.py +353 -312
- edsl/agents/__init__.py +3 -3
- edsl/agents/descriptors.py +99 -86
- edsl/agents/prompt_helpers.py +129 -129
- edsl/auto/AutoStudy.py +117 -117
- edsl/auto/StageBase.py +230 -230
- edsl/auto/StageGenerateSurvey.py +178 -178
- edsl/auto/StageLabelQuestions.py +125 -125
- edsl/auto/StagePersona.py +61 -61
- edsl/auto/StagePersonaDimensionValueRanges.py +88 -88
- edsl/auto/StagePersonaDimensionValues.py +74 -74
- edsl/auto/StagePersonaDimensions.py +69 -69
- edsl/auto/StageQuestions.py +73 -73
- edsl/auto/SurveyCreatorPipeline.py +21 -21
- edsl/auto/utilities.py +224 -224
- edsl/base/Base.py +289 -289
- edsl/config.py +149 -149
- edsl/conjure/AgentConstructionMixin.py +160 -152
- edsl/conjure/Conjure.py +62 -62
- edsl/conjure/InputData.py +659 -659
- edsl/conjure/InputDataCSV.py +48 -48
- edsl/conjure/InputDataMixinQuestionStats.py +182 -182
- edsl/conjure/InputDataPyRead.py +91 -91
- edsl/conjure/InputDataSPSS.py +8 -8
- edsl/conjure/InputDataStata.py +8 -8
- edsl/conjure/QuestionOptionMixin.py +76 -76
- edsl/conjure/QuestionTypeMixin.py +23 -23
- edsl/conjure/RawQuestion.py +65 -65
- edsl/conjure/SurveyResponses.py +7 -7
- edsl/conjure/__init__.py +9 -9
- edsl/conjure/naming_utilities.py +263 -263
- edsl/conjure/utilities.py +201 -201
- edsl/conversation/Conversation.py +290 -238
- edsl/conversation/car_buying.py +58 -58
- edsl/conversation/chips.py +95 -0
- edsl/conversation/mug_negotiation.py +81 -81
- edsl/conversation/next_speaker_utilities.py +93 -93
- edsl/coop/PriceFetcher.py +54 -54
- edsl/coop/__init__.py +2 -2
- edsl/coop/coop.py +958 -849
- edsl/coop/utils.py +131 -131
- edsl/data/Cache.py +527 -527
- edsl/data/CacheEntry.py +228 -228
- edsl/data/CacheHandler.py +149 -149
- edsl/data/RemoteCacheSync.py +97 -84
- edsl/data/SQLiteDict.py +292 -292
- edsl/data/__init__.py +4 -4
- edsl/data/orm.py +10 -10
- edsl/data_transfer_models.py +73 -73
- edsl/enums.py +173 -173
- edsl/exceptions/BaseException.py +21 -0
- edsl/exceptions/__init__.py +54 -50
- edsl/exceptions/agents.py +38 -40
- edsl/exceptions/configuration.py +16 -16
- edsl/exceptions/coop.py +10 -10
- edsl/exceptions/data.py +14 -14
- edsl/exceptions/general.py +34 -34
- edsl/exceptions/jobs.py +33 -33
- edsl/exceptions/language_models.py +63 -63
- edsl/exceptions/prompts.py +15 -15
- edsl/exceptions/questions.py +91 -91
- edsl/exceptions/results.py +29 -26
- edsl/exceptions/scenarios.py +22 -0
- edsl/exceptions/surveys.py +37 -34
- edsl/inference_services/AnthropicService.py +87 -87
- edsl/inference_services/AwsBedrock.py +120 -115
- edsl/inference_services/AzureAI.py +217 -217
- edsl/inference_services/DeepInfraService.py +18 -18
- edsl/inference_services/GoogleService.py +156 -156
- edsl/inference_services/GroqService.py +20 -20
- edsl/inference_services/InferenceServiceABC.py +147 -147
- edsl/inference_services/InferenceServicesCollection.py +97 -72
- edsl/inference_services/MistralAIService.py +123 -123
- edsl/inference_services/OllamaService.py +18 -18
- edsl/inference_services/OpenAIService.py +224 -224
- edsl/inference_services/TestService.py +89 -89
- edsl/inference_services/TogetherAIService.py +170 -170
- edsl/inference_services/models_available_cache.py +118 -118
- edsl/inference_services/rate_limits_cache.py +25 -25
- edsl/inference_services/registry.py +39 -39
- edsl/inference_services/write_available.py +10 -10
- edsl/jobs/Answers.py +56 -56
- edsl/jobs/Jobs.py +1347 -1112
- edsl/jobs/__init__.py +1 -1
- edsl/jobs/buckets/BucketCollection.py +63 -63
- edsl/jobs/buckets/ModelBuckets.py +65 -65
- edsl/jobs/buckets/TokenBucket.py +248 -248
- edsl/jobs/interviews/Interview.py +661 -651
- edsl/jobs/interviews/InterviewExceptionCollection.py +99 -99
- edsl/jobs/interviews/InterviewExceptionEntry.py +186 -182
- edsl/jobs/interviews/InterviewStatistic.py +63 -63
- edsl/jobs/interviews/InterviewStatisticsCollection.py +25 -25
- edsl/jobs/interviews/InterviewStatusDictionary.py +78 -78
- edsl/jobs/interviews/InterviewStatusLog.py +92 -92
- edsl/jobs/interviews/ReportErrors.py +66 -66
- edsl/jobs/interviews/interview_status_enum.py +9 -9
- edsl/jobs/runners/JobsRunnerAsyncio.py +338 -337
- edsl/jobs/runners/JobsRunnerStatus.py +332 -332
- edsl/jobs/tasks/QuestionTaskCreator.py +242 -242
- edsl/jobs/tasks/TaskCreators.py +64 -64
- edsl/jobs/tasks/TaskHistory.py +442 -441
- edsl/jobs/tasks/TaskStatusLog.py +23 -23
- edsl/jobs/tasks/task_status_enum.py +163 -163
- edsl/jobs/tokens/InterviewTokenUsage.py +27 -27
- edsl/jobs/tokens/TokenUsage.py +34 -34
- edsl/language_models/KeyLookup.py +30 -0
- edsl/language_models/LanguageModel.py +706 -718
- edsl/language_models/ModelList.py +102 -102
- edsl/language_models/RegisterLanguageModelsMeta.py +184 -184
- edsl/language_models/__init__.py +3 -2
- edsl/language_models/fake_openai_call.py +15 -15
- edsl/language_models/fake_openai_service.py +61 -61
- edsl/language_models/registry.py +137 -137
- edsl/language_models/repair.py +156 -156
- edsl/language_models/unused/ReplicateBase.py +83 -83
- edsl/language_models/utilities.py +64 -64
- edsl/notebooks/Notebook.py +259 -259
- edsl/notebooks/__init__.py +1 -1
- edsl/prompts/Prompt.py +357 -358
- edsl/prompts/__init__.py +2 -2
- edsl/questions/AnswerValidatorMixin.py +289 -289
- edsl/questions/QuestionBase.py +656 -616
- edsl/questions/QuestionBaseGenMixin.py +161 -161
- edsl/questions/QuestionBasePromptsMixin.py +234 -266
- edsl/questions/QuestionBudget.py +227 -227
- edsl/questions/QuestionCheckBox.py +359 -359
- edsl/questions/QuestionExtract.py +183 -183
- edsl/questions/QuestionFreeText.py +114 -113
- edsl/questions/QuestionFunctional.py +159 -159
- edsl/questions/QuestionList.py +231 -231
- edsl/questions/QuestionMultipleChoice.py +286 -286
- edsl/questions/QuestionNumerical.py +153 -153
- edsl/questions/QuestionRank.py +324 -324
- edsl/questions/Quick.py +41 -41
- edsl/questions/RegisterQuestionsMeta.py +71 -71
- edsl/questions/ResponseValidatorABC.py +174 -174
- edsl/questions/SimpleAskMixin.py +73 -73
- edsl/questions/__init__.py +26 -26
- edsl/questions/compose_questions.py +98 -98
- edsl/questions/decorators.py +21 -21
- edsl/questions/derived/QuestionLikertFive.py +76 -76
- edsl/questions/derived/QuestionLinearScale.py +87 -87
- edsl/questions/derived/QuestionTopK.py +91 -91
- edsl/questions/derived/QuestionYesNo.py +82 -82
- edsl/questions/descriptors.py +413 -418
- edsl/questions/prompt_templates/question_budget.jinja +13 -13
- edsl/questions/prompt_templates/question_checkbox.jinja +32 -32
- edsl/questions/prompt_templates/question_extract.jinja +11 -11
- edsl/questions/prompt_templates/question_free_text.jinja +3 -3
- edsl/questions/prompt_templates/question_linear_scale.jinja +11 -11
- edsl/questions/prompt_templates/question_list.jinja +17 -17
- edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -33
- edsl/questions/prompt_templates/question_numerical.jinja +36 -36
- edsl/questions/question_registry.py +147 -147
- edsl/questions/settings.py +12 -12
- edsl/questions/templates/budget/answering_instructions.jinja +7 -7
- edsl/questions/templates/budget/question_presentation.jinja +7 -7
- edsl/questions/templates/checkbox/answering_instructions.jinja +10 -10
- edsl/questions/templates/checkbox/question_presentation.jinja +22 -22
- edsl/questions/templates/extract/answering_instructions.jinja +7 -7
- edsl/questions/templates/likert_five/answering_instructions.jinja +10 -10
- edsl/questions/templates/likert_five/question_presentation.jinja +11 -11
- edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -5
- edsl/questions/templates/linear_scale/question_presentation.jinja +5 -5
- edsl/questions/templates/list/answering_instructions.jinja +3 -3
- edsl/questions/templates/list/question_presentation.jinja +5 -5
- edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -9
- edsl/questions/templates/multiple_choice/question_presentation.jinja +11 -11
- edsl/questions/templates/numerical/answering_instructions.jinja +6 -6
- edsl/questions/templates/numerical/question_presentation.jinja +6 -6
- edsl/questions/templates/rank/answering_instructions.jinja +11 -11
- edsl/questions/templates/rank/question_presentation.jinja +15 -15
- edsl/questions/templates/top_k/answering_instructions.jinja +8 -8
- edsl/questions/templates/top_k/question_presentation.jinja +22 -22
- edsl/questions/templates/yes_no/answering_instructions.jinja +6 -6
- edsl/questions/templates/yes_no/question_presentation.jinja +11 -11
- edsl/results/Dataset.py +293 -293
- edsl/results/DatasetExportMixin.py +717 -693
- edsl/results/DatasetTree.py +145 -145
- edsl/results/Result.py +450 -433
- edsl/results/Results.py +1071 -1158
- edsl/results/ResultsDBMixin.py +238 -238
- edsl/results/ResultsExportMixin.py +43 -43
- edsl/results/ResultsFetchMixin.py +33 -33
- edsl/results/ResultsGGMixin.py +121 -121
- edsl/results/ResultsToolsMixin.py +98 -98
- edsl/results/Selector.py +135 -118
- edsl/results/__init__.py +2 -2
- edsl/results/tree_explore.py +115 -115
- edsl/scenarios/FileStore.py +458 -443
- edsl/scenarios/Scenario.py +546 -507
- edsl/scenarios/ScenarioHtmlMixin.py +64 -59
- edsl/scenarios/ScenarioList.py +1112 -1101
- edsl/scenarios/ScenarioListExportMixin.py +52 -52
- edsl/scenarios/ScenarioListPdfMixin.py +261 -261
- edsl/scenarios/__init__.py +4 -2
- edsl/shared.py +1 -1
- edsl/study/ObjectEntry.py +173 -173
- edsl/study/ProofOfWork.py +113 -113
- edsl/study/SnapShot.py +80 -80
- edsl/study/Study.py +528 -528
- edsl/study/__init__.py +4 -4
- edsl/surveys/DAG.py +148 -148
- edsl/surveys/Memory.py +31 -31
- edsl/surveys/MemoryPlan.py +244 -244
- edsl/surveys/Rule.py +330 -324
- edsl/surveys/RuleCollection.py +387 -387
- edsl/surveys/Survey.py +1795 -1772
- edsl/surveys/SurveyCSS.py +261 -261
- edsl/surveys/SurveyExportMixin.py +259 -259
- edsl/surveys/SurveyFlowVisualizationMixin.py +121 -121
- edsl/surveys/SurveyQualtricsImport.py +284 -284
- edsl/surveys/__init__.py +3 -3
- edsl/surveys/base.py +53 -53
- edsl/surveys/descriptors.py +56 -56
- edsl/surveys/instructions/ChangeInstruction.py +47 -47
- edsl/surveys/instructions/Instruction.py +51 -51
- edsl/surveys/instructions/InstructionCollection.py +77 -77
- edsl/templates/error_reporting/base.html +23 -23
- edsl/templates/error_reporting/exceptions_by_model.html +34 -34
- edsl/templates/error_reporting/exceptions_by_question_name.html +16 -16
- edsl/templates/error_reporting/exceptions_by_type.html +16 -16
- edsl/templates/error_reporting/interview_details.html +115 -115
- edsl/templates/error_reporting/interviews.html +9 -9
- edsl/templates/error_reporting/overview.html +4 -4
- edsl/templates/error_reporting/performance_plot.html +1 -1
- edsl/templates/error_reporting/report.css +73 -73
- edsl/templates/error_reporting/report.html +117 -117
- edsl/templates/error_reporting/report.js +25 -25
- edsl/tools/__init__.py +1 -1
- edsl/tools/clusters.py +192 -192
- edsl/tools/embeddings.py +27 -27
- edsl/tools/embeddings_plotting.py +118 -118
- edsl/tools/plotting.py +112 -112
- edsl/tools/summarize.py +18 -18
- edsl/utilities/SystemInfo.py +28 -28
- edsl/utilities/__init__.py +22 -22
- edsl/utilities/ast_utilities.py +25 -25
- edsl/utilities/data/Registry.py +6 -6
- edsl/utilities/data/__init__.py +1 -1
- edsl/utilities/data/scooter_results.json +1 -1
- edsl/utilities/decorators.py +77 -77
- edsl/utilities/gcp_bucket/cloud_storage.py +96 -96
- edsl/utilities/interface.py +627 -627
- edsl/utilities/repair_functions.py +28 -28
- edsl/utilities/restricted_python.py +70 -70
- edsl/utilities/utilities.py +409 -391
- {edsl-0.1.36.dev6.dist-info → edsl-0.1.37.dist-info}/LICENSE +21 -21
- {edsl-0.1.36.dev6.dist-info → edsl-0.1.37.dist-info}/METADATA +1 -1
- edsl-0.1.37.dist-info/RECORD +283 -0
- edsl-0.1.36.dev6.dist-info/RECORD +0 -279
- {edsl-0.1.36.dev6.dist-info → edsl-0.1.37.dist-info}/WHEEL +0 -0
@@ -1,102 +1,102 @@
|
|
1
|
-
from typing import Optional
|
2
|
-
from collections import UserList
|
3
|
-
from edsl import Model
|
4
|
-
|
5
|
-
from edsl.language_models import LanguageModel
|
6
|
-
from edsl.Base import Base
|
7
|
-
from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
|
8
|
-
from edsl.utilities.utilities import is_valid_variable_name
|
9
|
-
from edsl.utilities.utilities import dict_hash
|
10
|
-
|
11
|
-
|
12
|
-
class ModelList(Base, UserList):
|
13
|
-
def __init__(self, data: Optional[list] = None):
|
14
|
-
"""Initialize the ScenarioList class.
|
15
|
-
|
16
|
-
>>> from edsl import Model
|
17
|
-
>>> m = ModelList(Model.available())
|
18
|
-
|
19
|
-
"""
|
20
|
-
if data is not None:
|
21
|
-
super().__init__(data)
|
22
|
-
else:
|
23
|
-
super().__init__([])
|
24
|
-
|
25
|
-
@property
|
26
|
-
def names(self):
|
27
|
-
"""
|
28
|
-
|
29
|
-
>>> ModelList.example().names
|
30
|
-
{'...'}
|
31
|
-
"""
|
32
|
-
return set([model.model for model in self])
|
33
|
-
|
34
|
-
def rich_print(self):
|
35
|
-
pass
|
36
|
-
|
37
|
-
def __repr__(self):
|
38
|
-
return f"ModelList({super().__repr__()})"
|
39
|
-
|
40
|
-
def __hash__(self):
|
41
|
-
"""Return a hash of the ModelList. This is used for comparison of ModelLists.
|
42
|
-
|
43
|
-
>>> isinstance(hash(Model()), int)
|
44
|
-
True
|
45
|
-
|
46
|
-
"""
|
47
|
-
from edsl.utilities.utilities import dict_hash
|
48
|
-
|
49
|
-
return dict_hash(self._to_dict(sort=True))
|
50
|
-
|
51
|
-
def _to_dict(self, sort=False):
|
52
|
-
if sort:
|
53
|
-
model_list = sorted([model for model in self], key=lambda x: hash(x))
|
54
|
-
return {"models": [model._to_dict() for model in model_list]}
|
55
|
-
else:
|
56
|
-
return {"models": [model._to_dict() for model in self]}
|
57
|
-
|
58
|
-
@classmethod
|
59
|
-
def from_names(self, *args, **kwargs):
|
60
|
-
"""A a model list from a list of names"""
|
61
|
-
if len(args) == 1 and isinstance(args[0], list):
|
62
|
-
args = args[0]
|
63
|
-
return ModelList([Model(model_name, **kwargs) for model_name in args])
|
64
|
-
|
65
|
-
@add_edsl_version
|
66
|
-
def to_dict(self):
|
67
|
-
"""
|
68
|
-
Convert the ModelList to a dictionary.
|
69
|
-
>>> ModelList.example().to_dict()
|
70
|
-
{'models': [...], 'edsl_version': '...', 'edsl_class_name': 'ModelList'}
|
71
|
-
"""
|
72
|
-
return self._to_dict()
|
73
|
-
|
74
|
-
@classmethod
|
75
|
-
@remove_edsl_version
|
76
|
-
def from_dict(cls, data):
|
77
|
-
"""
|
78
|
-
Create a ModelList from a dictionary.
|
79
|
-
|
80
|
-
>>> newm = ModelList.from_dict(ModelList.example().to_dict())
|
81
|
-
>>> assert ModelList.example() == newm
|
82
|
-
"""
|
83
|
-
return cls(data=[LanguageModel.from_dict(model) for model in data["models"]])
|
84
|
-
|
85
|
-
def code(self):
|
86
|
-
pass
|
87
|
-
|
88
|
-
@classmethod
|
89
|
-
def example(cls, randomize: bool = False) -> "ModelList":
|
90
|
-
"""
|
91
|
-
Returns an example ModelList instance.
|
92
|
-
|
93
|
-
:param randomize: If True, uses Model's randomize method.
|
94
|
-
"""
|
95
|
-
|
96
|
-
return cls([Model.example(randomize) for _ in range(3)])
|
97
|
-
|
98
|
-
|
99
|
-
if __name__ == "__main__":
|
100
|
-
import doctest
|
101
|
-
|
102
|
-
doctest.testmod(optionflags=doctest.ELLIPSIS)
|
1
|
+
from typing import Optional
|
2
|
+
from collections import UserList
|
3
|
+
from edsl import Model
|
4
|
+
|
5
|
+
from edsl.language_models import LanguageModel
|
6
|
+
from edsl.Base import Base
|
7
|
+
from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
|
8
|
+
from edsl.utilities.utilities import is_valid_variable_name
|
9
|
+
from edsl.utilities.utilities import dict_hash
|
10
|
+
|
11
|
+
|
12
|
+
class ModelList(Base, UserList):
|
13
|
+
def __init__(self, data: Optional[list] = None):
|
14
|
+
"""Initialize the ScenarioList class.
|
15
|
+
|
16
|
+
>>> from edsl import Model
|
17
|
+
>>> m = ModelList(Model.available())
|
18
|
+
|
19
|
+
"""
|
20
|
+
if data is not None:
|
21
|
+
super().__init__(data)
|
22
|
+
else:
|
23
|
+
super().__init__([])
|
24
|
+
|
25
|
+
@property
|
26
|
+
def names(self):
|
27
|
+
"""
|
28
|
+
|
29
|
+
>>> ModelList.example().names
|
30
|
+
{'...'}
|
31
|
+
"""
|
32
|
+
return set([model.model for model in self])
|
33
|
+
|
34
|
+
def rich_print(self):
|
35
|
+
pass
|
36
|
+
|
37
|
+
def __repr__(self):
|
38
|
+
return f"ModelList({super().__repr__()})"
|
39
|
+
|
40
|
+
def __hash__(self):
|
41
|
+
"""Return a hash of the ModelList. This is used for comparison of ModelLists.
|
42
|
+
|
43
|
+
>>> isinstance(hash(Model()), int)
|
44
|
+
True
|
45
|
+
|
46
|
+
"""
|
47
|
+
from edsl.utilities.utilities import dict_hash
|
48
|
+
|
49
|
+
return dict_hash(self._to_dict(sort=True))
|
50
|
+
|
51
|
+
def _to_dict(self, sort=False):
|
52
|
+
if sort:
|
53
|
+
model_list = sorted([model for model in self], key=lambda x: hash(x))
|
54
|
+
return {"models": [model._to_dict() for model in model_list]}
|
55
|
+
else:
|
56
|
+
return {"models": [model._to_dict() for model in self]}
|
57
|
+
|
58
|
+
@classmethod
|
59
|
+
def from_names(self, *args, **kwargs):
|
60
|
+
"""A a model list from a list of names"""
|
61
|
+
if len(args) == 1 and isinstance(args[0], list):
|
62
|
+
args = args[0]
|
63
|
+
return ModelList([Model(model_name, **kwargs) for model_name in args])
|
64
|
+
|
65
|
+
@add_edsl_version
|
66
|
+
def to_dict(self):
|
67
|
+
"""
|
68
|
+
Convert the ModelList to a dictionary.
|
69
|
+
>>> ModelList.example().to_dict()
|
70
|
+
{'models': [...], 'edsl_version': '...', 'edsl_class_name': 'ModelList'}
|
71
|
+
"""
|
72
|
+
return self._to_dict()
|
73
|
+
|
74
|
+
@classmethod
|
75
|
+
@remove_edsl_version
|
76
|
+
def from_dict(cls, data):
|
77
|
+
"""
|
78
|
+
Create a ModelList from a dictionary.
|
79
|
+
|
80
|
+
>>> newm = ModelList.from_dict(ModelList.example().to_dict())
|
81
|
+
>>> assert ModelList.example() == newm
|
82
|
+
"""
|
83
|
+
return cls(data=[LanguageModel.from_dict(model) for model in data["models"]])
|
84
|
+
|
85
|
+
def code(self):
|
86
|
+
pass
|
87
|
+
|
88
|
+
@classmethod
|
89
|
+
def example(cls, randomize: bool = False) -> "ModelList":
|
90
|
+
"""
|
91
|
+
Returns an example ModelList instance.
|
92
|
+
|
93
|
+
:param randomize: If True, uses Model's randomize method.
|
94
|
+
"""
|
95
|
+
|
96
|
+
return cls([Model.example(randomize) for _ in range(3)])
|
97
|
+
|
98
|
+
|
99
|
+
if __name__ == "__main__":
|
100
|
+
import doctest
|
101
|
+
|
102
|
+
doctest.testmod(optionflags=doctest.ELLIPSIS)
|
@@ -1,184 +1,184 @@
|
|
1
|
-
from abc import ABC, ABCMeta
|
2
|
-
from typing import Any, List, Callable
|
3
|
-
import inspect
|
4
|
-
from typing import get_type_hints
|
5
|
-
from edsl.exceptions.language_models import LanguageModelAttributeTypeError
|
6
|
-
from edsl.enums import InferenceServiceType
|
7
|
-
|
8
|
-
|
9
|
-
class RegisterLanguageModelsMeta(ABCMeta):
|
10
|
-
"""Metaclass to register output elements in a registry i.e., those that have a parent."""
|
11
|
-
|
12
|
-
_registry = {} # Initialize the registry as a dictionary
|
13
|
-
REQUIRED_CLASS_ATTRIBUTES = ["_model_", "_parameters_", "_inference_service_"]
|
14
|
-
|
15
|
-
def __init__(cls, name, bases, dct):
|
16
|
-
"""Register the class in the registry if it has a _model_ attribute."""
|
17
|
-
super(RegisterLanguageModelsMeta, cls).__init__(name, bases, dct)
|
18
|
-
# if name != "LanguageModel":
|
19
|
-
if (model_name := getattr(cls, "_model_", None)) is not None:
|
20
|
-
RegisterLanguageModelsMeta.check_required_class_variables(
|
21
|
-
cls, RegisterLanguageModelsMeta.REQUIRED_CLASS_ATTRIBUTES
|
22
|
-
)
|
23
|
-
|
24
|
-
## Check that model name is valid
|
25
|
-
# if not LanguageModelType.is_value_valid(model_name):
|
26
|
-
# acceptable_values = [item.value for item in LanguageModelType]
|
27
|
-
# raise LanguageModelAttributeTypeError(
|
28
|
-
# f"""A LanguageModel's model must be one of {LanguageModelType} values, which are
|
29
|
-
# {acceptable_values}. You passed {model_name}."""
|
30
|
-
# )
|
31
|
-
|
32
|
-
if not InferenceServiceType.is_value_valid(
|
33
|
-
inference_service := getattr(cls, "_inference_service_", None)
|
34
|
-
):
|
35
|
-
acceptable_values = [item.value for item in InferenceServiceType]
|
36
|
-
raise LanguageModelAttributeTypeError(
|
37
|
-
f"""A LanguageModel's model must have an _inference_service_ value from
|
38
|
-
{acceptable_values}. You passed {inference_service}."""
|
39
|
-
)
|
40
|
-
|
41
|
-
# LanguageModel children have to implement the async_execute_model_call method
|
42
|
-
RegisterLanguageModelsMeta.verify_method(
|
43
|
-
candidate_class=cls,
|
44
|
-
method_name="async_execute_model_call",
|
45
|
-
expected_return_type=dict[str, Any],
|
46
|
-
required_parameters=[("user_prompt", str), ("system_prompt", str)],
|
47
|
-
must_be_async=True,
|
48
|
-
)
|
49
|
-
# LanguageModel children have to implement the parse_response method
|
50
|
-
RegisterLanguageModelsMeta._registry[model_name] = cls
|
51
|
-
|
52
|
-
@classmethod
|
53
|
-
def get_registered_classes(cls):
|
54
|
-
"""Return the registry."""
|
55
|
-
return cls._registry
|
56
|
-
|
57
|
-
@staticmethod
|
58
|
-
def check_required_class_variables(
|
59
|
-
candidate_class: "LanguageModel", required_attributes: List[str] = None
|
60
|
-
):
|
61
|
-
"""Check if a class has the required attributes.
|
62
|
-
|
63
|
-
>>> class M:
|
64
|
-
... _model_ = "m"
|
65
|
-
... _parameters_ = {}
|
66
|
-
>>> RegisterLanguageModelsMeta.check_required_class_variables(M, ["_model_", "_parameters_"])
|
67
|
-
>>> class M2:
|
68
|
-
... _model_ = "m"
|
69
|
-
>>> RegisterLanguageModelsMeta.check_required_class_variables(M2, ["_model_", "_parameters_"])
|
70
|
-
Traceback (most recent call last):
|
71
|
-
...
|
72
|
-
Exception: Class M2 does not have required attribute _parameters_
|
73
|
-
"""
|
74
|
-
required_attributes = required_attributes or []
|
75
|
-
for attribute in required_attributes:
|
76
|
-
if not hasattr(candidate_class, attribute):
|
77
|
-
raise Exception(
|
78
|
-
f"Class {candidate_class.__name__} does not have required attribute {attribute}"
|
79
|
-
)
|
80
|
-
|
81
|
-
@staticmethod
|
82
|
-
def verify_method(
|
83
|
-
candidate_class: "LanguageModel",
|
84
|
-
method_name: str,
|
85
|
-
expected_return_type: Any,
|
86
|
-
required_parameters: List[tuple[str, Any]] = None,
|
87
|
-
must_be_async: bool = False,
|
88
|
-
):
|
89
|
-
"""Verify that a method is defined in a class, has the correct return type, and has the correct parameters."""
|
90
|
-
RegisterLanguageModelsMeta._check_method_defined(candidate_class, method_name)
|
91
|
-
|
92
|
-
required_parameters = required_parameters or []
|
93
|
-
method = getattr(candidate_class, method_name)
|
94
|
-
# signature = inspect.signature(method)
|
95
|
-
|
96
|
-
RegisterLanguageModelsMeta._check_return_type(method, expected_return_type)
|
97
|
-
|
98
|
-
if must_be_async:
|
99
|
-
RegisterLanguageModelsMeta._check_is_coroutine(method)
|
100
|
-
|
101
|
-
# Check the parameters
|
102
|
-
# params = signature.parameters
|
103
|
-
# for param_name, param_type in required_parameters:
|
104
|
-
# RegisterLanguageModelsMeta._verify_parameter(
|
105
|
-
# params, param_name, param_type, method_name
|
106
|
-
# )
|
107
|
-
|
108
|
-
@staticmethod
|
109
|
-
def _check_method_defined(cls, method_name):
|
110
|
-
"""Check if a method is defined in a class.
|
111
|
-
|
112
|
-
Example:
|
113
|
-
>>> class M:
|
114
|
-
... def f(self): pass
|
115
|
-
>>> RegisterLanguageModelsMeta._check_method_defined(M, "f")
|
116
|
-
>>> RegisterLanguageModelsMeta._check_method_defined(M, "g")
|
117
|
-
Traceback (most recent call last):
|
118
|
-
...
|
119
|
-
NotImplementedError: g method must be implemented.
|
120
|
-
"""
|
121
|
-
if not hasattr(cls, method_name):
|
122
|
-
raise NotImplementedError(f"{method_name} method must be implemented.")
|
123
|
-
|
124
|
-
@staticmethod
|
125
|
-
def _check_is_coroutine(func: Callable):
|
126
|
-
"""Check to make sure it's a coroutine function.
|
127
|
-
|
128
|
-
Example:
|
129
|
-
|
130
|
-
>>> def f(): pass
|
131
|
-
>>> RegisterLanguageModelsMeta._check_is_coroutine(f)
|
132
|
-
Traceback (most recent call last):
|
133
|
-
...
|
134
|
-
TypeError: A LangugeModel class with method f must be an asynchronous method.
|
135
|
-
"""
|
136
|
-
if not inspect.iscoroutinefunction(func):
|
137
|
-
raise TypeError(
|
138
|
-
f"A LangugeModel class with method {func.__name__} must be an asynchronous method."
|
139
|
-
)
|
140
|
-
|
141
|
-
@staticmethod
|
142
|
-
def _verify_parameter(params, param_name, param_type, method_name):
|
143
|
-
"""Verify that a parameter is defined in a method and has the correct type."""
|
144
|
-
pass
|
145
|
-
# if param_name not in params:
|
146
|
-
# raise TypeError(
|
147
|
-
# f"""Parameter "{param_name}" of method "{method_name}" must be defined.
|
148
|
-
# """
|
149
|
-
# )
|
150
|
-
# if params[param_name].annotation != param_type:
|
151
|
-
# raise TypeError(
|
152
|
-
# f"""Parameter "{param_name}" of method "{method_name}" must be of type {param_type.__name__}.
|
153
|
-
# Got {params[param_name].annotation} instead.
|
154
|
-
# """
|
155
|
-
# )
|
156
|
-
|
157
|
-
@staticmethod
|
158
|
-
def _check_return_type(method, expected_return_type):
|
159
|
-
"""
|
160
|
-
Check if the return type of a method is as expected.
|
161
|
-
|
162
|
-
Example:
|
163
|
-
"""
|
164
|
-
pass
|
165
|
-
# if inspect.isroutine(method):
|
166
|
-
# # return_type = inspect.signature(method).return_annotation
|
167
|
-
# return_type = get_type_hints(method)["return"]
|
168
|
-
# if return_type != expected_return_type:
|
169
|
-
# raise TypeError(
|
170
|
-
# f"Return type of {method.__name__} must be {expected_return_type}. Got {return_type}."
|
171
|
-
# )
|
172
|
-
|
173
|
-
@classmethod
|
174
|
-
def model_names_to_classes(cls):
|
175
|
-
"""Return a dictionary of model names to classes."""
|
176
|
-
d = {}
|
177
|
-
for classname, cls in cls._registry.items():
|
178
|
-
if hasattr(cls, "_model_"):
|
179
|
-
d[cls._model_] = cls
|
180
|
-
else:
|
181
|
-
raise Exception(
|
182
|
-
f"Class {classname} does not have a _model_ class attribute."
|
183
|
-
)
|
184
|
-
return d
|
1
|
+
from abc import ABC, ABCMeta
|
2
|
+
from typing import Any, List, Callable
|
3
|
+
import inspect
|
4
|
+
from typing import get_type_hints
|
5
|
+
from edsl.exceptions.language_models import LanguageModelAttributeTypeError
|
6
|
+
from edsl.enums import InferenceServiceType
|
7
|
+
|
8
|
+
|
9
|
+
class RegisterLanguageModelsMeta(ABCMeta):
|
10
|
+
"""Metaclass to register output elements in a registry i.e., those that have a parent."""
|
11
|
+
|
12
|
+
_registry = {} # Initialize the registry as a dictionary
|
13
|
+
REQUIRED_CLASS_ATTRIBUTES = ["_model_", "_parameters_", "_inference_service_"]
|
14
|
+
|
15
|
+
def __init__(cls, name, bases, dct):
|
16
|
+
"""Register the class in the registry if it has a _model_ attribute."""
|
17
|
+
super(RegisterLanguageModelsMeta, cls).__init__(name, bases, dct)
|
18
|
+
# if name != "LanguageModel":
|
19
|
+
if (model_name := getattr(cls, "_model_", None)) is not None:
|
20
|
+
RegisterLanguageModelsMeta.check_required_class_variables(
|
21
|
+
cls, RegisterLanguageModelsMeta.REQUIRED_CLASS_ATTRIBUTES
|
22
|
+
)
|
23
|
+
|
24
|
+
## Check that model name is valid
|
25
|
+
# if not LanguageModelType.is_value_valid(model_name):
|
26
|
+
# acceptable_values = [item.value for item in LanguageModelType]
|
27
|
+
# raise LanguageModelAttributeTypeError(
|
28
|
+
# f"""A LanguageModel's model must be one of {LanguageModelType} values, which are
|
29
|
+
# {acceptable_values}. You passed {model_name}."""
|
30
|
+
# )
|
31
|
+
|
32
|
+
if not InferenceServiceType.is_value_valid(
|
33
|
+
inference_service := getattr(cls, "_inference_service_", None)
|
34
|
+
):
|
35
|
+
acceptable_values = [item.value for item in InferenceServiceType]
|
36
|
+
raise LanguageModelAttributeTypeError(
|
37
|
+
f"""A LanguageModel's model must have an _inference_service_ value from
|
38
|
+
{acceptable_values}. You passed {inference_service}."""
|
39
|
+
)
|
40
|
+
|
41
|
+
# LanguageModel children have to implement the async_execute_model_call method
|
42
|
+
RegisterLanguageModelsMeta.verify_method(
|
43
|
+
candidate_class=cls,
|
44
|
+
method_name="async_execute_model_call",
|
45
|
+
expected_return_type=dict[str, Any],
|
46
|
+
required_parameters=[("user_prompt", str), ("system_prompt", str)],
|
47
|
+
must_be_async=True,
|
48
|
+
)
|
49
|
+
# LanguageModel children have to implement the parse_response method
|
50
|
+
RegisterLanguageModelsMeta._registry[model_name] = cls
|
51
|
+
|
52
|
+
@classmethod
|
53
|
+
def get_registered_classes(cls):
|
54
|
+
"""Return the registry."""
|
55
|
+
return cls._registry
|
56
|
+
|
57
|
+
@staticmethod
|
58
|
+
def check_required_class_variables(
|
59
|
+
candidate_class: "LanguageModel", required_attributes: List[str] = None
|
60
|
+
):
|
61
|
+
"""Check if a class has the required attributes.
|
62
|
+
|
63
|
+
>>> class M:
|
64
|
+
... _model_ = "m"
|
65
|
+
... _parameters_ = {}
|
66
|
+
>>> RegisterLanguageModelsMeta.check_required_class_variables(M, ["_model_", "_parameters_"])
|
67
|
+
>>> class M2:
|
68
|
+
... _model_ = "m"
|
69
|
+
>>> RegisterLanguageModelsMeta.check_required_class_variables(M2, ["_model_", "_parameters_"])
|
70
|
+
Traceback (most recent call last):
|
71
|
+
...
|
72
|
+
Exception: Class M2 does not have required attribute _parameters_
|
73
|
+
"""
|
74
|
+
required_attributes = required_attributes or []
|
75
|
+
for attribute in required_attributes:
|
76
|
+
if not hasattr(candidate_class, attribute):
|
77
|
+
raise Exception(
|
78
|
+
f"Class {candidate_class.__name__} does not have required attribute {attribute}"
|
79
|
+
)
|
80
|
+
|
81
|
+
@staticmethod
|
82
|
+
def verify_method(
|
83
|
+
candidate_class: "LanguageModel",
|
84
|
+
method_name: str,
|
85
|
+
expected_return_type: Any,
|
86
|
+
required_parameters: List[tuple[str, Any]] = None,
|
87
|
+
must_be_async: bool = False,
|
88
|
+
):
|
89
|
+
"""Verify that a method is defined in a class, has the correct return type, and has the correct parameters."""
|
90
|
+
RegisterLanguageModelsMeta._check_method_defined(candidate_class, method_name)
|
91
|
+
|
92
|
+
required_parameters = required_parameters or []
|
93
|
+
method = getattr(candidate_class, method_name)
|
94
|
+
# signature = inspect.signature(method)
|
95
|
+
|
96
|
+
RegisterLanguageModelsMeta._check_return_type(method, expected_return_type)
|
97
|
+
|
98
|
+
if must_be_async:
|
99
|
+
RegisterLanguageModelsMeta._check_is_coroutine(method)
|
100
|
+
|
101
|
+
# Check the parameters
|
102
|
+
# params = signature.parameters
|
103
|
+
# for param_name, param_type in required_parameters:
|
104
|
+
# RegisterLanguageModelsMeta._verify_parameter(
|
105
|
+
# params, param_name, param_type, method_name
|
106
|
+
# )
|
107
|
+
|
108
|
+
@staticmethod
|
109
|
+
def _check_method_defined(cls, method_name):
|
110
|
+
"""Check if a method is defined in a class.
|
111
|
+
|
112
|
+
Example:
|
113
|
+
>>> class M:
|
114
|
+
... def f(self): pass
|
115
|
+
>>> RegisterLanguageModelsMeta._check_method_defined(M, "f")
|
116
|
+
>>> RegisterLanguageModelsMeta._check_method_defined(M, "g")
|
117
|
+
Traceback (most recent call last):
|
118
|
+
...
|
119
|
+
NotImplementedError: g method must be implemented.
|
120
|
+
"""
|
121
|
+
if not hasattr(cls, method_name):
|
122
|
+
raise NotImplementedError(f"{method_name} method must be implemented.")
|
123
|
+
|
124
|
+
@staticmethod
|
125
|
+
def _check_is_coroutine(func: Callable):
|
126
|
+
"""Check to make sure it's a coroutine function.
|
127
|
+
|
128
|
+
Example:
|
129
|
+
|
130
|
+
>>> def f(): pass
|
131
|
+
>>> RegisterLanguageModelsMeta._check_is_coroutine(f)
|
132
|
+
Traceback (most recent call last):
|
133
|
+
...
|
134
|
+
TypeError: A LangugeModel class with method f must be an asynchronous method.
|
135
|
+
"""
|
136
|
+
if not inspect.iscoroutinefunction(func):
|
137
|
+
raise TypeError(
|
138
|
+
f"A LangugeModel class with method {func.__name__} must be an asynchronous method."
|
139
|
+
)
|
140
|
+
|
141
|
+
@staticmethod
|
142
|
+
def _verify_parameter(params, param_name, param_type, method_name):
|
143
|
+
"""Verify that a parameter is defined in a method and has the correct type."""
|
144
|
+
pass
|
145
|
+
# if param_name not in params:
|
146
|
+
# raise TypeError(
|
147
|
+
# f"""Parameter "{param_name}" of method "{method_name}" must be defined.
|
148
|
+
# """
|
149
|
+
# )
|
150
|
+
# if params[param_name].annotation != param_type:
|
151
|
+
# raise TypeError(
|
152
|
+
# f"""Parameter "{param_name}" of method "{method_name}" must be of type {param_type.__name__}.
|
153
|
+
# Got {params[param_name].annotation} instead.
|
154
|
+
# """
|
155
|
+
# )
|
156
|
+
|
157
|
+
@staticmethod
|
158
|
+
def _check_return_type(method, expected_return_type):
|
159
|
+
"""
|
160
|
+
Check if the return type of a method is as expected.
|
161
|
+
|
162
|
+
Example:
|
163
|
+
"""
|
164
|
+
pass
|
165
|
+
# if inspect.isroutine(method):
|
166
|
+
# # return_type = inspect.signature(method).return_annotation
|
167
|
+
# return_type = get_type_hints(method)["return"]
|
168
|
+
# if return_type != expected_return_type:
|
169
|
+
# raise TypeError(
|
170
|
+
# f"Return type of {method.__name__} must be {expected_return_type}. Got {return_type}."
|
171
|
+
# )
|
172
|
+
|
173
|
+
@classmethod
|
174
|
+
def model_names_to_classes(cls):
|
175
|
+
"""Return a dictionary of model names to classes."""
|
176
|
+
d = {}
|
177
|
+
for classname, cls in cls._registry.items():
|
178
|
+
if hasattr(cls, "_model_"):
|
179
|
+
d[cls._model_] = cls
|
180
|
+
else:
|
181
|
+
raise Exception(
|
182
|
+
f"Class {classname} does not have a _model_ class attribute."
|
183
|
+
)
|
184
|
+
return d
|
edsl/language_models/__init__.py
CHANGED
@@ -1,2 +1,3 @@
|
|
1
|
-
from edsl.language_models.LanguageModel import LanguageModel
|
2
|
-
from edsl.language_models.registry import Model
|
1
|
+
from edsl.language_models.LanguageModel import LanguageModel
|
2
|
+
from edsl.language_models.registry import Model
|
3
|
+
from edsl.language_models.KeyLookup import KeyLookup
|
@@ -1,15 +1,15 @@
|
|
1
|
-
from openai import AsyncOpenAI
|
2
|
-
import asyncio
|
3
|
-
|
4
|
-
client = AsyncOpenAI(base_url="http://127.0.0.1:8000/v1", api_key="fake_key")
|
5
|
-
|
6
|
-
|
7
|
-
async def main():
|
8
|
-
response = await client.chat.completions.create(
|
9
|
-
model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Question XX42"}]
|
10
|
-
)
|
11
|
-
print(response)
|
12
|
-
|
13
|
-
|
14
|
-
if __name__ == "__main__":
|
15
|
-
asyncio.run(main())
|
1
|
+
from openai import AsyncOpenAI
|
2
|
+
import asyncio
|
3
|
+
|
4
|
+
client = AsyncOpenAI(base_url="http://127.0.0.1:8000/v1", api_key="fake_key")
|
5
|
+
|
6
|
+
|
7
|
+
async def main():
|
8
|
+
response = await client.chat.completions.create(
|
9
|
+
model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Question XX42"}]
|
10
|
+
)
|
11
|
+
print(response)
|
12
|
+
|
13
|
+
|
14
|
+
if __name__ == "__main__":
|
15
|
+
asyncio.run(main())
|