edsl 0.1.37.dev4__py3-none-any.whl → 0.1.37.dev5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (261) hide show
  1. edsl/Base.py +303 -303
  2. edsl/BaseDiff.py +260 -260
  3. edsl/TemplateLoader.py +24 -24
  4. edsl/__init__.py +48 -48
  5. edsl/__version__.py +1 -1
  6. edsl/agents/Agent.py +855 -804
  7. edsl/agents/AgentList.py +350 -345
  8. edsl/agents/Invigilator.py +222 -222
  9. edsl/agents/InvigilatorBase.py +284 -305
  10. edsl/agents/PromptConstructor.py +353 -312
  11. edsl/agents/__init__.py +3 -3
  12. edsl/agents/descriptors.py +99 -86
  13. edsl/agents/prompt_helpers.py +129 -129
  14. edsl/auto/AutoStudy.py +117 -117
  15. edsl/auto/StageBase.py +230 -230
  16. edsl/auto/StageGenerateSurvey.py +178 -178
  17. edsl/auto/StageLabelQuestions.py +125 -125
  18. edsl/auto/StagePersona.py +61 -61
  19. edsl/auto/StagePersonaDimensionValueRanges.py +88 -88
  20. edsl/auto/StagePersonaDimensionValues.py +74 -74
  21. edsl/auto/StagePersonaDimensions.py +69 -69
  22. edsl/auto/StageQuestions.py +73 -73
  23. edsl/auto/SurveyCreatorPipeline.py +21 -21
  24. edsl/auto/utilities.py +224 -224
  25. edsl/base/Base.py +289 -289
  26. edsl/config.py +149 -149
  27. edsl/conjure/AgentConstructionMixin.py +160 -152
  28. edsl/conjure/Conjure.py +62 -62
  29. edsl/conjure/InputData.py +659 -659
  30. edsl/conjure/InputDataCSV.py +48 -48
  31. edsl/conjure/InputDataMixinQuestionStats.py +182 -182
  32. edsl/conjure/InputDataPyRead.py +91 -91
  33. edsl/conjure/InputDataSPSS.py +8 -8
  34. edsl/conjure/InputDataStata.py +8 -8
  35. edsl/conjure/QuestionOptionMixin.py +76 -76
  36. edsl/conjure/QuestionTypeMixin.py +23 -23
  37. edsl/conjure/RawQuestion.py +65 -65
  38. edsl/conjure/SurveyResponses.py +7 -7
  39. edsl/conjure/__init__.py +9 -9
  40. edsl/conjure/naming_utilities.py +263 -263
  41. edsl/conjure/utilities.py +201 -201
  42. edsl/conversation/Conversation.py +290 -238
  43. edsl/conversation/car_buying.py +58 -58
  44. edsl/conversation/chips.py +95 -0
  45. edsl/conversation/mug_negotiation.py +81 -81
  46. edsl/conversation/next_speaker_utilities.py +93 -93
  47. edsl/coop/PriceFetcher.py +54 -54
  48. edsl/coop/__init__.py +2 -2
  49. edsl/coop/coop.py +958 -827
  50. edsl/coop/utils.py +131 -131
  51. edsl/data/Cache.py +527 -527
  52. edsl/data/CacheEntry.py +228 -228
  53. edsl/data/CacheHandler.py +149 -149
  54. edsl/data/RemoteCacheSync.py +97 -97
  55. edsl/data/SQLiteDict.py +292 -292
  56. edsl/data/__init__.py +4 -4
  57. edsl/data/orm.py +10 -10
  58. edsl/data_transfer_models.py +73 -73
  59. edsl/enums.py +173 -173
  60. edsl/exceptions/BaseException.py +21 -0
  61. edsl/exceptions/__init__.py +54 -50
  62. edsl/exceptions/agents.py +38 -40
  63. edsl/exceptions/configuration.py +16 -16
  64. edsl/exceptions/coop.py +10 -10
  65. edsl/exceptions/data.py +14 -14
  66. edsl/exceptions/general.py +34 -34
  67. edsl/exceptions/jobs.py +33 -33
  68. edsl/exceptions/language_models.py +63 -63
  69. edsl/exceptions/prompts.py +15 -15
  70. edsl/exceptions/questions.py +91 -91
  71. edsl/exceptions/results.py +29 -26
  72. edsl/exceptions/scenarios.py +22 -0
  73. edsl/exceptions/surveys.py +37 -34
  74. edsl/inference_services/AnthropicService.py +87 -87
  75. edsl/inference_services/AwsBedrock.py +120 -120
  76. edsl/inference_services/AzureAI.py +217 -217
  77. edsl/inference_services/DeepInfraService.py +18 -18
  78. edsl/inference_services/GoogleService.py +156 -156
  79. edsl/inference_services/GroqService.py +20 -20
  80. edsl/inference_services/InferenceServiceABC.py +147 -147
  81. edsl/inference_services/InferenceServicesCollection.py +97 -74
  82. edsl/inference_services/MistralAIService.py +123 -123
  83. edsl/inference_services/OllamaService.py +18 -18
  84. edsl/inference_services/OpenAIService.py +224 -224
  85. edsl/inference_services/TestService.py +89 -89
  86. edsl/inference_services/TogetherAIService.py +170 -170
  87. edsl/inference_services/models_available_cache.py +118 -118
  88. edsl/inference_services/rate_limits_cache.py +25 -25
  89. edsl/inference_services/registry.py +39 -39
  90. edsl/inference_services/write_available.py +10 -10
  91. edsl/jobs/Answers.py +56 -56
  92. edsl/jobs/Jobs.py +1347 -1135
  93. edsl/jobs/__init__.py +1 -1
  94. edsl/jobs/buckets/BucketCollection.py +63 -63
  95. edsl/jobs/buckets/ModelBuckets.py +65 -65
  96. edsl/jobs/buckets/TokenBucket.py +248 -248
  97. edsl/jobs/interviews/Interview.py +661 -661
  98. edsl/jobs/interviews/InterviewExceptionCollection.py +99 -99
  99. edsl/jobs/interviews/InterviewExceptionEntry.py +186 -182
  100. edsl/jobs/interviews/InterviewStatistic.py +63 -63
  101. edsl/jobs/interviews/InterviewStatisticsCollection.py +25 -25
  102. edsl/jobs/interviews/InterviewStatusDictionary.py +78 -78
  103. edsl/jobs/interviews/InterviewStatusLog.py +92 -92
  104. edsl/jobs/interviews/ReportErrors.py +66 -66
  105. edsl/jobs/interviews/interview_status_enum.py +9 -9
  106. edsl/jobs/runners/JobsRunnerAsyncio.py +338 -338
  107. edsl/jobs/runners/JobsRunnerStatus.py +332 -332
  108. edsl/jobs/tasks/QuestionTaskCreator.py +242 -242
  109. edsl/jobs/tasks/TaskCreators.py +64 -64
  110. edsl/jobs/tasks/TaskHistory.py +442 -441
  111. edsl/jobs/tasks/TaskStatusLog.py +23 -23
  112. edsl/jobs/tasks/task_status_enum.py +163 -163
  113. edsl/jobs/tokens/InterviewTokenUsage.py +27 -27
  114. edsl/jobs/tokens/TokenUsage.py +34 -34
  115. edsl/language_models/KeyLookup.py +30 -0
  116. edsl/language_models/LanguageModel.py +706 -718
  117. edsl/language_models/ModelList.py +102 -102
  118. edsl/language_models/RegisterLanguageModelsMeta.py +184 -184
  119. edsl/language_models/__init__.py +3 -2
  120. edsl/language_models/fake_openai_call.py +15 -15
  121. edsl/language_models/fake_openai_service.py +61 -61
  122. edsl/language_models/registry.py +137 -137
  123. edsl/language_models/repair.py +156 -156
  124. edsl/language_models/unused/ReplicateBase.py +83 -83
  125. edsl/language_models/utilities.py +64 -64
  126. edsl/notebooks/Notebook.py +259 -259
  127. edsl/notebooks/__init__.py +1 -1
  128. edsl/prompts/Prompt.py +357 -353
  129. edsl/prompts/__init__.py +2 -2
  130. edsl/questions/AnswerValidatorMixin.py +289 -289
  131. edsl/questions/QuestionBase.py +656 -616
  132. edsl/questions/QuestionBaseGenMixin.py +161 -161
  133. edsl/questions/QuestionBasePromptsMixin.py +234 -266
  134. edsl/questions/QuestionBudget.py +227 -227
  135. edsl/questions/QuestionCheckBox.py +359 -359
  136. edsl/questions/QuestionExtract.py +183 -183
  137. edsl/questions/QuestionFreeText.py +114 -114
  138. edsl/questions/QuestionFunctional.py +159 -159
  139. edsl/questions/QuestionList.py +231 -231
  140. edsl/questions/QuestionMultipleChoice.py +286 -286
  141. edsl/questions/QuestionNumerical.py +153 -153
  142. edsl/questions/QuestionRank.py +324 -324
  143. edsl/questions/Quick.py +41 -41
  144. edsl/questions/RegisterQuestionsMeta.py +71 -71
  145. edsl/questions/ResponseValidatorABC.py +174 -174
  146. edsl/questions/SimpleAskMixin.py +73 -73
  147. edsl/questions/__init__.py +26 -26
  148. edsl/questions/compose_questions.py +98 -98
  149. edsl/questions/decorators.py +21 -21
  150. edsl/questions/derived/QuestionLikertFive.py +76 -76
  151. edsl/questions/derived/QuestionLinearScale.py +87 -87
  152. edsl/questions/derived/QuestionTopK.py +91 -91
  153. edsl/questions/derived/QuestionYesNo.py +82 -82
  154. edsl/questions/descriptors.py +413 -418
  155. edsl/questions/prompt_templates/question_budget.jinja +13 -13
  156. edsl/questions/prompt_templates/question_checkbox.jinja +32 -32
  157. edsl/questions/prompt_templates/question_extract.jinja +11 -11
  158. edsl/questions/prompt_templates/question_free_text.jinja +3 -3
  159. edsl/questions/prompt_templates/question_linear_scale.jinja +11 -11
  160. edsl/questions/prompt_templates/question_list.jinja +17 -17
  161. edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -33
  162. edsl/questions/prompt_templates/question_numerical.jinja +36 -36
  163. edsl/questions/question_registry.py +147 -147
  164. edsl/questions/settings.py +12 -12
  165. edsl/questions/templates/budget/answering_instructions.jinja +7 -7
  166. edsl/questions/templates/budget/question_presentation.jinja +7 -7
  167. edsl/questions/templates/checkbox/answering_instructions.jinja +10 -10
  168. edsl/questions/templates/checkbox/question_presentation.jinja +22 -22
  169. edsl/questions/templates/extract/answering_instructions.jinja +7 -7
  170. edsl/questions/templates/likert_five/answering_instructions.jinja +10 -10
  171. edsl/questions/templates/likert_five/question_presentation.jinja +11 -11
  172. edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -5
  173. edsl/questions/templates/linear_scale/question_presentation.jinja +5 -5
  174. edsl/questions/templates/list/answering_instructions.jinja +3 -3
  175. edsl/questions/templates/list/question_presentation.jinja +5 -5
  176. edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -9
  177. edsl/questions/templates/multiple_choice/question_presentation.jinja +11 -11
  178. edsl/questions/templates/numerical/answering_instructions.jinja +6 -6
  179. edsl/questions/templates/numerical/question_presentation.jinja +6 -6
  180. edsl/questions/templates/rank/answering_instructions.jinja +11 -11
  181. edsl/questions/templates/rank/question_presentation.jinja +15 -15
  182. edsl/questions/templates/top_k/answering_instructions.jinja +8 -8
  183. edsl/questions/templates/top_k/question_presentation.jinja +22 -22
  184. edsl/questions/templates/yes_no/answering_instructions.jinja +6 -6
  185. edsl/questions/templates/yes_no/question_presentation.jinja +11 -11
  186. edsl/results/Dataset.py +293 -293
  187. edsl/results/DatasetExportMixin.py +717 -693
  188. edsl/results/DatasetTree.py +145 -145
  189. edsl/results/Result.py +450 -435
  190. edsl/results/Results.py +1071 -1160
  191. edsl/results/ResultsDBMixin.py +238 -238
  192. edsl/results/ResultsExportMixin.py +43 -43
  193. edsl/results/ResultsFetchMixin.py +33 -33
  194. edsl/results/ResultsGGMixin.py +121 -121
  195. edsl/results/ResultsToolsMixin.py +98 -98
  196. edsl/results/Selector.py +135 -118
  197. edsl/results/__init__.py +2 -2
  198. edsl/results/tree_explore.py +115 -115
  199. edsl/scenarios/FileStore.py +458 -458
  200. edsl/scenarios/Scenario.py +546 -510
  201. edsl/scenarios/ScenarioHtmlMixin.py +64 -59
  202. edsl/scenarios/ScenarioList.py +1112 -1101
  203. edsl/scenarios/ScenarioListExportMixin.py +52 -52
  204. edsl/scenarios/ScenarioListPdfMixin.py +261 -261
  205. edsl/scenarios/__init__.py +4 -4
  206. edsl/shared.py +1 -1
  207. edsl/study/ObjectEntry.py +173 -173
  208. edsl/study/ProofOfWork.py +113 -113
  209. edsl/study/SnapShot.py +80 -80
  210. edsl/study/Study.py +528 -528
  211. edsl/study/__init__.py +4 -4
  212. edsl/surveys/DAG.py +148 -148
  213. edsl/surveys/Memory.py +31 -31
  214. edsl/surveys/MemoryPlan.py +244 -244
  215. edsl/surveys/Rule.py +330 -324
  216. edsl/surveys/RuleCollection.py +387 -387
  217. edsl/surveys/Survey.py +1795 -1772
  218. edsl/surveys/SurveyCSS.py +261 -261
  219. edsl/surveys/SurveyExportMixin.py +259 -259
  220. edsl/surveys/SurveyFlowVisualizationMixin.py +121 -121
  221. edsl/surveys/SurveyQualtricsImport.py +284 -284
  222. edsl/surveys/__init__.py +3 -3
  223. edsl/surveys/base.py +53 -53
  224. edsl/surveys/descriptors.py +56 -56
  225. edsl/surveys/instructions/ChangeInstruction.py +47 -47
  226. edsl/surveys/instructions/Instruction.py +51 -51
  227. edsl/surveys/instructions/InstructionCollection.py +77 -77
  228. edsl/templates/error_reporting/base.html +23 -23
  229. edsl/templates/error_reporting/exceptions_by_model.html +34 -34
  230. edsl/templates/error_reporting/exceptions_by_question_name.html +16 -16
  231. edsl/templates/error_reporting/exceptions_by_type.html +16 -16
  232. edsl/templates/error_reporting/interview_details.html +115 -115
  233. edsl/templates/error_reporting/interviews.html +9 -9
  234. edsl/templates/error_reporting/overview.html +4 -4
  235. edsl/templates/error_reporting/performance_plot.html +1 -1
  236. edsl/templates/error_reporting/report.css +73 -73
  237. edsl/templates/error_reporting/report.html +117 -117
  238. edsl/templates/error_reporting/report.js +25 -25
  239. edsl/tools/__init__.py +1 -1
  240. edsl/tools/clusters.py +192 -192
  241. edsl/tools/embeddings.py +27 -27
  242. edsl/tools/embeddings_plotting.py +118 -118
  243. edsl/tools/plotting.py +112 -112
  244. edsl/tools/summarize.py +18 -18
  245. edsl/utilities/SystemInfo.py +28 -28
  246. edsl/utilities/__init__.py +22 -22
  247. edsl/utilities/ast_utilities.py +25 -25
  248. edsl/utilities/data/Registry.py +6 -6
  249. edsl/utilities/data/__init__.py +1 -1
  250. edsl/utilities/data/scooter_results.json +1 -1
  251. edsl/utilities/decorators.py +77 -77
  252. edsl/utilities/gcp_bucket/cloud_storage.py +96 -96
  253. edsl/utilities/interface.py +627 -627
  254. edsl/utilities/repair_functions.py +28 -28
  255. edsl/utilities/restricted_python.py +70 -70
  256. edsl/utilities/utilities.py +409 -391
  257. {edsl-0.1.37.dev4.dist-info → edsl-0.1.37.dev5.dist-info}/LICENSE +21 -21
  258. {edsl-0.1.37.dev4.dist-info → edsl-0.1.37.dev5.dist-info}/METADATA +1 -1
  259. edsl-0.1.37.dev5.dist-info/RECORD +283 -0
  260. edsl-0.1.37.dev4.dist-info/RECORD +0 -279
  261. {edsl-0.1.37.dev4.dist-info → edsl-0.1.37.dev5.dist-info}/WHEEL +0 -0
