edsl 0.1.39__py3-none-any.whl → 0.1.39.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (212) hide show
  1. edsl/Base.py +116 -197
  2. edsl/__init__.py +7 -15
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +147 -351
  5. edsl/agents/AgentList.py +73 -211
  6. edsl/agents/Invigilator.py +50 -101
  7. edsl/agents/InvigilatorBase.py +70 -62
  8. edsl/agents/PromptConstructor.py +225 -143
  9. edsl/agents/__init__.py +1 -0
  10. edsl/agents/prompt_helpers.py +3 -3
  11. edsl/auto/AutoStudy.py +5 -18
  12. edsl/auto/StageBase.py +40 -53
  13. edsl/auto/StageQuestions.py +1 -2
  14. edsl/auto/utilities.py +6 -0
  15. edsl/config.py +2 -22
  16. edsl/conversation/car_buying.py +1 -2
  17. edsl/coop/PriceFetcher.py +1 -1
  18. edsl/coop/coop.py +47 -125
  19. edsl/coop/utils.py +14 -14
  20. edsl/data/Cache.py +27 -45
  21. edsl/data/CacheEntry.py +15 -12
  22. edsl/data/CacheHandler.py +12 -31
  23. edsl/data/RemoteCacheSync.py +46 -154
  24. edsl/data/__init__.py +3 -4
  25. edsl/data_transfer_models.py +1 -2
  26. edsl/enums.py +0 -27
  27. edsl/exceptions/__init__.py +50 -50
  28. edsl/exceptions/agents.py +0 -12
  29. edsl/exceptions/questions.py +6 -24
  30. edsl/exceptions/scenarios.py +0 -7
  31. edsl/inference_services/AnthropicService.py +19 -38
  32. edsl/inference_services/AwsBedrock.py +2 -0
  33. edsl/inference_services/AzureAI.py +2 -0
  34. edsl/inference_services/GoogleService.py +12 -7
  35. edsl/inference_services/InferenceServiceABC.py +85 -18
  36. edsl/inference_services/InferenceServicesCollection.py +79 -120
  37. edsl/inference_services/MistralAIService.py +3 -0
  38. edsl/inference_services/OpenAIService.py +35 -47
  39. edsl/inference_services/PerplexityService.py +3 -0
  40. edsl/inference_services/TestService.py +10 -11
  41. edsl/inference_services/TogetherAIService.py +3 -5
  42. edsl/jobs/Answers.py +14 -1
  43. edsl/jobs/Jobs.py +431 -356
  44. edsl/jobs/JobsChecks.py +10 -35
  45. edsl/jobs/JobsPrompts.py +4 -6
  46. edsl/jobs/JobsRemoteInferenceHandler.py +133 -205
  47. edsl/jobs/buckets/BucketCollection.py +3 -44
  48. edsl/jobs/buckets/TokenBucket.py +21 -53
  49. edsl/jobs/interviews/Interview.py +408 -143
  50. edsl/jobs/runners/JobsRunnerAsyncio.py +403 -88
  51. edsl/jobs/runners/JobsRunnerStatus.py +165 -133
  52. edsl/jobs/tasks/QuestionTaskCreator.py +19 -21
  53. edsl/jobs/tasks/TaskHistory.py +18 -38
  54. edsl/jobs/tasks/task_status_enum.py +2 -0
  55. edsl/language_models/KeyLookup.py +30 -0
  56. edsl/language_models/LanguageModel.py +236 -194
  57. edsl/language_models/ModelList.py +19 -28
  58. edsl/language_models/__init__.py +2 -1
  59. edsl/language_models/registry.py +190 -0
  60. edsl/language_models/repair.py +2 -2
  61. edsl/language_models/unused/ReplicateBase.py +83 -0
  62. edsl/language_models/utilities.py +4 -5
  63. edsl/notebooks/Notebook.py +14 -19
  64. edsl/prompts/Prompt.py +39 -29
  65. edsl/questions/{answer_validator_mixin.py → AnswerValidatorMixin.py} +2 -47
  66. edsl/questions/QuestionBase.py +214 -68
  67. edsl/questions/{question_base_gen_mixin.py → QuestionBaseGenMixin.py} +50 -57
  68. edsl/questions/QuestionBasePromptsMixin.py +3 -7
  69. edsl/questions/QuestionBudget.py +1 -1
  70. edsl/questions/QuestionCheckBox.py +3 -3
  71. edsl/questions/QuestionExtract.py +7 -5
  72. edsl/questions/QuestionFreeText.py +3 -2
  73. edsl/questions/QuestionList.py +18 -10
  74. edsl/questions/QuestionMultipleChoice.py +23 -67
  75. edsl/questions/QuestionNumerical.py +4 -2
  76. edsl/questions/QuestionRank.py +17 -7
  77. edsl/questions/{response_validator_abc.py → ResponseValidatorABC.py} +26 -40
  78. edsl/questions/SimpleAskMixin.py +3 -4
  79. edsl/questions/__init__.py +1 -2
  80. edsl/questions/derived/QuestionLinearScale.py +3 -6
  81. edsl/questions/derived/QuestionTopK.py +1 -1
  82. edsl/questions/descriptors.py +3 -17
  83. edsl/questions/question_registry.py +1 -1
  84. edsl/results/CSSParameterizer.py +1 -1
  85. edsl/results/Dataset.py +7 -170
  86. edsl/results/DatasetExportMixin.py +305 -168
  87. edsl/results/DatasetTree.py +8 -28
  88. edsl/results/Result.py +206 -298
  89. edsl/results/Results.py +131 -149
  90. edsl/results/ResultsDBMixin.py +238 -0
  91. edsl/results/ResultsExportMixin.py +0 -2
  92. edsl/results/{results_selector.py → Selector.py} +13 -23
  93. edsl/results/TableDisplay.py +171 -98
  94. edsl/results/__init__.py +1 -1
  95. edsl/scenarios/FileStore.py +239 -150
  96. edsl/scenarios/Scenario.py +193 -90
  97. edsl/scenarios/ScenarioHtmlMixin.py +3 -4
  98. edsl/scenarios/{scenario_join.py → ScenarioJoin.py} +6 -10
  99. edsl/scenarios/ScenarioList.py +244 -415
  100. edsl/scenarios/ScenarioListExportMixin.py +7 -0
  101. edsl/scenarios/ScenarioListPdfMixin.py +37 -15
  102. edsl/scenarios/__init__.py +2 -1
  103. edsl/study/ObjectEntry.py +1 -1
  104. edsl/study/SnapShot.py +1 -1
  105. edsl/study/Study.py +12 -5
  106. edsl/surveys/Rule.py +4 -5
  107. edsl/surveys/RuleCollection.py +27 -25
  108. edsl/surveys/Survey.py +791 -270
  109. edsl/surveys/SurveyCSS.py +8 -20
  110. edsl/surveys/{SurveyFlowVisualization.py → SurveyFlowVisualizationMixin.py} +9 -11
  111. edsl/surveys/__init__.py +2 -4
  112. edsl/surveys/descriptors.py +2 -6
  113. edsl/surveys/instructions/ChangeInstruction.py +2 -1
  114. edsl/surveys/instructions/Instruction.py +13 -4
  115. edsl/surveys/instructions/InstructionCollection.py +6 -11
  116. edsl/templates/error_reporting/interview_details.html +1 -1
  117. edsl/templates/error_reporting/report.html +1 -1
  118. edsl/tools/plotting.py +1 -1
  119. edsl/utilities/utilities.py +23 -35
  120. {edsl-0.1.39.dist-info → edsl-0.1.39.dev1.dist-info}/METADATA +10 -12
  121. edsl-0.1.39.dev1.dist-info/RECORD +277 -0
  122. {edsl-0.1.39.dist-info → edsl-0.1.39.dev1.dist-info}/WHEEL +1 -1
  123. edsl/agents/QuestionInstructionPromptBuilder.py +0 -128
  124. edsl/agents/QuestionTemplateReplacementsBuilder.py +0 -137
  125. edsl/agents/question_option_processor.py +0 -172
  126. edsl/coop/CoopFunctionsMixin.py +0 -15
  127. edsl/coop/ExpectedParrotKeyHandler.py +0 -125
  128. edsl/exceptions/inference_services.py +0 -5
  129. edsl/inference_services/AvailableModelCacheHandler.py +0 -184
  130. edsl/inference_services/AvailableModelFetcher.py +0 -215
  131. edsl/inference_services/ServiceAvailability.py +0 -135
  132. edsl/inference_services/data_structures.py +0 -134
  133. edsl/jobs/AnswerQuestionFunctionConstructor.py +0 -223
  134. edsl/jobs/FetchInvigilator.py +0 -47
  135. edsl/jobs/InterviewTaskManager.py +0 -98
  136. edsl/jobs/InterviewsConstructor.py +0 -50
  137. edsl/jobs/JobsComponentConstructor.py +0 -189
  138. edsl/jobs/JobsRemoteInferenceLogger.py +0 -239
  139. edsl/jobs/RequestTokenEstimator.py +0 -30
  140. edsl/jobs/async_interview_runner.py +0 -138
  141. edsl/jobs/buckets/TokenBucketAPI.py +0 -211
  142. edsl/jobs/buckets/TokenBucketClient.py +0 -191
  143. edsl/jobs/check_survey_scenario_compatibility.py +0 -85
  144. edsl/jobs/data_structures.py +0 -120
  145. edsl/jobs/decorators.py +0 -35
  146. edsl/jobs/jobs_status_enums.py +0 -9
  147. edsl/jobs/loggers/HTMLTableJobLogger.py +0 -304
  148. edsl/jobs/results_exceptions_handler.py +0 -98
  149. edsl/language_models/ComputeCost.py +0 -63
  150. edsl/language_models/PriceManager.py +0 -127
  151. edsl/language_models/RawResponseHandler.py +0 -106
  152. edsl/language_models/ServiceDataSources.py +0 -0
  153. edsl/language_models/key_management/KeyLookup.py +0 -63
  154. edsl/language_models/key_management/KeyLookupBuilder.py +0 -273
  155. edsl/language_models/key_management/KeyLookupCollection.py +0 -38
  156. edsl/language_models/key_management/__init__.py +0 -0
  157. edsl/language_models/key_management/models.py +0 -131
  158. edsl/language_models/model.py +0 -256
  159. edsl/notebooks/NotebookToLaTeX.py +0 -142
  160. edsl/questions/ExceptionExplainer.py +0 -77
  161. edsl/questions/HTMLQuestion.py +0 -103
  162. edsl/questions/QuestionMatrix.py +0 -265
  163. edsl/questions/data_structures.py +0 -20
  164. edsl/questions/loop_processor.py +0 -149
  165. edsl/questions/response_validator_factory.py +0 -34
  166. edsl/questions/templates/matrix/__init__.py +0 -1
  167. edsl/questions/templates/matrix/answering_instructions.jinja +0 -5
  168. edsl/questions/templates/matrix/question_presentation.jinja +0 -20
  169. edsl/results/MarkdownToDocx.py +0 -122
  170. edsl/results/MarkdownToPDF.py +0 -111
  171. edsl/results/TextEditor.py +0 -50
  172. edsl/results/file_exports.py +0 -252
  173. edsl/results/smart_objects.py +0 -96
  174. edsl/results/table_data_class.py +0 -12
  175. edsl/results/table_renderers.py +0 -118
  176. edsl/scenarios/ConstructDownloadLink.py +0 -109
  177. edsl/scenarios/DocumentChunker.py +0 -102
  178. edsl/scenarios/DocxScenario.py +0 -16
  179. edsl/scenarios/PdfExtractor.py +0 -40
  180. edsl/scenarios/directory_scanner.py +0 -96
  181. edsl/scenarios/file_methods.py +0 -85
  182. edsl/scenarios/handlers/__init__.py +0 -13
  183. edsl/scenarios/handlers/csv.py +0 -49
  184. edsl/scenarios/handlers/docx.py +0 -76
  185. edsl/scenarios/handlers/html.py +0 -37
  186. edsl/scenarios/handlers/json.py +0 -111
  187. edsl/scenarios/handlers/latex.py +0 -5
  188. edsl/scenarios/handlers/md.py +0 -51
  189. edsl/scenarios/handlers/pdf.py +0 -68
  190. edsl/scenarios/handlers/png.py +0 -39
  191. edsl/scenarios/handlers/pptx.py +0 -105
  192. edsl/scenarios/handlers/py.py +0 -294
  193. edsl/scenarios/handlers/sql.py +0 -313
  194. edsl/scenarios/handlers/sqlite.py +0 -149
  195. edsl/scenarios/handlers/txt.py +0 -33
  196. edsl/scenarios/scenario_selector.py +0 -156
  197. edsl/surveys/ConstructDAG.py +0 -92
  198. edsl/surveys/EditSurvey.py +0 -221
  199. edsl/surveys/InstructionHandler.py +0 -100
  200. edsl/surveys/MemoryManagement.py +0 -72
  201. edsl/surveys/RuleManager.py +0 -172
  202. edsl/surveys/Simulator.py +0 -75
  203. edsl/surveys/SurveyToApp.py +0 -141
  204. edsl/utilities/PrettyList.py +0 -56
  205. edsl/utilities/is_notebook.py +0 -18
  206. edsl/utilities/is_valid_variable_name.py +0 -11
  207. edsl/utilities/remove_edsl_version.py +0 -24
  208. edsl-0.1.39.dist-info/RECORD +0 -358
  209. /edsl/questions/{register_questions_meta.py → RegisterQuestionsMeta.py} +0 -0
  210. /edsl/results/{results_fetch_mixin.py → ResultsFetchMixin.py} +0 -0
  211. /edsl/results/{results_tools_mixin.py → ResultsToolsMixin.py} +0 -0
  212. {edsl-0.1.39.dist-info → edsl-0.1.39.dev1.dist-info}/LICENSE +0 -0
