edsl 0.1.39__py3-none-any.whl → 0.1.39.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +116 -197
- edsl/__init__.py +7 -15
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +147 -351
- edsl/agents/AgentList.py +73 -211
- edsl/agents/Invigilator.py +50 -101
- edsl/agents/InvigilatorBase.py +70 -62
- edsl/agents/PromptConstructor.py +225 -143
- edsl/agents/__init__.py +1 -0
- edsl/agents/prompt_helpers.py +3 -3
- edsl/auto/AutoStudy.py +5 -18
- edsl/auto/StageBase.py +40 -53
- edsl/auto/StageQuestions.py +1 -2
- edsl/auto/utilities.py +6 -0
- edsl/config.py +2 -22
- edsl/conversation/car_buying.py +1 -2
- edsl/coop/PriceFetcher.py +1 -1
- edsl/coop/coop.py +47 -125
- edsl/coop/utils.py +14 -14
- edsl/data/Cache.py +27 -45
- edsl/data/CacheEntry.py +15 -12
- edsl/data/CacheHandler.py +12 -31
- edsl/data/RemoteCacheSync.py +46 -154
- edsl/data/__init__.py +3 -4
- edsl/data_transfer_models.py +1 -2
- edsl/enums.py +0 -27
- edsl/exceptions/__init__.py +50 -50
- edsl/exceptions/agents.py +0 -12
- edsl/exceptions/questions.py +6 -24
- edsl/exceptions/scenarios.py +0 -7
- edsl/inference_services/AnthropicService.py +19 -38
- edsl/inference_services/AwsBedrock.py +2 -0
- edsl/inference_services/AzureAI.py +2 -0
- edsl/inference_services/GoogleService.py +12 -7
- edsl/inference_services/InferenceServiceABC.py +85 -18
- edsl/inference_services/InferenceServicesCollection.py +79 -120
- edsl/inference_services/MistralAIService.py +3 -0
- edsl/inference_services/OpenAIService.py +35 -47
- edsl/inference_services/PerplexityService.py +3 -0
- edsl/inference_services/TestService.py +10 -11
- edsl/inference_services/TogetherAIService.py +3 -5
- edsl/jobs/Answers.py +14 -1
- edsl/jobs/Jobs.py +431 -356
- edsl/jobs/JobsChecks.py +10 -35
- edsl/jobs/JobsPrompts.py +4 -6
- edsl/jobs/JobsRemoteInferenceHandler.py +133 -205
- edsl/jobs/buckets/BucketCollection.py +3 -44
- edsl/jobs/buckets/TokenBucket.py +21 -53
- edsl/jobs/interviews/Interview.py +408 -143
- edsl/jobs/runners/JobsRunnerAsyncio.py +403 -88
- edsl/jobs/runners/JobsRunnerStatus.py +165 -133
- edsl/jobs/tasks/QuestionTaskCreator.py +19 -21
- edsl/jobs/tasks/TaskHistory.py +18 -38
- edsl/jobs/tasks/task_status_enum.py +2 -0
- edsl/language_models/KeyLookup.py +30 -0
- edsl/language_models/LanguageModel.py +236 -194
- edsl/language_models/ModelList.py +19 -28
- edsl/language_models/__init__.py +2 -1
- edsl/language_models/registry.py +190 -0
- edsl/language_models/repair.py +2 -2
- edsl/language_models/unused/ReplicateBase.py +83 -0
- edsl/language_models/utilities.py +4 -5
- edsl/notebooks/Notebook.py +14 -19
- edsl/prompts/Prompt.py +39 -29
- edsl/questions/{answer_validator_mixin.py → AnswerValidatorMixin.py} +2 -47
- edsl/questions/QuestionBase.py +214 -68
- edsl/questions/{question_base_gen_mixin.py → QuestionBaseGenMixin.py} +50 -57
- edsl/questions/QuestionBasePromptsMixin.py +3 -7
- edsl/questions/QuestionBudget.py +1 -1
- edsl/questions/QuestionCheckBox.py +3 -3
- edsl/questions/QuestionExtract.py +7 -5
- edsl/questions/QuestionFreeText.py +3 -2
- edsl/questions/QuestionList.py +18 -10
- edsl/questions/QuestionMultipleChoice.py +23 -67
- edsl/questions/QuestionNumerical.py +4 -2
- edsl/questions/QuestionRank.py +17 -7
- edsl/questions/{response_validator_abc.py → ResponseValidatorABC.py} +26 -40
- edsl/questions/SimpleAskMixin.py +3 -4
- edsl/questions/__init__.py +1 -2
- edsl/questions/derived/QuestionLinearScale.py +3 -6
- edsl/questions/derived/QuestionTopK.py +1 -1
- edsl/questions/descriptors.py +3 -17
- edsl/questions/question_registry.py +1 -1
- edsl/results/CSSParameterizer.py +1 -1
- edsl/results/Dataset.py +7 -170
- edsl/results/DatasetExportMixin.py +305 -168
- edsl/results/DatasetTree.py +8 -28
- edsl/results/Result.py +206 -298
- edsl/results/Results.py +131 -149
- edsl/results/ResultsDBMixin.py +238 -0
- edsl/results/ResultsExportMixin.py +0 -2
- edsl/results/{results_selector.py → Selector.py} +13 -23
- edsl/results/TableDisplay.py +171 -98
- edsl/results/__init__.py +1 -1
- edsl/scenarios/FileStore.py +239 -150
- edsl/scenarios/Scenario.py +193 -90
- edsl/scenarios/ScenarioHtmlMixin.py +3 -4
- edsl/scenarios/{scenario_join.py → ScenarioJoin.py} +6 -10
- edsl/scenarios/ScenarioList.py +244 -415
- edsl/scenarios/ScenarioListExportMixin.py +7 -0
- edsl/scenarios/ScenarioListPdfMixin.py +37 -15
- edsl/scenarios/__init__.py +2 -1
- edsl/study/ObjectEntry.py +1 -1
- edsl/study/SnapShot.py +1 -1
- edsl/study/Study.py +12 -5
- edsl/surveys/Rule.py +4 -5
- edsl/surveys/RuleCollection.py +27 -25
- edsl/surveys/Survey.py +791 -270
- edsl/surveys/SurveyCSS.py +8 -20
- edsl/surveys/{SurveyFlowVisualization.py → SurveyFlowVisualizationMixin.py} +9 -11
- edsl/surveys/__init__.py +2 -4
- edsl/surveys/descriptors.py +2 -6
- edsl/surveys/instructions/ChangeInstruction.py +2 -1
- edsl/surveys/instructions/Instruction.py +13 -4
- edsl/surveys/instructions/InstructionCollection.py +6 -11
- edsl/templates/error_reporting/interview_details.html +1 -1
- edsl/templates/error_reporting/report.html +1 -1
- edsl/tools/plotting.py +1 -1
- edsl/utilities/utilities.py +23 -35
- {edsl-0.1.39.dist-info → edsl-0.1.39.dev1.dist-info}/METADATA +10 -12
- edsl-0.1.39.dev1.dist-info/RECORD +277 -0
- {edsl-0.1.39.dist-info → edsl-0.1.39.dev1.dist-info}/WHEEL +1 -1
- edsl/agents/QuestionInstructionPromptBuilder.py +0 -128
- edsl/agents/QuestionTemplateReplacementsBuilder.py +0 -137
- edsl/agents/question_option_processor.py +0 -172
- edsl/coop/CoopFunctionsMixin.py +0 -15
- edsl/coop/ExpectedParrotKeyHandler.py +0 -125
- edsl/exceptions/inference_services.py +0 -5
- edsl/inference_services/AvailableModelCacheHandler.py +0 -184
- edsl/inference_services/AvailableModelFetcher.py +0 -215
- edsl/inference_services/ServiceAvailability.py +0 -135
- edsl/inference_services/data_structures.py +0 -134
- edsl/jobs/AnswerQuestionFunctionConstructor.py +0 -223
- edsl/jobs/FetchInvigilator.py +0 -47
- edsl/jobs/InterviewTaskManager.py +0 -98
- edsl/jobs/InterviewsConstructor.py +0 -50
- edsl/jobs/JobsComponentConstructor.py +0 -189
- edsl/jobs/JobsRemoteInferenceLogger.py +0 -239
- edsl/jobs/RequestTokenEstimator.py +0 -30
- edsl/jobs/async_interview_runner.py +0 -138
- edsl/jobs/buckets/TokenBucketAPI.py +0 -211
- edsl/jobs/buckets/TokenBucketClient.py +0 -191
- edsl/jobs/check_survey_scenario_compatibility.py +0 -85
- edsl/jobs/data_structures.py +0 -120
- edsl/jobs/decorators.py +0 -35
- edsl/jobs/jobs_status_enums.py +0 -9
- edsl/jobs/loggers/HTMLTableJobLogger.py +0 -304
- edsl/jobs/results_exceptions_handler.py +0 -98
- edsl/language_models/ComputeCost.py +0 -63
- edsl/language_models/PriceManager.py +0 -127
- edsl/language_models/RawResponseHandler.py +0 -106
- edsl/language_models/ServiceDataSources.py +0 -0
- edsl/language_models/key_management/KeyLookup.py +0 -63
- edsl/language_models/key_management/KeyLookupBuilder.py +0 -273
- edsl/language_models/key_management/KeyLookupCollection.py +0 -38
- edsl/language_models/key_management/__init__.py +0 -0
- edsl/language_models/key_management/models.py +0 -131
- edsl/language_models/model.py +0 -256
- edsl/notebooks/NotebookToLaTeX.py +0 -142
- edsl/questions/ExceptionExplainer.py +0 -77
- edsl/questions/HTMLQuestion.py +0 -103
- edsl/questions/QuestionMatrix.py +0 -265
- edsl/questions/data_structures.py +0 -20
- edsl/questions/loop_processor.py +0 -149
- edsl/questions/response_validator_factory.py +0 -34
- edsl/questions/templates/matrix/__init__.py +0 -1
- edsl/questions/templates/matrix/answering_instructions.jinja +0 -5
- edsl/questions/templates/matrix/question_presentation.jinja +0 -20
- edsl/results/MarkdownToDocx.py +0 -122
- edsl/results/MarkdownToPDF.py +0 -111
- edsl/results/TextEditor.py +0 -50
- edsl/results/file_exports.py +0 -252
- edsl/results/smart_objects.py +0 -96
- edsl/results/table_data_class.py +0 -12
- edsl/results/table_renderers.py +0 -118
- edsl/scenarios/ConstructDownloadLink.py +0 -109
- edsl/scenarios/DocumentChunker.py +0 -102
- edsl/scenarios/DocxScenario.py +0 -16
- edsl/scenarios/PdfExtractor.py +0 -40
- edsl/scenarios/directory_scanner.py +0 -96
- edsl/scenarios/file_methods.py +0 -85
- edsl/scenarios/handlers/__init__.py +0 -13
- edsl/scenarios/handlers/csv.py +0 -49
- edsl/scenarios/handlers/docx.py +0 -76
- edsl/scenarios/handlers/html.py +0 -37
- edsl/scenarios/handlers/json.py +0 -111
- edsl/scenarios/handlers/latex.py +0 -5
- edsl/scenarios/handlers/md.py +0 -51
- edsl/scenarios/handlers/pdf.py +0 -68
- edsl/scenarios/handlers/png.py +0 -39
- edsl/scenarios/handlers/pptx.py +0 -105
- edsl/scenarios/handlers/py.py +0 -294
- edsl/scenarios/handlers/sql.py +0 -313
- edsl/scenarios/handlers/sqlite.py +0 -149
- edsl/scenarios/handlers/txt.py +0 -33
- edsl/scenarios/scenario_selector.py +0 -156
- edsl/surveys/ConstructDAG.py +0 -92
- edsl/surveys/EditSurvey.py +0 -221
- edsl/surveys/InstructionHandler.py +0 -100
- edsl/surveys/MemoryManagement.py +0 -72
- edsl/surveys/RuleManager.py +0 -172
- edsl/surveys/Simulator.py +0 -75
- edsl/surveys/SurveyToApp.py +0 -141
- edsl/utilities/PrettyList.py +0 -56
- edsl/utilities/is_notebook.py +0 -18
- edsl/utilities/is_valid_variable_name.py +0 -11
- edsl/utilities/remove_edsl_version.py +0 -24
- edsl-0.1.39.dist-info/RECORD +0 -358
- /edsl/questions/{register_questions_meta.py → RegisterQuestionsMeta.py} +0 -0
- /edsl/results/{results_fetch_mixin.py → ResultsFetchMixin.py} +0 -0
- /edsl/results/{results_tools_mixin.py → ResultsToolsMixin.py} +0 -0
- {edsl-0.1.39.dist-info → edsl-0.1.39.dev1.dist-info}/LICENSE +0 -0
@@ -1,172 +0,0 @@
|
|
1
|
-
from jinja2 import Environment, meta
|
2
|
-
from typing import List, Optional, Union
|
3
|
-
|
4
|
-
|
5
|
-
class QuestionOptionProcessor:
|
6
|
-
"""
|
7
|
-
Class that manages the processing of question options.
|
8
|
-
These can be provided directly, as a template string, or fetched from prior answers or the scenario.
|
9
|
-
"""
|
10
|
-
|
11
|
-
def __init__(self, prompt_constructor):
|
12
|
-
self.prompt_constructor = prompt_constructor
|
13
|
-
|
14
|
-
@staticmethod
|
15
|
-
def _get_default_options() -> list:
|
16
|
-
"""Return default placeholder options."""
|
17
|
-
return [f"<< Option {i} - Placeholder >>" for i in range(1, 4)]
|
18
|
-
|
19
|
-
@staticmethod
|
20
|
-
def _parse_template_variable(template_str: str) -> str:
|
21
|
-
"""
|
22
|
-
Extract the variable name from a template string.
|
23
|
-
|
24
|
-
Args:
|
25
|
-
template_str (str): Jinja template string
|
26
|
-
|
27
|
-
Returns:
|
28
|
-
str: Name of the first undefined variable in the template
|
29
|
-
|
30
|
-
>>> QuestionOptionProcessor._parse_template_variable("Here are some {{ options }}")
|
31
|
-
'options'
|
32
|
-
>>> QuestionOptionProcessor._parse_template_variable("Here are some {{ options }} and {{ other }}")
|
33
|
-
Traceback (most recent call last):
|
34
|
-
...
|
35
|
-
ValueError: Multiple variables found in template string
|
36
|
-
>>> QuestionOptionProcessor._parse_template_variable("Here are some")
|
37
|
-
Traceback (most recent call last):
|
38
|
-
...
|
39
|
-
ValueError: No variables found in template string
|
40
|
-
"""
|
41
|
-
env = Environment()
|
42
|
-
parsed_content = env.parse(template_str)
|
43
|
-
undeclared_variables = list(meta.find_undeclared_variables(parsed_content))
|
44
|
-
if not undeclared_variables:
|
45
|
-
raise ValueError("No variables found in template string")
|
46
|
-
if len(undeclared_variables) > 1:
|
47
|
-
raise ValueError("Multiple variables found in template string")
|
48
|
-
return undeclared_variables[0]
|
49
|
-
|
50
|
-
@staticmethod
|
51
|
-
def _get_options_from_scenario(
|
52
|
-
scenario: dict, option_key: str
|
53
|
-
) -> Union[list, None]:
|
54
|
-
"""
|
55
|
-
Try to get options from scenario data.
|
56
|
-
|
57
|
-
>>> from edsl import Scenario
|
58
|
-
>>> scenario = Scenario({"options": ["Option 1", "Option 2"]})
|
59
|
-
>>> QuestionOptionProcessor._get_options_from_scenario(scenario, "options")
|
60
|
-
['Option 1', 'Option 2']
|
61
|
-
|
62
|
-
|
63
|
-
Returns:
|
64
|
-
list | None: List of options if found in scenario, None otherwise
|
65
|
-
"""
|
66
|
-
scenario_options = scenario.get(option_key)
|
67
|
-
return scenario_options if isinstance(scenario_options, list) else None
|
68
|
-
|
69
|
-
@staticmethod
|
70
|
-
def _get_options_from_prior_answers(
|
71
|
-
prior_answers: dict, option_key: str
|
72
|
-
) -> Union[list, None]:
|
73
|
-
"""
|
74
|
-
Try to get options from prior answers.
|
75
|
-
|
76
|
-
prior_answers (dict): Dictionary of prior answers
|
77
|
-
option_key (str): Key to look up in prior answers
|
78
|
-
|
79
|
-
>>> from edsl import QuestionList as Q
|
80
|
-
>>> q = Q.example()
|
81
|
-
>>> q.answer = ["Option 1", "Option 2"]
|
82
|
-
>>> prior_answers = {"options": q}
|
83
|
-
>>> QuestionOptionProcessor._get_options_from_prior_answers(prior_answers, "options")
|
84
|
-
['Option 1', 'Option 2']
|
85
|
-
>>> QuestionOptionProcessor._get_options_from_prior_answers(prior_answers, "wrong_key") is None
|
86
|
-
True
|
87
|
-
|
88
|
-
Returns:
|
89
|
-
list | None: List of options if found in prior answers, None otherwise
|
90
|
-
"""
|
91
|
-
prior_answer = prior_answers.get(option_key)
|
92
|
-
if prior_answer and hasattr(prior_answer, "answer"):
|
93
|
-
if isinstance(prior_answer.answer, list):
|
94
|
-
return prior_answer.answer
|
95
|
-
return None
|
96
|
-
|
97
|
-
def get_question_options(self, question_data: dict) -> list:
|
98
|
-
"""
|
99
|
-
Extract and process question options from question data.
|
100
|
-
|
101
|
-
Args:
|
102
|
-
question_data (dict): Dictionary containing question configuration
|
103
|
-
|
104
|
-
Returns:
|
105
|
-
list: List of question options. Returns default placeholders if no valid options found.
|
106
|
-
|
107
|
-
>>> class MockPromptConstructor:
|
108
|
-
... pass
|
109
|
-
>>> mpc = MockPromptConstructor()
|
110
|
-
>>> from edsl import Scenario
|
111
|
-
>>> mpc.scenario = Scenario({"options": ["Option 1", "Option 2"]})
|
112
|
-
>>> processor = QuestionOptionProcessor(mpc)
|
113
|
-
|
114
|
-
The basic case where options are directly provided:
|
115
|
-
|
116
|
-
>>> question_data = {"question_options": ["Option 1", "Option 2"]}
|
117
|
-
>>> processor.get_question_options(question_data)
|
118
|
-
['Option 1', 'Option 2']
|
119
|
-
|
120
|
-
The case where options are provided as a template string:
|
121
|
-
|
122
|
-
>>> question_data = {"question_options": "{{ options }}"}
|
123
|
-
>>> processor.get_question_options(question_data)
|
124
|
-
['Option 1', 'Option 2']
|
125
|
-
|
126
|
-
The case where there is a templace string but it's in the prior answers:
|
127
|
-
|
128
|
-
>>> class MockQuestion:
|
129
|
-
... pass
|
130
|
-
>>> q0 = MockQuestion()
|
131
|
-
>>> q0.answer = ["Option 1", "Option 2"]
|
132
|
-
>>> mpc.prior_answers_dict = lambda: {'q0': q0}
|
133
|
-
>>> processor = QuestionOptionProcessor(mpc)
|
134
|
-
>>> question_data = {"question_options": "{{ q0 }}"}
|
135
|
-
>>> processor.get_question_options(question_data)
|
136
|
-
['Option 1', 'Option 2']
|
137
|
-
|
138
|
-
The case we're no options are found:
|
139
|
-
>>> processor.get_question_options({"question_options": "{{ poop }}"})
|
140
|
-
['<< Option 1 - Placeholder >>', '<< Option 2 - Placeholder >>', '<< Option 3 - Placeholder >>']
|
141
|
-
|
142
|
-
"""
|
143
|
-
options_entry = question_data.get("question_options")
|
144
|
-
|
145
|
-
# If not a template string, return as is or default
|
146
|
-
if not isinstance(options_entry, str):
|
147
|
-
return options_entry if options_entry else self._get_default_options()
|
148
|
-
|
149
|
-
# Parse template to get variable name
|
150
|
-
option_key = self._parse_template_variable(options_entry)
|
151
|
-
|
152
|
-
# Try getting options from scenario
|
153
|
-
scenario_options = self._get_options_from_scenario(
|
154
|
-
self.prompt_constructor.scenario, option_key
|
155
|
-
)
|
156
|
-
if scenario_options:
|
157
|
-
return scenario_options
|
158
|
-
|
159
|
-
# Try getting options from prior answers
|
160
|
-
prior_answer_options = self._get_options_from_prior_answers(
|
161
|
-
self.prompt_constructor.prior_answers_dict(), option_key
|
162
|
-
)
|
163
|
-
if prior_answer_options:
|
164
|
-
return prior_answer_options
|
165
|
-
|
166
|
-
return self._get_default_options()
|
167
|
-
|
168
|
-
|
169
|
-
if __name__ == "__main__":
|
170
|
-
import doctest
|
171
|
-
|
172
|
-
doctest.testmod()
|
edsl/coop/CoopFunctionsMixin.py
DELETED
@@ -1,15 +0,0 @@
|
|
1
|
-
class CoopFunctionsMixin:
|
2
|
-
def better_names(self, existing_names):
|
3
|
-
from edsl import QuestionList, Scenario
|
4
|
-
|
5
|
-
s = Scenario({"existing_names": existing_names})
|
6
|
-
q = QuestionList(
|
7
|
-
question_text="""The following colum names are already in use: {{ existing_names }}
|
8
|
-
Please provide new names for the columns.
|
9
|
-
They should be short, one or two words, and unique. They should be valid Python idenifiers.
|
10
|
-
No spaces - use underscores instead.
|
11
|
-
""",
|
12
|
-
question_name="better_names",
|
13
|
-
)
|
14
|
-
results = q.by(s).run(verbose=False)
|
15
|
-
return results.select("answer.better_names").first()
|
@@ -1,125 +0,0 @@
|
|
1
|
-
from pathlib import Path
|
2
|
-
import os
|
3
|
-
import platformdirs
|
4
|
-
|
5
|
-
|
6
|
-
import sys
|
7
|
-
import select
|
8
|
-
|
9
|
-
|
10
|
-
def get_input_with_timeout(prompt, timeout=5, default="y"):
|
11
|
-
print(prompt, end="", flush=True)
|
12
|
-
ready, _, _ = select.select([sys.stdin], [], [], timeout)
|
13
|
-
if ready:
|
14
|
-
return sys.stdin.readline().strip()
|
15
|
-
print(f"\nNo input received within {timeout} seconds. Using default: {default}")
|
16
|
-
return default
|
17
|
-
|
18
|
-
|
19
|
-
class ExpectedParrotKeyHandler:
|
20
|
-
asked_to_store_file_name = "asked_to_store.txt"
|
21
|
-
ep_key_file_name = "ep_api_key.txt"
|
22
|
-
application_name = "edsl"
|
23
|
-
|
24
|
-
@property
|
25
|
-
def config_dir(self):
|
26
|
-
return platformdirs.user_config_dir(self.application_name)
|
27
|
-
|
28
|
-
def _ep_key_file_exists(self) -> bool:
|
29
|
-
"""Check if the Expected Parrot key file exists."""
|
30
|
-
return Path(self.config_dir).joinpath(self.ep_key_file_name).exists()
|
31
|
-
|
32
|
-
def ok_to_ask_to_store(self):
|
33
|
-
"""Check if it's okay to ask the user to store the key."""
|
34
|
-
from edsl.config import CONFIG
|
35
|
-
|
36
|
-
if CONFIG.get("EDSL_RUN_MODE") != "production":
|
37
|
-
return False
|
38
|
-
|
39
|
-
return (
|
40
|
-
not Path(self.config_dir).joinpath(self.asked_to_store_file_name).exists()
|
41
|
-
)
|
42
|
-
|
43
|
-
def reset_asked_to_store(self):
|
44
|
-
"""Reset the flag that indicates whether the user has been asked to store the key."""
|
45
|
-
asked_to_store_path = Path(self.config_dir).joinpath(
|
46
|
-
self.asked_to_store_file_name
|
47
|
-
)
|
48
|
-
if asked_to_store_path.exists():
|
49
|
-
os.remove(asked_to_store_path)
|
50
|
-
print(
|
51
|
-
"Deleted the file that indicates whether the user has been asked to store the key."
|
52
|
-
)
|
53
|
-
|
54
|
-
def ask_to_store(self, api_key) -> bool:
|
55
|
-
"""Ask the user if they want to store the Expected Parrot key. If they say "yes", store it."""
|
56
|
-
if self.ok_to_ask_to_store():
|
57
|
-
# can_we_store = get_input_with_timeout(
|
58
|
-
# "Would you like to store your Expected Parrot key for future use? (y/n): ",
|
59
|
-
# timeout=5,
|
60
|
-
# default="y",
|
61
|
-
# )
|
62
|
-
can_we_store = "y"
|
63
|
-
if can_we_store.lower() == "y":
|
64
|
-
Path(self.config_dir).mkdir(parents=True, exist_ok=True)
|
65
|
-
self.store_ep_api_key(api_key)
|
66
|
-
# print("Stored Expected Parrot API key at ", self.config_dir)
|
67
|
-
return True
|
68
|
-
else:
|
69
|
-
Path(self.config_dir).mkdir(parents=True, exist_ok=True)
|
70
|
-
with open(
|
71
|
-
Path(self.config_dir).joinpath(self.asked_to_store_file_name), "w"
|
72
|
-
) as f:
|
73
|
-
f.write("Yes")
|
74
|
-
return False
|
75
|
-
|
76
|
-
def get_ep_api_key(self):
|
77
|
-
# check if the key is stored in the config_dir
|
78
|
-
api_key = None
|
79
|
-
api_key_from_cache = None
|
80
|
-
api_key_from_os = None
|
81
|
-
|
82
|
-
if self._ep_key_file_exists():
|
83
|
-
with open(Path(self.config_dir).joinpath(self.ep_key_file_name), "r") as f:
|
84
|
-
api_key_from_cache = f.read().strip()
|
85
|
-
|
86
|
-
api_key_from_os = os.getenv("EXPECTED_PARROT_API_KEY")
|
87
|
-
|
88
|
-
if api_key_from_os and api_key_from_cache:
|
89
|
-
if api_key_from_os != api_key_from_cache:
|
90
|
-
import warnings
|
91
|
-
|
92
|
-
warnings.warn(
|
93
|
-
"WARNING: The Expected Parrot API key from the environment variable "
|
94
|
-
"differs from the one stored in the config directory. Using the one "
|
95
|
-
"from the environment variable."
|
96
|
-
)
|
97
|
-
api_key = api_key_from_os
|
98
|
-
|
99
|
-
if api_key_from_os and not api_key_from_cache:
|
100
|
-
api_key = api_key_from_os
|
101
|
-
|
102
|
-
if not api_key_from_os and api_key_from_cache:
|
103
|
-
api_key = api_key_from_cache
|
104
|
-
|
105
|
-
if api_key is not None:
|
106
|
-
_ = self.ask_to_store(api_key)
|
107
|
-
return api_key
|
108
|
-
|
109
|
-
def delete_ep_api_key(self):
|
110
|
-
key_path = Path(self.config_dir) / self.ep_key_file_name
|
111
|
-
if key_path.exists():
|
112
|
-
os.remove(key_path)
|
113
|
-
print("Deleted Expected Parrot API key at ", key_path)
|
114
|
-
|
115
|
-
def store_ep_api_key(self, api_key):
|
116
|
-
# Create the directory if it doesn't exist
|
117
|
-
os.makedirs(self.config_dir, exist_ok=True)
|
118
|
-
|
119
|
-
# Create the path for the key file
|
120
|
-
key_path = Path(self.config_dir) / self.ep_key_file_name
|
121
|
-
|
122
|
-
# Save the key
|
123
|
-
with open(key_path, "w") as f:
|
124
|
-
f.write(api_key)
|
125
|
-
# print("Stored Expected Parrot API key at ", key_path)
|
@@ -1,184 +0,0 @@
|
|
1
|
-
from typing import List, Optional, get_args, Union
|
2
|
-
from pathlib import Path
|
3
|
-
import sqlite3
|
4
|
-
from datetime import datetime
|
5
|
-
import tempfile
|
6
|
-
from platformdirs import user_cache_dir
|
7
|
-
from dataclasses import dataclass
|
8
|
-
import os
|
9
|
-
|
10
|
-
from edsl.inference_services.data_structures import LanguageModelInfo, AvailableModels
|
11
|
-
from edsl.enums import InferenceServiceLiteral
|
12
|
-
|
13
|
-
|
14
|
-
class AvailableModelCacheHandler:
|
15
|
-
MAX_ROWS = 1000
|
16
|
-
CACHE_VALIDITY_HOURS = 48
|
17
|
-
|
18
|
-
def __init__(
|
19
|
-
self,
|
20
|
-
cache_validity_hours: int = 48,
|
21
|
-
verbose: bool = False,
|
22
|
-
testing_db_name: str = None,
|
23
|
-
):
|
24
|
-
self.cache_validity_hours = cache_validity_hours
|
25
|
-
self.verbose = verbose
|
26
|
-
|
27
|
-
if testing_db_name:
|
28
|
-
self.cache_dir = Path(tempfile.mkdtemp())
|
29
|
-
self.db_path = self.cache_dir / testing_db_name
|
30
|
-
else:
|
31
|
-
self.cache_dir = Path(user_cache_dir("edsl", "model_availability"))
|
32
|
-
self.db_path = self.cache_dir / "available_models.db"
|
33
|
-
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
34
|
-
|
35
|
-
if os.path.exists(self.db_path):
|
36
|
-
if self.verbose:
|
37
|
-
print(f"Using existing cache DB: {self.db_path}")
|
38
|
-
else:
|
39
|
-
self._initialize_db()
|
40
|
-
|
41
|
-
@property
|
42
|
-
def path_to_db(self):
|
43
|
-
return self.db_path
|
44
|
-
|
45
|
-
def _initialize_db(self):
|
46
|
-
"""Initialize the SQLite database with the required schema."""
|
47
|
-
with sqlite3.connect(self.db_path) as conn:
|
48
|
-
cursor = conn.cursor()
|
49
|
-
# Drop the old table if it exists (for migration)
|
50
|
-
cursor.execute("DROP TABLE IF EXISTS model_cache")
|
51
|
-
cursor.execute(
|
52
|
-
"""
|
53
|
-
CREATE TABLE IF NOT EXISTS model_cache (
|
54
|
-
timestamp DATETIME NOT NULL,
|
55
|
-
model_name TEXT NOT NULL,
|
56
|
-
service_name TEXT NOT NULL,
|
57
|
-
UNIQUE(model_name, service_name)
|
58
|
-
)
|
59
|
-
"""
|
60
|
-
)
|
61
|
-
conn.commit()
|
62
|
-
|
63
|
-
def _prune_old_entries(self, conn: sqlite3.Connection):
|
64
|
-
"""Delete oldest entries when MAX_ROWS is exceeded."""
|
65
|
-
cursor = conn.cursor()
|
66
|
-
cursor.execute("SELECT COUNT(*) FROM model_cache")
|
67
|
-
count = cursor.fetchone()[0]
|
68
|
-
|
69
|
-
if count > self.MAX_ROWS:
|
70
|
-
cursor.execute(
|
71
|
-
"""
|
72
|
-
DELETE FROM model_cache
|
73
|
-
WHERE rowid IN (
|
74
|
-
SELECT rowid
|
75
|
-
FROM model_cache
|
76
|
-
ORDER BY timestamp ASC
|
77
|
-
LIMIT ?
|
78
|
-
)
|
79
|
-
""",
|
80
|
-
(count - self.MAX_ROWS,),
|
81
|
-
)
|
82
|
-
conn.commit()
|
83
|
-
|
84
|
-
@classmethod
|
85
|
-
def example_models(cls) -> List[LanguageModelInfo]:
|
86
|
-
return [
|
87
|
-
LanguageModelInfo(
|
88
|
-
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", "deep_infra"
|
89
|
-
),
|
90
|
-
LanguageModelInfo("openai/gpt-4", "openai"),
|
91
|
-
]
|
92
|
-
|
93
|
-
def add_models_to_cache(self, models_data: List[LanguageModelInfo]):
|
94
|
-
"""Add new models to the cache, updating timestamps for existing entries."""
|
95
|
-
current_time = datetime.now()
|
96
|
-
|
97
|
-
with sqlite3.connect(self.db_path) as conn:
|
98
|
-
cursor = conn.cursor()
|
99
|
-
for model in models_data:
|
100
|
-
cursor.execute(
|
101
|
-
"""
|
102
|
-
INSERT INTO model_cache (timestamp, model_name, service_name)
|
103
|
-
VALUES (?, ?, ?)
|
104
|
-
ON CONFLICT(model_name, service_name)
|
105
|
-
DO UPDATE SET timestamp = excluded.timestamp
|
106
|
-
""",
|
107
|
-
(current_time, model.model_name, model.service_name),
|
108
|
-
)
|
109
|
-
|
110
|
-
# self._prune_old_entries(conn)
|
111
|
-
conn.commit()
|
112
|
-
|
113
|
-
def reset_cache(self):
|
114
|
-
"""Clear all entries from the cache."""
|
115
|
-
with sqlite3.connect(self.db_path) as conn:
|
116
|
-
cursor = conn.cursor()
|
117
|
-
cursor.execute("DELETE FROM model_cache")
|
118
|
-
conn.commit()
|
119
|
-
|
120
|
-
@property
|
121
|
-
def num_cache_entries(self):
|
122
|
-
"""Return the number of entries in the cache."""
|
123
|
-
with sqlite3.connect(self.db_path) as conn:
|
124
|
-
cursor = conn.cursor()
|
125
|
-
cursor.execute("SELECT COUNT(*) FROM model_cache")
|
126
|
-
count = cursor.fetchone()[0]
|
127
|
-
return count
|
128
|
-
|
129
|
-
def models(
|
130
|
-
self,
|
131
|
-
service: Optional[InferenceServiceLiteral],
|
132
|
-
) -> Union[None, AvailableModels]:
|
133
|
-
"""Return the available models within the cache validity period."""
|
134
|
-
# if service is not None:
|
135
|
-
# assert service in get_args(InferenceServiceLiteral)
|
136
|
-
|
137
|
-
with sqlite3.connect(self.db_path) as conn:
|
138
|
-
cursor = conn.cursor()
|
139
|
-
valid_time = datetime.now().timestamp() - (self.cache_validity_hours * 3600)
|
140
|
-
|
141
|
-
if self.verbose:
|
142
|
-
print(f"Fetching all with timestamp greater than {valid_time}")
|
143
|
-
|
144
|
-
cursor.execute(
|
145
|
-
"""
|
146
|
-
SELECT DISTINCT model_name, service_name
|
147
|
-
FROM model_cache
|
148
|
-
WHERE timestamp > ?
|
149
|
-
ORDER BY timestamp DESC
|
150
|
-
""",
|
151
|
-
(valid_time,),
|
152
|
-
)
|
153
|
-
|
154
|
-
results = cursor.fetchall()
|
155
|
-
if not results:
|
156
|
-
if self.verbose:
|
157
|
-
print("No results found in cache DB.")
|
158
|
-
return None
|
159
|
-
|
160
|
-
matching_models = [
|
161
|
-
LanguageModelInfo(model_name=row[0], service_name=row[1])
|
162
|
-
for row in results
|
163
|
-
]
|
164
|
-
|
165
|
-
if self.verbose:
|
166
|
-
print(f"Found {len(matching_models)} models in cache DB.")
|
167
|
-
if service:
|
168
|
-
matching_models = [
|
169
|
-
model for model in matching_models if model.service_name == service
|
170
|
-
]
|
171
|
-
|
172
|
-
return AvailableModels(matching_models)
|
173
|
-
|
174
|
-
|
175
|
-
if __name__ == "__main__":
|
176
|
-
import doctest
|
177
|
-
|
178
|
-
doctest.testmod()
|
179
|
-
# cache_handler = AvailableModelCacheHandler(verbose=True)
|
180
|
-
# models_data = cache_handler.example_models()
|
181
|
-
# cache_handler.add_models_to_cache(models_data)
|
182
|
-
# print(cache_handler.models())
|
183
|
-
# cache_handler.clear_cache()
|
184
|
-
# print(cache_handler.models())
|