edsl 0.1.39.dev3__py3-none-any.whl → 0.1.39.dev5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (341) hide show
  1. edsl/Base.py +413 -332
  2. edsl/BaseDiff.py +260 -260
  3. edsl/TemplateLoader.py +24 -24
  4. edsl/__init__.py +57 -49
  5. edsl/__version__.py +1 -1
  6. edsl/agents/Agent.py +1071 -867
  7. edsl/agents/AgentList.py +551 -413
  8. edsl/agents/Invigilator.py +284 -233
  9. edsl/agents/InvigilatorBase.py +257 -270
  10. edsl/agents/PromptConstructor.py +272 -354
  11. edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
  12. edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
  13. edsl/agents/__init__.py +2 -3
  14. edsl/agents/descriptors.py +99 -99
  15. edsl/agents/prompt_helpers.py +129 -129
  16. edsl/agents/question_option_processor.py +172 -0
  17. edsl/auto/AutoStudy.py +130 -117
  18. edsl/auto/StageBase.py +243 -230
  19. edsl/auto/StageGenerateSurvey.py +178 -178
  20. edsl/auto/StageLabelQuestions.py +125 -125
  21. edsl/auto/StagePersona.py +61 -61
  22. edsl/auto/StagePersonaDimensionValueRanges.py +88 -88
  23. edsl/auto/StagePersonaDimensionValues.py +74 -74
  24. edsl/auto/StagePersonaDimensions.py +69 -69
  25. edsl/auto/StageQuestions.py +74 -73
  26. edsl/auto/SurveyCreatorPipeline.py +21 -21
  27. edsl/auto/utilities.py +218 -224
  28. edsl/base/Base.py +279 -279
  29. edsl/config.py +177 -157
  30. edsl/conversation/Conversation.py +290 -290
  31. edsl/conversation/car_buying.py +59 -58
  32. edsl/conversation/chips.py +95 -95
  33. edsl/conversation/mug_negotiation.py +81 -81
  34. edsl/conversation/next_speaker_utilities.py +93 -93
  35. edsl/coop/CoopFunctionsMixin.py +15 -0
  36. edsl/coop/ExpectedParrotKeyHandler.py +125 -0
  37. edsl/coop/PriceFetcher.py +54 -54
  38. edsl/coop/__init__.py +2 -2
  39. edsl/coop/coop.py +1106 -1028
  40. edsl/coop/utils.py +131 -131
  41. edsl/data/Cache.py +573 -555
  42. edsl/data/CacheEntry.py +230 -233
  43. edsl/data/CacheHandler.py +168 -149
  44. edsl/data/RemoteCacheSync.py +186 -78
  45. edsl/data/SQLiteDict.py +292 -292
  46. edsl/data/__init__.py +5 -4
  47. edsl/data/orm.py +10 -10
  48. edsl/data_transfer_models.py +74 -73
  49. edsl/enums.py +202 -175
  50. edsl/exceptions/BaseException.py +21 -21
  51. edsl/exceptions/__init__.py +54 -54
  52. edsl/exceptions/agents.py +54 -42
  53. edsl/exceptions/cache.py +5 -5
  54. edsl/exceptions/configuration.py +16 -16
  55. edsl/exceptions/coop.py +10 -10
  56. edsl/exceptions/data.py +14 -14
  57. edsl/exceptions/general.py +34 -34
  58. edsl/exceptions/inference_services.py +5 -0
  59. edsl/exceptions/jobs.py +33 -33
  60. edsl/exceptions/language_models.py +63 -63
  61. edsl/exceptions/prompts.py +15 -15
  62. edsl/exceptions/questions.py +109 -91
  63. edsl/exceptions/results.py +29 -29
  64. edsl/exceptions/scenarios.py +29 -22
  65. edsl/exceptions/surveys.py +37 -37
  66. edsl/inference_services/AnthropicService.py +106 -87
  67. edsl/inference_services/AvailableModelCacheHandler.py +184 -0
  68. edsl/inference_services/AvailableModelFetcher.py +215 -0
  69. edsl/inference_services/AwsBedrock.py +118 -120
  70. edsl/inference_services/AzureAI.py +215 -217
  71. edsl/inference_services/DeepInfraService.py +18 -18
  72. edsl/inference_services/GoogleService.py +143 -148
  73. edsl/inference_services/GroqService.py +20 -20
  74. edsl/inference_services/InferenceServiceABC.py +80 -147
  75. edsl/inference_services/InferenceServicesCollection.py +138 -97
  76. edsl/inference_services/MistralAIService.py +120 -123
  77. edsl/inference_services/OllamaService.py +18 -18
  78. edsl/inference_services/OpenAIService.py +236 -224
  79. edsl/inference_services/PerplexityService.py +160 -163
  80. edsl/inference_services/ServiceAvailability.py +135 -0
  81. edsl/inference_services/TestService.py +90 -89
  82. edsl/inference_services/TogetherAIService.py +172 -170
  83. edsl/inference_services/data_structures.py +134 -0
  84. edsl/inference_services/models_available_cache.py +118 -118
  85. edsl/inference_services/rate_limits_cache.py +25 -25
  86. edsl/inference_services/registry.py +41 -41
  87. edsl/inference_services/write_available.py +10 -10
  88. edsl/jobs/AnswerQuestionFunctionConstructor.py +223 -0
  89. edsl/jobs/Answers.py +43 -56
  90. edsl/jobs/FetchInvigilator.py +47 -0
  91. edsl/jobs/InterviewTaskManager.py +98 -0
  92. edsl/jobs/InterviewsConstructor.py +50 -0
  93. edsl/jobs/Jobs.py +823 -898
  94. edsl/jobs/JobsChecks.py +172 -147
  95. edsl/jobs/JobsComponentConstructor.py +189 -0
  96. edsl/jobs/JobsPrompts.py +270 -268
  97. edsl/jobs/JobsRemoteInferenceHandler.py +311 -239
  98. edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
  99. edsl/jobs/RequestTokenEstimator.py +30 -0
  100. edsl/jobs/__init__.py +1 -1
  101. edsl/jobs/async_interview_runner.py +138 -0
  102. edsl/jobs/buckets/BucketCollection.py +104 -63
  103. edsl/jobs/buckets/ModelBuckets.py +65 -65
  104. edsl/jobs/buckets/TokenBucket.py +283 -251
  105. edsl/jobs/buckets/TokenBucketAPI.py +211 -0
  106. edsl/jobs/buckets/TokenBucketClient.py +191 -0
  107. edsl/jobs/check_survey_scenario_compatibility.py +85 -0
  108. edsl/jobs/data_structures.py +120 -0
  109. edsl/jobs/decorators.py +35 -0
  110. edsl/jobs/interviews/Interview.py +396 -661
  111. edsl/jobs/interviews/InterviewExceptionCollection.py +99 -99
  112. edsl/jobs/interviews/InterviewExceptionEntry.py +186 -186
  113. edsl/jobs/interviews/InterviewStatistic.py +63 -63
  114. edsl/jobs/interviews/InterviewStatisticsCollection.py +25 -25
  115. edsl/jobs/interviews/InterviewStatusDictionary.py +78 -78
  116. edsl/jobs/interviews/InterviewStatusLog.py +92 -92
  117. edsl/jobs/interviews/ReportErrors.py +66 -66
  118. edsl/jobs/interviews/interview_status_enum.py +9 -9
  119. edsl/jobs/jobs_status_enums.py +9 -0
  120. edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
  121. edsl/jobs/results_exceptions_handler.py +98 -0
  122. edsl/jobs/runners/JobsRunnerAsyncio.py +151 -466
  123. edsl/jobs/runners/JobsRunnerStatus.py +297 -330
  124. edsl/jobs/tasks/QuestionTaskCreator.py +244 -242
  125. edsl/jobs/tasks/TaskCreators.py +64 -64
  126. edsl/jobs/tasks/TaskHistory.py +470 -450
  127. edsl/jobs/tasks/TaskStatusLog.py +23 -23
  128. edsl/jobs/tasks/task_status_enum.py +161 -163
  129. edsl/jobs/tokens/InterviewTokenUsage.py +27 -27
  130. edsl/jobs/tokens/TokenUsage.py +34 -34
  131. edsl/language_models/ComputeCost.py +63 -0
  132. edsl/language_models/LanguageModel.py +626 -668
  133. edsl/language_models/ModelList.py +164 -155
  134. edsl/language_models/PriceManager.py +127 -0
  135. edsl/language_models/RawResponseHandler.py +106 -0
  136. edsl/language_models/RegisterLanguageModelsMeta.py +184 -184
  137. edsl/language_models/ServiceDataSources.py +0 -0
  138. edsl/language_models/__init__.py +2 -3
  139. edsl/language_models/fake_openai_call.py +15 -15
  140. edsl/language_models/fake_openai_service.py +61 -61
  141. edsl/language_models/key_management/KeyLookup.py +63 -0
  142. edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
  143. edsl/language_models/key_management/KeyLookupCollection.py +38 -0
  144. edsl/language_models/key_management/__init__.py +0 -0
  145. edsl/language_models/key_management/models.py +131 -0
  146. edsl/language_models/model.py +256 -0
  147. edsl/language_models/repair.py +156 -156
  148. edsl/language_models/utilities.py +65 -64
  149. edsl/notebooks/Notebook.py +263 -258
  150. edsl/notebooks/NotebookToLaTeX.py +142 -0
  151. edsl/notebooks/__init__.py +1 -1
  152. edsl/prompts/Prompt.py +352 -362
  153. edsl/prompts/__init__.py +2 -2
  154. edsl/questions/ExceptionExplainer.py +77 -0
  155. edsl/questions/HTMLQuestion.py +103 -0
  156. edsl/questions/QuestionBase.py +518 -664
  157. edsl/questions/QuestionBasePromptsMixin.py +221 -217
  158. edsl/questions/QuestionBudget.py +227 -227
  159. edsl/questions/QuestionCheckBox.py +359 -359
  160. edsl/questions/QuestionExtract.py +180 -182
  161. edsl/questions/QuestionFreeText.py +113 -114
  162. edsl/questions/QuestionFunctional.py +166 -166
  163. edsl/questions/QuestionList.py +223 -231
  164. edsl/questions/QuestionMatrix.py +265 -0
  165. edsl/questions/QuestionMultipleChoice.py +330 -286
  166. edsl/questions/QuestionNumerical.py +151 -153
  167. edsl/questions/QuestionRank.py +314 -324
  168. edsl/questions/Quick.py +41 -41
  169. edsl/questions/SimpleAskMixin.py +74 -73
  170. edsl/questions/__init__.py +27 -26
  171. edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +334 -289
  172. edsl/questions/compose_questions.py +98 -98
  173. edsl/questions/data_structures.py +20 -0
  174. edsl/questions/decorators.py +21 -21
  175. edsl/questions/derived/QuestionLikertFive.py +76 -76
  176. edsl/questions/derived/QuestionLinearScale.py +90 -87
  177. edsl/questions/derived/QuestionTopK.py +93 -93
  178. edsl/questions/derived/QuestionYesNo.py +82 -82
  179. edsl/questions/descriptors.py +427 -413
  180. edsl/questions/loop_processor.py +149 -0
  181. edsl/questions/prompt_templates/question_budget.jinja +13 -13
  182. edsl/questions/prompt_templates/question_checkbox.jinja +32 -32
  183. edsl/questions/prompt_templates/question_extract.jinja +11 -11
  184. edsl/questions/prompt_templates/question_free_text.jinja +3 -3
  185. edsl/questions/prompt_templates/question_linear_scale.jinja +11 -11
  186. edsl/questions/prompt_templates/question_list.jinja +17 -17
  187. edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -33
  188. edsl/questions/prompt_templates/question_numerical.jinja +36 -36
  189. edsl/questions/{QuestionBaseGenMixin.py → question_base_gen_mixin.py} +168 -161
  190. edsl/questions/question_registry.py +177 -177
  191. edsl/questions/{RegisterQuestionsMeta.py → register_questions_meta.py} +71 -71
  192. edsl/questions/{ResponseValidatorABC.py → response_validator_abc.py} +188 -174
  193. edsl/questions/response_validator_factory.py +34 -0
  194. edsl/questions/settings.py +12 -12
  195. edsl/questions/templates/budget/answering_instructions.jinja +7 -7
  196. edsl/questions/templates/budget/question_presentation.jinja +7 -7
  197. edsl/questions/templates/checkbox/answering_instructions.jinja +10 -10
  198. edsl/questions/templates/checkbox/question_presentation.jinja +22 -22
  199. edsl/questions/templates/extract/answering_instructions.jinja +7 -7
  200. edsl/questions/templates/likert_five/answering_instructions.jinja +10 -10
  201. edsl/questions/templates/likert_five/question_presentation.jinja +11 -11
  202. edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -5
  203. edsl/questions/templates/linear_scale/question_presentation.jinja +5 -5
  204. edsl/questions/templates/list/answering_instructions.jinja +3 -3
  205. edsl/questions/templates/list/question_presentation.jinja +5 -5
  206. edsl/questions/templates/matrix/__init__.py +1 -0
  207. edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
  208. edsl/questions/templates/matrix/question_presentation.jinja +20 -0
  209. edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -9
  210. edsl/questions/templates/multiple_choice/question_presentation.jinja +11 -11
  211. edsl/questions/templates/numerical/answering_instructions.jinja +6 -6
  212. edsl/questions/templates/numerical/question_presentation.jinja +6 -6
  213. edsl/questions/templates/rank/answering_instructions.jinja +11 -11
  214. edsl/questions/templates/rank/question_presentation.jinja +15 -15
  215. edsl/questions/templates/top_k/answering_instructions.jinja +8 -8
  216. edsl/questions/templates/top_k/question_presentation.jinja +22 -22
  217. edsl/questions/templates/yes_no/answering_instructions.jinja +6 -6
  218. edsl/questions/templates/yes_no/question_presentation.jinja +11 -11
  219. edsl/results/CSSParameterizer.py +108 -108
  220. edsl/results/Dataset.py +587 -424
  221. edsl/results/DatasetExportMixin.py +594 -731
  222. edsl/results/DatasetTree.py +295 -275
  223. edsl/results/MarkdownToDocx.py +122 -0
  224. edsl/results/MarkdownToPDF.py +111 -0
  225. edsl/results/Result.py +557 -465
  226. edsl/results/Results.py +1183 -1165
  227. edsl/results/ResultsExportMixin.py +45 -43
  228. edsl/results/ResultsGGMixin.py +121 -121
  229. edsl/results/TableDisplay.py +125 -198
  230. edsl/results/TextEditor.py +50 -0
  231. edsl/results/__init__.py +2 -2
  232. edsl/results/file_exports.py +252 -0
  233. edsl/results/{ResultsFetchMixin.py → results_fetch_mixin.py} +33 -33
  234. edsl/results/{Selector.py → results_selector.py} +145 -135
  235. edsl/results/{ResultsToolsMixin.py → results_tools_mixin.py} +98 -98
  236. edsl/results/smart_objects.py +96 -0
  237. edsl/results/table_data_class.py +12 -0
  238. edsl/results/table_display.css +77 -77
  239. edsl/results/table_renderers.py +118 -0
  240. edsl/results/tree_explore.py +115 -115
  241. edsl/scenarios/ConstructDownloadLink.py +109 -0
  242. edsl/scenarios/DocumentChunker.py +102 -0
  243. edsl/scenarios/DocxScenario.py +16 -0
  244. edsl/scenarios/FileStore.py +511 -632
  245. edsl/scenarios/PdfExtractor.py +40 -0
  246. edsl/scenarios/Scenario.py +498 -601
  247. edsl/scenarios/ScenarioHtmlMixin.py +65 -64
  248. edsl/scenarios/ScenarioList.py +1458 -1287
  249. edsl/scenarios/ScenarioListExportMixin.py +45 -52
  250. edsl/scenarios/ScenarioListPdfMixin.py +239 -261
  251. edsl/scenarios/__init__.py +3 -4
  252. edsl/scenarios/directory_scanner.py +96 -0
  253. edsl/scenarios/file_methods.py +85 -0
  254. edsl/scenarios/handlers/__init__.py +13 -0
  255. edsl/scenarios/handlers/csv.py +38 -0
  256. edsl/scenarios/handlers/docx.py +76 -0
  257. edsl/scenarios/handlers/html.py +37 -0
  258. edsl/scenarios/handlers/json.py +111 -0
  259. edsl/scenarios/handlers/latex.py +5 -0
  260. edsl/scenarios/handlers/md.py +51 -0
  261. edsl/scenarios/handlers/pdf.py +68 -0
  262. edsl/scenarios/handlers/png.py +39 -0
  263. edsl/scenarios/handlers/pptx.py +105 -0
  264. edsl/scenarios/handlers/py.py +294 -0
  265. edsl/scenarios/handlers/sql.py +313 -0
  266. edsl/scenarios/handlers/sqlite.py +149 -0
  267. edsl/scenarios/handlers/txt.py +33 -0
  268. edsl/scenarios/{ScenarioJoin.py → scenario_join.py} +131 -127
  269. edsl/scenarios/scenario_selector.py +156 -0
  270. edsl/shared.py +1 -1
  271. edsl/study/ObjectEntry.py +173 -173
  272. edsl/study/ProofOfWork.py +113 -113
  273. edsl/study/SnapShot.py +80 -80
  274. edsl/study/Study.py +521 -528
  275. edsl/study/__init__.py +4 -4
  276. edsl/surveys/ConstructDAG.py +92 -0
  277. edsl/surveys/DAG.py +148 -148
  278. edsl/surveys/EditSurvey.py +221 -0
  279. edsl/surveys/InstructionHandler.py +100 -0
  280. edsl/surveys/Memory.py +31 -31
  281. edsl/surveys/MemoryManagement.py +72 -0
  282. edsl/surveys/MemoryPlan.py +244 -244
  283. edsl/surveys/Rule.py +327 -326
  284. edsl/surveys/RuleCollection.py +385 -387
  285. edsl/surveys/RuleManager.py +172 -0
  286. edsl/surveys/Simulator.py +75 -0
  287. edsl/surveys/Survey.py +1280 -1801
  288. edsl/surveys/SurveyCSS.py +273 -261
  289. edsl/surveys/SurveyExportMixin.py +259 -259
  290. edsl/surveys/{SurveyFlowVisualizationMixin.py → SurveyFlowVisualization.py} +181 -179
  291. edsl/surveys/SurveyQualtricsImport.py +284 -284
  292. edsl/surveys/SurveyToApp.py +141 -0
  293. edsl/surveys/__init__.py +5 -3
  294. edsl/surveys/base.py +53 -53
  295. edsl/surveys/descriptors.py +60 -56
  296. edsl/surveys/instructions/ChangeInstruction.py +48 -49
  297. edsl/surveys/instructions/Instruction.py +56 -65
  298. edsl/surveys/instructions/InstructionCollection.py +82 -77
  299. edsl/templates/error_reporting/base.html +23 -23
  300. edsl/templates/error_reporting/exceptions_by_model.html +34 -34
  301. edsl/templates/error_reporting/exceptions_by_question_name.html +16 -16
  302. edsl/templates/error_reporting/exceptions_by_type.html +16 -16
  303. edsl/templates/error_reporting/interview_details.html +115 -115
  304. edsl/templates/error_reporting/interviews.html +19 -19
  305. edsl/templates/error_reporting/overview.html +4 -4
  306. edsl/templates/error_reporting/performance_plot.html +1 -1
  307. edsl/templates/error_reporting/report.css +73 -73
  308. edsl/templates/error_reporting/report.html +117 -117
  309. edsl/templates/error_reporting/report.js +25 -25
  310. edsl/tools/__init__.py +1 -1
  311. edsl/tools/clusters.py +192 -192
  312. edsl/tools/embeddings.py +27 -27
  313. edsl/tools/embeddings_plotting.py +118 -118
  314. edsl/tools/plotting.py +112 -112
  315. edsl/tools/summarize.py +18 -18
  316. edsl/utilities/PrettyList.py +56 -0
  317. edsl/utilities/SystemInfo.py +28 -28
  318. edsl/utilities/__init__.py +22 -22
  319. edsl/utilities/ast_utilities.py +25 -25
  320. edsl/utilities/data/Registry.py +6 -6
  321. edsl/utilities/data/__init__.py +1 -1
  322. edsl/utilities/data/scooter_results.json +1 -1
  323. edsl/utilities/decorators.py +77 -77
  324. edsl/utilities/gcp_bucket/cloud_storage.py +96 -96
  325. edsl/utilities/interface.py +627 -627
  326. edsl/utilities/is_notebook.py +18 -0
  327. edsl/utilities/is_valid_variable_name.py +11 -0
  328. edsl/utilities/naming_utilities.py +263 -263
  329. edsl/utilities/remove_edsl_version.py +24 -0
  330. edsl/utilities/repair_functions.py +28 -28
  331. edsl/utilities/restricted_python.py +70 -70
  332. edsl/utilities/utilities.py +436 -424
  333. {edsl-0.1.39.dev3.dist-info → edsl-0.1.39.dev5.dist-info}/LICENSE +21 -21
  334. {edsl-0.1.39.dev3.dist-info → edsl-0.1.39.dev5.dist-info}/METADATA +13 -11
  335. edsl-0.1.39.dev5.dist-info/RECORD +358 -0
  336. {edsl-0.1.39.dev3.dist-info → edsl-0.1.39.dev5.dist-info}/WHEEL +1 -1
  337. edsl/language_models/KeyLookup.py +0 -30
  338. edsl/language_models/registry.py +0 -190
  339. edsl/language_models/unused/ReplicateBase.py +0 -83
  340. edsl/results/ResultsDBMixin.py +0 -238
  341. edsl-0.1.39.dev3.dist-info/RECORD +0 -277
