edsl 0.1.38.dev3__py3-none-any.whl → 0.1.39__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (341) hide show
  1. edsl/Base.py +413 -303
  2. edsl/BaseDiff.py +260 -260
  3. edsl/TemplateLoader.py +24 -24
  4. edsl/__init__.py +57 -49
  5. edsl/__version__.py +1 -1
  6. edsl/agents/Agent.py +1071 -858
  7. edsl/agents/AgentList.py +551 -362
  8. edsl/agents/Invigilator.py +284 -222
  9. edsl/agents/InvigilatorBase.py +257 -284
  10. edsl/agents/PromptConstructor.py +272 -353
  11. edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
  12. edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
  13. edsl/agents/__init__.py +2 -3
  14. edsl/agents/descriptors.py +99 -99
  15. edsl/agents/prompt_helpers.py +129 -129
  16. edsl/agents/question_option_processor.py +172 -0
  17. edsl/auto/AutoStudy.py +130 -117
  18. edsl/auto/StageBase.py +243 -230
  19. edsl/auto/StageGenerateSurvey.py +178 -178
  20. edsl/auto/StageLabelQuestions.py +125 -125
  21. edsl/auto/StagePersona.py +61 -61
  22. edsl/auto/StagePersonaDimensionValueRanges.py +88 -88
  23. edsl/auto/StagePersonaDimensionValues.py +74 -74
  24. edsl/auto/StagePersonaDimensions.py +69 -69
  25. edsl/auto/StageQuestions.py +74 -73
  26. edsl/auto/SurveyCreatorPipeline.py +21 -21
  27. edsl/auto/utilities.py +218 -224
  28. edsl/base/Base.py +279 -279
  29. edsl/config.py +177 -149
  30. edsl/conversation/Conversation.py +290 -290
  31. edsl/conversation/car_buying.py +59 -58
  32. edsl/conversation/chips.py +95 -95
  33. edsl/conversation/mug_negotiation.py +81 -81
  34. edsl/conversation/next_speaker_utilities.py +93 -93
  35. edsl/coop/CoopFunctionsMixin.py +15 -0
  36. edsl/coop/ExpectedParrotKeyHandler.py +125 -0
  37. edsl/coop/PriceFetcher.py +54 -54
  38. edsl/coop/__init__.py +2 -2
  39. edsl/coop/coop.py +1106 -961
  40. edsl/coop/utils.py +131 -131
  41. edsl/data/Cache.py +573 -530
  42. edsl/data/CacheEntry.py +230 -228
  43. edsl/data/CacheHandler.py +168 -149
  44. edsl/data/RemoteCacheSync.py +186 -97
  45. edsl/data/SQLiteDict.py +292 -292
  46. edsl/data/__init__.py +5 -4
  47. edsl/data/orm.py +10 -10
  48. edsl/data_transfer_models.py +74 -73
  49. edsl/enums.py +202 -173
  50. edsl/exceptions/BaseException.py +21 -21
  51. edsl/exceptions/__init__.py +54 -54
  52. edsl/exceptions/agents.py +54 -42
  53. edsl/exceptions/cache.py +5 -5
  54. edsl/exceptions/configuration.py +16 -16
  55. edsl/exceptions/coop.py +10 -10
  56. edsl/exceptions/data.py +14 -14
  57. edsl/exceptions/general.py +34 -34
  58. edsl/exceptions/inference_services.py +5 -0
  59. edsl/exceptions/jobs.py +33 -33
  60. edsl/exceptions/language_models.py +63 -63
  61. edsl/exceptions/prompts.py +15 -15
  62. edsl/exceptions/questions.py +109 -91
  63. edsl/exceptions/results.py +29 -29
  64. edsl/exceptions/scenarios.py +29 -22
  65. edsl/exceptions/surveys.py +37 -37
  66. edsl/inference_services/AnthropicService.py +106 -87
  67. edsl/inference_services/AvailableModelCacheHandler.py +184 -0
  68. edsl/inference_services/AvailableModelFetcher.py +215 -0
  69. edsl/inference_services/AwsBedrock.py +118 -120
  70. edsl/inference_services/AzureAI.py +215 -217
  71. edsl/inference_services/DeepInfraService.py +18 -18
  72. edsl/inference_services/GoogleService.py +143 -156
  73. edsl/inference_services/GroqService.py +20 -20
  74. edsl/inference_services/InferenceServiceABC.py +80 -147
  75. edsl/inference_services/InferenceServicesCollection.py +138 -97
  76. edsl/inference_services/MistralAIService.py +120 -123
  77. edsl/inference_services/OllamaService.py +18 -18
  78. edsl/inference_services/OpenAIService.py +236 -224
  79. edsl/inference_services/PerplexityService.py +160 -0
  80. edsl/inference_services/ServiceAvailability.py +135 -0
  81. edsl/inference_services/TestService.py +90 -89
  82. edsl/inference_services/TogetherAIService.py +172 -170
  83. edsl/inference_services/data_structures.py +134 -0
  84. edsl/inference_services/models_available_cache.py +118 -118
  85. edsl/inference_services/rate_limits_cache.py +25 -25
  86. edsl/inference_services/registry.py +41 -39
  87. edsl/inference_services/write_available.py +10 -10
  88. edsl/jobs/AnswerQuestionFunctionConstructor.py +223 -0
  89. edsl/jobs/Answers.py +43 -56
  90. edsl/jobs/FetchInvigilator.py +47 -0
  91. edsl/jobs/InterviewTaskManager.py +98 -0
  92. edsl/jobs/InterviewsConstructor.py +50 -0
  93. edsl/jobs/Jobs.py +823 -1358
  94. edsl/jobs/JobsChecks.py +172 -0
  95. edsl/jobs/JobsComponentConstructor.py +189 -0
  96. edsl/jobs/JobsPrompts.py +270 -0
  97. edsl/jobs/JobsRemoteInferenceHandler.py +311 -0
  98. edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
  99. edsl/jobs/RequestTokenEstimator.py +30 -0
  100. edsl/jobs/__init__.py +1 -1
  101. edsl/jobs/async_interview_runner.py +138 -0
  102. edsl/jobs/buckets/BucketCollection.py +104 -63
  103. edsl/jobs/buckets/ModelBuckets.py +65 -65
  104. edsl/jobs/buckets/TokenBucket.py +283 -251
  105. edsl/jobs/buckets/TokenBucketAPI.py +211 -0
  106. edsl/jobs/buckets/TokenBucketClient.py +191 -0
  107. edsl/jobs/check_survey_scenario_compatibility.py +85 -0
  108. edsl/jobs/data_structures.py +120 -0
  109. edsl/jobs/decorators.py +35 -0
  110. edsl/jobs/interviews/Interview.py +396 -661
  111. edsl/jobs/interviews/InterviewExceptionCollection.py +99 -99
  112. edsl/jobs/interviews/InterviewExceptionEntry.py +186 -186
  113. edsl/jobs/interviews/InterviewStatistic.py +63 -63
  114. edsl/jobs/interviews/InterviewStatisticsCollection.py +25 -25
  115. edsl/jobs/interviews/InterviewStatusDictionary.py +78 -78
  116. edsl/jobs/interviews/InterviewStatusLog.py +92 -92
  117. edsl/jobs/interviews/ReportErrors.py +66 -66
  118. edsl/jobs/interviews/interview_status_enum.py +9 -9
  119. edsl/jobs/jobs_status_enums.py +9 -0
  120. edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
  121. edsl/jobs/results_exceptions_handler.py +98 -0
  122. edsl/jobs/runners/JobsRunnerAsyncio.py +151 -361
  123. edsl/jobs/runners/JobsRunnerStatus.py +298 -332
  124. edsl/jobs/tasks/QuestionTaskCreator.py +244 -242
  125. edsl/jobs/tasks/TaskCreators.py +64 -64
  126. edsl/jobs/tasks/TaskHistory.py +470 -451
  127. edsl/jobs/tasks/TaskStatusLog.py +23 -23
  128. edsl/jobs/tasks/task_status_enum.py +161 -163
  129. edsl/jobs/tokens/InterviewTokenUsage.py +27 -27
  130. edsl/jobs/tokens/TokenUsage.py +34 -34
  131. edsl/language_models/ComputeCost.py +63 -0
  132. edsl/language_models/LanguageModel.py +626 -708
  133. edsl/language_models/ModelList.py +164 -109
  134. edsl/language_models/PriceManager.py +127 -0
  135. edsl/language_models/RawResponseHandler.py +106 -0
  136. edsl/language_models/RegisterLanguageModelsMeta.py +184 -184
  137. edsl/language_models/ServiceDataSources.py +0 -0
  138. edsl/language_models/__init__.py +2 -3
  139. edsl/language_models/fake_openai_call.py +15 -15
  140. edsl/language_models/fake_openai_service.py +61 -61
  141. edsl/language_models/key_management/KeyLookup.py +63 -0
  142. edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
  143. edsl/language_models/key_management/KeyLookupCollection.py +38 -0
  144. edsl/language_models/key_management/__init__.py +0 -0
  145. edsl/language_models/key_management/models.py +131 -0
  146. edsl/language_models/model.py +256 -0
  147. edsl/language_models/repair.py +156 -156
  148. edsl/language_models/utilities.py +65 -64
  149. edsl/notebooks/Notebook.py +263 -258
  150. edsl/notebooks/NotebookToLaTeX.py +142 -0
  151. edsl/notebooks/__init__.py +1 -1
  152. edsl/prompts/Prompt.py +352 -357
  153. edsl/prompts/__init__.py +2 -2
  154. edsl/questions/ExceptionExplainer.py +77 -0
  155. edsl/questions/HTMLQuestion.py +103 -0
  156. edsl/questions/QuestionBase.py +518 -660
  157. edsl/questions/QuestionBasePromptsMixin.py +221 -217
  158. edsl/questions/QuestionBudget.py +227 -227
  159. edsl/questions/QuestionCheckBox.py +359 -359
  160. edsl/questions/QuestionExtract.py +180 -183
  161. edsl/questions/QuestionFreeText.py +113 -114
  162. edsl/questions/QuestionFunctional.py +166 -166
  163. edsl/questions/QuestionList.py +223 -231
  164. edsl/questions/QuestionMatrix.py +265 -0
  165. edsl/questions/QuestionMultipleChoice.py +330 -286
  166. edsl/questions/QuestionNumerical.py +151 -153
  167. edsl/questions/QuestionRank.py +314 -324
  168. edsl/questions/Quick.py +41 -41
  169. edsl/questions/SimpleAskMixin.py +74 -73
  170. edsl/questions/__init__.py +27 -26
  171. edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +334 -289
  172. edsl/questions/compose_questions.py +98 -98
  173. edsl/questions/data_structures.py +20 -0
  174. edsl/questions/decorators.py +21 -21
  175. edsl/questions/derived/QuestionLikertFive.py +76 -76
  176. edsl/questions/derived/QuestionLinearScale.py +90 -87
  177. edsl/questions/derived/QuestionTopK.py +93 -93
  178. edsl/questions/derived/QuestionYesNo.py +82 -82
  179. edsl/questions/descriptors.py +427 -413
  180. edsl/questions/loop_processor.py +149 -0
  181. edsl/questions/prompt_templates/question_budget.jinja +13 -13
  182. edsl/questions/prompt_templates/question_checkbox.jinja +32 -32
  183. edsl/questions/prompt_templates/question_extract.jinja +11 -11
  184. edsl/questions/prompt_templates/question_free_text.jinja +3 -3
  185. edsl/questions/prompt_templates/question_linear_scale.jinja +11 -11
  186. edsl/questions/prompt_templates/question_list.jinja +17 -17
  187. edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -33
  188. edsl/questions/prompt_templates/question_numerical.jinja +36 -36
  189. edsl/questions/{QuestionBaseGenMixin.py → question_base_gen_mixin.py} +168 -161
  190. edsl/questions/question_registry.py +177 -147
  191. edsl/questions/{RegisterQuestionsMeta.py → register_questions_meta.py} +71 -71
  192. edsl/questions/{ResponseValidatorABC.py → response_validator_abc.py} +188 -174
  193. edsl/questions/response_validator_factory.py +34 -0
  194. edsl/questions/settings.py +12 -12
  195. edsl/questions/templates/budget/answering_instructions.jinja +7 -7
  196. edsl/questions/templates/budget/question_presentation.jinja +7 -7
  197. edsl/questions/templates/checkbox/answering_instructions.jinja +10 -10
  198. edsl/questions/templates/checkbox/question_presentation.jinja +22 -22
  199. edsl/questions/templates/extract/answering_instructions.jinja +7 -7
  200. edsl/questions/templates/likert_five/answering_instructions.jinja +10 -10
  201. edsl/questions/templates/likert_five/question_presentation.jinja +11 -11
  202. edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -5
  203. edsl/questions/templates/linear_scale/question_presentation.jinja +5 -5
  204. edsl/questions/templates/list/answering_instructions.jinja +3 -3
  205. edsl/questions/templates/list/question_presentation.jinja +5 -5
  206. edsl/questions/templates/matrix/__init__.py +1 -0
  207. edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
  208. edsl/questions/templates/matrix/question_presentation.jinja +20 -0
  209. edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -9
  210. edsl/questions/templates/multiple_choice/question_presentation.jinja +11 -11
  211. edsl/questions/templates/numerical/answering_instructions.jinja +6 -6
  212. edsl/questions/templates/numerical/question_presentation.jinja +6 -6
  213. edsl/questions/templates/rank/answering_instructions.jinja +11 -11
  214. edsl/questions/templates/rank/question_presentation.jinja +15 -15
  215. edsl/questions/templates/top_k/answering_instructions.jinja +8 -8
  216. edsl/questions/templates/top_k/question_presentation.jinja +22 -22
  217. edsl/questions/templates/yes_no/answering_instructions.jinja +6 -6
  218. edsl/questions/templates/yes_no/question_presentation.jinja +11 -11
  219. edsl/results/CSSParameterizer.py +108 -0
  220. edsl/results/Dataset.py +587 -293
  221. edsl/results/DatasetExportMixin.py +594 -717
  222. edsl/results/DatasetTree.py +295 -145
  223. edsl/results/MarkdownToDocx.py +122 -0
  224. edsl/results/MarkdownToPDF.py +111 -0
  225. edsl/results/Result.py +557 -456
  226. edsl/results/Results.py +1183 -1071
  227. edsl/results/ResultsExportMixin.py +45 -43
  228. edsl/results/ResultsGGMixin.py +121 -121
  229. edsl/results/TableDisplay.py +125 -0
  230. edsl/results/TextEditor.py +50 -0
  231. edsl/results/__init__.py +2 -2
  232. edsl/results/file_exports.py +252 -0
  233. edsl/results/{ResultsFetchMixin.py → results_fetch_mixin.py} +33 -33
  234. edsl/results/{Selector.py → results_selector.py} +145 -135
  235. edsl/results/{ResultsToolsMixin.py → results_tools_mixin.py} +98 -98
  236. edsl/results/smart_objects.py +96 -0
  237. edsl/results/table_data_class.py +12 -0
  238. edsl/results/table_display.css +78 -0
  239. edsl/results/table_renderers.py +118 -0
  240. edsl/results/tree_explore.py +115 -115
  241. edsl/scenarios/ConstructDownloadLink.py +109 -0
  242. edsl/scenarios/DocumentChunker.py +102 -0
  243. edsl/scenarios/DocxScenario.py +16 -0
  244. edsl/scenarios/FileStore.py +543 -458
  245. edsl/scenarios/PdfExtractor.py +40 -0
  246. edsl/scenarios/Scenario.py +498 -544
  247. edsl/scenarios/ScenarioHtmlMixin.py +65 -64
  248. edsl/scenarios/ScenarioList.py +1458 -1112
  249. edsl/scenarios/ScenarioListExportMixin.py +45 -52
  250. edsl/scenarios/ScenarioListPdfMixin.py +239 -261
  251. edsl/scenarios/__init__.py +3 -4
  252. edsl/scenarios/directory_scanner.py +96 -0
  253. edsl/scenarios/file_methods.py +85 -0
  254. edsl/scenarios/handlers/__init__.py +13 -0
  255. edsl/scenarios/handlers/csv.py +49 -0
  256. edsl/scenarios/handlers/docx.py +76 -0
  257. edsl/scenarios/handlers/html.py +37 -0
  258. edsl/scenarios/handlers/json.py +111 -0
  259. edsl/scenarios/handlers/latex.py +5 -0
  260. edsl/scenarios/handlers/md.py +51 -0
  261. edsl/scenarios/handlers/pdf.py +68 -0
  262. edsl/scenarios/handlers/png.py +39 -0
  263. edsl/scenarios/handlers/pptx.py +105 -0
  264. edsl/scenarios/handlers/py.py +294 -0
  265. edsl/scenarios/handlers/sql.py +313 -0
  266. edsl/scenarios/handlers/sqlite.py +149 -0
  267. edsl/scenarios/handlers/txt.py +33 -0
  268. edsl/scenarios/scenario_join.py +131 -0
  269. edsl/scenarios/scenario_selector.py +156 -0
  270. edsl/shared.py +1 -1
  271. edsl/study/ObjectEntry.py +173 -173
  272. edsl/study/ProofOfWork.py +113 -113
  273. edsl/study/SnapShot.py +80 -80
  274. edsl/study/Study.py +521 -528
  275. edsl/study/__init__.py +4 -4
  276. edsl/surveys/ConstructDAG.py +92 -0
  277. edsl/surveys/DAG.py +148 -148
  278. edsl/surveys/EditSurvey.py +221 -0
  279. edsl/surveys/InstructionHandler.py +100 -0
  280. edsl/surveys/Memory.py +31 -31
  281. edsl/surveys/MemoryManagement.py +72 -0
  282. edsl/surveys/MemoryPlan.py +244 -244
  283. edsl/surveys/Rule.py +327 -326
  284. edsl/surveys/RuleCollection.py +385 -387
  285. edsl/surveys/RuleManager.py +172 -0
  286. edsl/surveys/Simulator.py +75 -0
  287. edsl/surveys/Survey.py +1280 -1787
  288. edsl/surveys/SurveyCSS.py +273 -261
  289. edsl/surveys/SurveyExportMixin.py +259 -259
  290. edsl/surveys/{SurveyFlowVisualizationMixin.py → SurveyFlowVisualization.py} +181 -121
  291. edsl/surveys/SurveyQualtricsImport.py +284 -284
  292. edsl/surveys/SurveyToApp.py +141 -0
  293. edsl/surveys/__init__.py +5 -3
  294. edsl/surveys/base.py +53 -53
  295. edsl/surveys/descriptors.py +60 -56
  296. edsl/surveys/instructions/ChangeInstruction.py +48 -49
  297. edsl/surveys/instructions/Instruction.py +56 -53
  298. edsl/surveys/instructions/InstructionCollection.py +82 -77
  299. edsl/templates/error_reporting/base.html +23 -23
  300. edsl/templates/error_reporting/exceptions_by_model.html +34 -34
  301. edsl/templates/error_reporting/exceptions_by_question_name.html +16 -16
  302. edsl/templates/error_reporting/exceptions_by_type.html +16 -16
  303. edsl/templates/error_reporting/interview_details.html +115 -115
  304. edsl/templates/error_reporting/interviews.html +19 -10
  305. edsl/templates/error_reporting/overview.html +4 -4
  306. edsl/templates/error_reporting/performance_plot.html +1 -1
  307. edsl/templates/error_reporting/report.css +73 -73
  308. edsl/templates/error_reporting/report.html +117 -117
  309. edsl/templates/error_reporting/report.js +25 -25
  310. edsl/tools/__init__.py +1 -1
  311. edsl/tools/clusters.py +192 -192
  312. edsl/tools/embeddings.py +27 -27
  313. edsl/tools/embeddings_plotting.py +118 -118
  314. edsl/tools/plotting.py +112 -112
  315. edsl/tools/summarize.py +18 -18
  316. edsl/utilities/PrettyList.py +56 -0
  317. edsl/utilities/SystemInfo.py +28 -28
  318. edsl/utilities/__init__.py +22 -22
  319. edsl/utilities/ast_utilities.py +25 -25
  320. edsl/utilities/data/Registry.py +6 -6
  321. edsl/utilities/data/__init__.py +1 -1
  322. edsl/utilities/data/scooter_results.json +1 -1
  323. edsl/utilities/decorators.py +77 -77
  324. edsl/utilities/gcp_bucket/cloud_storage.py +96 -96
  325. edsl/utilities/interface.py +627 -627
  326. edsl/utilities/is_notebook.py +18 -0
  327. edsl/utilities/is_valid_variable_name.py +11 -0
  328. edsl/utilities/naming_utilities.py +263 -263
  329. edsl/utilities/remove_edsl_version.py +24 -0
  330. edsl/utilities/repair_functions.py +28 -28
  331. edsl/utilities/restricted_python.py +70 -70
  332. edsl/utilities/utilities.py +436 -409
  333. {edsl-0.1.38.dev3.dist-info → edsl-0.1.39.dist-info}/LICENSE +21 -21
  334. {edsl-0.1.38.dev3.dist-info → edsl-0.1.39.dist-info}/METADATA +13 -10
  335. edsl-0.1.39.dist-info/RECORD +358 -0
  336. {edsl-0.1.38.dev3.dist-info → edsl-0.1.39.dist-info}/WHEEL +1 -1
  337. edsl/language_models/KeyLookup.py +0 -30
  338. edsl/language_models/registry.py +0 -137
  339. edsl/language_models/unused/ReplicateBase.py +0 -83
  340. edsl/results/ResultsDBMixin.py +0 -238
  341. edsl-0.1.38.dev3.dist-info/RECORD +0 -269
