edsl 0.1.36.dev2__py3-none-any.whl → 0.1.36.dev6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (257) hide show
  1. edsl/Base.py +303 -298
  2. edsl/BaseDiff.py +260 -260
  3. edsl/TemplateLoader.py +24 -24
  4. edsl/__init__.py +47 -47
  5. edsl/__version__.py +1 -1
  6. edsl/agents/Agent.py +804 -800
  7. edsl/agents/AgentList.py +337 -337
  8. edsl/agents/Invigilator.py +222 -222
  9. edsl/agents/InvigilatorBase.py +294 -294
  10. edsl/agents/PromptConstructor.py +312 -311
  11. edsl/agents/__init__.py +3 -3
  12. edsl/agents/descriptors.py +86 -86
  13. edsl/agents/prompt_helpers.py +129 -129
  14. edsl/auto/AutoStudy.py +117 -117
  15. edsl/auto/StageBase.py +230 -230
  16. edsl/auto/StageGenerateSurvey.py +178 -178
  17. edsl/auto/StageLabelQuestions.py +125 -125
  18. edsl/auto/StagePersona.py +61 -61
  19. edsl/auto/StagePersonaDimensionValueRanges.py +88 -88
  20. edsl/auto/StagePersonaDimensionValues.py +74 -74
  21. edsl/auto/StagePersonaDimensions.py +69 -69
  22. edsl/auto/StageQuestions.py +73 -73
  23. edsl/auto/SurveyCreatorPipeline.py +21 -21
  24. edsl/auto/utilities.py +224 -224
  25. edsl/base/Base.py +289 -289
  26. edsl/config.py +149 -149
  27. edsl/conjure/AgentConstructionMixin.py +152 -152
  28. edsl/conjure/Conjure.py +62 -62
  29. edsl/conjure/InputData.py +659 -659
  30. edsl/conjure/InputDataCSV.py +48 -48
  31. edsl/conjure/InputDataMixinQuestionStats.py +182 -182
  32. edsl/conjure/InputDataPyRead.py +91 -91
  33. edsl/conjure/InputDataSPSS.py +8 -8
  34. edsl/conjure/InputDataStata.py +8 -8
  35. edsl/conjure/QuestionOptionMixin.py +76 -76
  36. edsl/conjure/QuestionTypeMixin.py +23 -23
  37. edsl/conjure/RawQuestion.py +65 -65
  38. edsl/conjure/SurveyResponses.py +7 -7
  39. edsl/conjure/__init__.py +9 -9
  40. edsl/conjure/naming_utilities.py +263 -263
  41. edsl/conjure/utilities.py +201 -201
  42. edsl/conversation/Conversation.py +238 -238
  43. edsl/conversation/car_buying.py +58 -58
  44. edsl/conversation/mug_negotiation.py +81 -81
  45. edsl/conversation/next_speaker_utilities.py +93 -93
  46. edsl/coop/PriceFetcher.py +54 -58
  47. edsl/coop/__init__.py +2 -2
  48. edsl/coop/coop.py +849 -815
  49. edsl/coop/utils.py +131 -131
  50. edsl/data/Cache.py +527 -527
  51. edsl/data/CacheEntry.py +228 -228
  52. edsl/data/CacheHandler.py +149 -149
  53. edsl/data/RemoteCacheSync.py +84 -0
  54. edsl/data/SQLiteDict.py +292 -292
  55. edsl/data/__init__.py +4 -4
  56. edsl/data/orm.py +10 -10
  57. edsl/data_transfer_models.py +73 -73
  58. edsl/enums.py +173 -173
  59. edsl/exceptions/__init__.py +50 -50
  60. edsl/exceptions/agents.py +40 -40
  61. edsl/exceptions/configuration.py +16 -16
  62. edsl/exceptions/coop.py +10 -2
  63. edsl/exceptions/data.py +14 -14
  64. edsl/exceptions/general.py +34 -34
  65. edsl/exceptions/jobs.py +33 -33
  66. edsl/exceptions/language_models.py +63 -63
  67. edsl/exceptions/prompts.py +15 -15
  68. edsl/exceptions/questions.py +91 -91
  69. edsl/exceptions/results.py +26 -26
  70. edsl/exceptions/surveys.py +34 -34
  71. edsl/inference_services/AnthropicService.py +87 -87
  72. edsl/inference_services/AwsBedrock.py +115 -115
  73. edsl/inference_services/AzureAI.py +217 -217
  74. edsl/inference_services/DeepInfraService.py +18 -18
  75. edsl/inference_services/GoogleService.py +156 -156
  76. edsl/inference_services/GroqService.py +20 -20
  77. edsl/inference_services/InferenceServiceABC.py +147 -119
  78. edsl/inference_services/InferenceServicesCollection.py +72 -68
  79. edsl/inference_services/MistralAIService.py +123 -123
  80. edsl/inference_services/OllamaService.py +18 -18
  81. edsl/inference_services/OpenAIService.py +224 -224
  82. edsl/inference_services/TestService.py +89 -89
  83. edsl/inference_services/TogetherAIService.py +170 -170
  84. edsl/inference_services/models_available_cache.py +118 -94
  85. edsl/inference_services/rate_limits_cache.py +25 -25
  86. edsl/inference_services/registry.py +39 -39
  87. edsl/inference_services/write_available.py +10 -10
  88. edsl/jobs/Answers.py +56 -56
  89. edsl/jobs/Jobs.py +1112 -1089
  90. edsl/jobs/__init__.py +1 -1
  91. edsl/jobs/buckets/BucketCollection.py +63 -63
  92. edsl/jobs/buckets/ModelBuckets.py +65 -65
  93. edsl/jobs/buckets/TokenBucket.py +248 -248
  94. edsl/jobs/interviews/Interview.py +651 -633
  95. edsl/jobs/interviews/InterviewExceptionCollection.py +99 -90
  96. edsl/jobs/interviews/InterviewExceptionEntry.py +182 -164
  97. edsl/jobs/interviews/InterviewStatistic.py +63 -63
  98. edsl/jobs/interviews/InterviewStatisticsCollection.py +25 -25
  99. edsl/jobs/interviews/InterviewStatusDictionary.py +78 -78
  100. edsl/jobs/interviews/InterviewStatusLog.py +92 -92
  101. edsl/jobs/interviews/ReportErrors.py +66 -66
  102. edsl/jobs/interviews/interview_status_enum.py +9 -9
  103. edsl/jobs/runners/JobsRunnerAsyncio.py +337 -343
  104. edsl/jobs/runners/JobsRunnerStatus.py +332 -332
  105. edsl/jobs/tasks/QuestionTaskCreator.py +242 -242
  106. edsl/jobs/tasks/TaskCreators.py +64 -64
  107. edsl/jobs/tasks/TaskHistory.py +441 -425
  108. edsl/jobs/tasks/TaskStatusLog.py +23 -23
  109. edsl/jobs/tasks/task_status_enum.py +163 -163
  110. edsl/jobs/tokens/InterviewTokenUsage.py +27 -27
  111. edsl/jobs/tokens/TokenUsage.py +34 -34
  112. edsl/language_models/LanguageModel.py +718 -718
  113. edsl/language_models/ModelList.py +102 -102
  114. edsl/language_models/RegisterLanguageModelsMeta.py +184 -184
  115. edsl/language_models/__init__.py +2 -2
  116. edsl/language_models/fake_openai_call.py +15 -15
  117. edsl/language_models/fake_openai_service.py +61 -61
  118. edsl/language_models/registry.py +137 -137
  119. edsl/language_models/repair.py +156 -156
  120. edsl/language_models/unused/ReplicateBase.py +83 -83
  121. edsl/language_models/utilities.py +64 -64
  122. edsl/notebooks/Notebook.py +259 -259
  123. edsl/notebooks/__init__.py +1 -1
  124. edsl/prompts/Prompt.py +358 -358
  125. edsl/prompts/__init__.py +2 -2
  126. edsl/questions/AnswerValidatorMixin.py +289 -289
  127. edsl/questions/QuestionBase.py +616 -616
  128. edsl/questions/QuestionBaseGenMixin.py +161 -161
  129. edsl/questions/QuestionBasePromptsMixin.py +266 -266
  130. edsl/questions/QuestionBudget.py +227 -227
  131. edsl/questions/QuestionCheckBox.py +359 -359
  132. edsl/questions/QuestionExtract.py +183 -183
  133. edsl/questions/QuestionFreeText.py +113 -113
  134. edsl/questions/QuestionFunctional.py +159 -155
  135. edsl/questions/QuestionList.py +231 -231
  136. edsl/questions/QuestionMultipleChoice.py +286 -286
  137. edsl/questions/QuestionNumerical.py +153 -153
  138. edsl/questions/QuestionRank.py +324 -324
  139. edsl/questions/Quick.py +41 -41
  140. edsl/questions/RegisterQuestionsMeta.py +71 -71
  141. edsl/questions/ResponseValidatorABC.py +174 -174
  142. edsl/questions/SimpleAskMixin.py +73 -73
  143. edsl/questions/__init__.py +26 -26
  144. edsl/questions/compose_questions.py +98 -98
  145. edsl/questions/decorators.py +21 -21
  146. edsl/questions/derived/QuestionLikertFive.py +76 -76
  147. edsl/questions/derived/QuestionLinearScale.py +87 -87
  148. edsl/questions/derived/QuestionTopK.py +91 -91
  149. edsl/questions/derived/QuestionYesNo.py +82 -82
  150. edsl/questions/descriptors.py +418 -418
  151. edsl/questions/prompt_templates/question_budget.jinja +13 -13
  152. edsl/questions/prompt_templates/question_checkbox.jinja +32 -32
  153. edsl/questions/prompt_templates/question_extract.jinja +11 -11
  154. edsl/questions/prompt_templates/question_free_text.jinja +3 -3
  155. edsl/questions/prompt_templates/question_linear_scale.jinja +11 -11
  156. edsl/questions/prompt_templates/question_list.jinja +17 -17
  157. edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -33
  158. edsl/questions/prompt_templates/question_numerical.jinja +36 -36
  159. edsl/questions/question_registry.py +147 -147
  160. edsl/questions/settings.py +12 -12
  161. edsl/questions/templates/budget/answering_instructions.jinja +7 -7
  162. edsl/questions/templates/budget/question_presentation.jinja +7 -7
  163. edsl/questions/templates/checkbox/answering_instructions.jinja +10 -10
  164. edsl/questions/templates/checkbox/question_presentation.jinja +22 -22
  165. edsl/questions/templates/extract/answering_instructions.jinja +7 -7
  166. edsl/questions/templates/likert_five/answering_instructions.jinja +10 -10
  167. edsl/questions/templates/likert_five/question_presentation.jinja +11 -11
  168. edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -5
  169. edsl/questions/templates/linear_scale/question_presentation.jinja +5 -5
  170. edsl/questions/templates/list/answering_instructions.jinja +3 -3
  171. edsl/questions/templates/list/question_presentation.jinja +5 -5
  172. edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -9
  173. edsl/questions/templates/multiple_choice/question_presentation.jinja +11 -11
  174. edsl/questions/templates/numerical/answering_instructions.jinja +6 -6
  175. edsl/questions/templates/numerical/question_presentation.jinja +6 -6
  176. edsl/questions/templates/rank/answering_instructions.jinja +11 -11
  177. edsl/questions/templates/rank/question_presentation.jinja +15 -15
  178. edsl/questions/templates/top_k/answering_instructions.jinja +8 -8
  179. edsl/questions/templates/top_k/question_presentation.jinja +22 -22
  180. edsl/questions/templates/yes_no/answering_instructions.jinja +6 -6
  181. edsl/questions/templates/yes_no/question_presentation.jinja +11 -11
  182. edsl/results/Dataset.py +293 -281
  183. edsl/results/DatasetExportMixin.py +693 -693
  184. edsl/results/DatasetTree.py +145 -145
  185. edsl/results/Result.py +433 -431
  186. edsl/results/Results.py +1158 -1146
  187. edsl/results/ResultsDBMixin.py +238 -238
  188. edsl/results/ResultsExportMixin.py +43 -43
  189. edsl/results/ResultsFetchMixin.py +33 -33
  190. edsl/results/ResultsGGMixin.py +121 -121
  191. edsl/results/ResultsToolsMixin.py +98 -98
  192. edsl/results/Selector.py +118 -118
  193. edsl/results/__init__.py +2 -2
  194. edsl/results/tree_explore.py +115 -115
  195. edsl/scenarios/FileStore.py +443 -443
  196. edsl/scenarios/Scenario.py +507 -496
  197. edsl/scenarios/ScenarioHtmlMixin.py +59 -59
  198. edsl/scenarios/ScenarioList.py +1101 -1101
  199. edsl/scenarios/ScenarioListExportMixin.py +52 -52
  200. edsl/scenarios/ScenarioListPdfMixin.py +261 -261
  201. edsl/scenarios/__init__.py +2 -2
  202. edsl/shared.py +1 -1
  203. edsl/study/ObjectEntry.py +173 -173
  204. edsl/study/ProofOfWork.py +113 -113
  205. edsl/study/SnapShot.py +80 -80
  206. edsl/study/Study.py +528 -528
  207. edsl/study/__init__.py +4 -4
  208. edsl/surveys/DAG.py +148 -148
  209. edsl/surveys/Memory.py +31 -31
  210. edsl/surveys/MemoryPlan.py +244 -244
  211. edsl/surveys/Rule.py +324 -324
  212. edsl/surveys/RuleCollection.py +387 -387
  213. edsl/surveys/Survey.py +1772 -1769
  214. edsl/surveys/SurveyCSS.py +261 -261
  215. edsl/surveys/SurveyExportMixin.py +259 -259
  216. edsl/surveys/SurveyFlowVisualizationMixin.py +121 -121
  217. edsl/surveys/SurveyQualtricsImport.py +284 -284
  218. edsl/surveys/__init__.py +3 -3
  219. edsl/surveys/base.py +53 -53
  220. edsl/surveys/descriptors.py +56 -56
  221. edsl/surveys/instructions/ChangeInstruction.py +47 -47
  222. edsl/surveys/instructions/Instruction.py +51 -34
  223. edsl/surveys/instructions/InstructionCollection.py +77 -77
  224. edsl/templates/error_reporting/base.html +23 -23
  225. edsl/templates/error_reporting/exceptions_by_model.html +34 -34
  226. edsl/templates/error_reporting/exceptions_by_question_name.html +16 -16
  227. edsl/templates/error_reporting/exceptions_by_type.html +16 -16
  228. edsl/templates/error_reporting/interview_details.html +115 -115
  229. edsl/templates/error_reporting/interviews.html +9 -9
  230. edsl/templates/error_reporting/overview.html +4 -4
  231. edsl/templates/error_reporting/performance_plot.html +1 -1
  232. edsl/templates/error_reporting/report.css +73 -73
  233. edsl/templates/error_reporting/report.html +117 -117
  234. edsl/templates/error_reporting/report.js +25 -25
  235. edsl/tools/__init__.py +1 -1
  236. edsl/tools/clusters.py +192 -192
  237. edsl/tools/embeddings.py +27 -27
  238. edsl/tools/embeddings_plotting.py +118 -118
  239. edsl/tools/plotting.py +112 -112
  240. edsl/tools/summarize.py +18 -18
  241. edsl/utilities/SystemInfo.py +28 -28
  242. edsl/utilities/__init__.py +22 -22
  243. edsl/utilities/ast_utilities.py +25 -25
  244. edsl/utilities/data/Registry.py +6 -6
  245. edsl/utilities/data/__init__.py +1 -1
  246. edsl/utilities/data/scooter_results.json +1 -1
  247. edsl/utilities/decorators.py +77 -77
  248. edsl/utilities/gcp_bucket/cloud_storage.py +96 -96
  249. edsl/utilities/interface.py +627 -627
  250. edsl/utilities/repair_functions.py +28 -28
  251. edsl/utilities/restricted_python.py +70 -70
  252. edsl/utilities/utilities.py +391 -391
  253. {edsl-0.1.36.dev2.dist-info → edsl-0.1.36.dev6.dist-info}/LICENSE +21 -21
  254. {edsl-0.1.36.dev2.dist-info → edsl-0.1.36.dev6.dist-info}/METADATA +1 -1
  255. edsl-0.1.36.dev6.dist-info/RECORD +279 -0
  256. edsl-0.1.36.dev2.dist-info/RECORD +0 -278
  257. {edsl-0.1.36.dev2.dist-info → edsl-0.1.36.dev6.dist-info}/WHEEL +0 -0
