edsl 0.1.38.dev2__py3-none-any.whl → 0.1.38.dev3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +303 -303
- edsl/BaseDiff.py +260 -260
- edsl/TemplateLoader.py +24 -24
- edsl/__init__.py +49 -49
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +858 -858
- edsl/agents/AgentList.py +362 -362
- edsl/agents/Invigilator.py +222 -222
- edsl/agents/InvigilatorBase.py +284 -284
- edsl/agents/PromptConstructor.py +353 -353
- edsl/agents/__init__.py +3 -3
- edsl/agents/descriptors.py +99 -99
- edsl/agents/prompt_helpers.py +129 -129
- edsl/auto/AutoStudy.py +117 -117
- edsl/auto/StageBase.py +230 -230
- edsl/auto/StageGenerateSurvey.py +178 -178
- edsl/auto/StageLabelQuestions.py +125 -125
- edsl/auto/StagePersona.py +61 -61
- edsl/auto/StagePersonaDimensionValueRanges.py +88 -88
- edsl/auto/StagePersonaDimensionValues.py +74 -74
- edsl/auto/StagePersonaDimensions.py +69 -69
- edsl/auto/StageQuestions.py +73 -73
- edsl/auto/SurveyCreatorPipeline.py +21 -21
- edsl/auto/utilities.py +224 -224
- edsl/base/Base.py +279 -279
- edsl/config.py +149 -149
- edsl/conversation/Conversation.py +290 -290
- edsl/conversation/car_buying.py +58 -58
- edsl/conversation/chips.py +95 -95
- edsl/conversation/mug_negotiation.py +81 -81
- edsl/conversation/next_speaker_utilities.py +93 -93
- edsl/coop/PriceFetcher.py +54 -54
- edsl/coop/__init__.py +2 -2
- edsl/coop/coop.py +961 -961
- edsl/coop/utils.py +131 -131
- edsl/data/Cache.py +530 -530
- edsl/data/CacheEntry.py +228 -228
- edsl/data/CacheHandler.py +149 -149
- edsl/data/RemoteCacheSync.py +97 -97
- edsl/data/SQLiteDict.py +292 -292
- edsl/data/__init__.py +4 -4
- edsl/data/orm.py +10 -10
- edsl/data_transfer_models.py +73 -73
- edsl/enums.py +173 -173
- edsl/exceptions/BaseException.py +21 -21
- edsl/exceptions/__init__.py +54 -54
- edsl/exceptions/agents.py +42 -42
- edsl/exceptions/cache.py +5 -5
- edsl/exceptions/configuration.py +16 -16
- edsl/exceptions/coop.py +10 -10
- edsl/exceptions/data.py +14 -14
- edsl/exceptions/general.py +34 -34
- edsl/exceptions/jobs.py +33 -33
- edsl/exceptions/language_models.py +63 -63
- edsl/exceptions/prompts.py +15 -15
- edsl/exceptions/questions.py +91 -91
- edsl/exceptions/results.py +29 -29
- edsl/exceptions/scenarios.py +22 -22
- edsl/exceptions/surveys.py +37 -37
- edsl/inference_services/AnthropicService.py +87 -87
- edsl/inference_services/AwsBedrock.py +120 -120
- edsl/inference_services/AzureAI.py +217 -217
- edsl/inference_services/DeepInfraService.py +18 -18
- edsl/inference_services/GoogleService.py +156 -156
- edsl/inference_services/GroqService.py +20 -20
- edsl/inference_services/InferenceServiceABC.py +147 -147
- edsl/inference_services/InferenceServicesCollection.py +97 -97
- edsl/inference_services/MistralAIService.py +123 -123
- edsl/inference_services/OllamaService.py +18 -18
- edsl/inference_services/OpenAIService.py +224 -224
- edsl/inference_services/TestService.py +89 -89
- edsl/inference_services/TogetherAIService.py +170 -170
- edsl/inference_services/models_available_cache.py +118 -118
- edsl/inference_services/rate_limits_cache.py +25 -25
- edsl/inference_services/registry.py +39 -39
- edsl/inference_services/write_available.py +10 -10
- edsl/jobs/Answers.py +56 -56
- edsl/jobs/Jobs.py +1358 -1358
- edsl/jobs/__init__.py +1 -1
- edsl/jobs/buckets/BucketCollection.py +63 -63
- edsl/jobs/buckets/ModelBuckets.py +65 -65
- edsl/jobs/buckets/TokenBucket.py +251 -251
- edsl/jobs/interviews/Interview.py +661 -661
- edsl/jobs/interviews/InterviewExceptionCollection.py +99 -99
- edsl/jobs/interviews/InterviewExceptionEntry.py +186 -186
- edsl/jobs/interviews/InterviewStatistic.py +63 -63
- edsl/jobs/interviews/InterviewStatisticsCollection.py +25 -25
- edsl/jobs/interviews/InterviewStatusDictionary.py +78 -78
- edsl/jobs/interviews/InterviewStatusLog.py +92 -92
- edsl/jobs/interviews/ReportErrors.py +66 -66
- edsl/jobs/interviews/interview_status_enum.py +9 -9
- edsl/jobs/runners/JobsRunnerAsyncio.py +361 -361
- edsl/jobs/runners/JobsRunnerStatus.py +332 -332
- edsl/jobs/tasks/QuestionTaskCreator.py +242 -242
- edsl/jobs/tasks/TaskCreators.py +64 -64
- edsl/jobs/tasks/TaskHistory.py +451 -451
- edsl/jobs/tasks/TaskStatusLog.py +23 -23
- edsl/jobs/tasks/task_status_enum.py +163 -163
- edsl/jobs/tokens/InterviewTokenUsage.py +27 -27
- edsl/jobs/tokens/TokenUsage.py +34 -34
- edsl/language_models/KeyLookup.py +30 -30
- edsl/language_models/LanguageModel.py +708 -708
- edsl/language_models/ModelList.py +109 -109
- edsl/language_models/RegisterLanguageModelsMeta.py +184 -184
- edsl/language_models/__init__.py +3 -3
- edsl/language_models/fake_openai_call.py +15 -15
- edsl/language_models/fake_openai_service.py +61 -61
- edsl/language_models/registry.py +137 -137
- edsl/language_models/repair.py +156 -156
- edsl/language_models/unused/ReplicateBase.py +83 -83
- edsl/language_models/utilities.py +64 -64
- edsl/notebooks/Notebook.py +258 -258
- edsl/notebooks/__init__.py +1 -1
- edsl/prompts/Prompt.py +357 -357
- edsl/prompts/__init__.py +2 -2
- edsl/questions/AnswerValidatorMixin.py +289 -289
- edsl/questions/QuestionBase.py +660 -660
- edsl/questions/QuestionBaseGenMixin.py +161 -161
- edsl/questions/QuestionBasePromptsMixin.py +217 -217
- edsl/questions/QuestionBudget.py +227 -227
- edsl/questions/QuestionCheckBox.py +359 -359
- edsl/questions/QuestionExtract.py +183 -183
- edsl/questions/QuestionFreeText.py +114 -114
- edsl/questions/QuestionFunctional.py +166 -166
- edsl/questions/QuestionList.py +231 -231
- edsl/questions/QuestionMultipleChoice.py +286 -286
- edsl/questions/QuestionNumerical.py +153 -153
- edsl/questions/QuestionRank.py +324 -324
- edsl/questions/Quick.py +41 -41
- edsl/questions/RegisterQuestionsMeta.py +71 -71
- edsl/questions/ResponseValidatorABC.py +174 -174
- edsl/questions/SimpleAskMixin.py +73 -73
- edsl/questions/__init__.py +26 -26
- edsl/questions/compose_questions.py +98 -98
- edsl/questions/decorators.py +21 -21
- edsl/questions/derived/QuestionLikertFive.py +76 -76
- edsl/questions/derived/QuestionLinearScale.py +87 -87
- edsl/questions/derived/QuestionTopK.py +93 -93
- edsl/questions/derived/QuestionYesNo.py +82 -82
- edsl/questions/descriptors.py +413 -413
- edsl/questions/prompt_templates/question_budget.jinja +13 -13
- edsl/questions/prompt_templates/question_checkbox.jinja +32 -32
- edsl/questions/prompt_templates/question_extract.jinja +11 -11
- edsl/questions/prompt_templates/question_free_text.jinja +3 -3
- edsl/questions/prompt_templates/question_linear_scale.jinja +11 -11
- edsl/questions/prompt_templates/question_list.jinja +17 -17
- edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -33
- edsl/questions/prompt_templates/question_numerical.jinja +36 -36
- edsl/questions/question_registry.py +147 -147
- edsl/questions/settings.py +12 -12
- edsl/questions/templates/budget/answering_instructions.jinja +7 -7
- edsl/questions/templates/budget/question_presentation.jinja +7 -7
- edsl/questions/templates/checkbox/answering_instructions.jinja +10 -10
- edsl/questions/templates/checkbox/question_presentation.jinja +22 -22
- edsl/questions/templates/extract/answering_instructions.jinja +7 -7
- edsl/questions/templates/likert_five/answering_instructions.jinja +10 -10
- edsl/questions/templates/likert_five/question_presentation.jinja +11 -11
- edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -5
- edsl/questions/templates/linear_scale/question_presentation.jinja +5 -5
- edsl/questions/templates/list/answering_instructions.jinja +3 -3
- edsl/questions/templates/list/question_presentation.jinja +5 -5
- edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -9
- edsl/questions/templates/multiple_choice/question_presentation.jinja +11 -11
- edsl/questions/templates/numerical/answering_instructions.jinja +6 -6
- edsl/questions/templates/numerical/question_presentation.jinja +6 -6
- edsl/questions/templates/rank/answering_instructions.jinja +11 -11
- edsl/questions/templates/rank/question_presentation.jinja +15 -15
- edsl/questions/templates/top_k/answering_instructions.jinja +8 -8
- edsl/questions/templates/top_k/question_presentation.jinja +22 -22
- edsl/questions/templates/yes_no/answering_instructions.jinja +6 -6
- edsl/questions/templates/yes_no/question_presentation.jinja +11 -11
- edsl/results/Dataset.py +293 -293
- edsl/results/DatasetExportMixin.py +717 -717
- edsl/results/DatasetTree.py +145 -145
- edsl/results/Result.py +456 -456
- edsl/results/Results.py +1071 -1071
- edsl/results/ResultsDBMixin.py +238 -238
- edsl/results/ResultsExportMixin.py +43 -43
- edsl/results/ResultsFetchMixin.py +33 -33
- edsl/results/ResultsGGMixin.py +121 -121
- edsl/results/ResultsToolsMixin.py +98 -98
- edsl/results/Selector.py +135 -135
- edsl/results/__init__.py +2 -2
- edsl/results/tree_explore.py +115 -115
- edsl/scenarios/FileStore.py +458 -458
- edsl/scenarios/Scenario.py +544 -544
- edsl/scenarios/ScenarioHtmlMixin.py +64 -64
- edsl/scenarios/ScenarioList.py +1112 -1112
- edsl/scenarios/ScenarioListExportMixin.py +52 -52
- edsl/scenarios/ScenarioListPdfMixin.py +261 -261
- edsl/scenarios/__init__.py +4 -4
- edsl/shared.py +1 -1
- edsl/study/ObjectEntry.py +173 -173
- edsl/study/ProofOfWork.py +113 -113
- edsl/study/SnapShot.py +80 -80
- edsl/study/Study.py +528 -528
- edsl/study/__init__.py +4 -4
- edsl/surveys/DAG.py +148 -148
- edsl/surveys/Memory.py +31 -31
- edsl/surveys/MemoryPlan.py +244 -244
- edsl/surveys/Rule.py +326 -326
- edsl/surveys/RuleCollection.py +387 -387
- edsl/surveys/Survey.py +1787 -1787
- edsl/surveys/SurveyCSS.py +261 -261
- edsl/surveys/SurveyExportMixin.py +259 -259
- edsl/surveys/SurveyFlowVisualizationMixin.py +121 -121
- edsl/surveys/SurveyQualtricsImport.py +284 -284
- edsl/surveys/__init__.py +3 -3
- edsl/surveys/base.py +53 -53
- edsl/surveys/descriptors.py +56 -56
- edsl/surveys/instructions/ChangeInstruction.py +49 -49
- edsl/surveys/instructions/Instruction.py +53 -53
- edsl/surveys/instructions/InstructionCollection.py +77 -77
- edsl/templates/error_reporting/base.html +23 -23
- edsl/templates/error_reporting/exceptions_by_model.html +34 -34
- edsl/templates/error_reporting/exceptions_by_question_name.html +16 -16
- edsl/templates/error_reporting/exceptions_by_type.html +16 -16
- edsl/templates/error_reporting/interview_details.html +115 -115
- edsl/templates/error_reporting/interviews.html +9 -9
- edsl/templates/error_reporting/overview.html +4 -4
- edsl/templates/error_reporting/performance_plot.html +1 -1
- edsl/templates/error_reporting/report.css +73 -73
- edsl/templates/error_reporting/report.html +117 -117
- edsl/templates/error_reporting/report.js +25 -25
- edsl/tools/__init__.py +1 -1
- edsl/tools/clusters.py +192 -192
- edsl/tools/embeddings.py +27 -27
- edsl/tools/embeddings_plotting.py +118 -118
- edsl/tools/plotting.py +112 -112
- edsl/tools/summarize.py +18 -18
- edsl/utilities/SystemInfo.py +28 -28
- edsl/utilities/__init__.py +22 -22
- edsl/utilities/ast_utilities.py +25 -25
- edsl/utilities/data/Registry.py +6 -6
- edsl/utilities/data/__init__.py +1 -1
- edsl/utilities/data/scooter_results.json +1 -1
- edsl/utilities/decorators.py +77 -77
- edsl/utilities/gcp_bucket/cloud_storage.py +96 -96
- edsl/utilities/interface.py +627 -627
- edsl/utilities/naming_utilities.py +263 -263
- edsl/utilities/repair_functions.py +28 -28
- edsl/utilities/restricted_python.py +70 -70
- edsl/utilities/utilities.py +409 -409
- {edsl-0.1.38.dev2.dist-info → edsl-0.1.38.dev3.dist-info}/LICENSE +21 -21
- {edsl-0.1.38.dev2.dist-info → edsl-0.1.38.dev3.dist-info}/METADATA +1 -1
- edsl-0.1.38.dev3.dist-info/RECORD +269 -0
- edsl-0.1.38.dev2.dist-info/RECORD +0 -269
- {edsl-0.1.38.dev2.dist-info → edsl-0.1.38.dev3.dist-info}/WHEEL +0 -0
edsl/data_transfer_models.py
CHANGED
@@ -1,73 +1,73 @@
|
|
1
|
-
from typing import NamedTuple, Dict, List, Optional, Any
|
2
|
-
from dataclasses import dataclass, fields
|
3
|
-
import reprlib
|
4
|
-
|
5
|
-
|
6
|
-
class ModelInputs(NamedTuple):
|
7
|
-
"This is what was send by the agent to the model"
|
8
|
-
user_prompt: str
|
9
|
-
system_prompt: str
|
10
|
-
encoded_image: Optional[str] = None
|
11
|
-
|
12
|
-
|
13
|
-
class EDSLOutput(NamedTuple):
|
14
|
-
"This is the edsl dictionary that is returned by the model"
|
15
|
-
answer: Any
|
16
|
-
generated_tokens: str
|
17
|
-
comment: Optional[str] = None
|
18
|
-
|
19
|
-
|
20
|
-
class ModelResponse(NamedTuple):
|
21
|
-
"This is the metadata that is returned by the model and includes info about the cache"
|
22
|
-
response: dict
|
23
|
-
cache_used: bool
|
24
|
-
cache_key: str
|
25
|
-
cached_response: Optional[Dict[str, Any]] = None
|
26
|
-
cost: Optional[float] = None
|
27
|
-
|
28
|
-
|
29
|
-
class AgentResponseDict(NamedTuple):
|
30
|
-
edsl_dict: EDSLOutput
|
31
|
-
model_inputs: ModelInputs
|
32
|
-
model_outputs: ModelResponse
|
33
|
-
|
34
|
-
|
35
|
-
class EDSLResultObjectInput(NamedTuple):
|
36
|
-
generated_tokens: str
|
37
|
-
question_name: str
|
38
|
-
prompts: dict
|
39
|
-
cached_response: str
|
40
|
-
raw_model_response: str
|
41
|
-
cache_used: bool
|
42
|
-
cache_key: str
|
43
|
-
answer: Any
|
44
|
-
comment: str
|
45
|
-
validated: bool = False
|
46
|
-
exception_occurred: Exception = None
|
47
|
-
cost: Optional[float] = None
|
48
|
-
|
49
|
-
|
50
|
-
@dataclass
|
51
|
-
class ImageInfo:
|
52
|
-
file_path: str
|
53
|
-
file_name: str
|
54
|
-
image_format: str
|
55
|
-
file_size: int
|
56
|
-
encoded_image: str
|
57
|
-
|
58
|
-
def __repr__(self):
|
59
|
-
reprlib_instance = reprlib.Repr()
|
60
|
-
reprlib_instance.maxstring = 30 # Limit the string length for the encoded image
|
61
|
-
|
62
|
-
# Get all fields except encoded_image
|
63
|
-
field_reprs = [
|
64
|
-
f"{f.name}={getattr(self, f.name)!r}"
|
65
|
-
for f in fields(self)
|
66
|
-
if f.name != "encoded_image"
|
67
|
-
]
|
68
|
-
|
69
|
-
# Add the reprlib-restricted encoded_image field
|
70
|
-
field_reprs.append(f"encoded_image={reprlib_instance.repr(self.encoded_image)}")
|
71
|
-
|
72
|
-
# Join everything to create the repr
|
73
|
-
return f"{self.__class__.__name__}({', '.join(field_reprs)})"
|
1
|
+
from typing import NamedTuple, Dict, List, Optional, Any
|
2
|
+
from dataclasses import dataclass, fields
|
3
|
+
import reprlib
|
4
|
+
|
5
|
+
|
6
|
+
class ModelInputs(NamedTuple):
|
7
|
+
"This is what was send by the agent to the model"
|
8
|
+
user_prompt: str
|
9
|
+
system_prompt: str
|
10
|
+
encoded_image: Optional[str] = None
|
11
|
+
|
12
|
+
|
13
|
+
class EDSLOutput(NamedTuple):
|
14
|
+
"This is the edsl dictionary that is returned by the model"
|
15
|
+
answer: Any
|
16
|
+
generated_tokens: str
|
17
|
+
comment: Optional[str] = None
|
18
|
+
|
19
|
+
|
20
|
+
class ModelResponse(NamedTuple):
|
21
|
+
"This is the metadata that is returned by the model and includes info about the cache"
|
22
|
+
response: dict
|
23
|
+
cache_used: bool
|
24
|
+
cache_key: str
|
25
|
+
cached_response: Optional[Dict[str, Any]] = None
|
26
|
+
cost: Optional[float] = None
|
27
|
+
|
28
|
+
|
29
|
+
class AgentResponseDict(NamedTuple):
|
30
|
+
edsl_dict: EDSLOutput
|
31
|
+
model_inputs: ModelInputs
|
32
|
+
model_outputs: ModelResponse
|
33
|
+
|
34
|
+
|
35
|
+
class EDSLResultObjectInput(NamedTuple):
|
36
|
+
generated_tokens: str
|
37
|
+
question_name: str
|
38
|
+
prompts: dict
|
39
|
+
cached_response: str
|
40
|
+
raw_model_response: str
|
41
|
+
cache_used: bool
|
42
|
+
cache_key: str
|
43
|
+
answer: Any
|
44
|
+
comment: str
|
45
|
+
validated: bool = False
|
46
|
+
exception_occurred: Exception = None
|
47
|
+
cost: Optional[float] = None
|
48
|
+
|
49
|
+
|
50
|
+
@dataclass
|
51
|
+
class ImageInfo:
|
52
|
+
file_path: str
|
53
|
+
file_name: str
|
54
|
+
image_format: str
|
55
|
+
file_size: int
|
56
|
+
encoded_image: str
|
57
|
+
|
58
|
+
def __repr__(self):
|
59
|
+
reprlib_instance = reprlib.Repr()
|
60
|
+
reprlib_instance.maxstring = 30 # Limit the string length for the encoded image
|
61
|
+
|
62
|
+
# Get all fields except encoded_image
|
63
|
+
field_reprs = [
|
64
|
+
f"{f.name}={getattr(self, f.name)!r}"
|
65
|
+
for f in fields(self)
|
66
|
+
if f.name != "encoded_image"
|
67
|
+
]
|
68
|
+
|
69
|
+
# Add the reprlib-restricted encoded_image field
|
70
|
+
field_reprs.append(f"encoded_image={reprlib_instance.repr(self.encoded_image)}")
|
71
|
+
|
72
|
+
# Join everything to create the repr
|
73
|
+
return f"{self.__class__.__name__}({', '.join(field_reprs)})"
|
edsl/enums.py
CHANGED
@@ -1,173 +1,173 @@
|
|
1
|
-
"""Enums for the different types of questions, language models, and inference services."""
|
2
|
-
|
3
|
-
from enum import Enum
|
4
|
-
|
5
|
-
|
6
|
-
class EnumWithChecks(Enum):
|
7
|
-
"""Base class for all enums with checks."""
|
8
|
-
|
9
|
-
@classmethod
|
10
|
-
def is_value_valid(cls, value):
|
11
|
-
"""Check if the value is valid."""
|
12
|
-
return any(value == item.value for item in cls)
|
13
|
-
|
14
|
-
|
15
|
-
class QuestionType(EnumWithChecks):
|
16
|
-
"""Enum for the question types."""
|
17
|
-
|
18
|
-
MULTIPLE_CHOICE = "multiple_choice"
|
19
|
-
YES_NO = "yes_no"
|
20
|
-
FREE_TEXT = "free_text"
|
21
|
-
RANK = "rank"
|
22
|
-
BUDGET = "budget"
|
23
|
-
CHECKBOX = "checkbox"
|
24
|
-
EXTRACT = "extract"
|
25
|
-
FUNCTIONAL = "functional"
|
26
|
-
LIST = "list"
|
27
|
-
NUMERICAL = "numerical"
|
28
|
-
TOP_K = "top_k"
|
29
|
-
LIKERT_FIVE = "likert_five"
|
30
|
-
LINEAR_SCALE = "linear_scale"
|
31
|
-
|
32
|
-
|
33
|
-
# https://huggingface.co/meta-llama/Llama-2-70b-chat-hf
|
34
|
-
|
35
|
-
|
36
|
-
# class LanguageModelType(EnumWithChecks):
|
37
|
-
# """Enum for the language model types."""
|
38
|
-
|
39
|
-
# GPT_4 = "gpt-4-1106-preview"
|
40
|
-
# GPT_3_5_Turbo = "gpt-3.5-turbo"
|
41
|
-
# LLAMA_2_70B_CHAT_HF = "llama-2-70b-chat-hf"
|
42
|
-
# LLAMA_2_13B_CHAT_HF = "llama-2-13b-chat-hf"
|
43
|
-
# GEMINI_PRO = "gemini_pro"
|
44
|
-
# MIXTRAL_8x7B_INSTRUCT = "mixtral-8x7B-instruct-v0.1"
|
45
|
-
# TEST = "test"
|
46
|
-
# ANTHROPIC_3_OPUS = "claude-3-opus-20240229"
|
47
|
-
# ANTHROPIC_3_SONNET = "claude-3-sonnet-20240229"
|
48
|
-
# ANTHROPIC_3_HAIKU = "claude-3-haiku-20240307"
|
49
|
-
# DBRX_INSTRUCT = "dbrx-instruct"
|
50
|
-
|
51
|
-
|
52
|
-
class InferenceServiceType(EnumWithChecks):
|
53
|
-
"""Enum for the inference service types."""
|
54
|
-
|
55
|
-
BEDROCK = "bedrock"
|
56
|
-
DEEP_INFRA = "deep_infra"
|
57
|
-
REPLICATE = "replicate"
|
58
|
-
OPENAI = "openai"
|
59
|
-
GOOGLE = "google"
|
60
|
-
TEST = "test"
|
61
|
-
ANTHROPIC = "anthropic"
|
62
|
-
GROQ = "groq"
|
63
|
-
AZURE = "azure"
|
64
|
-
OLLAMA = "ollama"
|
65
|
-
MISTRAL = "mistral"
|
66
|
-
TOGETHER = "together"
|
67
|
-
|
68
|
-
|
69
|
-
service_to_api_keyname = {
|
70
|
-
InferenceServiceType.BEDROCK.value: "TBD",
|
71
|
-
InferenceServiceType.DEEP_INFRA.value: "DEEP_INFRA_API_KEY",
|
72
|
-
InferenceServiceType.REPLICATE.value: "TBD",
|
73
|
-
InferenceServiceType.OPENAI.value: "OPENAI_API_KEY",
|
74
|
-
InferenceServiceType.GOOGLE.value: "GOOGLE_API_KEY",
|
75
|
-
InferenceServiceType.TEST.value: "TBD",
|
76
|
-
InferenceServiceType.ANTHROPIC.value: "ANTHROPIC_API_KEY",
|
77
|
-
InferenceServiceType.GROQ.value: "GROQ_API_KEY",
|
78
|
-
InferenceServiceType.BEDROCK.value: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"],
|
79
|
-
InferenceServiceType.MISTRAL.value: "MISTRAL_API_KEY",
|
80
|
-
InferenceServiceType.TOGETHER.value: "TOGETHER_API_KEY",
|
81
|
-
}
|
82
|
-
|
83
|
-
|
84
|
-
class TokenPricing:
|
85
|
-
def __init__(
|
86
|
-
self,
|
87
|
-
*,
|
88
|
-
model_name,
|
89
|
-
prompt_token_price_per_k: float,
|
90
|
-
completion_token_price_per_k: float,
|
91
|
-
):
|
92
|
-
self.model_name = model_name
|
93
|
-
self.prompt_token_price = prompt_token_price_per_k / 1_000.0
|
94
|
-
self.completion_token_price = completion_token_price_per_k / 1_000.0
|
95
|
-
|
96
|
-
def __eq__(self, other):
|
97
|
-
if not isinstance(other, TokenPricing):
|
98
|
-
return False
|
99
|
-
return (
|
100
|
-
self.model_name == other.model_name
|
101
|
-
and self.prompt_token_price == other.prompt_token_price
|
102
|
-
and self.completion_token_price == other.completion_token_price
|
103
|
-
)
|
104
|
-
|
105
|
-
|
106
|
-
pricing = {
|
107
|
-
"dbrx-instruct": TokenPricing(
|
108
|
-
model_name="dbrx-instruct",
|
109
|
-
prompt_token_price_per_k=0.0,
|
110
|
-
completion_token_price_per_k=0.0,
|
111
|
-
),
|
112
|
-
"claude-3-opus-20240229": TokenPricing(
|
113
|
-
model_name="claude-3-opus-20240229",
|
114
|
-
prompt_token_price_per_k=0.0,
|
115
|
-
completion_token_price_per_k=0.0,
|
116
|
-
),
|
117
|
-
"claude-3-haiku-20240307": TokenPricing(
|
118
|
-
model_name="claude-3-haiku-20240307",
|
119
|
-
prompt_token_price_per_k=0.0,
|
120
|
-
completion_token_price_per_k=0.0,
|
121
|
-
),
|
122
|
-
"claude-3-sonnet-20240229": TokenPricing(
|
123
|
-
model_name="claude-3-sonnet-20240229",
|
124
|
-
prompt_token_price_per_k=0.0,
|
125
|
-
completion_token_price_per_k=0.0,
|
126
|
-
),
|
127
|
-
"gpt-3.5-turbo": TokenPricing(
|
128
|
-
model_name="gpt-3.5-turbo",
|
129
|
-
prompt_token_price_per_k=0.0005,
|
130
|
-
completion_token_price_per_k=0.0015,
|
131
|
-
),
|
132
|
-
"gpt-4-1106-preview": TokenPricing(
|
133
|
-
model_name="gpt-4",
|
134
|
-
prompt_token_price_per_k=0.01,
|
135
|
-
completion_token_price_per_k=0.03,
|
136
|
-
),
|
137
|
-
"test": TokenPricing(
|
138
|
-
model_name="test",
|
139
|
-
prompt_token_price_per_k=0.0,
|
140
|
-
completion_token_price_per_k=0.0,
|
141
|
-
),
|
142
|
-
"gemini_pro": TokenPricing(
|
143
|
-
model_name="gemini_pro",
|
144
|
-
prompt_token_price_per_k=0.0,
|
145
|
-
completion_token_price_per_k=0.0,
|
146
|
-
),
|
147
|
-
"llama-2-13b-chat-hf": TokenPricing(
|
148
|
-
model_name="llama-2-13b-chat-hf",
|
149
|
-
prompt_token_price_per_k=0.0,
|
150
|
-
completion_token_price_per_k=0.0,
|
151
|
-
),
|
152
|
-
"llama-2-70b-chat-hf": TokenPricing(
|
153
|
-
model_name="llama-2-70b-chat-hf",
|
154
|
-
prompt_token_price_per_k=0.0,
|
155
|
-
completion_token_price_per_k=0.0,
|
156
|
-
),
|
157
|
-
"mixtral-8x7B-instruct-v0.1": TokenPricing(
|
158
|
-
model_name="mixtral-8x7B-instruct-v0.1",
|
159
|
-
prompt_token_price_per_k=0.0,
|
160
|
-
completion_token_price_per_k=0.0,
|
161
|
-
),
|
162
|
-
}
|
163
|
-
|
164
|
-
|
165
|
-
def get_token_pricing(model_name):
|
166
|
-
if model_name in pricing:
|
167
|
-
return pricing[model_name]
|
168
|
-
else:
|
169
|
-
return TokenPricing(
|
170
|
-
model_name=model_name,
|
171
|
-
prompt_token_price_per_k=0.0,
|
172
|
-
completion_token_price_per_k=0.0,
|
173
|
-
)
|
1
|
+
"""Enums for the different types of questions, language models, and inference services."""
|
2
|
+
|
3
|
+
from enum import Enum
|
4
|
+
|
5
|
+
|
6
|
+
class EnumWithChecks(Enum):
|
7
|
+
"""Base class for all enums with checks."""
|
8
|
+
|
9
|
+
@classmethod
|
10
|
+
def is_value_valid(cls, value):
|
11
|
+
"""Check if the value is valid."""
|
12
|
+
return any(value == item.value for item in cls)
|
13
|
+
|
14
|
+
|
15
|
+
class QuestionType(EnumWithChecks):
|
16
|
+
"""Enum for the question types."""
|
17
|
+
|
18
|
+
MULTIPLE_CHOICE = "multiple_choice"
|
19
|
+
YES_NO = "yes_no"
|
20
|
+
FREE_TEXT = "free_text"
|
21
|
+
RANK = "rank"
|
22
|
+
BUDGET = "budget"
|
23
|
+
CHECKBOX = "checkbox"
|
24
|
+
EXTRACT = "extract"
|
25
|
+
FUNCTIONAL = "functional"
|
26
|
+
LIST = "list"
|
27
|
+
NUMERICAL = "numerical"
|
28
|
+
TOP_K = "top_k"
|
29
|
+
LIKERT_FIVE = "likert_five"
|
30
|
+
LINEAR_SCALE = "linear_scale"
|
31
|
+
|
32
|
+
|
33
|
+
# https://huggingface.co/meta-llama/Llama-2-70b-chat-hf
|
34
|
+
|
35
|
+
|
36
|
+
# class LanguageModelType(EnumWithChecks):
|
37
|
+
# """Enum for the language model types."""
|
38
|
+
|
39
|
+
# GPT_4 = "gpt-4-1106-preview"
|
40
|
+
# GPT_3_5_Turbo = "gpt-3.5-turbo"
|
41
|
+
# LLAMA_2_70B_CHAT_HF = "llama-2-70b-chat-hf"
|
42
|
+
# LLAMA_2_13B_CHAT_HF = "llama-2-13b-chat-hf"
|
43
|
+
# GEMINI_PRO = "gemini_pro"
|
44
|
+
# MIXTRAL_8x7B_INSTRUCT = "mixtral-8x7B-instruct-v0.1"
|
45
|
+
# TEST = "test"
|
46
|
+
# ANTHROPIC_3_OPUS = "claude-3-opus-20240229"
|
47
|
+
# ANTHROPIC_3_SONNET = "claude-3-sonnet-20240229"
|
48
|
+
# ANTHROPIC_3_HAIKU = "claude-3-haiku-20240307"
|
49
|
+
# DBRX_INSTRUCT = "dbrx-instruct"
|
50
|
+
|
51
|
+
|
52
|
+
class InferenceServiceType(EnumWithChecks):
|
53
|
+
"""Enum for the inference service types."""
|
54
|
+
|
55
|
+
BEDROCK = "bedrock"
|
56
|
+
DEEP_INFRA = "deep_infra"
|
57
|
+
REPLICATE = "replicate"
|
58
|
+
OPENAI = "openai"
|
59
|
+
GOOGLE = "google"
|
60
|
+
TEST = "test"
|
61
|
+
ANTHROPIC = "anthropic"
|
62
|
+
GROQ = "groq"
|
63
|
+
AZURE = "azure"
|
64
|
+
OLLAMA = "ollama"
|
65
|
+
MISTRAL = "mistral"
|
66
|
+
TOGETHER = "together"
|
67
|
+
|
68
|
+
|
69
|
+
service_to_api_keyname = {
|
70
|
+
InferenceServiceType.BEDROCK.value: "TBD",
|
71
|
+
InferenceServiceType.DEEP_INFRA.value: "DEEP_INFRA_API_KEY",
|
72
|
+
InferenceServiceType.REPLICATE.value: "TBD",
|
73
|
+
InferenceServiceType.OPENAI.value: "OPENAI_API_KEY",
|
74
|
+
InferenceServiceType.GOOGLE.value: "GOOGLE_API_KEY",
|
75
|
+
InferenceServiceType.TEST.value: "TBD",
|
76
|
+
InferenceServiceType.ANTHROPIC.value: "ANTHROPIC_API_KEY",
|
77
|
+
InferenceServiceType.GROQ.value: "GROQ_API_KEY",
|
78
|
+
InferenceServiceType.BEDROCK.value: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"],
|
79
|
+
InferenceServiceType.MISTRAL.value: "MISTRAL_API_KEY",
|
80
|
+
InferenceServiceType.TOGETHER.value: "TOGETHER_API_KEY",
|
81
|
+
}
|
82
|
+
|
83
|
+
|
84
|
+
class TokenPricing:
|
85
|
+
def __init__(
|
86
|
+
self,
|
87
|
+
*,
|
88
|
+
model_name,
|
89
|
+
prompt_token_price_per_k: float,
|
90
|
+
completion_token_price_per_k: float,
|
91
|
+
):
|
92
|
+
self.model_name = model_name
|
93
|
+
self.prompt_token_price = prompt_token_price_per_k / 1_000.0
|
94
|
+
self.completion_token_price = completion_token_price_per_k / 1_000.0
|
95
|
+
|
96
|
+
def __eq__(self, other):
|
97
|
+
if not isinstance(other, TokenPricing):
|
98
|
+
return False
|
99
|
+
return (
|
100
|
+
self.model_name == other.model_name
|
101
|
+
and self.prompt_token_price == other.prompt_token_price
|
102
|
+
and self.completion_token_price == other.completion_token_price
|
103
|
+
)
|
104
|
+
|
105
|
+
|
106
|
+
pricing = {
|
107
|
+
"dbrx-instruct": TokenPricing(
|
108
|
+
model_name="dbrx-instruct",
|
109
|
+
prompt_token_price_per_k=0.0,
|
110
|
+
completion_token_price_per_k=0.0,
|
111
|
+
),
|
112
|
+
"claude-3-opus-20240229": TokenPricing(
|
113
|
+
model_name="claude-3-opus-20240229",
|
114
|
+
prompt_token_price_per_k=0.0,
|
115
|
+
completion_token_price_per_k=0.0,
|
116
|
+
),
|
117
|
+
"claude-3-haiku-20240307": TokenPricing(
|
118
|
+
model_name="claude-3-haiku-20240307",
|
119
|
+
prompt_token_price_per_k=0.0,
|
120
|
+
completion_token_price_per_k=0.0,
|
121
|
+
),
|
122
|
+
"claude-3-sonnet-20240229": TokenPricing(
|
123
|
+
model_name="claude-3-sonnet-20240229",
|
124
|
+
prompt_token_price_per_k=0.0,
|
125
|
+
completion_token_price_per_k=0.0,
|
126
|
+
),
|
127
|
+
"gpt-3.5-turbo": TokenPricing(
|
128
|
+
model_name="gpt-3.5-turbo",
|
129
|
+
prompt_token_price_per_k=0.0005,
|
130
|
+
completion_token_price_per_k=0.0015,
|
131
|
+
),
|
132
|
+
"gpt-4-1106-preview": TokenPricing(
|
133
|
+
model_name="gpt-4",
|
134
|
+
prompt_token_price_per_k=0.01,
|
135
|
+
completion_token_price_per_k=0.03,
|
136
|
+
),
|
137
|
+
"test": TokenPricing(
|
138
|
+
model_name="test",
|
139
|
+
prompt_token_price_per_k=0.0,
|
140
|
+
completion_token_price_per_k=0.0,
|
141
|
+
),
|
142
|
+
"gemini_pro": TokenPricing(
|
143
|
+
model_name="gemini_pro",
|
144
|
+
prompt_token_price_per_k=0.0,
|
145
|
+
completion_token_price_per_k=0.0,
|
146
|
+
),
|
147
|
+
"llama-2-13b-chat-hf": TokenPricing(
|
148
|
+
model_name="llama-2-13b-chat-hf",
|
149
|
+
prompt_token_price_per_k=0.0,
|
150
|
+
completion_token_price_per_k=0.0,
|
151
|
+
),
|
152
|
+
"llama-2-70b-chat-hf": TokenPricing(
|
153
|
+
model_name="llama-2-70b-chat-hf",
|
154
|
+
prompt_token_price_per_k=0.0,
|
155
|
+
completion_token_price_per_k=0.0,
|
156
|
+
),
|
157
|
+
"mixtral-8x7B-instruct-v0.1": TokenPricing(
|
158
|
+
model_name="mixtral-8x7B-instruct-v0.1",
|
159
|
+
prompt_token_price_per_k=0.0,
|
160
|
+
completion_token_price_per_k=0.0,
|
161
|
+
),
|
162
|
+
}
|
163
|
+
|
164
|
+
|
165
|
+
def get_token_pricing(model_name):
|
166
|
+
if model_name in pricing:
|
167
|
+
return pricing[model_name]
|
168
|
+
else:
|
169
|
+
return TokenPricing(
|
170
|
+
model_name=model_name,
|
171
|
+
prompt_token_price_per_k=0.0,
|
172
|
+
completion_token_price_per_k=0.0,
|
173
|
+
)
|
edsl/exceptions/BaseException.py
CHANGED
@@ -1,21 +1,21 @@
|
|
1
|
-
class BaseException(Exception):
|
2
|
-
relevant_doc = "https://docs.expectedparrot.com/"
|
3
|
-
|
4
|
-
def __init__(self, message, *, show_docs=True):
|
5
|
-
# Format main error message
|
6
|
-
formatted_message = [message.strip()]
|
7
|
-
|
8
|
-
# Add documentation links if requested
|
9
|
-
if show_docs:
|
10
|
-
if hasattr(self, "relevant_doc"):
|
11
|
-
formatted_message.append(
|
12
|
-
f"\nFor more information, see:\n{self.relevant_doc}"
|
13
|
-
)
|
14
|
-
if hasattr(self, "relevant_notebook"):
|
15
|
-
formatted_message.append(
|
16
|
-
f"\nFor a usage example, see:\n{self.relevant_notebook}"
|
17
|
-
)
|
18
|
-
|
19
|
-
# Join with double newlines for clear separation
|
20
|
-
final_message = "\n\n".join(formatted_message)
|
21
|
-
super().__init__(final_message)
|
1
|
+
class BaseException(Exception):
|
2
|
+
relevant_doc = "https://docs.expectedparrot.com/"
|
3
|
+
|
4
|
+
def __init__(self, message, *, show_docs=True):
|
5
|
+
# Format main error message
|
6
|
+
formatted_message = [message.strip()]
|
7
|
+
|
8
|
+
# Add documentation links if requested
|
9
|
+
if show_docs:
|
10
|
+
if hasattr(self, "relevant_doc"):
|
11
|
+
formatted_message.append(
|
12
|
+
f"\nFor more information, see:\n{self.relevant_doc}"
|
13
|
+
)
|
14
|
+
if hasattr(self, "relevant_notebook"):
|
15
|
+
formatted_message.append(
|
16
|
+
f"\nFor a usage example, see:\n{self.relevant_notebook}"
|
17
|
+
)
|
18
|
+
|
19
|
+
# Join with double newlines for clear separation
|
20
|
+
final_message = "\n\n".join(formatted_message)
|
21
|
+
super().__init__(final_message)
|
edsl/exceptions/__init__.py
CHANGED
@@ -1,54 +1,54 @@
|
|
1
|
-
from .agents import (
|
2
|
-
# AgentAttributeLookupCallbackError,
|
3
|
-
AgentCombinationError,
|
4
|
-
# AgentLacksLLMError,
|
5
|
-
# AgentRespondedWithBadJSONError,
|
6
|
-
)
|
7
|
-
from .configuration import (
|
8
|
-
InvalidEnvironmentVariableError,
|
9
|
-
MissingEnvironmentVariableError,
|
10
|
-
)
|
11
|
-
from .data import (
|
12
|
-
DatabaseConnectionError,
|
13
|
-
DatabaseCRUDError,
|
14
|
-
DatabaseIntegrityError,
|
15
|
-
)
|
16
|
-
|
17
|
-
from .scenarios import (
|
18
|
-
ScenarioError,
|
19
|
-
)
|
20
|
-
|
21
|
-
from .general import MissingAPIKeyError
|
22
|
-
|
23
|
-
from .jobs import JobsRunError, InterviewErrorPriorTaskCanceled, InterviewTimeoutError
|
24
|
-
|
25
|
-
from .language_models import (
|
26
|
-
LanguageModelResponseNotJSONError,
|
27
|
-
LanguageModelMissingAttributeError,
|
28
|
-
LanguageModelAttributeTypeError,
|
29
|
-
LanguageModelDoNotAddError,
|
30
|
-
)
|
31
|
-
from .questions import (
|
32
|
-
QuestionAnswerValidationError,
|
33
|
-
QuestionAttributeMissing,
|
34
|
-
QuestionCreationValidationError,
|
35
|
-
QuestionResponseValidationError,
|
36
|
-
QuestionSerializationError,
|
37
|
-
QuestionScenarioRenderError,
|
38
|
-
)
|
39
|
-
from .results import (
|
40
|
-
ResultsBadMutationstringError,
|
41
|
-
ResultsColumnNotFoundError,
|
42
|
-
ResultsInvalidNameError,
|
43
|
-
ResultsMutateError,
|
44
|
-
)
|
45
|
-
from .surveys import (
|
46
|
-
SurveyCreationError,
|
47
|
-
SurveyHasNoRulesError,
|
48
|
-
SurveyRuleCannotEvaluateError,
|
49
|
-
SurveyRuleCollectionHasNoRulesAtNodeError,
|
50
|
-
SurveyRuleReferenceInRuleToUnknownQuestionError,
|
51
|
-
SurveyRuleRefersToFutureStateError,
|
52
|
-
SurveyRuleSendsYouBackwardsError,
|
53
|
-
SurveyRuleSkipLogicSyntaxError,
|
54
|
-
)
|
1
|
+
from .agents import (
|
2
|
+
# AgentAttributeLookupCallbackError,
|
3
|
+
AgentCombinationError,
|
4
|
+
# AgentLacksLLMError,
|
5
|
+
# AgentRespondedWithBadJSONError,
|
6
|
+
)
|
7
|
+
from .configuration import (
|
8
|
+
InvalidEnvironmentVariableError,
|
9
|
+
MissingEnvironmentVariableError,
|
10
|
+
)
|
11
|
+
from .data import (
|
12
|
+
DatabaseConnectionError,
|
13
|
+
DatabaseCRUDError,
|
14
|
+
DatabaseIntegrityError,
|
15
|
+
)
|
16
|
+
|
17
|
+
from .scenarios import (
|
18
|
+
ScenarioError,
|
19
|
+
)
|
20
|
+
|
21
|
+
from .general import MissingAPIKeyError
|
22
|
+
|
23
|
+
from .jobs import JobsRunError, InterviewErrorPriorTaskCanceled, InterviewTimeoutError
|
24
|
+
|
25
|
+
from .language_models import (
|
26
|
+
LanguageModelResponseNotJSONError,
|
27
|
+
LanguageModelMissingAttributeError,
|
28
|
+
LanguageModelAttributeTypeError,
|
29
|
+
LanguageModelDoNotAddError,
|
30
|
+
)
|
31
|
+
from .questions import (
|
32
|
+
QuestionAnswerValidationError,
|
33
|
+
QuestionAttributeMissing,
|
34
|
+
QuestionCreationValidationError,
|
35
|
+
QuestionResponseValidationError,
|
36
|
+
QuestionSerializationError,
|
37
|
+
QuestionScenarioRenderError,
|
38
|
+
)
|
39
|
+
from .results import (
|
40
|
+
ResultsBadMutationstringError,
|
41
|
+
ResultsColumnNotFoundError,
|
42
|
+
ResultsInvalidNameError,
|
43
|
+
ResultsMutateError,
|
44
|
+
)
|
45
|
+
from .surveys import (
|
46
|
+
SurveyCreationError,
|
47
|
+
SurveyHasNoRulesError,
|
48
|
+
SurveyRuleCannotEvaluateError,
|
49
|
+
SurveyRuleCollectionHasNoRulesAtNodeError,
|
50
|
+
SurveyRuleReferenceInRuleToUnknownQuestionError,
|
51
|
+
SurveyRuleRefersToFutureStateError,
|
52
|
+
SurveyRuleSendsYouBackwardsError,
|
53
|
+
SurveyRuleSkipLogicSyntaxError,
|
54
|
+
)
|