@@ -1,172 +0,0 @@
1
- from jinja2 import Environment, meta
2
- from typing import List, Optional, Union
3
-
4
-
5
- class QuestionOptionProcessor:
6
- """
7
- Class that manages the processing of question options.
8
- These can be provided directly, as a template string, or fetched from prior answers or the scenario.
9
- """
10
-
11
- def __init__(self, prompt_constructor):
12
- self.prompt_constructor = prompt_constructor
13
-
14
- @staticmethod
15
- def _get_default_options() -> list:
16
- """Return default placeholder options."""
17
- return [f"<< Option {i} - Placeholder >>" for i in range(1, 4)]
18
-
19
- @staticmethod
20
- def _parse_template_variable(template_str: str) -> str:
21
- """
22
- Extract the variable name from a template string.
23
-
24
- Args:
25
- template_str (str): Jinja template string
26
-
27
- Returns:
28
- str: Name of the first undefined variable in the template
29
-
30
- >>> QuestionOptionProcessor._parse_template_variable("Here are some {{ options }}")
31
- 'options'
32
- >>> QuestionOptionProcessor._parse_template_variable("Here are some {{ options }} and {{ other }}")
33
- Traceback (most recent call last):
34
- ...
35
- ValueError: Multiple variables found in template string
36
- >>> QuestionOptionProcessor._parse_template_variable("Here are some")
37
- Traceback (most recent call last):
38
- ...
39
- ValueError: No variables found in template string
40
- """
41
- env = Environment()
42
- parsed_content = env.parse(template_str)
43
- undeclared_variables = list(meta.find_undeclared_variables(parsed_content))
44
- if not undeclared_variables:
45
- raise ValueError("No variables found in template string")
46
- if len(undeclared_variables) > 1:
47
- raise ValueError("Multiple variables found in template string")
48
- return undeclared_variables[0]
49
-
50
- @staticmethod
51
- def _get_options_from_scenario(
52
- scenario: dict, option_key: str
53
- ) -> Union[list, None]:
54
- """
55
- Try to get options from scenario data.
56
-
57
- >>> from edsl import Scenario
58
- >>> scenario = Scenario({"options": ["Option 1", "Option 2"]})
59
- >>> QuestionOptionProcessor._get_options_from_scenario(scenario, "options")
60
- ['Option 1', 'Option 2']
61
-
62
-
63
- Returns:
64
- list | None: List of options if found in scenario, None otherwise
65
- """
66
- scenario_options = scenario.get(option_key)
67
- return scenario_options if isinstance(scenario_options, list) else None
68
-
69
- @staticmethod
70
- def _get_options_from_prior_answers(
71
- prior_answers: dict, option_key: str
72
- ) -> Union[list, None]:
73
- """
74
- Try to get options from prior answers.
75
-
76
- prior_answers (dict): Dictionary of prior answers
77
- option_key (str): Key to look up in prior answers
78
-
79
- >>> from edsl import QuestionList as Q
80
- >>> q = Q.example()
81
- >>> q.answer = ["Option 1", "Option 2"]
82
- >>> prior_answers = {"options": q}
83
- >>> QuestionOptionProcessor._get_options_from_prior_answers(prior_answers, "options")
84
- ['Option 1', 'Option 2']
85
- >>> QuestionOptionProcessor._get_options_from_prior_answers(prior_answers, "wrong_key") is None
86
- True
87
-
88
- Returns:
89
- list | None: List of options if found in prior answers, None otherwise
90
- """
91
- prior_answer = prior_answers.get(option_key)
92
- if prior_answer and hasattr(prior_answer, "answer"):
93
- if isinstance(prior_answer.answer, list):
94
- return prior_answer.answer
95
- return None
96
-
97
- def get_question_options(self, question_data: dict) -> list:
98
- """
99
- Extract and process question options from question data.
100
-
101
- Args:
102
- question_data (dict): Dictionary containing question configuration
103
-
104
- Returns:
105
- list: List of question options. Returns default placeholders if no valid options found.
106
-
107
- >>> class MockPromptConstructor:
108
- ... pass
109
- >>> mpc = MockPromptConstructor()
110
- >>> from edsl import Scenario
111
- >>> mpc.scenario = Scenario({"options": ["Option 1", "Option 2"]})
112
- >>> processor = QuestionOptionProcessor(mpc)
113
-
114
- The basic case where options are directly provided:
115
-
116
- >>> question_data = {"question_options": ["Option 1", "Option 2"]}
117
- >>> processor.get_question_options(question_data)
118
- ['Option 1', 'Option 2']
119
-
120
- The case where options are provided as a template string:
121
-
122
- >>> question_data = {"question_options": "{{ options }}"}
123
- >>> processor.get_question_options(question_data)
124
- ['Option 1', 'Option 2']
125
-
126
- The case where there is a templace string but it's in the prior answers:
127
-
128
- >>> class MockQuestion:
129
- ... pass
130
- >>> q0 = MockQuestion()
131
- >>> q0.answer = ["Option 1", "Option 2"]
132
- >>> mpc.prior_answers_dict = lambda: {'q0': q0}
133
- >>> processor = QuestionOptionProcessor(mpc)
134
- >>> question_data = {"question_options": "{{ q0 }}"}
135
- >>> processor.get_question_options(question_data)
136
- ['Option 1', 'Option 2']
137
-
138
- The case we're no options are found:
139
- >>> processor.get_question_options({"question_options": "{{ poop }}"})
140
- ['<< Option 1 - Placeholder >>', '<< Option 2 - Placeholder >>', '<< Option 3 - Placeholder >>']
141
-
142
- """
143
- options_entry = question_data.get("question_options")
144
-
145
- # If not a template string, return as is or default
146
- if not isinstance(options_entry, str):
147
- return options_entry if options_entry else self._get_default_options()
148
-
149
- # Parse template to get variable name
150
- option_key = self._parse_template_variable(options_entry)
151
-
152
- # Try getting options from scenario
153
- scenario_options = self._get_options_from_scenario(
154
- self.prompt_constructor.scenario, option_key
155
- )
156
- if scenario_options:
157
- return scenario_options
158
-
159
- # Try getting options from prior answers
160
- prior_answer_options = self._get_options_from_prior_answers(
161
- self.prompt_constructor.prior_answers_dict(), option_key
162
- )
163
- if prior_answer_options:
164
- return prior_answer_options
165
-
166
- return self._get_default_options()
167
-
168
-
169
- if __name__ == "__main__":
170
- import doctest
171
-
172
- doctest.testmod()
@@ -1,15 +0,0 @@
1
- class CoopFunctionsMixin:
2
- def better_names(self, existing_names):
3
- from edsl import QuestionList, Scenario
4
-
5
- s = Scenario({"existing_names": existing_names})
6
- q = QuestionList(
7
- question_text="""The following colum names are already in use: {{ existing_names }}
8
- Please provide new names for the columns.
9
- They should be short, one or two words, and unique. They should be valid Python idenifiers.
10
- No spaces - use underscores instead.
11
- """,
12
- question_name="better_names",
13
- )
14
- results = q.by(s).run(verbose=False)
15
- return results.select("answer.better_names").first()
@@ -1,125 +0,0 @@
1
- from pathlib import Path
2
- import os
3
- import platformdirs
4
-
5
-
6
- import sys
7
- import select
8
-
9
-
10
- def get_input_with_timeout(prompt, timeout=5, default="y"):
11
- print(prompt, end="", flush=True)
12
- ready, _, _ = select.select([sys.stdin], [], [], timeout)
13
- if ready:
14
- return sys.stdin.readline().strip()
15
- print(f"\nNo input received within {timeout} seconds. Using default: {default}")
16
- return default
17
-
18
-
19
- class ExpectedParrotKeyHandler:
20
- asked_to_store_file_name = "asked_to_store.txt"
21
- ep_key_file_name = "ep_api_key.txt"
22
- application_name = "edsl"
23
-
24
- @property
25
- def config_dir(self):
26
- return platformdirs.user_config_dir(self.application_name)
27
-
28
- def _ep_key_file_exists(self) -> bool:
29
- """Check if the Expected Parrot key file exists."""
30
- return Path(self.config_dir).joinpath(self.ep_key_file_name).exists()
31
-
32
- def ok_to_ask_to_store(self):
33
- """Check if it's okay to ask the user to store the key."""
34
- from edsl.config import CONFIG
35
-
36
- if CONFIG.get("EDSL_RUN_MODE") != "production":
37
- return False
38
-
39
- return (
40
- not Path(self.config_dir).joinpath(self.asked_to_store_file_name).exists()
41
- )
42
-
43
- def reset_asked_to_store(self):
44
- """Reset the flag that indicates whether the user has been asked to store the key."""
45
- asked_to_store_path = Path(self.config_dir).joinpath(
46
- self.asked_to_store_file_name
47
- )
48
- if asked_to_store_path.exists():
49
- os.remove(asked_to_store_path)
50
- print(
51
- "Deleted the file that indicates whether the user has been asked to store the key."
52
- )
53
-
54
- def ask_to_store(self, api_key) -> bool:
55
- """Ask the user if they want to store the Expected Parrot key. If they say "yes", store it."""
56
- if self.ok_to_ask_to_store():
57
- # can_we_store = get_input_with_timeout(
58
- # "Would you like to store your Expected Parrot key for future use? (y/n): ",
59
- # timeout=5,
60
- # default="y",
61
- # )
62
- can_we_store = "y"
63
- if can_we_store.lower() == "y":
64
- Path(self.config_dir).mkdir(parents=True, exist_ok=True)
65
- self.store_ep_api_key(api_key)
66
- # print("Stored Expected Parrot API key at ", self.config_dir)
67
- return True
68
- else:
69
- Path(self.config_dir).mkdir(parents=True, exist_ok=True)
70
- with open(
71
- Path(self.config_dir).joinpath(self.asked_to_store_file_name), "w"
72
- ) as f:
73
- f.write("Yes")
74
- return False
75
-
76
- def get_ep_api_key(self):
77
- # check if the key is stored in the config_dir
78
- api_key = None
79
- api_key_from_cache = None
80
- api_key_from_os = None
81
-
82
- if self._ep_key_file_exists():
83
- with open(Path(self.config_dir).joinpath(self.ep_key_file_name), "r") as f:
84
- api_key_from_cache = f.read().strip()
85
-
86
- api_key_from_os = os.getenv("EXPECTED_PARROT_API_KEY")
87
-
88
- if api_key_from_os and api_key_from_cache:
89
- if api_key_from_os != api_key_from_cache:
90
- import warnings
91
-
92
- warnings.warn(
93
- "WARNING: The Expected Parrot API key from the environment variable "
94
- "differs from the one stored in the config directory. Using the one "
95
- "from the environment variable."
96
- )
97
- api_key = api_key_from_os
98
-
99
- if api_key_from_os and not api_key_from_cache:
100
- api_key = api_key_from_os
101
-
102
- if not api_key_from_os and api_key_from_cache:
103
- api_key = api_key_from_cache
104
-
105
- if api_key is not None:
106
- _ = self.ask_to_store(api_key)
107
- return api_key
108
-
109
- def delete_ep_api_key(self):
110
- key_path = Path(self.config_dir) / self.ep_key_file_name
111
- if key_path.exists():
112
- os.remove(key_path)
113
- print("Deleted Expected Parrot API key at ", key_path)
114
-
115
- def store_ep_api_key(self, api_key):
116
- # Create the directory if it doesn't exist
117
- os.makedirs(self.config_dir, exist_ok=True)
118
-
119
- # Create the path for the key file
120
- key_path = Path(self.config_dir) / self.ep_key_file_name
121
-
122
- # Save the key
123
- with open(key_path, "w") as f:
124
- f.write(api_key)
125
- # print("Stored Expected Parrot API key at ", key_path)
@@ -1,5 +0,0 @@
1
- from edsl.exceptions.BaseException import BaseException
2
-
3
-
4
- class InferenceServiceError(BaseException):
5
- relevant_doc = "https://docs.expectedparrot.com/"
@@ -1,184 +0,0 @@
1
- from typing import List, Optional, get_args, Union
2
- from pathlib import Path
3
- import sqlite3
4
- from datetime import datetime
5
- import tempfile
6
- from platformdirs import user_cache_dir
7
- from dataclasses import dataclass
8
- import os
9
-
10
- from edsl.inference_services.data_structures import LanguageModelInfo, AvailableModels
11
- from edsl.enums import InferenceServiceLiteral
12
-
13
-
14
- class AvailableModelCacheHandler:
15
- MAX_ROWS = 1000
16
- CACHE_VALIDITY_HOURS = 48
17
-
18
- def __init__(
19
- self,
20
- cache_validity_hours: int = 48,
21
- verbose: bool = False,
22
- testing_db_name: str = None,
23
- ):
24
- self.cache_validity_hours = cache_validity_hours
25
- self.verbose = verbose
26
-
27
- if testing_db_name:
28
- self.cache_dir = Path(tempfile.mkdtemp())
29
- self.db_path = self.cache_dir / testing_db_name
30
- else:
31
- self.cache_dir = Path(user_cache_dir("edsl", "model_availability"))
32
- self.db_path = self.cache_dir / "available_models.db"
33
- self.cache_dir.mkdir(parents=True, exist_ok=True)
34
-
35
- if os.path.exists(self.db_path):
36
- if self.verbose:
37
- print(f"Using existing cache DB: {self.db_path}")
38
- else:
39
- self._initialize_db()
40
-
41
- @property
42
- def path_to_db(self):
43
- return self.db_path
44
-
45
- def _initialize_db(self):
46
- """Initialize the SQLite database with the required schema."""
47
- with sqlite3.connect(self.db_path) as conn:
48
- cursor = conn.cursor()
49
- # Drop the old table if it exists (for migration)
50
- cursor.execute("DROP TABLE IF EXISTS model_cache")
51
- cursor.execute(
52
- """
53
- CREATE TABLE IF NOT EXISTS model_cache (
54
- timestamp DATETIME NOT NULL,
55
- model_name TEXT NOT NULL,
56
- service_name TEXT NOT NULL,
57
- UNIQUE(model_name, service_name)
58
- )
59
- """
60
- )
61
- conn.commit()
62
-
63
- def _prune_old_entries(self, conn: sqlite3.Connection):
64
- """Delete oldest entries when MAX_ROWS is exceeded."""
65
- cursor = conn.cursor()
66
- cursor.execute("SELECT COUNT(*) FROM model_cache")
67
- count = cursor.fetchone()[0]
68
-
69
- if count > self.MAX_ROWS:
70
- cursor.execute(
71
- """
72
- DELETE FROM model_cache
73
- WHERE rowid IN (
74
- SELECT rowid
75
- FROM model_cache
76
- ORDER BY timestamp ASC
77
- LIMIT ?
78
- )
79
- """,
80
- (count - self.MAX_ROWS,),
81
- )
82
- conn.commit()
83
-
84
- @classmethod
85
- def example_models(cls) -> List[LanguageModelInfo]:
86
- return [
87
- LanguageModelInfo(
88
- "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", "deep_infra"
89
- ),
90
- LanguageModelInfo("openai/gpt-4", "openai"),
91
- ]
92
-
93
- def add_models_to_cache(self, models_data: List[LanguageModelInfo]):
94
- """Add new models to the cache, updating timestamps for existing entries."""
95
- current_time = datetime.now()
96
-
97
- with sqlite3.connect(self.db_path) as conn:
98
- cursor = conn.cursor()
99
- for model in models_data:
100
- cursor.execute(
101
- """
102
- INSERT INTO model_cache (timestamp, model_name, service_name)
103
- VALUES (?, ?, ?)
104
- ON CONFLICT(model_name, service_name)
105
- DO UPDATE SET timestamp = excluded.timestamp
106
- """,
107
- (current_time, model.model_name, model.service_name),
108
- )
109
-
110
- # self._prune_old_entries(conn)
111
- conn.commit()
112
-
113
- def reset_cache(self):
114
- """Clear all entries from the cache."""
115
- with sqlite3.connect(self.db_path) as conn:
116
- cursor = conn.cursor()
117
- cursor.execute("DELETE FROM model_cache")
118
- conn.commit()
119
-
120
- @property
121
- def num_cache_entries(self):
122
- """Return the number of entries in the cache."""
123
- with sqlite3.connect(self.db_path) as conn:
124
- cursor = conn.cursor()
125
- cursor.execute("SELECT COUNT(*) FROM model_cache")
126
- count = cursor.fetchone()[0]
127
- return count
128
-
129
- def models(
130
- self,
131
- service: Optional[InferenceServiceLiteral],
132
- ) -> Union[None, AvailableModels]:
133
- """Return the available models within the cache validity period."""
134
- # if service is not None:
135
- # assert service in get_args(InferenceServiceLiteral)
136
-
137
- with sqlite3.connect(self.db_path) as conn:
138
- cursor = conn.cursor()
139
- valid_time = datetime.now().timestamp() - (self.cache_validity_hours * 3600)
140
-
141
- if self.verbose:
142
- print(f"Fetching all with timestamp greater than {valid_time}")
143
-
144
- cursor.execute(
145
- """
146
- SELECT DISTINCT model_name, service_name
147
- FROM model_cache
148
- WHERE timestamp > ?
149
- ORDER BY timestamp DESC
150
- """,
151
- (valid_time,),
152
- )
153
-
154
- results = cursor.fetchall()
155
- if not results:
156
- if self.verbose:
157
- print("No results found in cache DB.")
158
- return None
159
-
160
- matching_models = [
161
- LanguageModelInfo(model_name=row[0], service_name=row[1])
162
- for row in results
163
- ]
164
-
165
- if self.verbose:
166
- print(f"Found {len(matching_models)} models in cache DB.")
167
- if service:
168
- matching_models = [
169
- model for model in matching_models if model.service_name == service
170
- ]
171
-
172
- return AvailableModels(matching_models)
173
-
174
-
175
- if __name__ == "__main__":
176
- import doctest
177
-
178
- doctest.testmod()
179
- # cache_handler = AvailableModelCacheHandler(verbose=True)
180
- # models_data = cache_handler.example_models()
181
- # cache_handler.add_models_to_cache(models_data)
182
- # print(cache_handler.models())
183
- # cache_handler.clear_cache()
184
- # print(cache_handler.models())