@@ -1,156 +1,156 @@
1
- import os
2
- from typing import Any, Dict, List, Optional
3
- import google
4
- import google.generativeai as genai
5
- from google.generativeai.types import GenerationConfig
6
- from google.api_core.exceptions import InvalidArgument
7
-
8
- from edsl.exceptions import MissingAPIKeyError
9
- from edsl.language_models.LanguageModel import LanguageModel
10
- from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
11
-
12
- safety_settings = [
13
- {
14
- "category": "HARM_CATEGORY_HARASSMENT",
15
- "threshold": "BLOCK_NONE",
16
- },
17
- {
18
- "category": "HARM_CATEGORY_HATE_SPEECH",
19
- "threshold": "BLOCK_NONE",
20
- },
21
- {
22
- "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
23
- "threshold": "BLOCK_NONE",
24
- },
25
- {
26
- "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
27
- "threshold": "BLOCK_NONE",
28
- },
29
- ]
30
-
31
-
32
- class GoogleService(InferenceServiceABC):
33
- _inference_service_ = "google"
34
- key_sequence = ["candidates", 0, "content", "parts", 0, "text"]
35
- usage_sequence = ["usage_metadata"]
36
- input_token_name = "prompt_token_count"
37
- output_token_name = "candidates_token_count"
38
-
39
- model_exclude_list = []
40
-
41
- # @classmethod
42
- # def available(cls) -> List[str]:
43
- # return ["gemini-pro", "gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.0-pro"]
44
-
45
- @classmethod
46
- def available(cls) -> List[str]:
47
- model_list = []
48
- for m in genai.list_models():
49
- if "generateContent" in m.supported_generation_methods:
50
- model_list.append(m.name.split("/")[-1])
51
- return model_list
52
-
53
- @classmethod
54
- def create_model(
55
- cls, model_name: str = "gemini-pro", model_class_name=None
56
- ) -> LanguageModel:
57
- if model_class_name is None:
58
- model_class_name = cls.to_class_name(model_name)
59
-
60
- class LLM(LanguageModel):
61
- _model_ = model_name
62
- key_sequence = cls.key_sequence
63
- usage_sequence = cls.usage_sequence
64
- input_token_name = cls.input_token_name
65
- output_token_name = cls.output_token_name
66
- _inference_service_ = cls._inference_service_
67
-
68
- _tpm = cls.get_tpm(cls)
69
- _rpm = cls.get_rpm(cls)
70
-
71
- _parameters_ = {
72
- "temperature": 0.5,
73
- "topP": 1,
74
- "topK": 1,
75
- "maxOutputTokens": 2048,
76
- "stopSequences": [],
77
- }
78
-
79
- api_token = None
80
- model = None
81
-
82
- @classmethod
83
- def initialize(cls):
84
- if cls.api_token is None:
85
- cls.api_token = os.getenv("GOOGLE_API_KEY")
86
- if not cls.api_token:
87
- raise MissingAPIKeyError(
88
- "GOOGLE_API_KEY environment variable is not set"
89
- )
90
- genai.configure(api_key=cls.api_token)
91
- cls.generative_model = genai.GenerativeModel(
92
- cls._model_, safety_settings=safety_settings
93
- )
94
-
95
- def __init__(self, *args, **kwargs):
96
- super().__init__(*args, **kwargs)
97
- self.initialize()
98
-
99
- def get_generation_config(self) -> GenerationConfig:
100
- return GenerationConfig(
101
- temperature=self.temperature,
102
- top_p=self.topP,
103
- top_k=self.topK,
104
- max_output_tokens=self.maxOutputTokens,
105
- stop_sequences=self.stopSequences,
106
- )
107
-
108
- async def async_execute_model_call(
109
- self,
110
- user_prompt: str,
111
- system_prompt: str = "",
112
- files_list: Optional["Files"] = None,
113
- ) -> Dict[str, Any]:
114
- generation_config = self.get_generation_config()
115
-
116
- if files_list is None:
117
- files_list = []
118
-
119
- if (
120
- system_prompt is not None
121
- and system_prompt != ""
122
- and self._model_ != "gemini-pro"
123
- ):
124
- try:
125
- self.generative_model = genai.GenerativeModel(
126
- self._model_,
127
- safety_settings=safety_settings,
128
- system_instruction=system_prompt,
129
- )
130
- except InvalidArgument as e:
131
- print(
132
- f"This model, {self._model_}, does not support system_instruction"
133
- )
134
- print("Will add system_prompt to user_prompt")
135
- user_prompt = f"{system_prompt}\n{user_prompt}"
136
-
137
- combined_prompt = [user_prompt]
138
- for file in files_list:
139
- if "google" not in file.external_locations:
140
- _ = file.upload_google()
141
- gen_ai_file = google.generativeai.types.file_types.File(
142
- file.external_locations["google"]
143
- )
144
- combined_prompt.append(gen_ai_file)
145
-
146
- response = await self.generative_model.generate_content_async(
147
- combined_prompt, generation_config=generation_config
148
- )
149
- return response.to_dict()
150
-
151
- LLM.__name__ = model_name
152
- return LLM
153
-
154
-
155
- if __name__ == "__main__":
156
- pass
1
+ import os
2
+ from typing import Any, Dict, List, Optional
3
+ import google
4
+ import google.generativeai as genai
5
+ from google.generativeai.types import GenerationConfig
6
+ from google.api_core.exceptions import InvalidArgument
7
+
8
+ from edsl.exceptions import MissingAPIKeyError
9
+ from edsl.language_models.LanguageModel import LanguageModel
10
+ from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
11
+
12
+ safety_settings = [
13
+ {
14
+ "category": "HARM_CATEGORY_HARASSMENT",
15
+ "threshold": "BLOCK_NONE",
16
+ },
17
+ {
18
+ "category": "HARM_CATEGORY_HATE_SPEECH",
19
+ "threshold": "BLOCK_NONE",
20
+ },
21
+ {
22
+ "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
23
+ "threshold": "BLOCK_NONE",
24
+ },
25
+ {
26
+ "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
27
+ "threshold": "BLOCK_NONE",
28
+ },
29
+ ]
30
+
31
+
32
+ class GoogleService(InferenceServiceABC):
33
+ _inference_service_ = "google"
34
+ key_sequence = ["candidates", 0, "content", "parts", 0, "text"]
35
+ usage_sequence = ["usage_metadata"]
36
+ input_token_name = "prompt_token_count"
37
+ output_token_name = "candidates_token_count"
38
+
39
+ model_exclude_list = []
40
+
41
+ # @classmethod
42
+ # def available(cls) -> List[str]:
43
+ # return ["gemini-pro", "gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.0-pro"]
44
+
45
+ @classmethod
46
+ def available(cls) -> List[str]:
47
+ model_list = []
48
+ for m in genai.list_models():
49
+ if "generateContent" in m.supported_generation_methods:
50
+ model_list.append(m.name.split("/")[-1])
51
+ return model_list
52
+
53
+ @classmethod
54
+ def create_model(
55
+ cls, model_name: str = "gemini-pro", model_class_name=None
56
+ ) -> LanguageModel:
57
+ if model_class_name is None:
58
+ model_class_name = cls.to_class_name(model_name)
59
+
60
+ class LLM(LanguageModel):
61
+ _model_ = model_name
62
+ key_sequence = cls.key_sequence
63
+ usage_sequence = cls.usage_sequence
64
+ input_token_name = cls.input_token_name
65
+ output_token_name = cls.output_token_name
66
+ _inference_service_ = cls._inference_service_
67
+
68
+ _tpm = cls.get_tpm(cls)
69
+ _rpm = cls.get_rpm(cls)
70
+
71
+ _parameters_ = {
72
+ "temperature": 0.5,
73
+ "topP": 1,
74
+ "topK": 1,
75
+ "maxOutputTokens": 2048,
76
+ "stopSequences": [],
77
+ }
78
+
79
+ api_token = None
80
+ model = None
81
+
82
+ @classmethod
83
+ def initialize(cls):
84
+ if cls.api_token is None:
85
+ cls.api_token = os.getenv("GOOGLE_API_KEY")
86
+ if not cls.api_token:
87
+ raise MissingAPIKeyError(
88
+ "GOOGLE_API_KEY environment variable is not set"
89
+ )
90
+ genai.configure(api_key=cls.api_token)
91
+ cls.generative_model = genai.GenerativeModel(
92
+ cls._model_, safety_settings=safety_settings
93
+ )
94
+
95
+ def __init__(self, *args, **kwargs):
96
+ super().__init__(*args, **kwargs)
97
+ self.initialize()
98
+
99
+ def get_generation_config(self) -> GenerationConfig:
100
+ return GenerationConfig(
101
+ temperature=self.temperature,
102
+ top_p=self.topP,
103
+ top_k=self.topK,
104
+ max_output_tokens=self.maxOutputTokens,
105
+ stop_sequences=self.stopSequences,
106
+ )
107
+
108
+ async def async_execute_model_call(
109
+ self,
110
+ user_prompt: str,
111
+ system_prompt: str = "",
112
+ files_list: Optional["Files"] = None,
113
+ ) -> Dict[str, Any]:
114
+ generation_config = self.get_generation_config()
115
+
116
+ if files_list is None:
117
+ files_list = []
118
+
119
+ if (
120
+ system_prompt is not None
121
+ and system_prompt != ""
122
+ and self._model_ != "gemini-pro"
123
+ ):
124
+ try:
125
+ self.generative_model = genai.GenerativeModel(
126
+ self._model_,
127
+ safety_settings=safety_settings,
128
+ system_instruction=system_prompt,
129
+ )
130
+ except InvalidArgument as e:
131
+ print(
132
+ f"This model, {self._model_}, does not support system_instruction"
133
+ )
134
+ print("Will add system_prompt to user_prompt")
135
+ user_prompt = f"{system_prompt}\n{user_prompt}"
136
+
137
+ combined_prompt = [user_prompt]
138
+ for file in files_list:
139
+ if "google" not in file.external_locations:
140
+ _ = file.upload_google()
141
+ gen_ai_file = google.generativeai.types.file_types.File(
142
+ file.external_locations["google"]
143
+ )
144
+ combined_prompt.append(gen_ai_file)
145
+
146
+ response = await self.generative_model.generate_content_async(
147
+ combined_prompt, generation_config=generation_config
148
+ )
149
+ return response.to_dict()
150
+
151
+ LLM.__name__ = model_name
152
+ return LLM
153
+
154
+
155
+ if __name__ == "__main__":
156
+ pass
@@ -1,20 +1,20 @@
1
- from typing import Any, List
2
- from edsl.inference_services.OpenAIService import OpenAIService
3
-
4
- import groq
5
-
6
-
7
- class GroqService(OpenAIService):
8
- """DeepInfra service class."""
9
-
10
- _inference_service_ = "groq"
11
- _env_key_name_ = "GROQ_API_KEY"
12
-
13
- _sync_client_ = groq.Groq
14
- _async_client_ = groq.AsyncGroq
15
-
16
- model_exclude_list = ["whisper-large-v3", "distil-whisper-large-v3-en"]
17
-
18
- # _base_url_ = "https://api.deepinfra.com/v1/openai"
19
- _base_url_ = None
20
- _models_list_cache: List[str] = []
1
+ from typing import Any, List
2
+ from edsl.inference_services.OpenAIService import OpenAIService
3
+
4
+ import groq
5
+
6
+
7
+ class GroqService(OpenAIService):
8
+ """DeepInfra service class."""
9
+
10
+ _inference_service_ = "groq"
11
+ _env_key_name_ = "GROQ_API_KEY"
12
+
13
+ _sync_client_ = groq.Groq
14
+ _async_client_ = groq.AsyncGroq
15
+
16
+ model_exclude_list = ["whisper-large-v3", "distil-whisper-large-v3-en"]
17
+
18
+ # _base_url_ = "https://api.deepinfra.com/v1/openai"
19
+ _base_url_ = None
20
+ _models_list_cache: List[str] = []
@@ -1,119 +1,147 @@
1
- from abc import abstractmethod, ABC
2
- import os
3
- import re
4
- from edsl.config import CONFIG
5
-
6
-
7
- class InferenceServiceABC(ABC):
8
- """
9
- Abstract class for inference services.
10
- Anthropic: https://docs.anthropic.com/en/api/rate-limits
11
- """
12
-
13
- default_levels = {
14
- "google": {"tpm": 2_000_000, "rpm": 15},
15
- "openai": {"tpm": 2_000_000, "rpm": 10_000},
16
- "anthropic": {"tpm": 2_000_000, "rpm": 500},
17
- }
18
-
19
- def __init_subclass__(cls):
20
- """
21
- Check that the subclass has the required attributes.
22
- - `key_sequence` attribute determines...
23
- - `model_exclude_list` attribute determines...
24
- """
25
- if not hasattr(cls, "key_sequence"):
26
- raise NotImplementedError(
27
- f"Class {cls.__name__} must have a 'key_sequence' attribute."
28
- )
29
- if not hasattr(cls, "model_exclude_list"):
30
- raise NotImplementedError(
31
- f"Class {cls.__name__} must have a 'model_exclude_list' attribute."
32
- )
33
-
34
- @classmethod
35
- def _get_limt(cls, limit_type: str) -> int:
36
- key = f"EDSL_SERVICE_{limit_type.upper()}_{cls._inference_service_.upper()}"
37
- if key in os.environ:
38
- return int(os.getenv(key))
39
-
40
- if cls._inference_service_ in cls.default_levels:
41
- return int(cls.default_levels[cls._inference_service_][limit_type])
42
-
43
- return int(CONFIG.get(f"EDSL_SERVICE_{limit_type.upper()}_BASELINE"))
44
-
45
- def get_tpm(cls) -> int:
46
- """
47
- Returns the TPM for the service. If the service is not defined in the environment variables, it will return the baseline TPM.
48
- """
49
- return cls._get_limt(limit_type="tpm")
50
-
51
- def get_rpm(cls):
52
- """
53
- Returns the RPM for the service. If the service is not defined in the environment variables, it will return the baseline RPM.
54
- """
55
- return cls._get_limt(limit_type="rpm")
56
-
57
- @abstractmethod
58
- def available() -> list[str]:
59
- """
60
- Returns a list of available models for the service.
61
- """
62
- pass
63
-
64
- @abstractmethod
65
- def create_model():
66
- """
67
- Returns a LanguageModel object.
68
- """
69
- pass
70
-
71
- @staticmethod
72
- def to_class_name(s):
73
- """
74
- Converts a string to a valid class name.
75
-
76
- >>> InferenceServiceABC.to_class_name("hello world")
77
- 'HelloWorld'
78
- """
79
-
80
- s = re.sub(r"[^a-zA-Z0-9 ]", "", s)
81
- s = "".join(word.title() for word in s.split())
82
- if s and s[0].isdigit():
83
- s = "Class" + s
84
- return s
85
-
86
-
87
- if __name__ == "__main__":
88
- pass
89
- # deep_infra_service = DeepInfraService("deep_infra", "DEEP_INFRA_API_KEY")
90
- # deep_infra_service.available()
91
- # m = deep_infra_service.create_model("microsoft/WizardLM-2-7B")
92
- # response = m().hello()
93
- # print(response)
94
-
95
- # anthropic_service = AnthropicService("anthropic", "ANTHROPIC_API_KEY")
96
- # anthropic_service.available()
97
- # m = anthropic_service.create_model("claude-3-opus-20240229")
98
- # response = m().hello()
99
- # print(response)
100
- # factory = OpenAIService("openai", "OPENAI_API")
101
- # factory.available()
102
- # m = factory.create_model("gpt-3.5-turbo")
103
- # response = m().hello()
104
-
105
- # from edsl import QuestionFreeText
106
- # results = QuestionFreeText.example().by(m()).run()
107
-
108
- # collection = InferenceServicesCollection([
109
- # OpenAIService,
110
- # AnthropicService,
111
- # DeepInfraService
112
- # ])
113
-
114
- # available = collection.available()
115
- # factory = collection.create_model_factory(*available[0])
116
- # m = factory()
117
- # from edsl import QuestionFreeText
118
- # results = QuestionFreeText.example().by(m).run()
119
- # print(results)
1
+ from abc import abstractmethod, ABC
2
+ import os
3
+ import re
4
+ from datetime import datetime, timedelta
5
+ from edsl.config import CONFIG
6
+
7
+
8
+ class InferenceServiceABC(ABC):
9
+ """
10
+ Abstract class for inference services.
11
+ Anthropic: https://docs.anthropic.com/en/api/rate-limits
12
+ """
13
+
14
+ _coop_config_vars = None
15
+
16
+ default_levels = {
17
+ "google": {"tpm": 2_000_000, "rpm": 15},
18
+ "openai": {"tpm": 2_000_000, "rpm": 10_000},
19
+ "anthropic": {"tpm": 2_000_000, "rpm": 500},
20
+ }
21
+
22
+ def __init_subclass__(cls):
23
+ """
24
+ Check that the subclass has the required attributes.
25
+ - `key_sequence` attribute determines...
26
+ - `model_exclude_list` attribute determines...
27
+ """
28
+ if not hasattr(cls, "key_sequence"):
29
+ raise NotImplementedError(
30
+ f"Class {cls.__name__} must have a 'key_sequence' attribute."
31
+ )
32
+ if not hasattr(cls, "model_exclude_list"):
33
+ raise NotImplementedError(
34
+ f"Class {cls.__name__} must have a 'model_exclude_list' attribute."
35
+ )
36
+
37
+ @classmethod
38
+ def _should_refresh_coop_config_vars(cls):
39
+ """
40
+ Returns True if config vars have been fetched over 24 hours ago, and False otherwise.
41
+ """
42
+
43
+ if cls._last_config_fetch is None:
44
+ return True
45
+ return (datetime.now() - cls._last_config_fetch) > timedelta(hours=24)
46
+
47
+ @classmethod
48
+ def _get_limt(cls, limit_type: str) -> int:
49
+ key = f"EDSL_SERVICE_{limit_type.upper()}_{cls._inference_service_.upper()}"
50
+ if key in os.environ:
51
+ return int(os.getenv(key))
52
+
53
+ if cls._coop_config_vars is None or cls._should_refresh_coop_config_vars():
54
+ try:
55
+ from edsl import Coop
56
+
57
+ c = Coop()
58
+ cls._coop_config_vars = c.fetch_rate_limit_config_vars()
59
+ cls._last_config_fetch = datetime.now()
60
+ if key in cls._coop_config_vars:
61
+ return cls._coop_config_vars[key]
62
+ except Exception:
63
+ cls._coop_config_vars = None
64
+ else:
65
+ if key in cls._coop_config_vars:
66
+ return cls._coop_config_vars[key]
67
+
68
+ if cls._inference_service_ in cls.default_levels:
69
+ return int(cls.default_levels[cls._inference_service_][limit_type])
70
+
71
+ return int(CONFIG.get(f"EDSL_SERVICE_{limit_type.upper()}_BASELINE"))
72
+
73
+ def get_tpm(cls) -> int:
74
+ """
75
+ Returns the TPM for the service. If the service is not defined in the environment variables, it will return the baseline TPM.
76
+ """
77
+ return cls._get_limt(limit_type="tpm")
78
+
79
+ def get_rpm(cls):
80
+ """
81
+ Returns the RPM for the service. If the service is not defined in the environment variables, it will return the baseline RPM.
82
+ """
83
+ return cls._get_limt(limit_type="rpm")
84
+
85
+ @abstractmethod
86
+ def available() -> list[str]:
87
+ """
88
+ Returns a list of available models for the service.
89
+ """
90
+ pass
91
+
92
+ @abstractmethod
93
+ def create_model():
94
+ """
95
+ Returns a LanguageModel object.
96
+ """
97
+ pass
98
+
99
+ @staticmethod
100
+ def to_class_name(s):
101
+ """
102
+ Converts a string to a valid class name.
103
+
104
+ >>> InferenceServiceABC.to_class_name("hello world")
105
+ 'HelloWorld'
106
+ """
107
+
108
+ s = re.sub(r"[^a-zA-Z0-9 ]", "", s)
109
+ s = "".join(word.title() for word in s.split())
110
+ if s and s[0].isdigit():
111
+ s = "Class" + s
112
+ return s
113
+
114
+
115
+ if __name__ == "__main__":
116
+ pass
117
+ # deep_infra_service = DeepInfraService("deep_infra", "DEEP_INFRA_API_KEY")
118
+ # deep_infra_service.available()
119
+ # m = deep_infra_service.create_model("microsoft/WizardLM-2-7B")
120
+ # response = m().hello()
121
+ # print(response)
122
+
123
+ # anthropic_service = AnthropicService("anthropic", "ANTHROPIC_API_KEY")
124
+ # anthropic_service.available()
125
+ # m = anthropic_service.create_model("claude-3-opus-20240229")
126
+ # response = m().hello()
127
+ # print(response)
128
+ # factory = OpenAIService("openai", "OPENAI_API")
129
+ # factory.available()
130
+ # m = factory.create_model("gpt-3.5-turbo")
131
+ # response = m().hello()
132
+
133
+ # from edsl import QuestionFreeText
134
+ # results = QuestionFreeText.example().by(m()).run()
135
+
136
+ # collection = InferenceServicesCollection([
137
+ # OpenAIService,
138
+ # AnthropicService,
139
+ # DeepInfraService
140
+ # ])
141
+
142
+ # available = collection.available()
143
+ # factory = collection.create_model_factory(*available[0])
144
+ # m = factory()
145
+ # from edsl import QuestionFreeText
146
+ # results = QuestionFreeText.example().by(m).run()
147
+ # print(results)