edsl 0.1.39.dev3__py3-none-any.whl → 0.1.39.dev4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (344) hide show
  1. edsl/Base.py +413 -332
  2. edsl/BaseDiff.py +260 -260
  3. edsl/TemplateLoader.py +24 -24
  4. edsl/__init__.py +57 -49
  5. edsl/__version__.py +1 -1
  6. edsl/agents/Agent.py +1071 -867
  7. edsl/agents/AgentList.py +551 -413
  8. edsl/agents/Invigilator.py +284 -233
  9. edsl/agents/InvigilatorBase.py +257 -270
  10. edsl/agents/PromptConstructor.py +272 -354
  11. edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
  12. edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
  13. edsl/agents/__init__.py +2 -3
  14. edsl/agents/descriptors.py +99 -99
  15. edsl/agents/prompt_helpers.py +129 -129
  16. edsl/agents/question_option_processor.py +172 -0
  17. edsl/auto/AutoStudy.py +130 -117
  18. edsl/auto/StageBase.py +243 -230
  19. edsl/auto/StageGenerateSurvey.py +178 -178
  20. edsl/auto/StageLabelQuestions.py +125 -125
  21. edsl/auto/StagePersona.py +61 -61
  22. edsl/auto/StagePersonaDimensionValueRanges.py +88 -88
  23. edsl/auto/StagePersonaDimensionValues.py +74 -74
  24. edsl/auto/StagePersonaDimensions.py +69 -69
  25. edsl/auto/StageQuestions.py +74 -73
  26. edsl/auto/SurveyCreatorPipeline.py +21 -21
  27. edsl/auto/utilities.py +218 -224
  28. edsl/base/Base.py +279 -279
  29. edsl/config.py +177 -157
  30. edsl/conversation/Conversation.py +290 -290
  31. edsl/conversation/car_buying.py +59 -58
  32. edsl/conversation/chips.py +95 -95
  33. edsl/conversation/mug_negotiation.py +81 -81
  34. edsl/conversation/next_speaker_utilities.py +93 -93
  35. edsl/coop/CoopFunctionsMixin.py +15 -0
  36. edsl/coop/ExpectedParrotKeyHandler.py +125 -0
  37. edsl/coop/PriceFetcher.py +54 -54
  38. edsl/coop/__init__.py +2 -2
  39. edsl/coop/coop.py +1106 -1028
  40. edsl/coop/utils.py +131 -131
  41. edsl/data/Cache.py +573 -555
  42. edsl/data/CacheEntry.py +230 -233
  43. edsl/data/CacheHandler.py +168 -149
  44. edsl/data/RemoteCacheSync.py +186 -78
  45. edsl/data/SQLiteDict.py +292 -292
  46. edsl/data/__init__.py +5 -4
  47. edsl/data/hack.py +10 -0
  48. edsl/data/orm.py +10 -10
  49. edsl/data_transfer_models.py +74 -73
  50. edsl/enums.py +202 -175
  51. edsl/exceptions/BaseException.py +21 -21
  52. edsl/exceptions/__init__.py +54 -54
  53. edsl/exceptions/agents.py +54 -42
  54. edsl/exceptions/cache.py +5 -5
  55. edsl/exceptions/configuration.py +16 -16
  56. edsl/exceptions/coop.py +10 -10
  57. edsl/exceptions/data.py +14 -14
  58. edsl/exceptions/general.py +34 -34
  59. edsl/exceptions/inference_services.py +5 -0
  60. edsl/exceptions/jobs.py +33 -33
  61. edsl/exceptions/language_models.py +63 -63
  62. edsl/exceptions/prompts.py +15 -15
  63. edsl/exceptions/questions.py +109 -91
  64. edsl/exceptions/results.py +29 -29
  65. edsl/exceptions/scenarios.py +29 -22
  66. edsl/exceptions/surveys.py +37 -37
  67. edsl/inference_services/AnthropicService.py +106 -87
  68. edsl/inference_services/AvailableModelCacheHandler.py +184 -0
  69. edsl/inference_services/AvailableModelFetcher.py +215 -0
  70. edsl/inference_services/AwsBedrock.py +118 -120
  71. edsl/inference_services/AzureAI.py +215 -217
  72. edsl/inference_services/DeepInfraService.py +18 -18
  73. edsl/inference_services/GoogleService.py +143 -148
  74. edsl/inference_services/GroqService.py +20 -20
  75. edsl/inference_services/InferenceServiceABC.py +80 -147
  76. edsl/inference_services/InferenceServicesCollection.py +138 -97
  77. edsl/inference_services/MistralAIService.py +120 -123
  78. edsl/inference_services/OllamaService.py +18 -18
  79. edsl/inference_services/OpenAIService.py +236 -224
  80. edsl/inference_services/PerplexityService.py +160 -163
  81. edsl/inference_services/ServiceAvailability.py +135 -0
  82. edsl/inference_services/TestService.py +90 -89
  83. edsl/inference_services/TogetherAIService.py +172 -170
  84. edsl/inference_services/data_structures.py +134 -0
  85. edsl/inference_services/models_available_cache.py +118 -118
  86. edsl/inference_services/rate_limits_cache.py +25 -25
  87. edsl/inference_services/registry.py +41 -41
  88. edsl/inference_services/write_available.py +10 -10
  89. edsl/jobs/AnswerQuestionFunctionConstructor.py +223 -0
  90. edsl/jobs/Answers.py +43 -56
  91. edsl/jobs/FetchInvigilator.py +47 -0
  92. edsl/jobs/InterviewTaskManager.py +98 -0
  93. edsl/jobs/InterviewsConstructor.py +50 -0
  94. edsl/jobs/Jobs.py +823 -898
  95. edsl/jobs/JobsChecks.py +172 -147
  96. edsl/jobs/JobsComponentConstructor.py +189 -0
  97. edsl/jobs/JobsPrompts.py +270 -268
  98. edsl/jobs/JobsRemoteInferenceHandler.py +311 -239
  99. edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
  100. edsl/jobs/RequestTokenEstimator.py +30 -0
  101. edsl/jobs/__init__.py +1 -1
  102. edsl/jobs/async_interview_runner.py +138 -0
  103. edsl/jobs/buckets/BucketCollection.py +104 -63
  104. edsl/jobs/buckets/ModelBuckets.py +65 -65
  105. edsl/jobs/buckets/TokenBucket.py +283 -251
  106. edsl/jobs/buckets/TokenBucketAPI.py +211 -0
  107. edsl/jobs/buckets/TokenBucketClient.py +191 -0
  108. edsl/jobs/check_survey_scenario_compatibility.py +85 -0
  109. edsl/jobs/data_structures.py +120 -0
  110. edsl/jobs/decorators.py +35 -0
  111. edsl/jobs/interviews/Interview.py +396 -661
  112. edsl/jobs/interviews/InterviewExceptionCollection.py +99 -99
  113. edsl/jobs/interviews/InterviewExceptionEntry.py +186 -186
  114. edsl/jobs/interviews/InterviewStatistic.py +63 -63
  115. edsl/jobs/interviews/InterviewStatisticsCollection.py +25 -25
  116. edsl/jobs/interviews/InterviewStatusDictionary.py +78 -78
  117. edsl/jobs/interviews/InterviewStatusLog.py +92 -92
  118. edsl/jobs/interviews/ReportErrors.py +66 -66
  119. edsl/jobs/interviews/interview_status_enum.py +9 -9
  120. edsl/jobs/jobs_status_enums.py +9 -0
  121. edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
  122. edsl/jobs/results_exceptions_handler.py +98 -0
  123. edsl/jobs/runners/JobsRunnerAsyncio.py +151 -466
  124. edsl/jobs/runners/JobsRunnerStatus.py +297 -330
  125. edsl/jobs/tasks/QuestionTaskCreator.py +244 -242
  126. edsl/jobs/tasks/TaskCreators.py +64 -64
  127. edsl/jobs/tasks/TaskHistory.py +470 -450
  128. edsl/jobs/tasks/TaskStatusLog.py +23 -23
  129. edsl/jobs/tasks/task_status_enum.py +161 -163
  130. edsl/jobs/tokens/InterviewTokenUsage.py +27 -27
  131. edsl/jobs/tokens/TokenUsage.py +34 -34
  132. edsl/language_models/ComputeCost.py +63 -0
  133. edsl/language_models/LanguageModel.py +626 -668
  134. edsl/language_models/ModelList.py +164 -155
  135. edsl/language_models/PriceManager.py +127 -0
  136. edsl/language_models/RawResponseHandler.py +106 -0
  137. edsl/language_models/RegisterLanguageModelsMeta.py +184 -184
  138. edsl/language_models/ServiceDataSources.py +0 -0
  139. edsl/language_models/__init__.py +2 -3
  140. edsl/language_models/fake_openai_call.py +15 -15
  141. edsl/language_models/fake_openai_service.py +61 -61
  142. edsl/language_models/key_management/KeyLookup.py +63 -0
  143. edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
  144. edsl/language_models/key_management/KeyLookupCollection.py +38 -0
  145. edsl/language_models/key_management/__init__.py +0 -0
  146. edsl/language_models/key_management/models.py +131 -0
  147. edsl/language_models/model.py +256 -0
  148. edsl/language_models/repair.py +156 -156
  149. edsl/language_models/utilities.py +65 -64
  150. edsl/notebooks/Notebook.py +263 -258
  151. edsl/notebooks/NotebookToLaTeX.py +142 -0
  152. edsl/notebooks/__init__.py +1 -1
  153. edsl/prompts/Prompt.py +352 -362
  154. edsl/prompts/__init__.py +2 -2
  155. edsl/questions/ExceptionExplainer.py +77 -0
  156. edsl/questions/HTMLQuestion.py +103 -0
  157. edsl/questions/QuestionBase.py +518 -664
  158. edsl/questions/QuestionBasePromptsMixin.py +221 -217
  159. edsl/questions/QuestionBudget.py +227 -227
  160. edsl/questions/QuestionCheckBox.py +359 -359
  161. edsl/questions/QuestionExtract.py +180 -182
  162. edsl/questions/QuestionFreeText.py +113 -114
  163. edsl/questions/QuestionFunctional.py +166 -166
  164. edsl/questions/QuestionList.py +223 -231
  165. edsl/questions/QuestionMatrix.py +265 -0
  166. edsl/questions/QuestionMultipleChoice.py +330 -286
  167. edsl/questions/QuestionNumerical.py +151 -153
  168. edsl/questions/QuestionRank.py +314 -324
  169. edsl/questions/Quick.py +41 -41
  170. edsl/questions/SimpleAskMixin.py +74 -73
  171. edsl/questions/__init__.py +27 -26
  172. edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +334 -289
  173. edsl/questions/compose_questions.py +98 -98
  174. edsl/questions/data_structures.py +20 -0
  175. edsl/questions/decorators.py +21 -21
  176. edsl/questions/derived/QuestionLikertFive.py +76 -76
  177. edsl/questions/derived/QuestionLinearScale.py +90 -87
  178. edsl/questions/derived/QuestionTopK.py +93 -93
  179. edsl/questions/derived/QuestionYesNo.py +82 -82
  180. edsl/questions/descriptors.py +427 -413
  181. edsl/questions/loop_processor.py +149 -0
  182. edsl/questions/prompt_templates/question_budget.jinja +13 -13
  183. edsl/questions/prompt_templates/question_checkbox.jinja +32 -32
  184. edsl/questions/prompt_templates/question_extract.jinja +11 -11
  185. edsl/questions/prompt_templates/question_free_text.jinja +3 -3
  186. edsl/questions/prompt_templates/question_linear_scale.jinja +11 -11
  187. edsl/questions/prompt_templates/question_list.jinja +17 -17
  188. edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -33
  189. edsl/questions/prompt_templates/question_numerical.jinja +36 -36
  190. edsl/questions/{QuestionBaseGenMixin.py → question_base_gen_mixin.py} +168 -161
  191. edsl/questions/question_registry.py +177 -177
  192. edsl/questions/{RegisterQuestionsMeta.py → register_questions_meta.py} +71 -71
  193. edsl/questions/{ResponseValidatorABC.py → response_validator_abc.py} +188 -174
  194. edsl/questions/response_validator_factory.py +34 -0
  195. edsl/questions/settings.py +12 -12
  196. edsl/questions/templates/budget/answering_instructions.jinja +7 -7
  197. edsl/questions/templates/budget/question_presentation.jinja +7 -7
  198. edsl/questions/templates/checkbox/answering_instructions.jinja +10 -10
  199. edsl/questions/templates/checkbox/question_presentation.jinja +22 -22
  200. edsl/questions/templates/extract/answering_instructions.jinja +7 -7
  201. edsl/questions/templates/likert_five/answering_instructions.jinja +10 -10
  202. edsl/questions/templates/likert_five/question_presentation.jinja +11 -11
  203. edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -5
  204. edsl/questions/templates/linear_scale/question_presentation.jinja +5 -5
  205. edsl/questions/templates/list/answering_instructions.jinja +3 -3
  206. edsl/questions/templates/list/question_presentation.jinja +5 -5
  207. edsl/questions/templates/matrix/__init__.py +1 -0
  208. edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
  209. edsl/questions/templates/matrix/question_presentation.jinja +20 -0
  210. edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -9
  211. edsl/questions/templates/multiple_choice/question_presentation.jinja +11 -11
  212. edsl/questions/templates/numerical/answering_instructions.jinja +6 -6
  213. edsl/questions/templates/numerical/question_presentation.jinja +6 -6
  214. edsl/questions/templates/rank/answering_instructions.jinja +11 -11
  215. edsl/questions/templates/rank/question_presentation.jinja +15 -15
  216. edsl/questions/templates/top_k/answering_instructions.jinja +8 -8
  217. edsl/questions/templates/top_k/question_presentation.jinja +22 -22
  218. edsl/questions/templates/yes_no/answering_instructions.jinja +6 -6
  219. edsl/questions/templates/yes_no/question_presentation.jinja +11 -11
  220. edsl/results/CSSParameterizer.py +108 -108
  221. edsl/results/Dataset.py +587 -424
  222. edsl/results/DatasetExportMixin.py +594 -731
  223. edsl/results/DatasetTree.py +295 -275
  224. edsl/results/MarkdownToDocx.py +122 -0
  225. edsl/results/MarkdownToPDF.py +111 -0
  226. edsl/results/Result.py +557 -465
  227. edsl/results/Results.py +1183 -1165
  228. edsl/results/ResultsExportMixin.py +45 -43
  229. edsl/results/ResultsGGMixin.py +121 -121
  230. edsl/results/TableDisplay.py +125 -198
  231. edsl/results/TextEditor.py +50 -0
  232. edsl/results/__init__.py +2 -2
  233. edsl/results/file_exports.py +252 -0
  234. edsl/results/{ResultsFetchMixin.py → results_fetch_mixin.py} +33 -33
  235. edsl/results/{Selector.py → results_selector.py} +145 -135
  236. edsl/results/{ResultsToolsMixin.py → results_tools_mixin.py} +98 -98
  237. edsl/results/smart_objects.py +96 -0
  238. edsl/results/table_data_class.py +12 -0
  239. edsl/results/table_display.css +77 -77
  240. edsl/results/table_renderers.py +118 -0
  241. edsl/results/tree_explore.py +115 -115
  242. edsl/scenarios/ConstructDownloadLink.py +109 -0
  243. edsl/scenarios/DocumentChunker.py +102 -0
  244. edsl/scenarios/DocxScenario.py +16 -0
  245. edsl/scenarios/FileStore.py +511 -632
  246. edsl/scenarios/PdfExtractor.py +40 -0
  247. edsl/scenarios/Scenario.py +498 -601
  248. edsl/scenarios/ScenarioHtmlMixin.py +65 -64
  249. edsl/scenarios/ScenarioList.py +1458 -1287
  250. edsl/scenarios/ScenarioListExportMixin.py +45 -52
  251. edsl/scenarios/ScenarioListPdfMixin.py +239 -261
  252. edsl/scenarios/__init__.py +3 -4
  253. edsl/scenarios/directory_scanner.py +96 -0
  254. edsl/scenarios/file_methods.py +85 -0
  255. edsl/scenarios/handlers/__init__.py +13 -0
  256. edsl/scenarios/handlers/csv.py +38 -0
  257. edsl/scenarios/handlers/docx.py +76 -0
  258. edsl/scenarios/handlers/html.py +37 -0
  259. edsl/scenarios/handlers/json.py +111 -0
  260. edsl/scenarios/handlers/latex.py +5 -0
  261. edsl/scenarios/handlers/md.py +51 -0
  262. edsl/scenarios/handlers/pdf.py +68 -0
  263. edsl/scenarios/handlers/png.py +39 -0
  264. edsl/scenarios/handlers/pptx.py +105 -0
  265. edsl/scenarios/handlers/py.py +294 -0
  266. edsl/scenarios/handlers/sql.py +313 -0
  267. edsl/scenarios/handlers/sqlite.py +149 -0
  268. edsl/scenarios/handlers/txt.py +33 -0
  269. edsl/scenarios/{ScenarioJoin.py → scenario_join.py} +131 -127
  270. edsl/scenarios/scenario_selector.py +156 -0
  271. edsl/shared.py +1 -1
  272. edsl/study/ObjectEntry.py +173 -173
  273. edsl/study/ProofOfWork.py +113 -113
  274. edsl/study/SnapShot.py +80 -80
  275. edsl/study/Study.py +521 -528
  276. edsl/study/__init__.py +4 -4
  277. edsl/surveys/ConstructDAG.py +92 -0
  278. edsl/surveys/DAG.py +148 -148
  279. edsl/surveys/EditSurvey.py +221 -0
  280. edsl/surveys/InstructionHandler.py +100 -0
  281. edsl/surveys/Memory.py +31 -31
  282. edsl/surveys/MemoryManagement.py +72 -0
  283. edsl/surveys/MemoryPlan.py +244 -244
  284. edsl/surveys/Rule.py +327 -326
  285. edsl/surveys/RuleCollection.py +385 -387
  286. edsl/surveys/RuleManager.py +172 -0
  287. edsl/surveys/Simulator.py +75 -0
  288. edsl/surveys/Survey.py +1280 -1801
  289. edsl/surveys/SurveyCSS.py +273 -261
  290. edsl/surveys/SurveyExportMixin.py +259 -259
  291. edsl/surveys/{SurveyFlowVisualizationMixin.py → SurveyFlowVisualization.py} +181 -179
  292. edsl/surveys/SurveyQualtricsImport.py +284 -284
  293. edsl/surveys/SurveyToApp.py +141 -0
  294. edsl/surveys/__init__.py +5 -3
  295. edsl/surveys/base.py +53 -53
  296. edsl/surveys/descriptors.py +60 -56
  297. edsl/surveys/instructions/ChangeInstruction.py +48 -49
  298. edsl/surveys/instructions/Instruction.py +56 -65
  299. edsl/surveys/instructions/InstructionCollection.py +82 -77
  300. edsl/templates/error_reporting/base.html +23 -23
  301. edsl/templates/error_reporting/exceptions_by_model.html +34 -34
  302. edsl/templates/error_reporting/exceptions_by_question_name.html +16 -16
  303. edsl/templates/error_reporting/exceptions_by_type.html +16 -16
  304. edsl/templates/error_reporting/interview_details.html +115 -115
  305. edsl/templates/error_reporting/interviews.html +19 -19
  306. edsl/templates/error_reporting/overview.html +4 -4
  307. edsl/templates/error_reporting/performance_plot.html +1 -1
  308. edsl/templates/error_reporting/report.css +73 -73
  309. edsl/templates/error_reporting/report.html +117 -117
  310. edsl/templates/error_reporting/report.js +25 -25
  311. edsl/test_h +1 -0
  312. edsl/tools/__init__.py +1 -1
  313. edsl/tools/clusters.py +192 -192
  314. edsl/tools/embeddings.py +27 -27
  315. edsl/tools/embeddings_plotting.py +118 -118
  316. edsl/tools/plotting.py +112 -112
  317. edsl/tools/summarize.py +18 -18
  318. edsl/utilities/PrettyList.py +56 -0
  319. edsl/utilities/SystemInfo.py +28 -28
  320. edsl/utilities/__init__.py +22 -22
  321. edsl/utilities/ast_utilities.py +25 -25
  322. edsl/utilities/data/Registry.py +6 -6
  323. edsl/utilities/data/__init__.py +1 -1
  324. edsl/utilities/data/scooter_results.json +1 -1
  325. edsl/utilities/decorators.py +77 -77
  326. edsl/utilities/gcp_bucket/cloud_storage.py +96 -96
  327. edsl/utilities/gcp_bucket/example.py +50 -0
  328. edsl/utilities/interface.py +627 -627
  329. edsl/utilities/is_notebook.py +18 -0
  330. edsl/utilities/is_valid_variable_name.py +11 -0
  331. edsl/utilities/naming_utilities.py +263 -263
  332. edsl/utilities/remove_edsl_version.py +24 -0
  333. edsl/utilities/repair_functions.py +28 -28
  334. edsl/utilities/restricted_python.py +70 -70
  335. edsl/utilities/utilities.py +436 -424
  336. {edsl-0.1.39.dev3.dist-info → edsl-0.1.39.dev4.dist-info}/LICENSE +21 -21
  337. {edsl-0.1.39.dev3.dist-info → edsl-0.1.39.dev4.dist-info}/METADATA +13 -11
  338. edsl-0.1.39.dev4.dist-info/RECORD +361 -0
  339. edsl/language_models/KeyLookup.py +0 -30
  340. edsl/language_models/registry.py +0 -190
  341. edsl/language_models/unused/ReplicateBase.py +0 -83
  342. edsl/results/ResultsDBMixin.py +0 -238
  343. edsl-0.1.39.dev3.dist-info/RECORD +0 -277
  344. {edsl-0.1.39.dev3.dist-info → edsl-0.1.39.dev4.dist-info}/WHEEL +0 -0