@@ -1,156 +1,143 @@
1
- import os
2
- from typing import Any, Dict, List, Optional
3
- import google
4
- import google.generativeai as genai
5
- from google.generativeai.types import GenerationConfig
6
- from google.api_core.exceptions import InvalidArgument
7
-
8
- from edsl.exceptions import MissingAPIKeyError
9
- from edsl.language_models.LanguageModel import LanguageModel
10
- from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
11
-
12
- safety_settings = [
13
- {
14
- "category": "HARM_CATEGORY_HARASSMENT",
15
- "threshold": "BLOCK_NONE",
16
- },
17
- {
18
- "category": "HARM_CATEGORY_HATE_SPEECH",
19
- "threshold": "BLOCK_NONE",
20
- },
21
- {
22
- "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
23
- "threshold": "BLOCK_NONE",
24
- },
25
- {
26
- "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
27
- "threshold": "BLOCK_NONE",
28
- },
29
- ]
30
-
31
-
32
- class GoogleService(InferenceServiceABC):
33
- _inference_service_ = "google"
34
- key_sequence = ["candidates", 0, "content", "parts", 0, "text"]
35
- usage_sequence = ["usage_metadata"]
36
- input_token_name = "prompt_token_count"
37
- output_token_name = "candidates_token_count"
38
-
39
- model_exclude_list = []
40
-
41
- # @classmethod
42
- # def available(cls) -> List[str]:
43
- # return ["gemini-pro", "gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.0-pro"]
44
-
45
- @classmethod
46
- def available(cls) -> List[str]:
47
- model_list = []
48
- for m in genai.list_models():
49
- if "generateContent" in m.supported_generation_methods:
50
- model_list.append(m.name.split("/")[-1])
51
- return model_list
52
-
53
- @classmethod
54
- def create_model(
55
- cls, model_name: str = "gemini-pro", model_class_name=None
56
- ) -> LanguageModel:
57
- if model_class_name is None:
58
- model_class_name = cls.to_class_name(model_name)
59
-
60
- class LLM(LanguageModel):
61
- _model_ = model_name
62
- key_sequence = cls.key_sequence
63
- usage_sequence = cls.usage_sequence
64
- input_token_name = cls.input_token_name
65
- output_token_name = cls.output_token_name
66
- _inference_service_ = cls._inference_service_
67
-
68
- _tpm = cls.get_tpm(cls)
69
- _rpm = cls.get_rpm(cls)
70
-
71
- _parameters_ = {
72
- "temperature": 0.5,
73
- "topP": 1,
74
- "topK": 1,
75
- "maxOutputTokens": 2048,
76
- "stopSequences": [],
77
- }
78
-
79
- api_token = None
80
- model = None
81
-
82
- @classmethod
83
- def initialize(cls):
84
- if cls.api_token is None:
85
- cls.api_token = os.getenv("GOOGLE_API_KEY")
86
- if not cls.api_token:
87
- raise MissingAPIKeyError(
88
- "GOOGLE_API_KEY environment variable is not set"
89
- )
90
- genai.configure(api_key=cls.api_token)
91
- cls.generative_model = genai.GenerativeModel(
92
- cls._model_, safety_settings=safety_settings
93
- )
94
-
95
- def __init__(self, *args, **kwargs):
96
- super().__init__(*args, **kwargs)
97
- self.initialize()
98
-
99
- def get_generation_config(self) -> GenerationConfig:
100
- return GenerationConfig(
101
- temperature=self.temperature,
102
- top_p=self.topP,
103
- top_k=self.topK,
104
- max_output_tokens=self.maxOutputTokens,
105
- stop_sequences=self.stopSequences,
106
- )
107
-
108
- async def async_execute_model_call(
109
- self,
110
- user_prompt: str,
111
- system_prompt: str = "",
112
- files_list: Optional["Files"] = None,
113
- ) -> Dict[str, Any]:
114
- generation_config = self.get_generation_config()
115
-
116
- if files_list is None:
117
- files_list = []
118
-
119
- if (
120
- system_prompt is not None
121
- and system_prompt != ""
122
- and self._model_ != "gemini-pro"
123
- ):
124
- try:
125
- self.generative_model = genai.GenerativeModel(
126
- self._model_,
127
- safety_settings=safety_settings,
128
- system_instruction=system_prompt,
129
- )
130
- except InvalidArgument as e:
131
- print(
132
- f"This model, {self._model_}, does not support system_instruction"
133
- )
134
- print("Will add system_prompt to user_prompt")
135
- user_prompt = f"{system_prompt}\n{user_prompt}"
136
-
137
- combined_prompt = [user_prompt]
138
- for file in files_list:
139
- if "google" not in file.external_locations:
140
- _ = file.upload_google()
141
- gen_ai_file = google.generativeai.types.file_types.File(
142
- file.external_locations["google"]
143
- )
144
- combined_prompt.append(gen_ai_file)
145
-
146
- response = await self.generative_model.generate_content_async(
147
- combined_prompt, generation_config=generation_config
148
- )
149
- return response.to_dict()
150
-
151
- LLM.__name__ = model_name
152
- return LLM
153
-
154
-
155
- if __name__ == "__main__":
156
- pass
1
+ # import os
2
+ from typing import Any, Dict, List, Optional
3
+ import google
4
+ import google.generativeai as genai
5
+ from google.generativeai.types import GenerationConfig
6
+ from google.api_core.exceptions import InvalidArgument
7
+
8
+ # from edsl.exceptions.general import MissingAPIKeyError
9
+ from edsl.language_models.LanguageModel import LanguageModel
10
+ from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
11
+ from edsl.coop import Coop
12
+
13
+ safety_settings = [
14
+ {
15
+ "category": "HARM_CATEGORY_HARASSMENT",
16
+ "threshold": "BLOCK_NONE",
17
+ },
18
+ {
19
+ "category": "HARM_CATEGORY_HATE_SPEECH",
20
+ "threshold": "BLOCK_NONE",
21
+ },
22
+ {
23
+ "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
24
+ "threshold": "BLOCK_NONE",
25
+ },
26
+ {
27
+ "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
28
+ "threshold": "BLOCK_NONE",
29
+ },
30
+ ]
31
+
32
+
33
+ class GoogleService(InferenceServiceABC):
34
+ _inference_service_ = "google"
35
+ key_sequence = ["candidates", 0, "content", "parts", 0, "text"]
36
+ usage_sequence = ["usage_metadata"]
37
+ input_token_name = "prompt_token_count"
38
+ output_token_name = "candidates_token_count"
39
+
40
+ model_exclude_list = []
41
+
42
+ @classmethod
43
+ def get_model_list(cls):
44
+ model_list = []
45
+ for m in genai.list_models():
46
+ if "generateContent" in m.supported_generation_methods:
47
+ model_list.append(m.name.split("/")[-1])
48
+ return model_list
49
+
50
+ @classmethod
51
+ def available(cls) -> List[str]:
52
+ return cls.get_model_list()
53
+
54
+ @classmethod
55
+ def create_model(
56
+ cls, model_name: str = "gemini-pro", model_class_name=None
57
+ ) -> LanguageModel:
58
+ if model_class_name is None:
59
+ model_class_name = cls.to_class_name(model_name)
60
+
61
+ class LLM(LanguageModel):
62
+ _model_ = model_name
63
+ key_sequence = cls.key_sequence
64
+ usage_sequence = cls.usage_sequence
65
+ input_token_name = cls.input_token_name
66
+ output_token_name = cls.output_token_name
67
+ _inference_service_ = cls._inference_service_
68
+
69
+ _parameters_ = {
70
+ "temperature": 0.5,
71
+ "topP": 1,
72
+ "topK": 1,
73
+ "maxOutputTokens": 2048,
74
+ "stopSequences": [],
75
+ }
76
+
77
+ model = None
78
+
79
+ def __init__(self, *args, **kwargs):
80
+ super().__init__(*args, **kwargs)
81
+
82
+ def get_generation_config(self) -> GenerationConfig:
83
+ return GenerationConfig(
84
+ temperature=self.temperature,
85
+ top_p=self.topP,
86
+ top_k=self.topK,
87
+ max_output_tokens=self.maxOutputTokens,
88
+ stop_sequences=self.stopSequences,
89
+ )
90
+
91
+ async def async_execute_model_call(
92
+ self,
93
+ user_prompt: str,
94
+ system_prompt: str = "",
95
+ files_list: Optional["Files"] = None,
96
+ ) -> Dict[str, Any]:
97
+ generation_config = self.get_generation_config()
98
+
99
+ if files_list is None:
100
+ files_list = []
101
+ genai.configure(api_key=self.api_token)
102
+ if (
103
+ system_prompt is not None
104
+ and system_prompt != ""
105
+ and self._model_ != "gemini-pro"
106
+ ):
107
+ try:
108
+ self.generative_model = genai.GenerativeModel(
109
+ self._model_,
110
+ safety_settings=safety_settings,
111
+ system_instruction=system_prompt,
112
+ )
113
+ except InvalidArgument as e:
114
+ print(
115
+ f"This model, {self._model_}, does not support system_instruction"
116
+ )
117
+ print("Will add system_prompt to user_prompt")
118
+ user_prompt = f"{system_prompt}\n{user_prompt}"
119
+ else:
120
+ self.generative_model = genai.GenerativeModel(
121
+ self._model_,
122
+ safety_settings=safety_settings,
123
+ )
124
+ combined_prompt = [user_prompt]
125
+ for file in files_list:
126
+ if "google" not in file.external_locations:
127
+ _ = file.upload_google()
128
+ gen_ai_file = google.generativeai.types.file_types.File(
129
+ file.external_locations["google"]
130
+ )
131
+ combined_prompt.append(gen_ai_file)
132
+
133
+ response = await self.generative_model.generate_content_async(
134
+ combined_prompt, generation_config=generation_config
135
+ )
136
+ return response.to_dict()
137
+
138
+ LLM.__name__ = model_name
139
+ return LLM
140
+
141
+
142
+ if __name__ == "__main__":
143
+ pass
@@ -1,20 +1,20 @@
1
- from typing import Any, List
2
- from edsl.inference_services.OpenAIService import OpenAIService
3
-
4
- import groq
5
-
6
-
7
- class GroqService(OpenAIService):
8
- """DeepInfra service class."""
9
-
10
- _inference_service_ = "groq"
11
- _env_key_name_ = "GROQ_API_KEY"
12
-
13
- _sync_client_ = groq.Groq
14
- _async_client_ = groq.AsyncGroq
15
-
16
- model_exclude_list = ["whisper-large-v3", "distil-whisper-large-v3-en"]
17
-
18
- # _base_url_ = "https://api.deepinfra.com/v1/openai"
19
- _base_url_ = None
20
- _models_list_cache: List[str] = []
1
+ from typing import Any, List
2
+ from edsl.inference_services.OpenAIService import OpenAIService
3
+
4
+ import groq
5
+
6
+
7
+ class GroqService(OpenAIService):
8
+ """DeepInfra service class."""
9
+
10
+ _inference_service_ = "groq"
11
+ _env_key_name_ = "GROQ_API_KEY"
12
+
13
+ _sync_client_ = groq.Groq
14
+ _async_client_ = groq.AsyncGroq
15
+
16
+ model_exclude_list = ["whisper-large-v3", "distil-whisper-large-v3-en"]
17
+
18
+ # _base_url_ = "https://api.deepinfra.com/v1/openai"
19
+ _base_url_ = None
20
+ _models_list_cache: List[str] = []
@@ -1,147 +1,80 @@
1
- from abc import abstractmethod, ABC
2
- import os
3
- import re
4
- from datetime import datetime, timedelta
5
- from edsl.config import CONFIG
6
-
7
-
8
- class InferenceServiceABC(ABC):
9
- """
10
- Abstract class for inference services.
11
- Anthropic: https://docs.anthropic.com/en/api/rate-limits
12
- """
13
-
14
- _coop_config_vars = None
15
-
16
- default_levels = {
17
- "google": {"tpm": 2_000_000, "rpm": 15},
18
- "openai": {"tpm": 2_000_000, "rpm": 10_000},
19
- "anthropic": {"tpm": 2_000_000, "rpm": 500},
20
- }
21
-
22
- def __init_subclass__(cls):
23
- """
24
- Check that the subclass has the required attributes.
25
- - `key_sequence` attribute determines...
26
- - `model_exclude_list` attribute determines...
27
- """
28
- if not hasattr(cls, "key_sequence"):
29
- raise NotImplementedError(
30
- f"Class {cls.__name__} must have a 'key_sequence' attribute."
31
- )
32
- if not hasattr(cls, "model_exclude_list"):
33
- raise NotImplementedError(
34
- f"Class {cls.__name__} must have a 'model_exclude_list' attribute."
35
- )
36
-
37
- @classmethod
38
- def _should_refresh_coop_config_vars(cls):
39
- """
40
- Returns True if config vars have been fetched over 24 hours ago, and False otherwise.
41
- """
42
-
43
- if cls._last_config_fetch is None:
44
- return True
45
- return (datetime.now() - cls._last_config_fetch) > timedelta(hours=24)
46
-
47
- @classmethod
48
- def _get_limt(cls, limit_type: str) -> int:
49
- key = f"EDSL_SERVICE_{limit_type.upper()}_{cls._inference_service_.upper()}"
50
- if key in os.environ:
51
- return int(os.getenv(key))
52
-
53
- if cls._coop_config_vars is None or cls._should_refresh_coop_config_vars():
54
- try:
55
- from edsl import Coop
56
-
57
- c = Coop()
58
- cls._coop_config_vars = c.fetch_rate_limit_config_vars()
59
- cls._last_config_fetch = datetime.now()
60
- if key in cls._coop_config_vars:
61
- return cls._coop_config_vars[key]
62
- except Exception:
63
- cls._coop_config_vars = None
64
- else:
65
- if key in cls._coop_config_vars:
66
- return cls._coop_config_vars[key]
67
-
68
- if cls._inference_service_ in cls.default_levels:
69
- return int(cls.default_levels[cls._inference_service_][limit_type])
70
-
71
- return int(CONFIG.get(f"EDSL_SERVICE_{limit_type.upper()}_BASELINE"))
72
-
73
- def get_tpm(cls) -> int:
74
- """
75
- Returns the TPM for the service. If the service is not defined in the environment variables, it will return the baseline TPM.
76
- """
77
- return cls._get_limt(limit_type="tpm")
78
-
79
- def get_rpm(cls):
80
- """
81
- Returns the RPM for the service. If the service is not defined in the environment variables, it will return the baseline RPM.
82
- """
83
- return cls._get_limt(limit_type="rpm")
84
-
85
- @abstractmethod
86
- def available() -> list[str]:
87
- """
88
- Returns a list of available models for the service.
89
- """
90
- pass
91
-
92
- @abstractmethod
93
- def create_model():
94
- """
95
- Returns a LanguageModel object.
96
- """
97
- pass
98
-
99
- @staticmethod
100
- def to_class_name(s):
101
- """
102
- Converts a string to a valid class name.
103
-
104
- >>> InferenceServiceABC.to_class_name("hello world")
105
- 'HelloWorld'
106
- """
107
-
108
- s = re.sub(r"[^a-zA-Z0-9 ]", "", s)
109
- s = "".join(word.title() for word in s.split())
110
- if s and s[0].isdigit():
111
- s = "Class" + s
112
- return s
113
-
114
-
115
- if __name__ == "__main__":
116
- pass
117
- # deep_infra_service = DeepInfraService("deep_infra", "DEEP_INFRA_API_KEY")
118
- # deep_infra_service.available()
119
- # m = deep_infra_service.create_model("microsoft/WizardLM-2-7B")
120
- # response = m().hello()
121
- # print(response)
122
-
123
- # anthropic_service = AnthropicService("anthropic", "ANTHROPIC_API_KEY")
124
- # anthropic_service.available()
125
- # m = anthropic_service.create_model("claude-3-opus-20240229")
126
- # response = m().hello()
127
- # print(response)
128
- # factory = OpenAIService("openai", "OPENAI_API")
129
- # factory.available()
130
- # m = factory.create_model("gpt-3.5-turbo")
131
- # response = m().hello()
132
-
133
- # from edsl import QuestionFreeText
134
- # results = QuestionFreeText.example().by(m()).run()
135
-
136
- # collection = InferenceServicesCollection([
137
- # OpenAIService,
138
- # AnthropicService,
139
- # DeepInfraService
140
- # ])
141
-
142
- # available = collection.available()
143
- # factory = collection.create_model_factory(*available[0])
144
- # m = factory()
145
- # from edsl import QuestionFreeText
146
- # results = QuestionFreeText.example().by(m).run()
147
- # print(results)
1
+ from abc import abstractmethod, ABC
2
+ import re
3
+ from datetime import datetime, timedelta
4
+ from edsl.config import CONFIG
5
+
6
+
7
+ class InferenceServiceABC(ABC):
8
+ """
9
+ Abstract class for inference services.
10
+ """
11
+
12
+ _coop_config_vars = None
13
+
14
+ def __init_subclass__(cls):
15
+ """
16
+ Check that the subclass has the required attributes.
17
+ - `key_sequence` attribute determines...
18
+ - `model_exclude_list` attribute determines...
19
+ """
20
+ must_have_attributes = [
21
+ "key_sequence",
22
+ "model_exclude_list",
23
+ "usage_sequence",
24
+ "input_token_name",
25
+ "output_token_name",
26
+ ]
27
+ for attr in must_have_attributes:
28
+ if not hasattr(cls, attr):
29
+ raise NotImplementedError(
30
+ f"Class {cls.__name__} must have a '{attr}' attribute."
31
+ )
32
+
33
+ @property
34
+ def service_name(self):
35
+ return self._inference_service_
36
+
37
+ @classmethod
38
+ def _should_refresh_coop_config_vars(cls):
39
+ """
40
+ Returns True if config vars have been fetched over 24 hours ago, and False otherwise.
41
+ """
42
+
43
+ if cls._last_config_fetch is None:
44
+ return True
45
+ return (datetime.now() - cls._last_config_fetch) > timedelta(hours=24)
46
+
47
+ @abstractmethod
48
+ def available() -> list[str]:
49
+ """
50
+ Returns a list of available models for the service.
51
+ """
52
+ pass
53
+
54
+ @abstractmethod
55
+ def create_model():
56
+ """
57
+ Returns a LanguageModel object.
58
+ """
59
+ pass
60
+
61
+ @staticmethod
62
+ def to_class_name(s):
63
+ """
64
+ Converts a string to a valid class name.
65
+
66
+ >>> InferenceServiceABC.to_class_name("hello world")
67
+ 'HelloWorld'
68
+ """
69
+
70
+ s = re.sub(r"[^a-zA-Z0-9 ]", "", s)
71
+ s = "".join(word.title() for word in s.split())
72
+ if s and s[0].isdigit():
73
+ s = "Class" + s
74
+ return s
75
+
76
+
77
+ if __name__ == "__main__":
78
+ import doctest
79
+
80
+ doctest.testmod()