@@ -0,0 +1,184 @@
1
+ from typing import List, Optional, get_args, Union
2
+ from pathlib import Path
3
+ import sqlite3
4
+ from datetime import datetime
5
+ import tempfile
6
+ from platformdirs import user_cache_dir
7
+ from dataclasses import dataclass
8
+ import os
9
+
10
+ from edsl.inference_services.data_structures import LanguageModelInfo, AvailableModels
11
+ from edsl.enums import InferenceServiceLiteral
12
+
13
+
14
+ class AvailableModelCacheHandler:
15
+ MAX_ROWS = 1000
16
+ CACHE_VALIDITY_HOURS = 48
17
+
18
+ def __init__(
19
+ self,
20
+ cache_validity_hours: int = 48,
21
+ verbose: bool = False,
22
+ testing_db_name: str = None,
23
+ ):
24
+ self.cache_validity_hours = cache_validity_hours
25
+ self.verbose = verbose
26
+
27
+ if testing_db_name:
28
+ self.cache_dir = Path(tempfile.mkdtemp())
29
+ self.db_path = self.cache_dir / testing_db_name
30
+ else:
31
+ self.cache_dir = Path(user_cache_dir("edsl", "model_availability"))
32
+ self.db_path = self.cache_dir / "available_models.db"
33
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
34
+
35
+ if os.path.exists(self.db_path):
36
+ if self.verbose:
37
+ print(f"Using existing cache DB: {self.db_path}")
38
+ else:
39
+ self._initialize_db()
40
+
41
+ @property
42
+ def path_to_db(self):
43
+ return self.db_path
44
+
45
+ def _initialize_db(self):
46
+ """Initialize the SQLite database with the required schema."""
47
+ with sqlite3.connect(self.db_path) as conn:
48
+ cursor = conn.cursor()
49
+ # Drop the old table if it exists (for migration)
50
+ cursor.execute("DROP TABLE IF EXISTS model_cache")
51
+ cursor.execute(
52
+ """
53
+ CREATE TABLE IF NOT EXISTS model_cache (
54
+ timestamp DATETIME NOT NULL,
55
+ model_name TEXT NOT NULL,
56
+ service_name TEXT NOT NULL,
57
+ UNIQUE(model_name, service_name)
58
+ )
59
+ """
60
+ )
61
+ conn.commit()
62
+
63
+ def _prune_old_entries(self, conn: sqlite3.Connection):
64
+ """Delete oldest entries when MAX_ROWS is exceeded."""
65
+ cursor = conn.cursor()
66
+ cursor.execute("SELECT COUNT(*) FROM model_cache")
67
+ count = cursor.fetchone()[0]
68
+
69
+ if count > self.MAX_ROWS:
70
+ cursor.execute(
71
+ """
72
+ DELETE FROM model_cache
73
+ WHERE rowid IN (
74
+ SELECT rowid
75
+ FROM model_cache
76
+ ORDER BY timestamp ASC
77
+ LIMIT ?
78
+ )
79
+ """,
80
+ (count - self.MAX_ROWS,),
81
+ )
82
+ conn.commit()
83
+
84
+ @classmethod
85
+ def example_models(cls) -> List[LanguageModelInfo]:
86
+ return [
87
+ LanguageModelInfo(
88
+ "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", "deep_infra"
89
+ ),
90
+ LanguageModelInfo("openai/gpt-4", "openai"),
91
+ ]
92
+
93
+ def add_models_to_cache(self, models_data: List[LanguageModelInfo]):
94
+ """Add new models to the cache, updating timestamps for existing entries."""
95
+ current_time = datetime.now()
96
+
97
+ with sqlite3.connect(self.db_path) as conn:
98
+ cursor = conn.cursor()
99
+ for model in models_data:
100
+ cursor.execute(
101
+ """
102
+ INSERT INTO model_cache (timestamp, model_name, service_name)
103
+ VALUES (?, ?, ?)
104
+ ON CONFLICT(model_name, service_name)
105
+ DO UPDATE SET timestamp = excluded.timestamp
106
+ """,
107
+ (current_time, model.model_name, model.service_name),
108
+ )
109
+
110
+ # self._prune_old_entries(conn)
111
+ conn.commit()
112
+
113
+ def reset_cache(self):
114
+ """Clear all entries from the cache."""
115
+ with sqlite3.connect(self.db_path) as conn:
116
+ cursor = conn.cursor()
117
+ cursor.execute("DELETE FROM model_cache")
118
+ conn.commit()
119
+
120
+ @property
121
+ def num_cache_entries(self):
122
+ """Return the number of entries in the cache."""
123
+ with sqlite3.connect(self.db_path) as conn:
124
+ cursor = conn.cursor()
125
+ cursor.execute("SELECT COUNT(*) FROM model_cache")
126
+ count = cursor.fetchone()[0]
127
+ return count
128
+
129
+ def models(
130
+ self,
131
+ service: Optional[InferenceServiceLiteral],
132
+ ) -> Union[None, AvailableModels]:
133
+ """Return the available models within the cache validity period."""
134
+ # if service is not None:
135
+ # assert service in get_args(InferenceServiceLiteral)
136
+
137
+ with sqlite3.connect(self.db_path) as conn:
138
+ cursor = conn.cursor()
139
+ valid_time = datetime.now().timestamp() - (self.cache_validity_hours * 3600)
140
+
141
+ if self.verbose:
142
+ print(f"Fetching all with timestamp greater than {valid_time}")
143
+
144
+ cursor.execute(
145
+ """
146
+ SELECT DISTINCT model_name, service_name
147
+ FROM model_cache
148
+ WHERE timestamp > ?
149
+ ORDER BY timestamp DESC
150
+ """,
151
+ (valid_time,),
152
+ )
153
+
154
+ results = cursor.fetchall()
155
+ if not results:
156
+ if self.verbose:
157
+ print("No results found in cache DB.")
158
+ return None
159
+
160
+ matching_models = [
161
+ LanguageModelInfo(model_name=row[0], service_name=row[1])
162
+ for row in results
163
+ ]
164
+
165
+ if self.verbose:
166
+ print(f"Found {len(matching_models)} models in cache DB.")
167
+ if service:
168
+ matching_models = [
169
+ model for model in matching_models if model.service_name == service
170
+ ]
171
+
172
+ return AvailableModels(matching_models)
173
+
174
+
175
+ if __name__ == "__main__":
176
+ import doctest
177
+
178
+ doctest.testmod()
179
+ # cache_handler = AvailableModelCacheHandler(verbose=True)
180
+ # models_data = cache_handler.example_models()
181
+ # cache_handler.add_models_to_cache(models_data)
182
+ # print(cache_handler.models())
183
+ # cache_handler.clear_cache()
184
+ # print(cache_handler.models())
@@ -0,0 +1,215 @@
1
+ from typing import Any, List, Tuple, Optional, Dict, TYPE_CHECKING, Union, Generator
2
+ from concurrent.futures import ThreadPoolExecutor, as_completed
3
+ from collections import UserList
4
+
5
+ from edsl.inference_services.ServiceAvailability import ServiceAvailability
6
+ from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
7
+ from edsl.inference_services.data_structures import ModelNamesList
8
+ from edsl.enums import InferenceServiceLiteral
9
+
10
+ from edsl.inference_services.data_structures import LanguageModelInfo
11
+ from edsl.inference_services.AvailableModelCacheHandler import (
12
+ AvailableModelCacheHandler,
13
+ )
14
+
15
+
16
+ from edsl.inference_services.data_structures import AvailableModels
17
+
18
+
19
+ class AvailableModelFetcher:
20
+ """Fetches available models from the various services with JSON caching."""
21
+
22
+ service_availability = ServiceAvailability()
23
+ CACHE_VALIDITY_HOURS = 48 # Cache validity period in hours
24
+
25
+ def __init__(
26
+ self,
27
+ services: List["InferenceServiceABC"],
28
+ added_models: Dict[str, List[str]],
29
+ verbose: bool = False,
30
+ use_cache: bool = True,
31
+ ):
32
+ self.services = services
33
+ self.added_models = added_models
34
+ self._service_map = {
35
+ service._inference_service_: service for service in services
36
+ }
37
+ self.verbose = verbose
38
+ if use_cache:
39
+ self.cache_handler = AvailableModelCacheHandler()
40
+ else:
41
+ self.cache_handler = None
42
+
43
+ @property
44
+ def num_cache_entries(self):
45
+ return self.cache_handler.num_cache_entries
46
+
47
+ @property
48
+ def path_to_db(self):
49
+ return self.cache_handler.path_to_db
50
+
51
+ def reset_cache(self):
52
+ if self.cache_handler:
53
+ self.cache_handler.reset_cache()
54
+
55
+ def available(
56
+ self,
57
+ service: Optional[InferenceServiceABC] = None,
58
+ force_refresh: bool = False,
59
+ ) -> List[LanguageModelInfo]:
60
+ """
61
+ Get available models from all services, using cached data when available.
62
+
63
+ :param service: Optional[InferenceServiceABC] - If specified, only fetch models for this service.
64
+
65
+ >>> from edsl.inference_services.OpenAIService import OpenAIService
66
+ >>> af = AvailableModelFetcher([OpenAIService()], {})
67
+ >>> af.available(service="openai")
68
+ [LanguageModelInfo(model_name='...', service_name='openai'), ...]
69
+
70
+ Returns a list of [model, service_name, index] entries.
71
+ """
72
+
73
+ if service: # they passed a specific service
74
+ matching_models, _ = self.get_available_models_by_service(
75
+ service=service, force_refresh=force_refresh
76
+ )
77
+ return matching_models
78
+
79
+ # Nope, we need to fetch them all
80
+ all_models = self._get_all_models()
81
+
82
+ # if self.cache_handler:
83
+ # self.cache_handler.add_models_to_cache(all_models)
84
+
85
+ return all_models
86
+
87
+ def get_available_models_by_service(
88
+ self,
89
+ service: Union["InferenceServiceABC", InferenceServiceLiteral],
90
+ force_refresh: bool = False,
91
+ ) -> Tuple[AvailableModels, InferenceServiceLiteral]:
92
+ """Get models for a single service.
93
+
94
+ :param service: InferenceServiceABC - e.g., OpenAIService or "openai"
95
+ :return: Tuple[List[LanguageModelInfo], InferenceServiceLiteral]
96
+ """
97
+ if isinstance(service, str):
98
+ service = self._fetch_service_by_service_name(service)
99
+
100
+ if not force_refresh:
101
+ models_from_cache = self.cache_handler.models(
102
+ service=service._inference_service_
103
+ )
104
+ if self.verbose:
105
+ print(
106
+ "Searching cache for models with service name:",
107
+ service._inference_service_,
108
+ )
109
+ print("Got models from cache:", models_from_cache)
110
+ else:
111
+ models_from_cache = None
112
+
113
+ if models_from_cache:
114
+ # print(f"Models from cache for {service}: {models_from_cache}")
115
+ # print(hasattr(models_from_cache[0], "service_name"))
116
+ return models_from_cache, service._inference_service_
117
+ else:
118
+ return self.get_available_models_by_service_fresh(service)
119
+
120
+ def get_available_models_by_service_fresh(
121
+ self, service: Union["InferenceServiceABC", InferenceServiceLiteral]
122
+ ) -> Tuple[AvailableModels, InferenceServiceLiteral]:
123
+ """Get models for a single service. This method always fetches fresh data.
124
+
125
+ :param service: InferenceServiceABC - e.g., OpenAIService or "openai"
126
+ :return: Tuple[List[LanguageModelInfo], InferenceServiceLiteral]
127
+ """
128
+ if isinstance(service, str):
129
+ service = self._fetch_service_by_service_name(service)
130
+
131
+ service_models: ModelNamesList = (
132
+ self.service_availability.get_service_available(service, warn=False)
133
+ )
134
+ service_name = service._inference_service_
135
+
136
+ if not service_models:
137
+ import warnings
138
+
139
+ warnings.warn(f"No models found for service {service_name}")
140
+ return [], service_name
141
+
142
+ models_list = AvailableModels(
143
+ [
144
+ LanguageModelInfo(
145
+ model_name=model_name,
146
+ service_name=service_name,
147
+ )
148
+ for model_name in service_models
149
+ ]
150
+ )
151
+ self.cache_handler.add_models_to_cache(models_list) # update the cache
152
+ return models_list, service_name
153
+
154
+ def _fetch_service_by_service_name(
155
+ self, service_name: InferenceServiceLiteral
156
+ ) -> "InferenceServiceABC":
157
+ """The service name is the _inference_service_ attribute of the service."""
158
+ if service_name in self._service_map:
159
+ return self._service_map[service_name]
160
+ raise ValueError(f"Service {service_name} not found")
161
+
162
+ def _get_all_models(self, force_refresh=False) -> List[LanguageModelInfo]:
163
+ all_models = []
164
+ with ThreadPoolExecutor(max_workers=min(len(self.services), 10)) as executor:
165
+ future_to_service = {
166
+ executor.submit(
167
+ self.get_available_models_by_service, service, force_refresh
168
+ ): service
169
+ for service in self.services
170
+ }
171
+
172
+ for future in as_completed(future_to_service):
173
+ try:
174
+ models, service_name = future.result()
175
+ all_models.extend(models)
176
+
177
+ # Add any additional models for this service
178
+ for model in self.added_models.get(service_name, []):
179
+ all_models.append(
180
+ LanguageModelInfo(
181
+ model_name=model, service_name=service_name
182
+ )
183
+ )
184
+
185
+ except Exception as exc:
186
+ print(f"Service query failed for service {service_name}: {exc}")
187
+ continue
188
+
189
+ return AvailableModels(all_models)
190
+
191
+
192
+ def main():
193
+ from edsl.inference_services.OpenAIService import OpenAIService
194
+
195
+ af = AvailableModelFetcher([OpenAIService()], {}, verbose=True)
196
+ # print(af.available(service="openai"))
197
+ all_models = AvailableModelFetcher([OpenAIService()], {})._get_all_models(
198
+ force_refresh=True
199
+ )
200
+ print(all_models)
201
+
202
+
203
+ if __name__ == "__main__":
204
+ import doctest
205
+
206
+ doctest.testmod(optionflags=doctest.ELLIPSIS)
207
+ # main()
208
+
209
+ # from edsl.inference_services.OpenAIService import OpenAIService
210
+
211
+ # af = AvailableModelFetcher([OpenAIService()], {}, verbose=True)
212
+ # # print(af.available(service="openai"))
213
+
214
+ # all_models = AvailableModelFetcher([OpenAIService()], {})._get_all_models()
215
+ # print(all_models)
@@ -1,120 +1,118 @@
1
- import os
2
- from typing import Any, List, Optional
3
- import re
4
- import boto3
5
- from botocore.exceptions import ClientError
6
- from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
7
- from edsl.language_models.LanguageModel import LanguageModel
8
- import json
9
- from edsl.utilities.utilities import fix_partial_correct_response
10
-
11
-
12
- class AwsBedrockService(InferenceServiceABC):
13
- """AWS Bedrock service class."""
14
-
15
- _inference_service_ = "bedrock"
16
- _env_key_name_ = (
17
- "AWS_ACCESS_KEY_ID" # or any other environment key for AWS credentials
18
- )
19
- key_sequence = ["output", "message", "content", 0, "text"]
20
- input_token_name = "inputTokens"
21
- output_token_name = "outputTokens"
22
- usage_sequence = ["usage"]
23
- model_exclude_list = [
24
- "ai21.j2-grande-instruct",
25
- "ai21.j2-jumbo-instruct",
26
- "ai21.j2-mid",
27
- "ai21.j2-mid-v1",
28
- "ai21.j2-ultra",
29
- "ai21.j2-ultra-v1",
30
- ]
31
- _models_list_cache: List[str] = []
32
-
33
- @classmethod
34
- def available(cls):
35
- """Fetch available models from AWS Bedrock."""
36
-
37
- region = os.getenv("AWS_REGION", "us-east-1")
38
-
39
- if not cls._models_list_cache:
40
- client = boto3.client("bedrock", region_name=region)
41
- all_models_ids = [
42
- x["modelId"] for x in client.list_foundation_models()["modelSummaries"]
43
- ]
44
- else:
45
- all_models_ids = cls._models_list_cache
46
-
47
- return [m for m in all_models_ids if m not in cls.model_exclude_list]
48
-
49
- @classmethod
50
- def create_model(
51
- cls, model_name: str = "amazon.titan-tg1-large", model_class_name=None
52
- ) -> LanguageModel:
53
- if model_class_name is None:
54
- model_class_name = cls.to_class_name(model_name)
55
-
56
- class LLM(LanguageModel):
57
- """
58
- Child class of LanguageModel for interacting with AWS Bedrock models.
59
- """
60
-
61
- key_sequence = cls.key_sequence
62
- usage_sequence = cls.usage_sequence
63
- _inference_service_ = cls._inference_service_
64
- _model_ = model_name
65
- _parameters_ = {
66
- "temperature": 0.5,
67
- "max_tokens": 512,
68
- "top_p": 0.9,
69
- }
70
- input_token_name = cls.input_token_name
71
- output_token_name = cls.output_token_name
72
- _rpm = cls.get_rpm(cls)
73
- _tpm = cls.get_tpm(cls)
74
-
75
- async def async_execute_model_call(
76
- self,
77
- user_prompt: str,
78
- system_prompt: str = "",
79
- files_list: Optional[List["FileStore"]] = None,
80
- ) -> dict[str, Any]:
81
- """Calls the AWS Bedrock API and returns the API response."""
82
-
83
- api_token = (
84
- self.api_token
85
- ) # call to check the if env variables are set.
86
-
87
- region = os.getenv("AWS_REGION", "us-east-1")
88
- client = boto3.client("bedrock-runtime", region_name=region)
89
-
90
- conversation = [
91
- {
92
- "role": "user",
93
- "content": [{"text": user_prompt}],
94
- }
95
- ]
96
- system = [
97
- {
98
- "text": system_prompt,
99
- }
100
- ]
101
- try:
102
- response = client.converse(
103
- modelId=self._model_,
104
- messages=conversation,
105
- inferenceConfig={
106
- "maxTokens": self.max_tokens,
107
- "temperature": self.temperature,
108
- "topP": self.top_p,
109
- },
110
- # system=system,
111
- additionalModelRequestFields={},
112
- )
113
- return response
114
- except (ClientError, Exception) as e:
115
- print(e)
116
- return {"error": str(e)}
117
-
118
- LLM.__name__ = model_class_name
119
-
120
- return LLM
1
+ import os
2
+ from typing import Any, List, Optional
3
+ import re
4
+ import boto3
5
+ from botocore.exceptions import ClientError
6
+ from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
7
+ from edsl.language_models.LanguageModel import LanguageModel
8
+ import json
9
+ from edsl.utilities.utilities import fix_partial_correct_response
10
+
11
+
12
+ class AwsBedrockService(InferenceServiceABC):
13
+ """AWS Bedrock service class."""
14
+
15
+ _inference_service_ = "bedrock"
16
+ _env_key_name_ = (
17
+ "AWS_ACCESS_KEY_ID" # or any other environment key for AWS credentials
18
+ )
19
+ key_sequence = ["output", "message", "content", 0, "text"]
20
+ input_token_name = "inputTokens"
21
+ output_token_name = "outputTokens"
22
+ usage_sequence = ["usage"]
23
+ model_exclude_list = [
24
+ "ai21.j2-grande-instruct",
25
+ "ai21.j2-jumbo-instruct",
26
+ "ai21.j2-mid",
27
+ "ai21.j2-mid-v1",
28
+ "ai21.j2-ultra",
29
+ "ai21.j2-ultra-v1",
30
+ ]
31
+ _models_list_cache: List[str] = []
32
+
33
+ @classmethod
34
+ def available(cls):
35
+ """Fetch available models from AWS Bedrock."""
36
+
37
+ region = os.getenv("AWS_REGION", "us-east-1")
38
+
39
+ if not cls._models_list_cache:
40
+ client = boto3.client("bedrock", region_name=region)
41
+ all_models_ids = [
42
+ x["modelId"] for x in client.list_foundation_models()["modelSummaries"]
43
+ ]
44
+ else:
45
+ all_models_ids = cls._models_list_cache
46
+
47
+ return [m for m in all_models_ids if m not in cls.model_exclude_list]
48
+
49
+ @classmethod
50
+ def create_model(
51
+ cls, model_name: str = "amazon.titan-tg1-large", model_class_name=None
52
+ ) -> LanguageModel:
53
+ if model_class_name is None:
54
+ model_class_name = cls.to_class_name(model_name)
55
+
56
+ class LLM(LanguageModel):
57
+ """
58
+ Child class of LanguageModel for interacting with AWS Bedrock models.
59
+ """
60
+
61
+ key_sequence = cls.key_sequence
62
+ usage_sequence = cls.usage_sequence
63
+ _inference_service_ = cls._inference_service_
64
+ _model_ = model_name
65
+ _parameters_ = {
66
+ "temperature": 0.5,
67
+ "max_tokens": 512,
68
+ "top_p": 0.9,
69
+ }
70
+ input_token_name = cls.input_token_name
71
+ output_token_name = cls.output_token_name
72
+
73
+ async def async_execute_model_call(
74
+ self,
75
+ user_prompt: str,
76
+ system_prompt: str = "",
77
+ files_list: Optional[List["FileStore"]] = None,
78
+ ) -> dict[str, Any]:
79
+ """Calls the AWS Bedrock API and returns the API response."""
80
+
81
+ api_token = (
82
+ self.api_token
83
+ ) # call to check the if env variables are set.
84
+
85
+ region = os.getenv("AWS_REGION", "us-east-1")
86
+ client = boto3.client("bedrock-runtime", region_name=region)
87
+
88
+ conversation = [
89
+ {
90
+ "role": "user",
91
+ "content": [{"text": user_prompt}],
92
+ }
93
+ ]
94
+ system = [
95
+ {
96
+ "text": system_prompt,
97
+ }
98
+ ]
99
+ try:
100
+ response = client.converse(
101
+ modelId=self._model_,
102
+ messages=conversation,
103
+ inferenceConfig={
104
+ "maxTokens": self.max_tokens,
105
+ "temperature": self.temperature,
106
+ "topP": self.top_p,
107
+ },
108
+ # system=system,
109
+ additionalModelRequestFields={},
110
+ )
111
+ return response
112
+ except (ClientError, Exception) as e:
113
+ print(e)
114
+ return {"error": str(e)}
115
+
116
+ LLM.__name__ = model_class_name
117
+
118
+ return LLM