@@ -1,156 +1,156 @@
1
- import os
2
- from typing import Any, Dict, List, Optional
3
- import google
4
- import google.generativeai as genai
5
- from google.generativeai.types import GenerationConfig
6
- from google.api_core.exceptions import InvalidArgument
7
-
8
- from edsl.exceptions import MissingAPIKeyError
9
- from edsl.language_models.LanguageModel import LanguageModel
10
- from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
11
-
12
- safety_settings = [
13
- {
14
- "category": "HARM_CATEGORY_HARASSMENT",
15
- "threshold": "BLOCK_NONE",
16
- },
17
- {
18
- "category": "HARM_CATEGORY_HATE_SPEECH",
19
- "threshold": "BLOCK_NONE",
20
- },
21
- {
22
- "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
23
- "threshold": "BLOCK_NONE",
24
- },
25
- {
26
- "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
27
- "threshold": "BLOCK_NONE",
28
- },
29
- ]
30
-
31
-
32
- class GoogleService(InferenceServiceABC):
33
- _inference_service_ = "google"
34
- key_sequence = ["candidates", 0, "content", "parts", 0, "text"]
35
- usage_sequence = ["usage_metadata"]
36
- input_token_name = "prompt_token_count"
37
- output_token_name = "candidates_token_count"
38
-
39
- model_exclude_list = []
40
-
41
- # @classmethod
42
- # def available(cls) -> List[str]:
43
- # return ["gemini-pro", "gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.0-pro"]
44
-
45
- @classmethod
46
- def available(cls) -> List[str]:
47
- model_list = []
48
- for m in genai.list_models():
49
- if "generateContent" in m.supported_generation_methods:
50
- model_list.append(m.name.split("/")[-1])
51
- return model_list
52
-
53
- @classmethod
54
- def create_model(
55
- cls, model_name: str = "gemini-pro", model_class_name=None
56
- ) -> LanguageModel:
57
- if model_class_name is None:
58
- model_class_name = cls.to_class_name(model_name)
59
-
60
- class LLM(LanguageModel):
61
- _model_ = model_name
62
- key_sequence = cls.key_sequence
63
- usage_sequence = cls.usage_sequence
64
- input_token_name = cls.input_token_name
65
- output_token_name = cls.output_token_name
66
- _inference_service_ = cls._inference_service_
67
-
68
- _tpm = cls.get_tpm(cls)
69
- _rpm = cls.get_rpm(cls)
70
-
71
- _parameters_ = {
72
- "temperature": 0.5,
73
- "topP": 1,
74
- "topK": 1,
75
- "maxOutputTokens": 2048,
76
- "stopSequences": [],
77
- }
78
-
79
- api_token = None
80
- model = None
81
-
82
- @classmethod
83
- def initialize(cls):
84
- if cls.api_token is None:
85
- cls.api_token = os.getenv("GOOGLE_API_KEY")
86
- if not cls.api_token:
87
- raise MissingAPIKeyError(
88
- "GOOGLE_API_KEY environment variable is not set"
89
- )
90
- genai.configure(api_key=cls.api_token)
91
- cls.generative_model = genai.GenerativeModel(
92
- cls._model_, safety_settings=safety_settings
93
- )
94
-
95
- def __init__(self, *args, **kwargs):
96
- super().__init__(*args, **kwargs)
97
- self.initialize()
98
-
99
- def get_generation_config(self) -> GenerationConfig:
100
- return GenerationConfig(
101
- temperature=self.temperature,
102
- top_p=self.topP,
103
- top_k=self.topK,
104
- max_output_tokens=self.maxOutputTokens,
105
- stop_sequences=self.stopSequences,
106
- )
107
-
108
- async def async_execute_model_call(
109
- self,
110
- user_prompt: str,
111
- system_prompt: str = "",
112
- files_list: Optional["Files"] = None,
113
- ) -> Dict[str, Any]:
114
- generation_config = self.get_generation_config()
115
-
116
- if files_list is None:
117
- files_list = []
118
-
119
- if (
120
- system_prompt is not None
121
- and system_prompt != ""
122
- and self._model_ != "gemini-pro"
123
- ):
124
- try:
125
- self.generative_model = genai.GenerativeModel(
126
- self._model_,
127
- safety_settings=safety_settings,
128
- system_instruction=system_prompt,
129
- )
130
- except InvalidArgument as e:
131
- print(
132
- f"This model, {self._model_}, does not support system_instruction"
133
- )
134
- print("Will add system_prompt to user_prompt")
135
- user_prompt = f"{system_prompt}\n{user_prompt}"
136
-
137
- combined_prompt = [user_prompt]
138
- for file in files_list:
139
- if "google" not in file.external_locations:
140
- _ = file.upload_google()
141
- gen_ai_file = google.generativeai.types.file_types.File(
142
- file.external_locations["google"]
143
- )
144
- combined_prompt.append(gen_ai_file)
145
-
146
- response = await self.generative_model.generate_content_async(
147
- combined_prompt, generation_config=generation_config
148
- )
149
- return response.to_dict()
150
-
151
- LLM.__name__ = model_name
152
- return LLM
153
-
154
-
155
- if __name__ == "__main__":
156
- pass
1
+ import os
2
+ from typing import Any, Dict, List, Optional
3
+ import google
4
+ import google.generativeai as genai
5
+ from google.generativeai.types import GenerationConfig
6
+ from google.api_core.exceptions import InvalidArgument
7
+
8
+ from edsl.exceptions import MissingAPIKeyError
9
+ from edsl.language_models.LanguageModel import LanguageModel
10
+ from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
11
+
12
+ safety_settings = [
13
+ {
14
+ "category": "HARM_CATEGORY_HARASSMENT",
15
+ "threshold": "BLOCK_NONE",
16
+ },
17
+ {
18
+ "category": "HARM_CATEGORY_HATE_SPEECH",
19
+ "threshold": "BLOCK_NONE",
20
+ },
21
+ {
22
+ "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
23
+ "threshold": "BLOCK_NONE",
24
+ },
25
+ {
26
+ "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
27
+ "threshold": "BLOCK_NONE",
28
+ },
29
+ ]
30
+
31
+
32
+ class GoogleService(InferenceServiceABC):
33
+ _inference_service_ = "google"
34
+ key_sequence = ["candidates", 0, "content", "parts", 0, "text"]
35
+ usage_sequence = ["usage_metadata"]
36
+ input_token_name = "prompt_token_count"
37
+ output_token_name = "candidates_token_count"
38
+
39
+ model_exclude_list = []
40
+
41
+ # @classmethod
42
+ # def available(cls) -> List[str]:
43
+ # return ["gemini-pro", "gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.0-pro"]
44
+
45
+ @classmethod
46
+ def available(cls) -> List[str]:
47
+ model_list = []
48
+ for m in genai.list_models():
49
+ if "generateContent" in m.supported_generation_methods:
50
+ model_list.append(m.name.split("/")[-1])
51
+ return model_list
52
+
53
+ @classmethod
54
+ def create_model(
55
+ cls, model_name: str = "gemini-pro", model_class_name=None
56
+ ) -> LanguageModel:
57
+ if model_class_name is None:
58
+ model_class_name = cls.to_class_name(model_name)
59
+
60
+ class LLM(LanguageModel):
61
+ _model_ = model_name
62
+ key_sequence = cls.key_sequence
63
+ usage_sequence = cls.usage_sequence
64
+ input_token_name = cls.input_token_name
65
+ output_token_name = cls.output_token_name
66
+ _inference_service_ = cls._inference_service_
67
+
68
+ _tpm = cls.get_tpm(cls)
69
+ _rpm = cls.get_rpm(cls)
70
+
71
+ _parameters_ = {
72
+ "temperature": 0.5,
73
+ "topP": 1,
74
+ "topK": 1,
75
+ "maxOutputTokens": 2048,
76
+ "stopSequences": [],
77
+ }
78
+
79
+ api_token = None
80
+ model = None
81
+
82
+ @classmethod
83
+ def initialize(cls):
84
+ if cls.api_token is None:
85
+ cls.api_token = os.getenv("GOOGLE_API_KEY")
86
+ if not cls.api_token:
87
+ raise MissingAPIKeyError(
88
+ "GOOGLE_API_KEY environment variable is not set"
89
+ )
90
+ genai.configure(api_key=cls.api_token)
91
+ cls.generative_model = genai.GenerativeModel(
92
+ cls._model_, safety_settings=safety_settings
93
+ )
94
+
95
+ def __init__(self, *args, **kwargs):
96
+ super().__init__(*args, **kwargs)
97
+ self.initialize()
98
+
99
+ def get_generation_config(self) -> GenerationConfig:
100
+ return GenerationConfig(
101
+ temperature=self.temperature,
102
+ top_p=self.topP,
103
+ top_k=self.topK,
104
+ max_output_tokens=self.maxOutputTokens,
105
+ stop_sequences=self.stopSequences,
106
+ )
107
+
108
+ async def async_execute_model_call(
109
+ self,
110
+ user_prompt: str,
111
+ system_prompt: str = "",
112
+ files_list: Optional["Files"] = None,
113
+ ) -> Dict[str, Any]:
114
+ generation_config = self.get_generation_config()
115
+
116
+ if files_list is None:
117
+ files_list = []
118
+
119
+ if (
120
+ system_prompt is not None
121
+ and system_prompt != ""
122
+ and self._model_ != "gemini-pro"
123
+ ):
124
+ try:
125
+ self.generative_model = genai.GenerativeModel(
126
+ self._model_,
127
+ safety_settings=safety_settings,
128
+ system_instruction=system_prompt,
129
+ )
130
+ except InvalidArgument as e:
131
+ print(
132
+ f"This model, {self._model_}, does not support system_instruction"
133
+ )
134
+ print("Will add system_prompt to user_prompt")
135
+ user_prompt = f"{system_prompt}\n{user_prompt}"
136
+
137
+ combined_prompt = [user_prompt]
138
+ for file in files_list:
139
+ if "google" not in file.external_locations:
140
+ _ = file.upload_google()
141
+ gen_ai_file = google.generativeai.types.file_types.File(
142
+ file.external_locations["google"]
143
+ )
144
+ combined_prompt.append(gen_ai_file)
145
+
146
+ response = await self.generative_model.generate_content_async(
147
+ combined_prompt, generation_config=generation_config
148
+ )
149
+ return response.to_dict()
150
+
151
+ LLM.__name__ = model_name
152
+ return LLM
153
+
154
+
155
+ if __name__ == "__main__":
156
+ pass
@@ -1,20 +1,20 @@
1
- from typing import Any, List
2
- from edsl.inference_services.OpenAIService import OpenAIService
3
-
4
- import groq
5
-
6
-
7
- class GroqService(OpenAIService):
8
- """DeepInfra service class."""
9
-
10
- _inference_service_ = "groq"
11
- _env_key_name_ = "GROQ_API_KEY"
12
-
13
- _sync_client_ = groq.Groq
14
- _async_client_ = groq.AsyncGroq
15
-
16
- model_exclude_list = ["whisper-large-v3", "distil-whisper-large-v3-en"]
17
-
18
- # _base_url_ = "https://api.deepinfra.com/v1/openai"
19
- _base_url_ = None
20
- _models_list_cache: List[str] = []
1
+ from typing import Any, List
2
+ from edsl.inference_services.OpenAIService import OpenAIService
3
+
4
+ import groq
5
+
6
+
7
+ class GroqService(OpenAIService):
8
+ """DeepInfra service class."""
9
+
10
+ _inference_service_ = "groq"
11
+ _env_key_name_ = "GROQ_API_KEY"
12
+
13
+ _sync_client_ = groq.Groq
14
+ _async_client_ = groq.AsyncGroq
15
+
16
+ model_exclude_list = ["whisper-large-v3", "distil-whisper-large-v3-en"]
17
+
18
+ # _base_url_ = "https://api.deepinfra.com/v1/openai"
19
+ _base_url_ = None
20
+ _models_list_cache: List[str] = []
@@ -1,147 +1,147 @@
1
- from abc import abstractmethod, ABC
2
- import os
3
- import re
4
- from datetime import datetime, timedelta
5
- from edsl.config import CONFIG
6
-
7
-
8
- class InferenceServiceABC(ABC):
9
- """
10
- Abstract class for inference services.
11
- Anthropic: https://docs.anthropic.com/en/api/rate-limits
12
- """
13
-
14
- _coop_config_vars = None
15
-
16
- default_levels = {
17
- "google": {"tpm": 2_000_000, "rpm": 15},
18
- "openai": {"tpm": 2_000_000, "rpm": 10_000},
19
- "anthropic": {"tpm": 2_000_000, "rpm": 500},
20
- }
21
-
22
- def __init_subclass__(cls):
23
- """
24
- Check that the subclass has the required attributes.
25
- - `key_sequence` attribute determines...
26
- - `model_exclude_list` attribute determines...
27
- """
28
- if not hasattr(cls, "key_sequence"):
29
- raise NotImplementedError(
30
- f"Class {cls.__name__} must have a 'key_sequence' attribute."
31
- )
32
- if not hasattr(cls, "model_exclude_list"):
33
- raise NotImplementedError(
34
- f"Class {cls.__name__} must have a 'model_exclude_list' attribute."
35
- )
36
-
37
- @classmethod
38
- def _should_refresh_coop_config_vars(cls):
39
- """
40
- Returns True if config vars have been fetched over 24 hours ago, and False otherwise.
41
- """
42
-
43
- if cls._last_config_fetch is None:
44
- return True
45
- return (datetime.now() - cls._last_config_fetch) > timedelta(hours=24)
46
-
47
- @classmethod
48
- def _get_limt(cls, limit_type: str) -> int:
49
- key = f"EDSL_SERVICE_{limit_type.upper()}_{cls._inference_service_.upper()}"
50
- if key in os.environ:
51
- return int(os.getenv(key))
52
-
53
- if cls._coop_config_vars is None or cls._should_refresh_coop_config_vars():
54
- try:
55
- from edsl import Coop
56
-
57
- c = Coop()
58
- cls._coop_config_vars = c.fetch_rate_limit_config_vars()
59
- cls._last_config_fetch = datetime.now()
60
- if key in cls._coop_config_vars:
61
- return cls._coop_config_vars[key]
62
- except Exception:
63
- cls._coop_config_vars = None
64
- else:
65
- if key in cls._coop_config_vars:
66
- return cls._coop_config_vars[key]
67
-
68
- if cls._inference_service_ in cls.default_levels:
69
- return int(cls.default_levels[cls._inference_service_][limit_type])
70
-
71
- return int(CONFIG.get(f"EDSL_SERVICE_{limit_type.upper()}_BASELINE"))
72
-
73
- def get_tpm(cls) -> int:
74
- """
75
- Returns the TPM for the service. If the service is not defined in the environment variables, it will return the baseline TPM.
76
- """
77
- return cls._get_limt(limit_type="tpm")
78
-
79
- def get_rpm(cls):
80
- """
81
- Returns the RPM for the service. If the service is not defined in the environment variables, it will return the baseline RPM.
82
- """
83
- return cls._get_limt(limit_type="rpm")
84
-
85
- @abstractmethod
86
- def available() -> list[str]:
87
- """
88
- Returns a list of available models for the service.
89
- """
90
- pass
91
-
92
- @abstractmethod
93
- def create_model():
94
- """
95
- Returns a LanguageModel object.
96
- """
97
- pass
98
-
99
- @staticmethod
100
- def to_class_name(s):
101
- """
102
- Converts a string to a valid class name.
103
-
104
- >>> InferenceServiceABC.to_class_name("hello world")
105
- 'HelloWorld'
106
- """
107
-
108
- s = re.sub(r"[^a-zA-Z0-9 ]", "", s)
109
- s = "".join(word.title() for word in s.split())
110
- if s and s[0].isdigit():
111
- s = "Class" + s
112
- return s
113
-
114
-
115
- if __name__ == "__main__":
116
- pass
117
- # deep_infra_service = DeepInfraService("deep_infra", "DEEP_INFRA_API_KEY")
118
- # deep_infra_service.available()
119
- # m = deep_infra_service.create_model("microsoft/WizardLM-2-7B")
120
- # response = m().hello()
121
- # print(response)
122
-
123
- # anthropic_service = AnthropicService("anthropic", "ANTHROPIC_API_KEY")
124
- # anthropic_service.available()
125
- # m = anthropic_service.create_model("claude-3-opus-20240229")
126
- # response = m().hello()
127
- # print(response)
128
- # factory = OpenAIService("openai", "OPENAI_API")
129
- # factory.available()
130
- # m = factory.create_model("gpt-3.5-turbo")
131
- # response = m().hello()
132
-
133
- # from edsl import QuestionFreeText
134
- # results = QuestionFreeText.example().by(m()).run()
135
-
136
- # collection = InferenceServicesCollection([
137
- # OpenAIService,
138
- # AnthropicService,
139
- # DeepInfraService
140
- # ])
141
-
142
- # available = collection.available()
143
- # factory = collection.create_model_factory(*available[0])
144
- # m = factory()
145
- # from edsl import QuestionFreeText
146
- # results = QuestionFreeText.example().by(m).run()
147
- # print(results)
1
+ from abc import abstractmethod, ABC
2
+ import os
3
+ import re
4
+ from datetime import datetime, timedelta
5
+ from edsl.config import CONFIG
6
+
7
+
8
+ class InferenceServiceABC(ABC):
9
+ """
10
+ Abstract class for inference services.
11
+ Anthropic: https://docs.anthropic.com/en/api/rate-limits
12
+ """
13
+
14
+ _coop_config_vars = None
15
+
16
+ default_levels = {
17
+ "google": {"tpm": 2_000_000, "rpm": 15},
18
+ "openai": {"tpm": 2_000_000, "rpm": 10_000},
19
+ "anthropic": {"tpm": 2_000_000, "rpm": 500},
20
+ }
21
+
22
+ def __init_subclass__(cls):
23
+ """
24
+ Check that the subclass has the required attributes.
25
+ - `key_sequence` attribute determines...
26
+ - `model_exclude_list` attribute determines...
27
+ """
28
+ if not hasattr(cls, "key_sequence"):
29
+ raise NotImplementedError(
30
+ f"Class {cls.__name__} must have a 'key_sequence' attribute."
31
+ )
32
+ if not hasattr(cls, "model_exclude_list"):
33
+ raise NotImplementedError(
34
+ f"Class {cls.__name__} must have a 'model_exclude_list' attribute."
35
+ )
36
+
37
+ @classmethod
38
+ def _should_refresh_coop_config_vars(cls):
39
+ """
40
+ Returns True if config vars have been fetched over 24 hours ago, and False otherwise.
41
+ """
42
+
43
+ if cls._last_config_fetch is None:
44
+ return True
45
+ return (datetime.now() - cls._last_config_fetch) > timedelta(hours=24)
46
+
47
+ @classmethod
48
+ def _get_limt(cls, limit_type: str) -> int:
49
+ key = f"EDSL_SERVICE_{limit_type.upper()}_{cls._inference_service_.upper()}"
50
+ if key in os.environ:
51
+ return int(os.getenv(key))
52
+
53
+ if cls._coop_config_vars is None or cls._should_refresh_coop_config_vars():
54
+ try:
55
+ from edsl import Coop
56
+
57
+ c = Coop()
58
+ cls._coop_config_vars = c.fetch_rate_limit_config_vars()
59
+ cls._last_config_fetch = datetime.now()
60
+ if key in cls._coop_config_vars:
61
+ return cls._coop_config_vars[key]
62
+ except Exception:
63
+ cls._coop_config_vars = None
64
+ else:
65
+ if key in cls._coop_config_vars:
66
+ return cls._coop_config_vars[key]
67
+
68
+ if cls._inference_service_ in cls.default_levels:
69
+ return int(cls.default_levels[cls._inference_service_][limit_type])
70
+
71
+ return int(CONFIG.get(f"EDSL_SERVICE_{limit_type.upper()}_BASELINE"))
72
+
73
+ def get_tpm(cls) -> int:
74
+ """
75
+ Returns the TPM for the service. If the service is not defined in the environment variables, it will return the baseline TPM.
76
+ """
77
+ return cls._get_limt(limit_type="tpm")
78
+
79
+ def get_rpm(cls):
80
+ """
81
+ Returns the RPM for the service. If the service is not defined in the environment variables, it will return the baseline RPM.
82
+ """
83
+ return cls._get_limt(limit_type="rpm")
84
+
85
+ @abstractmethod
86
+ def available() -> list[str]:
87
+ """
88
+ Returns a list of available models for the service.
89
+ """
90
+ pass
91
+
92
+ @abstractmethod
93
+ def create_model():
94
+ """
95
+ Returns a LanguageModel object.
96
+ """
97
+ pass
98
+
99
+ @staticmethod
100
+ def to_class_name(s):
101
+ """
102
+ Converts a string to a valid class name.
103
+
104
+ >>> InferenceServiceABC.to_class_name("hello world")
105
+ 'HelloWorld'
106
+ """
107
+
108
+ s = re.sub(r"[^a-zA-Z0-9 ]", "", s)
109
+ s = "".join(word.title() for word in s.split())
110
+ if s and s[0].isdigit():
111
+ s = "Class" + s
112
+ return s
113
+
114
+
115
+ if __name__ == "__main__":
116
+ pass
117
+ # deep_infra_service = DeepInfraService("deep_infra", "DEEP_INFRA_API_KEY")
118
+ # deep_infra_service.available()
119
+ # m = deep_infra_service.create_model("microsoft/WizardLM-2-7B")
120
+ # response = m().hello()
121
+ # print(response)
122
+
123
+ # anthropic_service = AnthropicService("anthropic", "ANTHROPIC_API_KEY")
124
+ # anthropic_service.available()
125
+ # m = anthropic_service.create_model("claude-3-opus-20240229")
126
+ # response = m().hello()
127
+ # print(response)
128
+ # factory = OpenAIService("openai", "OPENAI_API")
129
+ # factory.available()
130
+ # m = factory.create_model("gpt-3.5-turbo")
131
+ # response = m().hello()
132
+
133
+ # from edsl import QuestionFreeText
134
+ # results = QuestionFreeText.example().by(m()).run()
135
+
136
+ # collection = InferenceServicesCollection([
137
+ # OpenAIService,
138
+ # AnthropicService,
139
+ # DeepInfraService
140
+ # ])
141
+
142
+ # available = collection.available()
143
+ # factory = collection.create_model_factory(*available[0])
144
+ # m = factory()
145
+ # from edsl import QuestionFreeText
146
+ # results = QuestionFreeText.example().by(m).run()
147
+ # print(results)