edsl 0.1.39.dev3__py3-none-any.whl → 0.1.39.dev4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (344) hide show
  1. edsl/Base.py +413 -332
  2. edsl/BaseDiff.py +260 -260
  3. edsl/TemplateLoader.py +24 -24
  4. edsl/__init__.py +57 -49
  5. edsl/__version__.py +1 -1
  6. edsl/agents/Agent.py +1071 -867
  7. edsl/agents/AgentList.py +551 -413
  8. edsl/agents/Invigilator.py +284 -233
  9. edsl/agents/InvigilatorBase.py +257 -270
  10. edsl/agents/PromptConstructor.py +272 -354
  11. edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
  12. edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
  13. edsl/agents/__init__.py +2 -3
  14. edsl/agents/descriptors.py +99 -99
  15. edsl/agents/prompt_helpers.py +129 -129
  16. edsl/agents/question_option_processor.py +172 -0
  17. edsl/auto/AutoStudy.py +130 -117
  18. edsl/auto/StageBase.py +243 -230
  19. edsl/auto/StageGenerateSurvey.py +178 -178
  20. edsl/auto/StageLabelQuestions.py +125 -125
  21. edsl/auto/StagePersona.py +61 -61
  22. edsl/auto/StagePersonaDimensionValueRanges.py +88 -88
  23. edsl/auto/StagePersonaDimensionValues.py +74 -74
  24. edsl/auto/StagePersonaDimensions.py +69 -69
  25. edsl/auto/StageQuestions.py +74 -73
  26. edsl/auto/SurveyCreatorPipeline.py +21 -21
  27. edsl/auto/utilities.py +218 -224
  28. edsl/base/Base.py +279 -279
  29. edsl/config.py +177 -157
  30. edsl/conversation/Conversation.py +290 -290
  31. edsl/conversation/car_buying.py +59 -58
  32. edsl/conversation/chips.py +95 -95
  33. edsl/conversation/mug_negotiation.py +81 -81
  34. edsl/conversation/next_speaker_utilities.py +93 -93
  35. edsl/coop/CoopFunctionsMixin.py +15 -0
  36. edsl/coop/ExpectedParrotKeyHandler.py +125 -0
  37. edsl/coop/PriceFetcher.py +54 -54
  38. edsl/coop/__init__.py +2 -2
  39. edsl/coop/coop.py +1106 -1028
  40. edsl/coop/utils.py +131 -131
  41. edsl/data/Cache.py +573 -555
  42. edsl/data/CacheEntry.py +230 -233
  43. edsl/data/CacheHandler.py +168 -149
  44. edsl/data/RemoteCacheSync.py +186 -78
  45. edsl/data/SQLiteDict.py +292 -292
  46. edsl/data/__init__.py +5 -4
  47. edsl/data/hack.py +10 -0
  48. edsl/data/orm.py +10 -10
  49. edsl/data_transfer_models.py +74 -73
  50. edsl/enums.py +202 -175
  51. edsl/exceptions/BaseException.py +21 -21
  52. edsl/exceptions/__init__.py +54 -54
  53. edsl/exceptions/agents.py +54 -42
  54. edsl/exceptions/cache.py +5 -5
  55. edsl/exceptions/configuration.py +16 -16
  56. edsl/exceptions/coop.py +10 -10
  57. edsl/exceptions/data.py +14 -14
  58. edsl/exceptions/general.py +34 -34
  59. edsl/exceptions/inference_services.py +5 -0
  60. edsl/exceptions/jobs.py +33 -33
  61. edsl/exceptions/language_models.py +63 -63
  62. edsl/exceptions/prompts.py +15 -15
  63. edsl/exceptions/questions.py +109 -91
  64. edsl/exceptions/results.py +29 -29
  65. edsl/exceptions/scenarios.py +29 -22
  66. edsl/exceptions/surveys.py +37 -37
  67. edsl/inference_services/AnthropicService.py +106 -87
  68. edsl/inference_services/AvailableModelCacheHandler.py +184 -0
  69. edsl/inference_services/AvailableModelFetcher.py +215 -0
  70. edsl/inference_services/AwsBedrock.py +118 -120
  71. edsl/inference_services/AzureAI.py +215 -217
  72. edsl/inference_services/DeepInfraService.py +18 -18
  73. edsl/inference_services/GoogleService.py +143 -148
  74. edsl/inference_services/GroqService.py +20 -20
  75. edsl/inference_services/InferenceServiceABC.py +80 -147
  76. edsl/inference_services/InferenceServicesCollection.py +138 -97
  77. edsl/inference_services/MistralAIService.py +120 -123
  78. edsl/inference_services/OllamaService.py +18 -18
  79. edsl/inference_services/OpenAIService.py +236 -224
  80. edsl/inference_services/PerplexityService.py +160 -163
  81. edsl/inference_services/ServiceAvailability.py +135 -0
  82. edsl/inference_services/TestService.py +90 -89
  83. edsl/inference_services/TogetherAIService.py +172 -170
  84. edsl/inference_services/data_structures.py +134 -0
  85. edsl/inference_services/models_available_cache.py +118 -118
  86. edsl/inference_services/rate_limits_cache.py +25 -25
  87. edsl/inference_services/registry.py +41 -41
  88. edsl/inference_services/write_available.py +10 -10
  89. edsl/jobs/AnswerQuestionFunctionConstructor.py +223 -0
  90. edsl/jobs/Answers.py +43 -56
  91. edsl/jobs/FetchInvigilator.py +47 -0
  92. edsl/jobs/InterviewTaskManager.py +98 -0
  93. edsl/jobs/InterviewsConstructor.py +50 -0
  94. edsl/jobs/Jobs.py +823 -898
  95. edsl/jobs/JobsChecks.py +172 -147
  96. edsl/jobs/JobsComponentConstructor.py +189 -0
  97. edsl/jobs/JobsPrompts.py +270 -268
  98. edsl/jobs/JobsRemoteInferenceHandler.py +311 -239
  99. edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
  100. edsl/jobs/RequestTokenEstimator.py +30 -0
  101. edsl/jobs/__init__.py +1 -1
  102. edsl/jobs/async_interview_runner.py +138 -0
  103. edsl/jobs/buckets/BucketCollection.py +104 -63
  104. edsl/jobs/buckets/ModelBuckets.py +65 -65
  105. edsl/jobs/buckets/TokenBucket.py +283 -251
  106. edsl/jobs/buckets/TokenBucketAPI.py +211 -0
  107. edsl/jobs/buckets/TokenBucketClient.py +191 -0
  108. edsl/jobs/check_survey_scenario_compatibility.py +85 -0
  109. edsl/jobs/data_structures.py +120 -0
  110. edsl/jobs/decorators.py +35 -0
  111. edsl/jobs/interviews/Interview.py +396 -661
  112. edsl/jobs/interviews/InterviewExceptionCollection.py +99 -99
  113. edsl/jobs/interviews/InterviewExceptionEntry.py +186 -186
  114. edsl/jobs/interviews/InterviewStatistic.py +63 -63
  115. edsl/jobs/interviews/InterviewStatisticsCollection.py +25 -25
  116. edsl/jobs/interviews/InterviewStatusDictionary.py +78 -78
  117. edsl/jobs/interviews/InterviewStatusLog.py +92 -92
  118. edsl/jobs/interviews/ReportErrors.py +66 -66
  119. edsl/jobs/interviews/interview_status_enum.py +9 -9
  120. edsl/jobs/jobs_status_enums.py +9 -0
  121. edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
  122. edsl/jobs/results_exceptions_handler.py +98 -0
  123. edsl/jobs/runners/JobsRunnerAsyncio.py +151 -466
  124. edsl/jobs/runners/JobsRunnerStatus.py +297 -330
  125. edsl/jobs/tasks/QuestionTaskCreator.py +244 -242
  126. edsl/jobs/tasks/TaskCreators.py +64 -64
  127. edsl/jobs/tasks/TaskHistory.py +470 -450
  128. edsl/jobs/tasks/TaskStatusLog.py +23 -23
  129. edsl/jobs/tasks/task_status_enum.py +161 -163
  130. edsl/jobs/tokens/InterviewTokenUsage.py +27 -27
  131. edsl/jobs/tokens/TokenUsage.py +34 -34
  132. edsl/language_models/ComputeCost.py +63 -0
  133. edsl/language_models/LanguageModel.py +626 -668
  134. edsl/language_models/ModelList.py +164 -155
  135. edsl/language_models/PriceManager.py +127 -0
  136. edsl/language_models/RawResponseHandler.py +106 -0
  137. edsl/language_models/RegisterLanguageModelsMeta.py +184 -184
  138. edsl/language_models/ServiceDataSources.py +0 -0
  139. edsl/language_models/__init__.py +2 -3
  140. edsl/language_models/fake_openai_call.py +15 -15
  141. edsl/language_models/fake_openai_service.py +61 -61
  142. edsl/language_models/key_management/KeyLookup.py +63 -0
  143. edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
  144. edsl/language_models/key_management/KeyLookupCollection.py +38 -0
  145. edsl/language_models/key_management/__init__.py +0 -0
  146. edsl/language_models/key_management/models.py +131 -0
  147. edsl/language_models/model.py +256 -0
  148. edsl/language_models/repair.py +156 -156
  149. edsl/language_models/utilities.py +65 -64
  150. edsl/notebooks/Notebook.py +263 -258
  151. edsl/notebooks/NotebookToLaTeX.py +142 -0
  152. edsl/notebooks/__init__.py +1 -1
  153. edsl/prompts/Prompt.py +352 -362
  154. edsl/prompts/__init__.py +2 -2
  155. edsl/questions/ExceptionExplainer.py +77 -0
  156. edsl/questions/HTMLQuestion.py +103 -0
  157. edsl/questions/QuestionBase.py +518 -664
  158. edsl/questions/QuestionBasePromptsMixin.py +221 -217
  159. edsl/questions/QuestionBudget.py +227 -227
  160. edsl/questions/QuestionCheckBox.py +359 -359
  161. edsl/questions/QuestionExtract.py +180 -182
  162. edsl/questions/QuestionFreeText.py +113 -114
  163. edsl/questions/QuestionFunctional.py +166 -166
  164. edsl/questions/QuestionList.py +223 -231
  165. edsl/questions/QuestionMatrix.py +265 -0
  166. edsl/questions/QuestionMultipleChoice.py +330 -286
  167. edsl/questions/QuestionNumerical.py +151 -153
  168. edsl/questions/QuestionRank.py +314 -324
  169. edsl/questions/Quick.py +41 -41
  170. edsl/questions/SimpleAskMixin.py +74 -73
  171. edsl/questions/__init__.py +27 -26
  172. edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +334 -289
  173. edsl/questions/compose_questions.py +98 -98
  174. edsl/questions/data_structures.py +20 -0
  175. edsl/questions/decorators.py +21 -21
  176. edsl/questions/derived/QuestionLikertFive.py +76 -76
  177. edsl/questions/derived/QuestionLinearScale.py +90 -87
  178. edsl/questions/derived/QuestionTopK.py +93 -93
  179. edsl/questions/derived/QuestionYesNo.py +82 -82
  180. edsl/questions/descriptors.py +427 -413
  181. edsl/questions/loop_processor.py +149 -0
  182. edsl/questions/prompt_templates/question_budget.jinja +13 -13
  183. edsl/questions/prompt_templates/question_checkbox.jinja +32 -32
  184. edsl/questions/prompt_templates/question_extract.jinja +11 -11
  185. edsl/questions/prompt_templates/question_free_text.jinja +3 -3
  186. edsl/questions/prompt_templates/question_linear_scale.jinja +11 -11
  187. edsl/questions/prompt_templates/question_list.jinja +17 -17
  188. edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -33
  189. edsl/questions/prompt_templates/question_numerical.jinja +36 -36
  190. edsl/questions/{QuestionBaseGenMixin.py → question_base_gen_mixin.py} +168 -161
  191. edsl/questions/question_registry.py +177 -177
  192. edsl/questions/{RegisterQuestionsMeta.py → register_questions_meta.py} +71 -71
  193. edsl/questions/{ResponseValidatorABC.py → response_validator_abc.py} +188 -174
  194. edsl/questions/response_validator_factory.py +34 -0
  195. edsl/questions/settings.py +12 -12
  196. edsl/questions/templates/budget/answering_instructions.jinja +7 -7
  197. edsl/questions/templates/budget/question_presentation.jinja +7 -7
  198. edsl/questions/templates/checkbox/answering_instructions.jinja +10 -10
  199. edsl/questions/templates/checkbox/question_presentation.jinja +22 -22
  200. edsl/questions/templates/extract/answering_instructions.jinja +7 -7
  201. edsl/questions/templates/likert_five/answering_instructions.jinja +10 -10
  202. edsl/questions/templates/likert_five/question_presentation.jinja +11 -11
  203. edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -5
  204. edsl/questions/templates/linear_scale/question_presentation.jinja +5 -5
  205. edsl/questions/templates/list/answering_instructions.jinja +3 -3
  206. edsl/questions/templates/list/question_presentation.jinja +5 -5
  207. edsl/questions/templates/matrix/__init__.py +1 -0
  208. edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
  209. edsl/questions/templates/matrix/question_presentation.jinja +20 -0
  210. edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -9
  211. edsl/questions/templates/multiple_choice/question_presentation.jinja +11 -11
  212. edsl/questions/templates/numerical/answering_instructions.jinja +6 -6
  213. edsl/questions/templates/numerical/question_presentation.jinja +6 -6
  214. edsl/questions/templates/rank/answering_instructions.jinja +11 -11
  215. edsl/questions/templates/rank/question_presentation.jinja +15 -15
  216. edsl/questions/templates/top_k/answering_instructions.jinja +8 -8
  217. edsl/questions/templates/top_k/question_presentation.jinja +22 -22
  218. edsl/questions/templates/yes_no/answering_instructions.jinja +6 -6
  219. edsl/questions/templates/yes_no/question_presentation.jinja +11 -11
  220. edsl/results/CSSParameterizer.py +108 -108
  221. edsl/results/Dataset.py +587 -424
  222. edsl/results/DatasetExportMixin.py +594 -731
  223. edsl/results/DatasetTree.py +295 -275
  224. edsl/results/MarkdownToDocx.py +122 -0
  225. edsl/results/MarkdownToPDF.py +111 -0
  226. edsl/results/Result.py +557 -465
  227. edsl/results/Results.py +1183 -1165
  228. edsl/results/ResultsExportMixin.py +45 -43
  229. edsl/results/ResultsGGMixin.py +121 -121
  230. edsl/results/TableDisplay.py +125 -198
  231. edsl/results/TextEditor.py +50 -0
  232. edsl/results/__init__.py +2 -2
  233. edsl/results/file_exports.py +252 -0
  234. edsl/results/{ResultsFetchMixin.py → results_fetch_mixin.py} +33 -33
  235. edsl/results/{Selector.py → results_selector.py} +145 -135
  236. edsl/results/{ResultsToolsMixin.py → results_tools_mixin.py} +98 -98
  237. edsl/results/smart_objects.py +96 -0
  238. edsl/results/table_data_class.py +12 -0
  239. edsl/results/table_display.css +77 -77
  240. edsl/results/table_renderers.py +118 -0
  241. edsl/results/tree_explore.py +115 -115
  242. edsl/scenarios/ConstructDownloadLink.py +109 -0
  243. edsl/scenarios/DocumentChunker.py +102 -0
  244. edsl/scenarios/DocxScenario.py +16 -0
  245. edsl/scenarios/FileStore.py +511 -632
  246. edsl/scenarios/PdfExtractor.py +40 -0
  247. edsl/scenarios/Scenario.py +498 -601
  248. edsl/scenarios/ScenarioHtmlMixin.py +65 -64
  249. edsl/scenarios/ScenarioList.py +1458 -1287
  250. edsl/scenarios/ScenarioListExportMixin.py +45 -52
  251. edsl/scenarios/ScenarioListPdfMixin.py +239 -261
  252. edsl/scenarios/__init__.py +3 -4
  253. edsl/scenarios/directory_scanner.py +96 -0
  254. edsl/scenarios/file_methods.py +85 -0
  255. edsl/scenarios/handlers/__init__.py +13 -0
  256. edsl/scenarios/handlers/csv.py +38 -0
  257. edsl/scenarios/handlers/docx.py +76 -0
  258. edsl/scenarios/handlers/html.py +37 -0
  259. edsl/scenarios/handlers/json.py +111 -0
  260. edsl/scenarios/handlers/latex.py +5 -0
  261. edsl/scenarios/handlers/md.py +51 -0
  262. edsl/scenarios/handlers/pdf.py +68 -0
  263. edsl/scenarios/handlers/png.py +39 -0
  264. edsl/scenarios/handlers/pptx.py +105 -0
  265. edsl/scenarios/handlers/py.py +294 -0
  266. edsl/scenarios/handlers/sql.py +313 -0
  267. edsl/scenarios/handlers/sqlite.py +149 -0
  268. edsl/scenarios/handlers/txt.py +33 -0
  269. edsl/scenarios/{ScenarioJoin.py → scenario_join.py} +131 -127
  270. edsl/scenarios/scenario_selector.py +156 -0
  271. edsl/shared.py +1 -1
  272. edsl/study/ObjectEntry.py +173 -173
  273. edsl/study/ProofOfWork.py +113 -113
  274. edsl/study/SnapShot.py +80 -80
  275. edsl/study/Study.py +521 -528
  276. edsl/study/__init__.py +4 -4
  277. edsl/surveys/ConstructDAG.py +92 -0
  278. edsl/surveys/DAG.py +148 -148
  279. edsl/surveys/EditSurvey.py +221 -0
  280. edsl/surveys/InstructionHandler.py +100 -0
  281. edsl/surveys/Memory.py +31 -31
  282. edsl/surveys/MemoryManagement.py +72 -0
  283. edsl/surveys/MemoryPlan.py +244 -244
  284. edsl/surveys/Rule.py +327 -326
  285. edsl/surveys/RuleCollection.py +385 -387
  286. edsl/surveys/RuleManager.py +172 -0
  287. edsl/surveys/Simulator.py +75 -0
  288. edsl/surveys/Survey.py +1280 -1801
  289. edsl/surveys/SurveyCSS.py +273 -261
  290. edsl/surveys/SurveyExportMixin.py +259 -259
  291. edsl/surveys/{SurveyFlowVisualizationMixin.py → SurveyFlowVisualization.py} +181 -179
  292. edsl/surveys/SurveyQualtricsImport.py +284 -284
  293. edsl/surveys/SurveyToApp.py +141 -0
  294. edsl/surveys/__init__.py +5 -3
  295. edsl/surveys/base.py +53 -53
  296. edsl/surveys/descriptors.py +60 -56
  297. edsl/surveys/instructions/ChangeInstruction.py +48 -49
  298. edsl/surveys/instructions/Instruction.py +56 -65
  299. edsl/surveys/instructions/InstructionCollection.py +82 -77
  300. edsl/templates/error_reporting/base.html +23 -23
  301. edsl/templates/error_reporting/exceptions_by_model.html +34 -34
  302. edsl/templates/error_reporting/exceptions_by_question_name.html +16 -16
  303. edsl/templates/error_reporting/exceptions_by_type.html +16 -16
  304. edsl/templates/error_reporting/interview_details.html +115 -115
  305. edsl/templates/error_reporting/interviews.html +19 -19
  306. edsl/templates/error_reporting/overview.html +4 -4
  307. edsl/templates/error_reporting/performance_plot.html +1 -1
  308. edsl/templates/error_reporting/report.css +73 -73
  309. edsl/templates/error_reporting/report.html +117 -117
  310. edsl/templates/error_reporting/report.js +25 -25
  311. edsl/test_h +1 -0
  312. edsl/tools/__init__.py +1 -1
  313. edsl/tools/clusters.py +192 -192
  314. edsl/tools/embeddings.py +27 -27
  315. edsl/tools/embeddings_plotting.py +118 -118
  316. edsl/tools/plotting.py +112 -112
  317. edsl/tools/summarize.py +18 -18
  318. edsl/utilities/PrettyList.py +56 -0
  319. edsl/utilities/SystemInfo.py +28 -28
  320. edsl/utilities/__init__.py +22 -22
  321. edsl/utilities/ast_utilities.py +25 -25
  322. edsl/utilities/data/Registry.py +6 -6
  323. edsl/utilities/data/__init__.py +1 -1
  324. edsl/utilities/data/scooter_results.json +1 -1
  325. edsl/utilities/decorators.py +77 -77
  326. edsl/utilities/gcp_bucket/cloud_storage.py +96 -96
  327. edsl/utilities/gcp_bucket/example.py +50 -0
  328. edsl/utilities/interface.py +627 -627
  329. edsl/utilities/is_notebook.py +18 -0
  330. edsl/utilities/is_valid_variable_name.py +11 -0
  331. edsl/utilities/naming_utilities.py +263 -263
  332. edsl/utilities/remove_edsl_version.py +24 -0
  333. edsl/utilities/repair_functions.py +28 -28
  334. edsl/utilities/restricted_python.py +70 -70
  335. edsl/utilities/utilities.py +436 -424
  336. {edsl-0.1.39.dev3.dist-info → edsl-0.1.39.dev4.dist-info}/LICENSE +21 -21
  337. {edsl-0.1.39.dev3.dist-info → edsl-0.1.39.dev4.dist-info}/METADATA +13 -11
  338. edsl-0.1.39.dev4.dist-info/RECORD +361 -0
  339. edsl/language_models/KeyLookup.py +0 -30
  340. edsl/language_models/registry.py +0 -190
  341. edsl/language_models/unused/ReplicateBase.py +0 -83
  342. edsl/results/ResultsDBMixin.py +0 -238
  343. edsl-0.1.39.dev3.dist-info/RECORD +0 -277
  344. {edsl-0.1.39.dev3.dist-info → edsl-0.1.39.dev4.dist-info}/WHEEL +0 -0
