edsl 0.1.39.dev3__py3-none-any.whl → 0.1.39.dev5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (341) hide show
  1. edsl/Base.py +413 -332
  2. edsl/BaseDiff.py +260 -260
  3. edsl/TemplateLoader.py +24 -24
  4. edsl/__init__.py +57 -49
  5. edsl/__version__.py +1 -1
  6. edsl/agents/Agent.py +1071 -867
  7. edsl/agents/AgentList.py +551 -413
  8. edsl/agents/Invigilator.py +284 -233
  9. edsl/agents/InvigilatorBase.py +257 -270
  10. edsl/agents/PromptConstructor.py +272 -354
  11. edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
  12. edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
  13. edsl/agents/__init__.py +2 -3
  14. edsl/agents/descriptors.py +99 -99
  15. edsl/agents/prompt_helpers.py +129 -129
  16. edsl/agents/question_option_processor.py +172 -0
  17. edsl/auto/AutoStudy.py +130 -117
  18. edsl/auto/StageBase.py +243 -230
  19. edsl/auto/StageGenerateSurvey.py +178 -178
  20. edsl/auto/StageLabelQuestions.py +125 -125
  21. edsl/auto/StagePersona.py +61 -61
  22. edsl/auto/StagePersonaDimensionValueRanges.py +88 -88
  23. edsl/auto/StagePersonaDimensionValues.py +74 -74
  24. edsl/auto/StagePersonaDimensions.py +69 -69
  25. edsl/auto/StageQuestions.py +74 -73
  26. edsl/auto/SurveyCreatorPipeline.py +21 -21
  27. edsl/auto/utilities.py +218 -224
  28. edsl/base/Base.py +279 -279
  29. edsl/config.py +177 -157
  30. edsl/conversation/Conversation.py +290 -290
  31. edsl/conversation/car_buying.py +59 -58
  32. edsl/conversation/chips.py +95 -95
  33. edsl/conversation/mug_negotiation.py +81 -81
  34. edsl/conversation/next_speaker_utilities.py +93 -93
  35. edsl/coop/CoopFunctionsMixin.py +15 -0
  36. edsl/coop/ExpectedParrotKeyHandler.py +125 -0
  37. edsl/coop/PriceFetcher.py +54 -54
  38. edsl/coop/__init__.py +2 -2
  39. edsl/coop/coop.py +1106 -1028
  40. edsl/coop/utils.py +131 -131
  41. edsl/data/Cache.py +573 -555
  42. edsl/data/CacheEntry.py +230 -233
  43. edsl/data/CacheHandler.py +168 -149
  44. edsl/data/RemoteCacheSync.py +186 -78
  45. edsl/data/SQLiteDict.py +292 -292
  46. edsl/data/__init__.py +5 -4
  47. edsl/data/orm.py +10 -10
  48. edsl/data_transfer_models.py +74 -73
  49. edsl/enums.py +202 -175
  50. edsl/exceptions/BaseException.py +21 -21
  51. edsl/exceptions/__init__.py +54 -54
  52. edsl/exceptions/agents.py +54 -42
  53. edsl/exceptions/cache.py +5 -5
  54. edsl/exceptions/configuration.py +16 -16
  55. edsl/exceptions/coop.py +10 -10
  56. edsl/exceptions/data.py +14 -14
  57. edsl/exceptions/general.py +34 -34
  58. edsl/exceptions/inference_services.py +5 -0
  59. edsl/exceptions/jobs.py +33 -33
  60. edsl/exceptions/language_models.py +63 -63
  61. edsl/exceptions/prompts.py +15 -15
  62. edsl/exceptions/questions.py +109 -91
  63. edsl/exceptions/results.py +29 -29
  64. edsl/exceptions/scenarios.py +29 -22
  65. edsl/exceptions/surveys.py +37 -37
  66. edsl/inference_services/AnthropicService.py +106 -87
  67. edsl/inference_services/AvailableModelCacheHandler.py +184 -0
  68. edsl/inference_services/AvailableModelFetcher.py +215 -0
  69. edsl/inference_services/AwsBedrock.py +118 -120
  70. edsl/inference_services/AzureAI.py +215 -217
  71. edsl/inference_services/DeepInfraService.py +18 -18
  72. edsl/inference_services/GoogleService.py +143 -148
  73. edsl/inference_services/GroqService.py +20 -20
  74. edsl/inference_services/InferenceServiceABC.py +80 -147
  75. edsl/inference_services/InferenceServicesCollection.py +138 -97
  76. edsl/inference_services/MistralAIService.py +120 -123
  77. edsl/inference_services/OllamaService.py +18 -18
  78. edsl/inference_services/OpenAIService.py +236 -224
  79. edsl/inference_services/PerplexityService.py +160 -163
  80. edsl/inference_services/ServiceAvailability.py +135 -0
  81. edsl/inference_services/TestService.py +90 -89
  82. edsl/inference_services/TogetherAIService.py +172 -170
  83. edsl/inference_services/data_structures.py +134 -0
  84. edsl/inference_services/models_available_cache.py +118 -118
  85. edsl/inference_services/rate_limits_cache.py +25 -25
  86. edsl/inference_services/registry.py +41 -41
  87. edsl/inference_services/write_available.py +10 -10
  88. edsl/jobs/AnswerQuestionFunctionConstructor.py +223 -0
  89. edsl/jobs/Answers.py +43 -56
  90. edsl/jobs/FetchInvigilator.py +47 -0
  91. edsl/jobs/InterviewTaskManager.py +98 -0
  92. edsl/jobs/InterviewsConstructor.py +50 -0
  93. edsl/jobs/Jobs.py +823 -898
  94. edsl/jobs/JobsChecks.py +172 -147
  95. edsl/jobs/JobsComponentConstructor.py +189 -0
  96. edsl/jobs/JobsPrompts.py +270 -268
  97. edsl/jobs/JobsRemoteInferenceHandler.py +311 -239
  98. edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
  99. edsl/jobs/RequestTokenEstimator.py +30 -0
  100. edsl/jobs/__init__.py +1 -1
  101. edsl/jobs/async_interview_runner.py +138 -0
  102. edsl/jobs/buckets/BucketCollection.py +104 -63
  103. edsl/jobs/buckets/ModelBuckets.py +65 -65
  104. edsl/jobs/buckets/TokenBucket.py +283 -251
  105. edsl/jobs/buckets/TokenBucketAPI.py +211 -0
  106. edsl/jobs/buckets/TokenBucketClient.py +191 -0
  107. edsl/jobs/check_survey_scenario_compatibility.py +85 -0
  108. edsl/jobs/data_structures.py +120 -0
  109. edsl/jobs/decorators.py +35 -0
  110. edsl/jobs/interviews/Interview.py +396 -661
  111. edsl/jobs/interviews/InterviewExceptionCollection.py +99 -99
  112. edsl/jobs/interviews/InterviewExceptionEntry.py +186 -186
  113. edsl/jobs/interviews/InterviewStatistic.py +63 -63
  114. edsl/jobs/interviews/InterviewStatisticsCollection.py +25 -25
  115. edsl/jobs/interviews/InterviewStatusDictionary.py +78 -78
  116. edsl/jobs/interviews/InterviewStatusLog.py +92 -92
  117. edsl/jobs/interviews/ReportErrors.py +66 -66
  118. edsl/jobs/interviews/interview_status_enum.py +9 -9
  119. edsl/jobs/jobs_status_enums.py +9 -0
  120. edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
  121. edsl/jobs/results_exceptions_handler.py +98 -0
  122. edsl/jobs/runners/JobsRunnerAsyncio.py +151 -466
  123. edsl/jobs/runners/JobsRunnerStatus.py +297 -330
  124. edsl/jobs/tasks/QuestionTaskCreator.py +244 -242
  125. edsl/jobs/tasks/TaskCreators.py +64 -64
  126. edsl/jobs/tasks/TaskHistory.py +470 -450
  127. edsl/jobs/tasks/TaskStatusLog.py +23 -23
  128. edsl/jobs/tasks/task_status_enum.py +161 -163
  129. edsl/jobs/tokens/InterviewTokenUsage.py +27 -27
  130. edsl/jobs/tokens/TokenUsage.py +34 -34
  131. edsl/language_models/ComputeCost.py +63 -0
  132. edsl/language_models/LanguageModel.py +626 -668
  133. edsl/language_models/ModelList.py +164 -155
  134. edsl/language_models/PriceManager.py +127 -0
  135. edsl/language_models/RawResponseHandler.py +106 -0
  136. edsl/language_models/RegisterLanguageModelsMeta.py +184 -184
  137. edsl/language_models/ServiceDataSources.py +0 -0
  138. edsl/language_models/__init__.py +2 -3
  139. edsl/language_models/fake_openai_call.py +15 -15
  140. edsl/language_models/fake_openai_service.py +61 -61
  141. edsl/language_models/key_management/KeyLookup.py +63 -0
  142. edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
  143. edsl/language_models/key_management/KeyLookupCollection.py +38 -0
  144. edsl/language_models/key_management/__init__.py +0 -0
  145. edsl/language_models/key_management/models.py +131 -0
  146. edsl/language_models/model.py +256 -0
  147. edsl/language_models/repair.py +156 -156
  148. edsl/language_models/utilities.py +65 -64
  149. edsl/notebooks/Notebook.py +263 -258
  150. edsl/notebooks/NotebookToLaTeX.py +142 -0
  151. edsl/notebooks/__init__.py +1 -1
  152. edsl/prompts/Prompt.py +352 -362
  153. edsl/prompts/__init__.py +2 -2
  154. edsl/questions/ExceptionExplainer.py +77 -0
  155. edsl/questions/HTMLQuestion.py +103 -0
  156. edsl/questions/QuestionBase.py +518 -664
  157. edsl/questions/QuestionBasePromptsMixin.py +221 -217
  158. edsl/questions/QuestionBudget.py +227 -227
  159. edsl/questions/QuestionCheckBox.py +359 -359
  160. edsl/questions/QuestionExtract.py +180 -182
  161. edsl/questions/QuestionFreeText.py +113 -114
  162. edsl/questions/QuestionFunctional.py +166 -166
  163. edsl/questions/QuestionList.py +223 -231
  164. edsl/questions/QuestionMatrix.py +265 -0
  165. edsl/questions/QuestionMultipleChoice.py +330 -286
  166. edsl/questions/QuestionNumerical.py +151 -153
  167. edsl/questions/QuestionRank.py +314 -324
  168. edsl/questions/Quick.py +41 -41
  169. edsl/questions/SimpleAskMixin.py +74 -73
  170. edsl/questions/__init__.py +27 -26
  171. edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +334 -289
  172. edsl/questions/compose_questions.py +98 -98
  173. edsl/questions/data_structures.py +20 -0
  174. edsl/questions/decorators.py +21 -21
  175. edsl/questions/derived/QuestionLikertFive.py +76 -76
  176. edsl/questions/derived/QuestionLinearScale.py +90 -87
  177. edsl/questions/derived/QuestionTopK.py +93 -93
  178. edsl/questions/derived/QuestionYesNo.py +82 -82
  179. edsl/questions/descriptors.py +427 -413
  180. edsl/questions/loop_processor.py +149 -0
  181. edsl/questions/prompt_templates/question_budget.jinja +13 -13
  182. edsl/questions/prompt_templates/question_checkbox.jinja +32 -32
  183. edsl/questions/prompt_templates/question_extract.jinja +11 -11
  184. edsl/questions/prompt_templates/question_free_text.jinja +3 -3
  185. edsl/questions/prompt_templates/question_linear_scale.jinja +11 -11
  186. edsl/questions/prompt_templates/question_list.jinja +17 -17
  187. edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -33
  188. edsl/questions/prompt_templates/question_numerical.jinja +36 -36
  189. edsl/questions/{QuestionBaseGenMixin.py → question_base_gen_mixin.py} +168 -161
  190. edsl/questions/question_registry.py +177 -177
  191. edsl/questions/{RegisterQuestionsMeta.py → register_questions_meta.py} +71 -71
  192. edsl/questions/{ResponseValidatorABC.py → response_validator_abc.py} +188 -174
  193. edsl/questions/response_validator_factory.py +34 -0
  194. edsl/questions/settings.py +12 -12
  195. edsl/questions/templates/budget/answering_instructions.jinja +7 -7
  196. edsl/questions/templates/budget/question_presentation.jinja +7 -7
  197. edsl/questions/templates/checkbox/answering_instructions.jinja +10 -10
  198. edsl/questions/templates/checkbox/question_presentation.jinja +22 -22
  199. edsl/questions/templates/extract/answering_instructions.jinja +7 -7
  200. edsl/questions/templates/likert_five/answering_instructions.jinja +10 -10
  201. edsl/questions/templates/likert_five/question_presentation.jinja +11 -11
  202. edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -5
  203. edsl/questions/templates/linear_scale/question_presentation.jinja +5 -5
  204. edsl/questions/templates/list/answering_instructions.jinja +3 -3
  205. edsl/questions/templates/list/question_presentation.jinja +5 -5
  206. edsl/questions/templates/matrix/__init__.py +1 -0
  207. edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
  208. edsl/questions/templates/matrix/question_presentation.jinja +20 -0
  209. edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -9
  210. edsl/questions/templates/multiple_choice/question_presentation.jinja +11 -11
  211. edsl/questions/templates/numerical/answering_instructions.jinja +6 -6
  212. edsl/questions/templates/numerical/question_presentation.jinja +6 -6
  213. edsl/questions/templates/rank/answering_instructions.jinja +11 -11
  214. edsl/questions/templates/rank/question_presentation.jinja +15 -15
  215. edsl/questions/templates/top_k/answering_instructions.jinja +8 -8
  216. edsl/questions/templates/top_k/question_presentation.jinja +22 -22
  217. edsl/questions/templates/yes_no/answering_instructions.jinja +6 -6
  218. edsl/questions/templates/yes_no/question_presentation.jinja +11 -11
  219. edsl/results/CSSParameterizer.py +108 -108
  220. edsl/results/Dataset.py +587 -424
  221. edsl/results/DatasetExportMixin.py +594 -731
  222. edsl/results/DatasetTree.py +295 -275
  223. edsl/results/MarkdownToDocx.py +122 -0
  224. edsl/results/MarkdownToPDF.py +111 -0
  225. edsl/results/Result.py +557 -465
  226. edsl/results/Results.py +1183 -1165
  227. edsl/results/ResultsExportMixin.py +45 -43
  228. edsl/results/ResultsGGMixin.py +121 -121
  229. edsl/results/TableDisplay.py +125 -198
  230. edsl/results/TextEditor.py +50 -0
  231. edsl/results/__init__.py +2 -2
  232. edsl/results/file_exports.py +252 -0
  233. edsl/results/{ResultsFetchMixin.py → results_fetch_mixin.py} +33 -33
  234. edsl/results/{Selector.py → results_selector.py} +145 -135
  235. edsl/results/{ResultsToolsMixin.py → results_tools_mixin.py} +98 -98
  236. edsl/results/smart_objects.py +96 -0
  237. edsl/results/table_data_class.py +12 -0
  238. edsl/results/table_display.css +77 -77
  239. edsl/results/table_renderers.py +118 -0
  240. edsl/results/tree_explore.py +115 -115
  241. edsl/scenarios/ConstructDownloadLink.py +109 -0
  242. edsl/scenarios/DocumentChunker.py +102 -0
  243. edsl/scenarios/DocxScenario.py +16 -0
  244. edsl/scenarios/FileStore.py +511 -632
  245. edsl/scenarios/PdfExtractor.py +40 -0
  246. edsl/scenarios/Scenario.py +498 -601
  247. edsl/scenarios/ScenarioHtmlMixin.py +65 -64
  248. edsl/scenarios/ScenarioList.py +1458 -1287
  249. edsl/scenarios/ScenarioListExportMixin.py +45 -52
  250. edsl/scenarios/ScenarioListPdfMixin.py +239 -261
  251. edsl/scenarios/__init__.py +3 -4
  252. edsl/scenarios/directory_scanner.py +96 -0
  253. edsl/scenarios/file_methods.py +85 -0
  254. edsl/scenarios/handlers/__init__.py +13 -0
  255. edsl/scenarios/handlers/csv.py +38 -0
  256. edsl/scenarios/handlers/docx.py +76 -0
  257. edsl/scenarios/handlers/html.py +37 -0
  258. edsl/scenarios/handlers/json.py +111 -0
  259. edsl/scenarios/handlers/latex.py +5 -0
  260. edsl/scenarios/handlers/md.py +51 -0
  261. edsl/scenarios/handlers/pdf.py +68 -0
  262. edsl/scenarios/handlers/png.py +39 -0
  263. edsl/scenarios/handlers/pptx.py +105 -0
  264. edsl/scenarios/handlers/py.py +294 -0
  265. edsl/scenarios/handlers/sql.py +313 -0
  266. edsl/scenarios/handlers/sqlite.py +149 -0
  267. edsl/scenarios/handlers/txt.py +33 -0
  268. edsl/scenarios/{ScenarioJoin.py → scenario_join.py} +131 -127
  269. edsl/scenarios/scenario_selector.py +156 -0
  270. edsl/shared.py +1 -1
  271. edsl/study/ObjectEntry.py +173 -173
  272. edsl/study/ProofOfWork.py +113 -113
  273. edsl/study/SnapShot.py +80 -80
  274. edsl/study/Study.py +521 -528
  275. edsl/study/__init__.py +4 -4
  276. edsl/surveys/ConstructDAG.py +92 -0
  277. edsl/surveys/DAG.py +148 -148
  278. edsl/surveys/EditSurvey.py +221 -0
  279. edsl/surveys/InstructionHandler.py +100 -0
  280. edsl/surveys/Memory.py +31 -31
  281. edsl/surveys/MemoryManagement.py +72 -0
  282. edsl/surveys/MemoryPlan.py +244 -244
  283. edsl/surveys/Rule.py +327 -326
  284. edsl/surveys/RuleCollection.py +385 -387
  285. edsl/surveys/RuleManager.py +172 -0
  286. edsl/surveys/Simulator.py +75 -0
  287. edsl/surveys/Survey.py +1280 -1801
  288. edsl/surveys/SurveyCSS.py +273 -261
  289. edsl/surveys/SurveyExportMixin.py +259 -259
  290. edsl/surveys/{SurveyFlowVisualizationMixin.py → SurveyFlowVisualization.py} +181 -179
  291. edsl/surveys/SurveyQualtricsImport.py +284 -284
  292. edsl/surveys/SurveyToApp.py +141 -0
  293. edsl/surveys/__init__.py +5 -3
  294. edsl/surveys/base.py +53 -53
  295. edsl/surveys/descriptors.py +60 -56
  296. edsl/surveys/instructions/ChangeInstruction.py +48 -49
  297. edsl/surveys/instructions/Instruction.py +56 -65
  298. edsl/surveys/instructions/InstructionCollection.py +82 -77
  299. edsl/templates/error_reporting/base.html +23 -23
  300. edsl/templates/error_reporting/exceptions_by_model.html +34 -34
  301. edsl/templates/error_reporting/exceptions_by_question_name.html +16 -16
  302. edsl/templates/error_reporting/exceptions_by_type.html +16 -16
  303. edsl/templates/error_reporting/interview_details.html +115 -115
  304. edsl/templates/error_reporting/interviews.html +19 -19
  305. edsl/templates/error_reporting/overview.html +4 -4
  306. edsl/templates/error_reporting/performance_plot.html +1 -1
  307. edsl/templates/error_reporting/report.css +73 -73
  308. edsl/templates/error_reporting/report.html +117 -117
  309. edsl/templates/error_reporting/report.js +25 -25
  310. edsl/tools/__init__.py +1 -1
  311. edsl/tools/clusters.py +192 -192
  312. edsl/tools/embeddings.py +27 -27
  313. edsl/tools/embeddings_plotting.py +118 -118
  314. edsl/tools/plotting.py +112 -112
  315. edsl/tools/summarize.py +18 -18
  316. edsl/utilities/PrettyList.py +56 -0
  317. edsl/utilities/SystemInfo.py +28 -28
  318. edsl/utilities/__init__.py +22 -22
  319. edsl/utilities/ast_utilities.py +25 -25
  320. edsl/utilities/data/Registry.py +6 -6
  321. edsl/utilities/data/__init__.py +1 -1
  322. edsl/utilities/data/scooter_results.json +1 -1
  323. edsl/utilities/decorators.py +77 -77
  324. edsl/utilities/gcp_bucket/cloud_storage.py +96 -96
  325. edsl/utilities/interface.py +627 -627
  326. edsl/utilities/is_notebook.py +18 -0
  327. edsl/utilities/is_valid_variable_name.py +11 -0
  328. edsl/utilities/naming_utilities.py +263 -263
  329. edsl/utilities/remove_edsl_version.py +24 -0
  330. edsl/utilities/repair_functions.py +28 -28
  331. edsl/utilities/restricted_python.py +70 -70
  332. edsl/utilities/utilities.py +436 -424
  333. {edsl-0.1.39.dev3.dist-info → edsl-0.1.39.dev5.dist-info}/LICENSE +21 -21
  334. {edsl-0.1.39.dev3.dist-info → edsl-0.1.39.dev5.dist-info}/METADATA +13 -11
  335. edsl-0.1.39.dev5.dist-info/RECORD +358 -0
  336. {edsl-0.1.39.dev3.dist-info → edsl-0.1.39.dev5.dist-info}/WHEEL +1 -1
  337. edsl/language_models/KeyLookup.py +0 -30
  338. edsl/language_models/registry.py +0 -190
  339. edsl/language_models/unused/ReplicateBase.py +0 -83
  340. edsl/results/ResultsDBMixin.py +0 -238
  341. edsl-0.1.39.dev3.dist-info/RECORD +0 -277