@@ -1,224 +1,236 @@
1
- from __future__ import annotations
2
- from typing import Any, List, Optional
3
- import os
4
-
5
- import openai
6
-
7
- from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
8
- from edsl.language_models import LanguageModel
9
- from edsl.inference_services.rate_limits_cache import rate_limits
10
- from edsl.utilities.utilities import fix_partial_correct_response
11
-
12
- from edsl.config import CONFIG
13
-
14
-
15
- class OpenAIService(InferenceServiceABC):
16
- """OpenAI service class."""
17
-
18
- _inference_service_ = "openai"
19
- _env_key_name_ = "OPENAI_API_KEY"
20
- _base_url_ = None
21
-
22
- _sync_client_ = openai.OpenAI
23
- _async_client_ = openai.AsyncOpenAI
24
-
25
- _sync_client_instance = None
26
- _async_client_instance = None
27
-
28
- key_sequence = ["choices", 0, "message", "content"]
29
- usage_sequence = ["usage"]
30
- input_token_name = "prompt_tokens"
31
- output_token_name = "completion_tokens"
32
-
33
- def __init_subclass__(cls, **kwargs):
34
- super().__init_subclass__(**kwargs)
35
- # so subclasses have to create their own instances of the clients
36
- cls._sync_client_instance = None
37
- cls._async_client_instance = None
38
-
39
- @classmethod
40
- def sync_client(cls):
41
- if cls._sync_client_instance is None:
42
- cls._sync_client_instance = cls._sync_client_(
43
- api_key=os.getenv(cls._env_key_name_), base_url=cls._base_url_
44
- )
45
- return cls._sync_client_instance
46
-
47
- @classmethod
48
- def async_client(cls):
49
- if cls._async_client_instance is None:
50
- cls._async_client_instance = cls._async_client_(
51
- api_key=os.getenv(cls._env_key_name_), base_url=cls._base_url_
52
- )
53
- return cls._async_client_instance
54
-
55
- model_exclude_list = [
56
- "whisper-1",
57
- "davinci-002",
58
- "dall-e-2",
59
- "tts-1-hd-1106",
60
- "tts-1-hd",
61
- "dall-e-3",
62
- "tts-1",
63
- "babbage-002",
64
- "tts-1-1106",
65
- "text-embedding-3-large",
66
- "text-embedding-3-small",
67
- "text-embedding-ada-002",
68
- "ft:davinci-002:mit-horton-lab::8OfuHgoo",
69
- "gpt-3.5-turbo-instruct-0914",
70
- "gpt-3.5-turbo-instruct",
71
- ]
72
- _models_list_cache: List[str] = []
73
-
74
- @classmethod
75
- def get_model_list(cls):
76
- raw_list = cls.sync_client().models.list()
77
- if hasattr(raw_list, "data"):
78
- return raw_list.data
79
- else:
80
- return raw_list
81
-
82
- @classmethod
83
- def available(cls) -> List[str]:
84
- if not cls._models_list_cache:
85
- try:
86
- cls._models_list_cache = [
87
- m.id
88
- for m in cls.get_model_list()
89
- if m.id not in cls.model_exclude_list
90
- ]
91
- except Exception as e:
92
- raise
93
- return cls._models_list_cache
94
-
95
- @classmethod
96
- def create_model(cls, model_name, model_class_name=None) -> LanguageModel:
97
- if model_class_name is None:
98
- model_class_name = cls.to_class_name(model_name)
99
-
100
- class LLM(LanguageModel):
101
- """
102
- Child class of LanguageModel for interacting with OpenAI models
103
- """
104
-
105
- key_sequence = cls.key_sequence
106
- usage_sequence = cls.usage_sequence
107
- input_token_name = cls.input_token_name
108
- output_token_name = cls.output_token_name
109
-
110
- _rpm = cls.get_rpm(cls)
111
- _tpm = cls.get_tpm(cls)
112
-
113
- _inference_service_ = cls._inference_service_
114
- _model_ = model_name
115
- _parameters_ = {
116
- "temperature": 0.5,
117
- "max_tokens": 1000,
118
- "top_p": 1,
119
- "frequency_penalty": 0,
120
- "presence_penalty": 0,
121
- "logprobs": False,
122
- "top_logprobs": 3,
123
- }
124
-
125
- def sync_client(self):
126
- return cls.sync_client()
127
-
128
- def async_client(self):
129
- return cls.async_client()
130
-
131
- @classmethod
132
- def available(cls) -> list[str]:
133
- return cls.sync_client().models.list()
134
-
135
- def get_headers(self) -> dict[str, Any]:
136
- client = self.sync_client()
137
- response = client.chat.completions.with_raw_response.create(
138
- messages=[
139
- {
140
- "role": "user",
141
- "content": "Say this is a test",
142
- }
143
- ],
144
- model=self.model,
145
- )
146
- return dict(response.headers)
147
-
148
- def get_rate_limits(self) -> dict[str, Any]:
149
- try:
150
- if "openai" in rate_limits:
151
- headers = rate_limits["openai"]
152
-
153
- else:
154
- headers = self.get_headers()
155
-
156
- except Exception as e:
157
- return {
158
- "rpm": 10_000,
159
- "tpm": 2_000_000,
160
- }
161
- else:
162
- return {
163
- "rpm": int(headers["x-ratelimit-limit-requests"]),
164
- "tpm": int(headers["x-ratelimit-limit-tokens"]),
165
- }
166
-
167
- async def async_execute_model_call(
168
- self,
169
- user_prompt: str,
170
- system_prompt: str = "",
171
- files_list: Optional[List["Files"]] = None,
172
- invigilator: Optional[
173
- "InvigilatorAI"
174
- ] = None, # TBD - can eventually be used for function-calling
175
- ) -> dict[str, Any]:
176
- """Calls the OpenAI API and returns the API response."""
177
- if files_list:
178
- encoded_image = files_list[0].base64_string
179
- content = [{"type": "text", "text": user_prompt}]
180
- content.append(
181
- {
182
- "type": "image_url",
183
- "image_url": {
184
- "url": f"data:image/jpeg;base64,{encoded_image}"
185
- },
186
- }
187
- )
188
- else:
189
- content = user_prompt
190
- client = self.async_client()
191
-
192
- messages = [
193
- {"role": "system", "content": system_prompt},
194
- {"role": "user", "content": content},
195
- ]
196
- if (
197
- system_prompt == "" and self.omit_system_prompt_if_empty
198
- ) or "o1" in self.model:
199
- messages = messages[1:]
200
-
201
- params = {
202
- "model": self.model,
203
- "messages": messages,
204
- "temperature": self.temperature,
205
- "max_tokens": self.max_tokens,
206
- "top_p": self.top_p,
207
- "frequency_penalty": self.frequency_penalty,
208
- "presence_penalty": self.presence_penalty,
209
- "logprobs": self.logprobs,
210
- "top_logprobs": self.top_logprobs if self.logprobs else None,
211
- }
212
- if "o1" in self.model:
213
- params.pop("max_tokens")
214
- params["max_completion_tokens"] = self.max_tokens
215
- params["temperature"] = 1
216
- try:
217
- response = await client.chat.completions.create(**params)
218
- except Exception as e:
219
- print(e)
220
- return response.model_dump()
221
-
222
- LLM.__name__ = "LanguageModel"
223
-
224
- return LLM
1
+ from __future__ import annotations
2
+ from typing import Any, List, Optional, Dict, NewType
3
+ import os
4
+
5
+
6
+ import openai
7
+
8
+ from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
9
+ from edsl.language_models.LanguageModel import LanguageModel
10
+ from edsl.inference_services.rate_limits_cache import rate_limits
11
+ from edsl.utilities.utilities import fix_partial_correct_response
12
+
13
+ from edsl.config import CONFIG
14
+
15
+ APIToken = NewType("APIToken", str)
16
+
17
+
18
+ class OpenAIService(InferenceServiceABC):
19
+ """OpenAI service class."""
20
+
21
+ _inference_service_ = "openai"
22
+ _env_key_name_ = "OPENAI_API_KEY"
23
+ _base_url_ = None
24
+
25
+ _sync_client_ = openai.OpenAI
26
+ _async_client_ = openai.AsyncOpenAI
27
+
28
+ _sync_client_instances: Dict[APIToken, openai.OpenAI] = {}
29
+ _async_client_instances: Dict[APIToken, openai.AsyncOpenAI] = {}
30
+
31
+ key_sequence = ["choices", 0, "message", "content"]
32
+ usage_sequence = ["usage"]
33
+ input_token_name = "prompt_tokens"
34
+ output_token_name = "completion_tokens"
35
+
36
+ available_models_url = "https://platform.openai.com/docs/models/gp"
37
+
38
+ def __init_subclass__(cls, **kwargs):
39
+ super().__init_subclass__(**kwargs)
40
+ # so subclasses that use the OpenAI api key have to create their own instances of the clients
41
+ cls._sync_client_instances = {}
42
+ cls._async_client_instances = {}
43
+
44
+ @classmethod
45
+ def sync_client(cls, api_key):
46
+ if api_key not in cls._sync_client_instances:
47
+ client = cls._sync_client_(
48
+ api_key=api_key,
49
+ base_url=cls._base_url_,
50
+ )
51
+ cls._sync_client_instances[api_key] = client
52
+ client = cls._sync_client_instances[api_key]
53
+ return client
54
+
55
+ @classmethod
56
+ def async_client(cls, api_key):
57
+ if api_key not in cls._async_client_instances:
58
+ client = cls._async_client_(
59
+ api_key=api_key,
60
+ base_url=cls._base_url_,
61
+ )
62
+ cls._async_client_instances[api_key] = client
63
+ client = cls._async_client_instances[api_key]
64
+ return client
65
+
66
+ model_exclude_list = [
67
+ "whisper-1",
68
+ "davinci-002",
69
+ "dall-e-2",
70
+ "tts-1-hd-1106",
71
+ "tts-1-hd",
72
+ "dall-e-3",
73
+ "tts-1",
74
+ "babbage-002",
75
+ "tts-1-1106",
76
+ "text-embedding-3-large",
77
+ "text-embedding-3-small",
78
+ "text-embedding-ada-002",
79
+ "ft:davinci-002:mit-horton-lab::8OfuHgoo",
80
+ "gpt-3.5-turbo-instruct-0914",
81
+ "gpt-3.5-turbo-instruct",
82
+ ]
83
+ _models_list_cache: List[str] = []
84
+
85
+ @classmethod
86
+ def get_model_list(cls, api_key=None):
87
+ if api_key is None:
88
+ api_key = os.getenv(cls._env_key_name_)
89
+ raw_list = cls.sync_client(api_key).models.list()
90
+ if hasattr(raw_list, "data"):
91
+ return raw_list.data
92
+ else:
93
+ return raw_list
94
+
95
+ @classmethod
96
+ def available(cls, api_token=None) -> List[str]:
97
+ if api_token is None:
98
+ api_token = os.getenv(cls._env_key_name_)
99
+ if not cls._models_list_cache:
100
+ try:
101
+ cls._models_list_cache = [
102
+ m.id
103
+ for m in cls.get_model_list(api_key=api_token)
104
+ if m.id not in cls.model_exclude_list
105
+ ]
106
+ except Exception as e:
107
+ raise
108
+ return cls._models_list_cache
109
+
110
+ @classmethod
111
+ def create_model(cls, model_name, model_class_name=None) -> LanguageModel:
112
+ if model_class_name is None:
113
+ model_class_name = cls.to_class_name(model_name)
114
+
115
+ class LLM(LanguageModel):
116
+ """
117
+ Child class of LanguageModel for interacting with OpenAI models
118
+ """
119
+
120
+ key_sequence = cls.key_sequence
121
+ usage_sequence = cls.usage_sequence
122
+ input_token_name = cls.input_token_name
123
+ output_token_name = cls.output_token_name
124
+
125
+ _inference_service_ = cls._inference_service_
126
+ _model_ = model_name
127
+ _parameters_ = {
128
+ "temperature": 0.5,
129
+ "max_tokens": 1000,
130
+ "top_p": 1,
131
+ "frequency_penalty": 0,
132
+ "presence_penalty": 0,
133
+ "logprobs": False,
134
+ "top_logprobs": 3,
135
+ }
136
+
137
+ def sync_client(self):
138
+ return cls.sync_client(api_key=self.api_token)
139
+
140
+ def async_client(self):
141
+ return cls.async_client(api_key=self.api_token)
142
+
143
+ @classmethod
144
+ def available(cls) -> list[str]:
145
+ return cls.sync_client().models.list()
146
+
147
+ def get_headers(self) -> dict[str, Any]:
148
+ client = self.sync_client()
149
+ response = client.chat.completions.with_raw_response.create(
150
+ messages=[
151
+ {
152
+ "role": "user",
153
+ "content": "Say this is a test",
154
+ }
155
+ ],
156
+ model=self.model,
157
+ )
158
+ return dict(response.headers)
159
+
160
+ def get_rate_limits(self) -> dict[str, Any]:
161
+ try:
162
+ if "openai" in rate_limits:
163
+ headers = rate_limits["openai"]
164
+
165
+ else:
166
+ headers = self.get_headers()
167
+
168
+ except Exception as e:
169
+ return {
170
+ "rpm": 10_000,
171
+ "tpm": 2_000_000,
172
+ }
173
+ else:
174
+ return {
175
+ "rpm": int(headers["x-ratelimit-limit-requests"]),
176
+ "tpm": int(headers["x-ratelimit-limit-tokens"]),
177
+ }
178
+
179
+ async def async_execute_model_call(
180
+ self,
181
+ user_prompt: str,
182
+ system_prompt: str = "",
183
+ files_list: Optional[List["Files"]] = None,
184
+ invigilator: Optional[
185
+ "InvigilatorAI"
186
+ ] = None, # TBD - can eventually be used for function-calling
187
+ ) -> dict[str, Any]:
188
+ """Calls the OpenAI API and returns the API response."""
189
+ if files_list:
190
+ content = [{"type": "text", "text": user_prompt}]
191
+ for file_entry in files_list:
192
+ content.append(
193
+ {
194
+ "type": "image_url",
195
+ "image_url": {
196
+ "url": f"data:{file_entry.mime_type};base64,{file_entry.base64_string}"
197
+ },
198
+ }
199
+ )
200
+ else:
201
+ content = user_prompt
202
+ client = self.async_client()
203
+
204
+ messages = [
205
+ {"role": "system", "content": system_prompt},
206
+ {"role": "user", "content": content},
207
+ ]
208
+ if (
209
+ system_prompt == "" and self.omit_system_prompt_if_empty
210
+ ) or "o1" in self.model:
211
+ messages = messages[1:]
212
+
213
+ params = {
214
+ "model": self.model,
215
+ "messages": messages,
216
+ "temperature": self.temperature,
217
+ "max_tokens": self.max_tokens,
218
+ "top_p": self.top_p,
219
+ "frequency_penalty": self.frequency_penalty,
220
+ "presence_penalty": self.presence_penalty,
221
+ "logprobs": self.logprobs,
222
+ "top_logprobs": self.top_logprobs if self.logprobs else None,
223
+ }
224
+ if "o1" in self.model:
225
+ params.pop("max_tokens")
226
+ params["max_completion_tokens"] = self.max_tokens
227
+ params["temperature"] = 1
228
+ try:
229
+ response = await client.chat.completions.create(**params)
230
+ except Exception as e:
231
+ print(e)
232
+ return response.model_dump()
233
+
234
+ LLM.__name__ = "LanguageModel"
235
+
236
+ return LLM