@@ -0,0 +1,184 @@
1
+ from typing import List, Optional, get_args, Union
2
+ from pathlib import Path
3
+ import sqlite3
4
+ from datetime import datetime
5
+ import tempfile
6
+ from platformdirs import user_cache_dir
7
+ from dataclasses import dataclass
8
+ import os
9
+
10
+ from edsl.inference_services.data_structures import LanguageModelInfo, AvailableModels
11
+ from edsl.enums import InferenceServiceLiteral
12
+
13
+
14
+ class AvailableModelCacheHandler:
15
+ MAX_ROWS = 1000
16
+ CACHE_VALIDITY_HOURS = 48
17
+
18
+ def __init__(
19
+ self,
20
+ cache_validity_hours: int = 48,
21
+ verbose: bool = False,
22
+ testing_db_name: str = None,
23
+ ):
24
+ self.cache_validity_hours = cache_validity_hours
25
+ self.verbose = verbose
26
+
27
+ if testing_db_name:
28
+ self.cache_dir = Path(tempfile.mkdtemp())
29
+ self.db_path = self.cache_dir / testing_db_name
30
+ else:
31
+ self.cache_dir = Path(user_cache_dir("edsl", "model_availability"))
32
+ self.db_path = self.cache_dir / "available_models.db"
33
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
34
+
35
+ if os.path.exists(self.db_path):
36
+ if self.verbose:
37
+ print(f"Using existing cache DB: {self.db_path}")
38
+ else:
39
+ self._initialize_db()
40
+
41
+ @property
42
+ def path_to_db(self):
43
+ return self.db_path
44
+
45
+ def _initialize_db(self):
46
+ """Initialize the SQLite database with the required schema."""
47
+ with sqlite3.connect(self.db_path) as conn:
48
+ cursor = conn.cursor()
49
+ # Drop the old table if it exists (for migration)
50
+ cursor.execute("DROP TABLE IF EXISTS model_cache")
51
+ cursor.execute(
52
+ """
53
+ CREATE TABLE IF NOT EXISTS model_cache (
54
+ timestamp DATETIME NOT NULL,
55
+ model_name TEXT NOT NULL,
56
+ service_name TEXT NOT NULL,
57
+ UNIQUE(model_name, service_name)
58
+ )
59
+ """
60
+ )
61
+ conn.commit()
62
+
63
+ def _prune_old_entries(self, conn: sqlite3.Connection):
64
+ """Delete oldest entries when MAX_ROWS is exceeded."""
65
+ cursor = conn.cursor()
66
+ cursor.execute("SELECT COUNT(*) FROM model_cache")
67
+ count = cursor.fetchone()[0]
68
+
69
+ if count > self.MAX_ROWS:
70
+ cursor.execute(
71
+ """
72
+ DELETE FROM model_cache
73
+ WHERE rowid IN (
74
+ SELECT rowid
75
+ FROM model_cache
76
+ ORDER BY timestamp ASC
77
+ LIMIT ?
78
+ )
79
+ """,
80
+ (count - self.MAX_ROWS,),
81
+ )
82
+ conn.commit()
83
+
84
+ @classmethod
85
+ def example_models(cls) -> List[LanguageModelInfo]:
86
+ return [
87
+ LanguageModelInfo(
88
+ "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", "deep_infra"
89
+ ),
90
+ LanguageModelInfo("openai/gpt-4", "openai"),
91
+ ]
92
+
93
+ def add_models_to_cache(self, models_data: List[LanguageModelInfo]):
94
+ """Add new models to the cache, updating timestamps for existing entries."""
95
+ current_time = datetime.now()
96
+
97
+ with sqlite3.connect(self.db_path) as conn:
98
+ cursor = conn.cursor()
99
+ for model in models_data:
100
+ cursor.execute(
101
+ """
102
+ INSERT INTO model_cache (timestamp, model_name, service_name)
103
+ VALUES (?, ?, ?)
104
+ ON CONFLICT(model_name, service_name)
105
+ DO UPDATE SET timestamp = excluded.timestamp
106
+ """,
107
+ (current_time, model.model_name, model.service_name),
108
+ )
109
+
110
+ # self._prune_old_entries(conn)
111
+ conn.commit()
112
+
113
+ def reset_cache(self):
114
+ """Clear all entries from the cache."""
115
+ with sqlite3.connect(self.db_path) as conn:
116
+ cursor = conn.cursor()
117
+ cursor.execute("DELETE FROM model_cache")
118
+ conn.commit()
119
+
120
+ @property
121
+ def num_cache_entries(self):
122
+ """Return the number of entries in the cache."""
123
+ with sqlite3.connect(self.db_path) as conn:
124
+ cursor = conn.cursor()
125
+ cursor.execute("SELECT COUNT(*) FROM model_cache")
126
+ count = cursor.fetchone()[0]
127
+ return count
128
+
129
+ def models(
130
+ self,
131
+ service: Optional[InferenceServiceLiteral],
132
+ ) -> Union[None, AvailableModels]:
133
+ """Return the available models within the cache validity period."""
134
+ # if service is not None:
135
+ # assert service in get_args(InferenceServiceLiteral)
136
+
137
+ with sqlite3.connect(self.db_path) as conn:
138
+ cursor = conn.cursor()
139
+ valid_time = datetime.now().timestamp() - (self.cache_validity_hours * 3600)
140
+
141
+ if self.verbose:
142
+ print(f"Fetching all with timestamp greater than {valid_time}")
143
+
144
+ cursor.execute(
145
+ """
146
+ SELECT DISTINCT model_name, service_name
147
+ FROM model_cache
148
+ WHERE timestamp > ?
149
+ ORDER BY timestamp DESC
150
+ """,
151
+ (valid_time,),
152
+ )
153
+
154
+ results = cursor.fetchall()
155
+ if not results:
156
+ if self.verbose:
157
+ print("No results found in cache DB.")
158
+ return None
159
+
160
+ matching_models = [
161
+ LanguageModelInfo(model_name=row[0], service_name=row[1])
162
+ for row in results
163
+ ]
164
+
165
+ if self.verbose:
166
+ print(f"Found {len(matching_models)} models in cache DB.")
167
+ if service:
168
+ matching_models = [
169
+ model for model in matching_models if model.service_name == service
170
+ ]
171
+
172
+ return AvailableModels(matching_models)
173
+
174
+
175
+ if __name__ == "__main__":
176
+ import doctest
177
+
178
+ doctest.testmod()
179
+ # cache_handler = AvailableModelCacheHandler(verbose=True)
180
+ # models_data = cache_handler.example_models()
181
+ # cache_handler.add_models_to_cache(models_data)
182
+ # print(cache_handler.models())
183
+ # cache_handler.clear_cache()
184
+ # print(cache_handler.models())
@@ -0,0 +1,215 @@
1
+ from typing import Any, List, Tuple, Optional, Dict, TYPE_CHECKING, Union, Generator
2
+ from concurrent.futures import ThreadPoolExecutor, as_completed
3
+ from collections import UserList
4
+
5
+ from edsl.inference_services.ServiceAvailability import ServiceAvailability
6
+ from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
7
+ from edsl.inference_services.data_structures import ModelNamesList
8
+ from edsl.enums import InferenceServiceLiteral
9
+
10
+ from edsl.inference_services.data_structures import LanguageModelInfo
11
+ from edsl.inference_services.AvailableModelCacheHandler import (
12
+ AvailableModelCacheHandler,
13
+ )
14
+
15
+
16
+ from edsl.inference_services.data_structures import AvailableModels
17
+
18
+
19
+ class AvailableModelFetcher:
20
+ """Fetches available models from the various services with JSON caching."""
21
+
22
+ service_availability = ServiceAvailability()
23
+ CACHE_VALIDITY_HOURS = 48 # Cache validity period in hours
24
+
25
+ def __init__(
26
+ self,
27
+ services: List["InferenceServiceABC"],
28
+ added_models: Dict[str, List[str]],
29
+ verbose: bool = False,
30
+ use_cache: bool = True,
31
+ ):
32
+ self.services = services
33
+ self.added_models = added_models
34
+ self._service_map = {
35
+ service._inference_service_: service for service in services
36
+ }
37
+ self.verbose = verbose
38
+ if use_cache:
39
+ self.cache_handler = AvailableModelCacheHandler()
40
+ else:
41
+ self.cache_handler = None
42
+
43
+ @property
44
+ def num_cache_entries(self):
45
+ return self.cache_handler.num_cache_entries
46
+
47
+ @property
48
+ def path_to_db(self):
49
+ return self.cache_handler.path_to_db
50
+
51
+ def reset_cache(self):
52
+ if self.cache_handler:
53
+ self.cache_handler.reset_cache()
54
+
55
+ def available(
56
+ self,
57
+ service: Optional[InferenceServiceABC] = None,
58
+ force_refresh: bool = False,
59
+ ) -> List[LanguageModelInfo]:
60
+ """
61
+ Get available models from all services, using cached data when available.
62
+
63
+ :param service: Optional[InferenceServiceABC] - If specified, only fetch models for this service.
64
+
65
+ >>> from edsl.inference_services.OpenAIService import OpenAIService
66
+ >>> af = AvailableModelFetcher([OpenAIService()], {})
67
+ >>> af.available(service="openai")
68
+ [LanguageModelInfo(model_name='...', service_name='openai'), ...]
69
+
70
+ Returns a list of [model, service_name, index] entries.
71
+ """
72
+
73
+ if service: # they passed a specific service
74
+ matching_models, _ = self.get_available_models_by_service(
75
+ service=service, force_refresh=force_refresh
76
+ )
77
+ return matching_models
78
+
79
+ # Nope, we need to fetch them all
80
+ all_models = self._get_all_models()
81
+
82
+ # if self.cache_handler:
83
+ # self.cache_handler.add_models_to_cache(all_models)
84
+
85
+ return all_models
86
+
87
+ def get_available_models_by_service(
88
+ self,
89
+ service: Union["InferenceServiceABC", InferenceServiceLiteral],
90
+ force_refresh: bool = False,
91
+ ) -> Tuple[AvailableModels, InferenceServiceLiteral]:
92
+ """Get models for a single service.
93
+
94
+ :param service: InferenceServiceABC - e.g., OpenAIService or "openai"
95
+ :return: Tuple[List[LanguageModelInfo], InferenceServiceLiteral]
96
+ """
97
+ if isinstance(service, str):
98
+ service = self._fetch_service_by_service_name(service)
99
+
100
+ if not force_refresh:
101
+ models_from_cache = self.cache_handler.models(
102
+ service=service._inference_service_
103
+ )
104
+ if self.verbose:
105
+ print(
106
+ "Searching cache for models with service name:",
107
+ service._inference_service_,
108
+ )
109
+ print("Got models from cache:", models_from_cache)
110
+ else:
111
+ models_from_cache = None
112
+
113
+ if models_from_cache:
114
+ # print(f"Models from cache for {service}: {models_from_cache}")
115
+ # print(hasattr(models_from_cache[0], "service_name"))
116
+ return models_from_cache, service._inference_service_
117
+ else:
118
+ return self.get_available_models_by_service_fresh(service)
119
+
120
+ def get_available_models_by_service_fresh(
121
+ self, service: Union["InferenceServiceABC", InferenceServiceLiteral]
122
+ ) -> Tuple[AvailableModels, InferenceServiceLiteral]:
123
+ """Get models for a single service. This method always fetches fresh data.
124
+
125
+ :param service: InferenceServiceABC - e.g., OpenAIService or "openai"
126
+ :return: Tuple[List[LanguageModelInfo], InferenceServiceLiteral]
127
+ """
128
+ if isinstance(service, str):
129
+ service = self._fetch_service_by_service_name(service)
130
+
131
+ service_models: ModelNamesList = (
132
+ self.service_availability.get_service_available(service, warn=False)
133
+ )
134
+ service_name = service._inference_service_
135
+
136
+ if not service_models:
137
+ import warnings
138
+
139
+ warnings.warn(f"No models found for service {service_name}")
140
+ return [], service_name
141
+
142
+ models_list = AvailableModels(
143
+ [
144
+ LanguageModelInfo(
145
+ model_name=model_name,
146
+ service_name=service_name,
147
+ )
148
+ for model_name in service_models
149
+ ]
150
+ )
151
+ self.cache_handler.add_models_to_cache(models_list) # update the cache
152
+ return models_list, service_name
153
+
154
+ def _fetch_service_by_service_name(
155
+ self, service_name: InferenceServiceLiteral
156
+ ) -> "InferenceServiceABC":
157
+ """The service name is the _inference_service_ attribute of the service."""
158
+ if service_name in self._service_map:
159
+ return self._service_map[service_name]
160
+ raise ValueError(f"Service {service_name} not found")
161
+
162
+ def _get_all_models(self, force_refresh=False) -> List[LanguageModelInfo]:
163
+ all_models = []
164
+ with ThreadPoolExecutor(max_workers=min(len(self.services), 10)) as executor:
165
+ future_to_service = {
166
+ executor.submit(
167
+ self.get_available_models_by_service, service, force_refresh
168
+ ): service
169
+ for service in self.services
170
+ }
171
+
172
+ for future in as_completed(future_to_service):
173
+ try:
174
+ models, service_name = future.result()
175
+ all_models.extend(models)
176
+
177
+ # Add any additional models for this service
178
+ for model in self.added_models.get(service_name, []):
179
+ all_models.append(
180
+ LanguageModelInfo(
181
+ model_name=model, service_name=service_name
182
+ )
183
+ )
184
+
185
+ except Exception as exc:
186
+ print(f"Service query failed for service {service_name}: {exc}")
187
+ continue
188
+
189
+ return AvailableModels(all_models)
190
+
191
+
192
+ def main():
193
+ from edsl.inference_services.OpenAIService import OpenAIService
194
+
195
+ af = AvailableModelFetcher([OpenAIService()], {}, verbose=True)
196
+ # print(af.available(service="openai"))
197
+ all_models = AvailableModelFetcher([OpenAIService()], {})._get_all_models(
198
+ force_refresh=True
199
+ )
200
+ print(all_models)
201
+
202
+
203
+ if __name__ == "__main__":
204
+ import doctest
205
+
206
+ doctest.testmod(optionflags=doctest.ELLIPSIS)
207
+ # main()
208
+
209
+ # from edsl.inference_services.OpenAIService import OpenAIService
210
+
211
+ # af = AvailableModelFetcher([OpenAIService()], {}, verbose=True)
212
+ # # print(af.available(service="openai"))
213
+
214
+ # all_models = AvailableModelFetcher([OpenAIService()], {})._get_all_models()
215
+ # print(all_models)
@@ -1,120 +1,118 @@
1
- import os
2
- from typing import Any, List, Optional
3
- import re
4
- import boto3
5
- from botocore.exceptions import ClientError
6
- from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
7
- from edsl.language_models.LanguageModel import LanguageModel
8
- import json
9
- from edsl.utilities.utilities import fix_partial_correct_response
10
-
11
-
12
- class AwsBedrockService(InferenceServiceABC):
13
- """AWS Bedrock service class."""
14
-
15
- _inference_service_ = "bedrock"
16
- _env_key_name_ = (
17
- "AWS_ACCESS_KEY_ID" # or any other environment key for AWS credentials
18
- )
19
- key_sequence = ["output", "message", "content", 0, "text"]
20
- input_token_name = "inputTokens"
21
- output_token_name = "outputTokens"
22
- usage_sequence = ["usage"]
23
- model_exclude_list = [
24
- "ai21.j2-grande-instruct",
25
- "ai21.j2-jumbo-instruct",
26
- "ai21.j2-mid",
27
- "ai21.j2-mid-v1",
28
- "ai21.j2-ultra",
29
- "ai21.j2-ultra-v1",
30
- ]
31
- _models_list_cache: List[str] = []
32
-
33
- @classmethod
34
- def available(cls):
35
- """Fetch available models from AWS Bedrock."""
36
-
37
- region = os.getenv("AWS_REGION", "us-east-1")
38
-
39
- if not cls._models_list_cache:
40
- client = boto3.client("bedrock", region_name=region)
41
- all_models_ids = [
42
- x["modelId"] for x in client.list_foundation_models()["modelSummaries"]
43
- ]
44
- else:
45
- all_models_ids = cls._models_list_cache
46
-
47
- return [m for m in all_models_ids if m not in cls.model_exclude_list]
48
-
49
- @classmethod
50
- def create_model(
51
- cls, model_name: str = "amazon.titan-tg1-large", model_class_name=None
52
- ) -> LanguageModel:
53
- if model_class_name is None:
54
- model_class_name = cls.to_class_name(model_name)
55
-
56
- class LLM(LanguageModel):
57
- """
58
- Child class of LanguageModel for interacting with AWS Bedrock models.
59
- """
60
-
61
- key_sequence = cls.key_sequence
62
- usage_sequence = cls.usage_sequence
63
- _inference_service_ = cls._inference_service_
64
- _model_ = model_name
65
- _parameters_ = {
66
- "temperature": 0.5,
67
- "max_tokens": 512,
68
- "top_p": 0.9,
69
- }
70
- input_token_name = cls.input_token_name
71
- output_token_name = cls.output_token_name
72
- _rpm = cls.get_rpm(cls)
73
- _tpm = cls.get_tpm(cls)
74
-
75
- async def async_execute_model_call(
76
- self,
77
- user_prompt: str,
78
- system_prompt: str = "",
79
- files_list: Optional[List["FileStore"]] = None,
80
- ) -> dict[str, Any]:
81
- """Calls the AWS Bedrock API and returns the API response."""
82
-
83
- api_token = (
84
- self.api_token
85
- ) # call to check the if env variables are set.
86
-
87
- region = os.getenv("AWS_REGION", "us-east-1")
88
- client = boto3.client("bedrock-runtime", region_name=region)
89
-
90
- conversation = [
91
- {
92
- "role": "user",
93
- "content": [{"text": user_prompt}],
94
- }
95
- ]
96
- system = [
97
- {
98
- "text": system_prompt,
99
- }
100
- ]
101
- try:
102
- response = client.converse(
103
- modelId=self._model_,
104
- messages=conversation,
105
- inferenceConfig={
106
- "maxTokens": self.max_tokens,
107
- "temperature": self.temperature,
108
- "topP": self.top_p,
109
- },
110
- # system=system,
111
- additionalModelRequestFields={},
112
- )
113
- return response
114
- except (ClientError, Exception) as e:
115
- print(e)
116
- return {"error": str(e)}
117
-
118
- LLM.__name__ = model_class_name
119
-
120
- return LLM
1
+ import os
2
+ from typing import Any, List, Optional
3
+ import re
4
+ import boto3
5
+ from botocore.exceptions import ClientError
6
+ from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
7
+ from edsl.language_models.LanguageModel import LanguageModel
8
+ import json
9
+ from edsl.utilities.utilities import fix_partial_correct_response
10
+
11
+
12
+ class AwsBedrockService(InferenceServiceABC):
13
+ """AWS Bedrock service class."""
14
+
15
+ _inference_service_ = "bedrock"
16
+ _env_key_name_ = (
17
+ "AWS_ACCESS_KEY_ID" # or any other environment key for AWS credentials
18
+ )
19
+ key_sequence = ["output", "message", "content", 0, "text"]
20
+ input_token_name = "inputTokens"
21
+ output_token_name = "outputTokens"
22
+ usage_sequence = ["usage"]
23
+ model_exclude_list = [
24
+ "ai21.j2-grande-instruct",
25
+ "ai21.j2-jumbo-instruct",
26
+ "ai21.j2-mid",
27
+ "ai21.j2-mid-v1",
28
+ "ai21.j2-ultra",
29
+ "ai21.j2-ultra-v1",
30
+ ]
31
+ _models_list_cache: List[str] = []
32
+
33
+ @classmethod
34
+ def available(cls):
35
+ """Fetch available models from AWS Bedrock."""
36
+
37
+ region = os.getenv("AWS_REGION", "us-east-1")
38
+
39
+ if not cls._models_list_cache:
40
+ client = boto3.client("bedrock", region_name=region)
41
+ all_models_ids = [
42
+ x["modelId"] for x in client.list_foundation_models()["modelSummaries"]
43
+ ]
44
+ else:
45
+ all_models_ids = cls._models_list_cache
46
+
47
+ return [m for m in all_models_ids if m not in cls.model_exclude_list]
48
+
49
+ @classmethod
50
+ def create_model(
51
+ cls, model_name: str = "amazon.titan-tg1-large", model_class_name=None
52
+ ) -> LanguageModel:
53
+ if model_class_name is None:
54
+ model_class_name = cls.to_class_name(model_name)
55
+
56
+ class LLM(LanguageModel):
57
+ """
58
+ Child class of LanguageModel for interacting with AWS Bedrock models.
59
+ """
60
+
61
+ key_sequence = cls.key_sequence
62
+ usage_sequence = cls.usage_sequence
63
+ _inference_service_ = cls._inference_service_
64
+ _model_ = model_name
65
+ _parameters_ = {
66
+ "temperature": 0.5,
67
+ "max_tokens": 512,
68
+ "top_p": 0.9,
69
+ }
70
+ input_token_name = cls.input_token_name
71
+ output_token_name = cls.output_token_name
72
+
73
+ async def async_execute_model_call(
74
+ self,
75
+ user_prompt: str,
76
+ system_prompt: str = "",
77
+ files_list: Optional[List["FileStore"]] = None,
78
+ ) -> dict[str, Any]:
79
+ """Calls the AWS Bedrock API and returns the API response."""
80
+
81
+ api_token = (
82
+ self.api_token
83
+ ) # call to check the if env variables are set.
84
+
85
+ region = os.getenv("AWS_REGION", "us-east-1")
86
+ client = boto3.client("bedrock-runtime", region_name=region)
87
+
88
+ conversation = [
89
+ {
90
+ "role": "user",
91
+ "content": [{"text": user_prompt}],
92
+ }
93
+ ]
94
+ system = [
95
+ {
96
+ "text": system_prompt,
97
+ }
98
+ ]
99
+ try:
100
+ response = client.converse(
101
+ modelId=self._model_,
102
+ messages=conversation,
103
+ inferenceConfig={
104
+ "maxTokens": self.max_tokens,
105
+ "temperature": self.temperature,
106
+ "topP": self.top_p,
107
+ },
108
+ # system=system,
109
+ additionalModelRequestFields={},
110
+ )
111
+ return response
112
+ except (ClientError, Exception) as e:
113
+ print(e)
114
+ return {"error": str(e)}
115
+
116
+ LLM.__name__ = model_class_name
117
+
118
+ return LLM