@@ -1,97 +1,138 @@
1
- from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
2
- import warnings
3
-
4
-
5
- class InferenceServicesCollection:
6
- added_models = {}
7
-
8
- def __init__(self, services: list[InferenceServiceABC] = None):
9
- self.services = services or []
10
-
11
- @classmethod
12
- def add_model(cls, service_name, model_name):
13
- if service_name not in cls.added_models:
14
- cls.added_models[service_name] = []
15
- cls.added_models[service_name].append(model_name)
16
-
17
- @staticmethod
18
- def _get_service_available(service, warn: bool = False) -> list[str]:
19
- try:
20
- service_models = service.available()
21
- except Exception:
22
- if warn:
23
- warnings.warn(
24
- f"""Error getting models for {service._inference_service_}.
25
- Check that you have properly stored your Expected Parrot API key and activated remote inference, or stored your own API keys for the language models that you want to use.
26
- See https://docs.expectedparrot.com/en/latest/api_keys.html for instructions on storing API keys.
27
- Relying on Coop.""",
28
- UserWarning,
29
- )
30
-
31
- # Use the list of models on Coop as a fallback
32
- try:
33
- from edsl import Coop
34
-
35
- c = Coop()
36
- models_from_coop = c.fetch_models()
37
- service_models = models_from_coop.get(service._inference_service_, [])
38
-
39
- # cache results
40
- service._models_list_cache = service_models
41
-
42
- # Finally, use the available models cache from the Python file
43
- except Exception:
44
- if warn:
45
- warnings.warn(
46
- f"""Error getting models for {service._inference_service_}.
47
- Relying on EDSL cache.""",
48
- UserWarning,
49
- )
50
-
51
- from edsl.inference_services.models_available_cache import (
52
- models_available,
53
- )
54
-
55
- service_models = models_available.get(service._inference_service_, [])
56
-
57
- # cache results
58
- service._models_list_cache = service_models
59
-
60
- return service_models
61
-
62
- def available(self):
63
- total_models = []
64
- for service in self.services:
65
- service_models = self._get_service_available(service)
66
- for model in service_models:
67
- total_models.append([model, service._inference_service_, -1])
68
-
69
- for model in self.added_models.get(service._inference_service_, []):
70
- total_models.append([model, service._inference_service_, -1])
71
-
72
- sorted_models = sorted(total_models)
73
- for i, model in enumerate(sorted_models):
74
- model[2] = i
75
- model = tuple(model)
76
- return sorted_models
77
-
78
- def register(self, service):
79
- self.services.append(service)
80
-
81
- def create_model_factory(self, model_name: str, service_name=None, index=None):
82
- from edsl.inference_services.TestService import TestService
83
-
84
- if model_name == "test":
85
- return TestService.create_model(model_name)
86
-
87
- if service_name:
88
- for service in self.services:
89
- if service_name == service._inference_service_:
90
- return service.create_model(model_name)
91
-
92
- for service in self.services:
93
- if model_name in self._get_service_available(service):
94
- if service_name is None or service_name == service._inference_service_:
95
- return service.create_model(model_name)
96
-
97
- raise Exception(f"Model {model_name} not found in any of the services")
1
+ from functools import lru_cache
2
+ from collections import defaultdict
3
+ from typing import Optional, Protocol, Dict, List, Tuple, TYPE_CHECKING, Literal
4
+
5
+ from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
6
+ from edsl.inference_services.AvailableModelFetcher import AvailableModelFetcher
7
+ from edsl.exceptions.inference_services import InferenceServiceError
8
+
9
+ if TYPE_CHECKING:
10
+ from edsl.language_models.LanguageModel import LanguageModel
11
+ from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
12
+
13
+
14
+ class ModelCreator(Protocol):
15
+ def create_model(self, model_name: str) -> "LanguageModel":
16
+ ...
17
+
18
+
19
+ from edsl.enums import InferenceServiceLiteral
20
+
21
+
22
+ class ModelResolver:
23
+ def __init__(
24
+ self,
25
+ services: List[InferenceServiceLiteral],
26
+ models_to_services: Dict[InferenceServiceLiteral, InferenceServiceABC],
27
+ availability_fetcher: "AvailableModelFetcher",
28
+ ):
29
+ """
30
+ Class for determining which service to use for a given model.
31
+ """
32
+ self.services = services
33
+ self._models_to_services = models_to_services
34
+ self.availability_fetcher = availability_fetcher
35
+ self._service_names_to_classes = {
36
+ service._inference_service_: service for service in services
37
+ }
38
+
39
+ def resolve_model(
40
+ self, model_name: str, service_name: Optional[InferenceServiceLiteral] = None
41
+ ) -> InferenceServiceABC:
42
+ """Returns an InferenceServiceABC object for the given model name.
43
+
44
+ :param model_name: The name of the model to resolve. E.g., 'gpt-4o'
45
+ :param service_name: The name of the service to use. E.g., 'openai'
46
+ :return: An InferenceServiceABC object
47
+
48
+ """
49
+ if model_name == "test":
50
+ from edsl.inference_services.TestService import TestService
51
+
52
+ return TestService()
53
+
54
+ if service_name is not None:
55
+ service: InferenceServiceABC = self._service_names_to_classes.get(
56
+ service_name
57
+ )
58
+ if not service:
59
+ raise InferenceServiceError(f"Service {service_name} not found")
60
+ return service
61
+
62
+ if model_name in self._models_to_services: # maybe we've seen it before!
63
+ return self._models_to_services[model_name]
64
+
65
+ for service in self.services:
66
+ (
67
+ available_models,
68
+ service_name,
69
+ ) = self.availability_fetcher.get_available_models_by_service(service)
70
+ if model_name in available_models:
71
+ self._models_to_services[model_name] = service
72
+ return service
73
+
74
+ raise InferenceServiceError(
75
+ f"""Model {model_name} not found in any services.
76
+ If you know the service that has this model, use the service_name parameter directly.
77
+ E.g., Model("gpt-4o", service_name="openai")
78
+ """
79
+ )
80
+
81
+
82
+ class InferenceServicesCollection:
83
+ added_models = defaultdict(list) # Moved back to class level
84
+
85
+ def __init__(self, services: Optional[List[InferenceServiceABC]] = None):
86
+ self.services = services or []
87
+ self._models_to_services: Dict[str, InferenceServiceABC] = {}
88
+
89
+ self.availability_fetcher = AvailableModelFetcher(
90
+ self.services, self.added_models
91
+ )
92
+ self.resolver = ModelResolver(
93
+ self.services, self._models_to_services, self.availability_fetcher
94
+ )
95
+
96
+ @classmethod
97
+ def add_model(cls, service_name: str, model_name: str) -> None:
98
+ if service_name not in cls.added_models:
99
+ cls.added_models[service_name].append(model_name)
100
+
101
+ def service_names_to_classes(self) -> Dict[str, InferenceServiceABC]:
102
+ return {service._inference_service_: service for service in self.services}
103
+
104
+ def available(
105
+ self,
106
+ service: Optional[str] = None,
107
+ ) -> List[Tuple[str, str, int]]:
108
+ return self.availability_fetcher.available(service)
109
+
110
+ def reset_cache(self) -> None:
111
+ self.availability_fetcher.reset_cache()
112
+
113
+ @property
114
+ def num_cache_entries(self) -> int:
115
+ return self.availability_fetcher.num_cache_entries
116
+
117
+ def register(self, service: InferenceServiceABC) -> None:
118
+ self.services.append(service)
119
+
120
+ def create_model_factory(
121
+ self, model_name: str, service_name: Optional[InferenceServiceLiteral] = None
122
+ ) -> "LanguageModel":
123
+
124
+ if service_name is None: # we try to find the right service
125
+ service = self.resolver.resolve_model(model_name, service_name)
126
+ else: # if they passed a service, we'll use that
127
+ service = self.service_names_to_classes().get(service_name)
128
+
129
+ if not service: # but if we can't find it, we'll raise an error
130
+ raise InferenceServiceError(f"Service {service_name} not found")
131
+
132
+ return service.create_model(model_name)
133
+
134
+
135
+ if __name__ == "__main__":
136
+ import doctest
137
+
138
+ doctest.testmod()
@@ -1,123 +1,120 @@
1
- import os
2
- from typing import Any, List, Optional
3
- from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
4
- from edsl.language_models.LanguageModel import LanguageModel
5
- import asyncio
6
- from mistralai import Mistral
7
-
8
- from edsl.exceptions.language_models import LanguageModelBadResponseError
9
-
10
-
11
- class MistralAIService(InferenceServiceABC):
12
- """Mistral AI service class."""
13
-
14
- key_sequence = ["choices", 0, "message", "content"]
15
- usage_sequence = ["usage"]
16
-
17
- _inference_service_ = "mistral"
18
- _env_key_name_ = "MISTRAL_API_KEY" # Environment variable for Mistral API key
19
- input_token_name = "prompt_tokens"
20
- output_token_name = "completion_tokens"
21
-
22
- _sync_client_instance = None
23
- _async_client_instance = None
24
-
25
- _sync_client = Mistral
26
- _async_client = Mistral
27
-
28
- _models_list_cache: List[str] = []
29
- model_exclude_list = []
30
-
31
- def __init_subclass__(cls, **kwargs):
32
- super().__init_subclass__(**kwargs)
33
- # so subclasses have to create their own instances of the clients
34
- cls._sync_client_instance = None
35
- cls._async_client_instance = None
36
-
37
- @classmethod
38
- def sync_client(cls):
39
- if cls._sync_client_instance is None:
40
- cls._sync_client_instance = cls._sync_client(
41
- api_key=os.getenv(cls._env_key_name_)
42
- )
43
- return cls._sync_client_instance
44
-
45
- @classmethod
46
- def async_client(cls):
47
- if cls._async_client_instance is None:
48
- cls._async_client_instance = cls._async_client(
49
- api_key=os.getenv(cls._env_key_name_)
50
- )
51
- return cls._async_client_instance
52
-
53
- @classmethod
54
- def available(cls) -> list[str]:
55
- if not cls._models_list_cache:
56
- cls._models_list_cache = [
57
- m.id for m in cls.sync_client().models.list().data
58
- ]
59
-
60
- return cls._models_list_cache
61
-
62
- @classmethod
63
- def create_model(
64
- cls, model_name: str = "mistral", model_class_name=None
65
- ) -> LanguageModel:
66
- if model_class_name is None:
67
- model_class_name = cls.to_class_name(model_name)
68
-
69
- class LLM(LanguageModel):
70
- """
71
- Child class of LanguageModel for interacting with Mistral models.
72
- """
73
-
74
- key_sequence = cls.key_sequence
75
- usage_sequence = cls.usage_sequence
76
-
77
- input_token_name = cls.input_token_name
78
- output_token_name = cls.output_token_name
79
-
80
- _inference_service_ = cls._inference_service_
81
- _model_ = model_name
82
- _parameters_ = {
83
- "temperature": 0.5,
84
- "max_tokens": 512,
85
- "top_p": 0.9,
86
- }
87
-
88
- _tpm = cls.get_tpm(cls)
89
- _rpm = cls.get_rpm(cls)
90
-
91
- def sync_client(self):
92
- return cls.sync_client()
93
-
94
- def async_client(self):
95
- return cls.async_client()
96
-
97
- async def async_execute_model_call(
98
- self,
99
- user_prompt: str,
100
- system_prompt: str = "",
101
- files_list: Optional[List["FileStore"]] = None,
102
- ) -> dict[str, Any]:
103
- """Calls the Mistral API and returns the API response."""
104
- s = self.async_client()
105
-
106
- try:
107
- res = await s.chat.complete_async(
108
- model=model_name,
109
- messages=[
110
- {
111
- "content": user_prompt,
112
- "role": "user",
113
- },
114
- ],
115
- )
116
- except Exception as e:
117
- raise LanguageModelBadResponseError(f"Error with Mistral API: {e}")
118
-
119
- return res.model_dump()
120
-
121
- LLM.__name__ = model_class_name
122
-
123
- return LLM
1
+ import os
2
+ from typing import Any, List, Optional
3
+ from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
4
+ from edsl.language_models.LanguageModel import LanguageModel
5
+ import asyncio
6
+ from mistralai import Mistral
7
+
8
+ from edsl.exceptions.language_models import LanguageModelBadResponseError
9
+
10
+
11
+ class MistralAIService(InferenceServiceABC):
12
+ """Mistral AI service class."""
13
+
14
+ key_sequence = ["choices", 0, "message", "content"]
15
+ usage_sequence = ["usage"]
16
+
17
+ _inference_service_ = "mistral"
18
+ _env_key_name_ = "MISTRAL_API_KEY" # Environment variable for Mistral API key
19
+ input_token_name = "prompt_tokens"
20
+ output_token_name = "completion_tokens"
21
+
22
+ _sync_client_instance = None
23
+ _async_client_instance = None
24
+
25
+ _sync_client = Mistral
26
+ _async_client = Mistral
27
+
28
+ _models_list_cache: List[str] = []
29
+ model_exclude_list = []
30
+
31
+ def __init_subclass__(cls, **kwargs):
32
+ super().__init_subclass__(**kwargs)
33
+ # so subclasses have to create their own instances of the clients
34
+ cls._sync_client_instance = None
35
+ cls._async_client_instance = None
36
+
37
+ @classmethod
38
+ def sync_client(cls):
39
+ if cls._sync_client_instance is None:
40
+ cls._sync_client_instance = cls._sync_client(
41
+ api_key=os.getenv(cls._env_key_name_)
42
+ )
43
+ return cls._sync_client_instance
44
+
45
+ @classmethod
46
+ def async_client(cls):
47
+ if cls._async_client_instance is None:
48
+ cls._async_client_instance = cls._async_client(
49
+ api_key=os.getenv(cls._env_key_name_)
50
+ )
51
+ return cls._async_client_instance
52
+
53
+ @classmethod
54
+ def available(cls) -> list[str]:
55
+ if not cls._models_list_cache:
56
+ cls._models_list_cache = [
57
+ m.id for m in cls.sync_client().models.list().data
58
+ ]
59
+
60
+ return cls._models_list_cache
61
+
62
+ @classmethod
63
+ def create_model(
64
+ cls, model_name: str = "mistral", model_class_name=None
65
+ ) -> LanguageModel:
66
+ if model_class_name is None:
67
+ model_class_name = cls.to_class_name(model_name)
68
+
69
+ class LLM(LanguageModel):
70
+ """
71
+ Child class of LanguageModel for interacting with Mistral models.
72
+ """
73
+
74
+ key_sequence = cls.key_sequence
75
+ usage_sequence = cls.usage_sequence
76
+
77
+ input_token_name = cls.input_token_name
78
+ output_token_name = cls.output_token_name
79
+
80
+ _inference_service_ = cls._inference_service_
81
+ _model_ = model_name
82
+ _parameters_ = {
83
+ "temperature": 0.5,
84
+ "max_tokens": 512,
85
+ "top_p": 0.9,
86
+ }
87
+
88
+ def sync_client(self):
89
+ return cls.sync_client()
90
+
91
+ def async_client(self):
92
+ return cls.async_client()
93
+
94
+ async def async_execute_model_call(
95
+ self,
96
+ user_prompt: str,
97
+ system_prompt: str = "",
98
+ files_list: Optional[List["FileStore"]] = None,
99
+ ) -> dict[str, Any]:
100
+ """Calls the Mistral API and returns the API response."""
101
+ s = self.async_client()
102
+
103
+ try:
104
+ res = await s.chat.complete_async(
105
+ model=model_name,
106
+ messages=[
107
+ {
108
+ "content": user_prompt,
109
+ "role": "user",
110
+ },
111
+ ],
112
+ )
113
+ except Exception as e:
114
+ raise LanguageModelBadResponseError(f"Error with Mistral API: {e}")
115
+
116
+ return res.model_dump()
117
+
118
+ LLM.__name__ = model_class_name
119
+
120
+ return LLM
@@ -1,18 +1,18 @@
1
- import aiohttp
2
- import json
3
- import requests
4
- from typing import Any, List
5
-
6
- # from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
7
- from edsl.language_models import LanguageModel
8
-
9
- from edsl.inference_services.OpenAIService import OpenAIService
10
-
11
-
12
- class OllamaService(OpenAIService):
13
- """DeepInfra service class."""
14
-
15
- _inference_service_ = "ollama"
16
- _env_key_name_ = "DEEP_INFRA_API_KEY"
17
- _base_url_ = "http://localhost:11434/v1"
18
- _models_list_cache: List[str] = []
1
+ import aiohttp
2
+ import json
3
+ import requests
4
+ from typing import Any, List
5
+
6
+ # from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
7
+ from edsl.language_models import LanguageModel
8
+
9
+ from edsl.inference_services.OpenAIService import OpenAIService
10
+
11
+
12
+ class OllamaService(OpenAIService):
13
+ """DeepInfra service class."""
14
+
15
+ _inference_service_ = "ollama"
16
+ _env_key_name_ = "DEEP_INFRA_API_KEY"
17
+ _base_url_ = "http://localhost:11434/v1"
18
+ _models_list_cache: List[str] = []