edsl 0.1.39__py3-none-any.whl → 0.1.39.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (212) hide show
  1. edsl/Base.py +116 -197
  2. edsl/__init__.py +7 -15
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +147 -351
  5. edsl/agents/AgentList.py +73 -211
  6. edsl/agents/Invigilator.py +50 -101
  7. edsl/agents/InvigilatorBase.py +70 -62
  8. edsl/agents/PromptConstructor.py +225 -143
  9. edsl/agents/__init__.py +1 -0
  10. edsl/agents/prompt_helpers.py +3 -3
  11. edsl/auto/AutoStudy.py +5 -18
  12. edsl/auto/StageBase.py +40 -53
  13. edsl/auto/StageQuestions.py +1 -2
  14. edsl/auto/utilities.py +6 -0
  15. edsl/config.py +2 -22
  16. edsl/conversation/car_buying.py +1 -2
  17. edsl/coop/PriceFetcher.py +1 -1
  18. edsl/coop/coop.py +47 -125
  19. edsl/coop/utils.py +14 -14
  20. edsl/data/Cache.py +27 -45
  21. edsl/data/CacheEntry.py +15 -12
  22. edsl/data/CacheHandler.py +12 -31
  23. edsl/data/RemoteCacheSync.py +46 -154
  24. edsl/data/__init__.py +3 -4
  25. edsl/data_transfer_models.py +1 -2
  26. edsl/enums.py +0 -27
  27. edsl/exceptions/__init__.py +50 -50
  28. edsl/exceptions/agents.py +0 -12
  29. edsl/exceptions/questions.py +6 -24
  30. edsl/exceptions/scenarios.py +0 -7
  31. edsl/inference_services/AnthropicService.py +19 -38
  32. edsl/inference_services/AwsBedrock.py +2 -0
  33. edsl/inference_services/AzureAI.py +2 -0
  34. edsl/inference_services/GoogleService.py +12 -7
  35. edsl/inference_services/InferenceServiceABC.py +85 -18
  36. edsl/inference_services/InferenceServicesCollection.py +79 -120
  37. edsl/inference_services/MistralAIService.py +3 -0
  38. edsl/inference_services/OpenAIService.py +35 -47
  39. edsl/inference_services/PerplexityService.py +3 -0
  40. edsl/inference_services/TestService.py +10 -11
  41. edsl/inference_services/TogetherAIService.py +3 -5
  42. edsl/jobs/Answers.py +14 -1
  43. edsl/jobs/Jobs.py +431 -356
  44. edsl/jobs/JobsChecks.py +10 -35
  45. edsl/jobs/JobsPrompts.py +4 -6
  46. edsl/jobs/JobsRemoteInferenceHandler.py +133 -205
  47. edsl/jobs/buckets/BucketCollection.py +3 -44
  48. edsl/jobs/buckets/TokenBucket.py +21 -53
  49. edsl/jobs/interviews/Interview.py +408 -143
  50. edsl/jobs/runners/JobsRunnerAsyncio.py +403 -88
  51. edsl/jobs/runners/JobsRunnerStatus.py +165 -133
  52. edsl/jobs/tasks/QuestionTaskCreator.py +19 -21
  53. edsl/jobs/tasks/TaskHistory.py +18 -38
  54. edsl/jobs/tasks/task_status_enum.py +2 -0
  55. edsl/language_models/KeyLookup.py +30 -0
  56. edsl/language_models/LanguageModel.py +236 -194
  57. edsl/language_models/ModelList.py +19 -28
  58. edsl/language_models/__init__.py +2 -1
  59. edsl/language_models/registry.py +190 -0
  60. edsl/language_models/repair.py +2 -2
  61. edsl/language_models/unused/ReplicateBase.py +83 -0
  62. edsl/language_models/utilities.py +4 -5
  63. edsl/notebooks/Notebook.py +14 -19
  64. edsl/prompts/Prompt.py +39 -29
  65. edsl/questions/{answer_validator_mixin.py → AnswerValidatorMixin.py} +2 -47
  66. edsl/questions/QuestionBase.py +214 -68
  67. edsl/questions/{question_base_gen_mixin.py → QuestionBaseGenMixin.py} +50 -57
  68. edsl/questions/QuestionBasePromptsMixin.py +3 -7
  69. edsl/questions/QuestionBudget.py +1 -1
  70. edsl/questions/QuestionCheckBox.py +3 -3
  71. edsl/questions/QuestionExtract.py +7 -5
  72. edsl/questions/QuestionFreeText.py +3 -2
  73. edsl/questions/QuestionList.py +18 -10
  74. edsl/questions/QuestionMultipleChoice.py +23 -67
  75. edsl/questions/QuestionNumerical.py +4 -2
  76. edsl/questions/QuestionRank.py +17 -7
  77. edsl/questions/{response_validator_abc.py → ResponseValidatorABC.py} +26 -40
  78. edsl/questions/SimpleAskMixin.py +3 -4
  79. edsl/questions/__init__.py +1 -2
  80. edsl/questions/derived/QuestionLinearScale.py +3 -6
  81. edsl/questions/derived/QuestionTopK.py +1 -1
  82. edsl/questions/descriptors.py +3 -17
  83. edsl/questions/question_registry.py +1 -1
  84. edsl/results/CSSParameterizer.py +1 -1
  85. edsl/results/Dataset.py +7 -170
  86. edsl/results/DatasetExportMixin.py +305 -168
  87. edsl/results/DatasetTree.py +8 -28
  88. edsl/results/Result.py +206 -298
  89. edsl/results/Results.py +131 -149
  90. edsl/results/ResultsDBMixin.py +238 -0
  91. edsl/results/ResultsExportMixin.py +0 -2
  92. edsl/results/{results_selector.py → Selector.py} +13 -23
  93. edsl/results/TableDisplay.py +171 -98
  94. edsl/results/__init__.py +1 -1
  95. edsl/scenarios/FileStore.py +239 -150
  96. edsl/scenarios/Scenario.py +193 -90
  97. edsl/scenarios/ScenarioHtmlMixin.py +3 -4
  98. edsl/scenarios/{scenario_join.py → ScenarioJoin.py} +6 -10
  99. edsl/scenarios/ScenarioList.py +244 -415
  100. edsl/scenarios/ScenarioListExportMixin.py +7 -0
  101. edsl/scenarios/ScenarioListPdfMixin.py +37 -15
  102. edsl/scenarios/__init__.py +2 -1
  103. edsl/study/ObjectEntry.py +1 -1
  104. edsl/study/SnapShot.py +1 -1
  105. edsl/study/Study.py +12 -5
  106. edsl/surveys/Rule.py +4 -5
  107. edsl/surveys/RuleCollection.py +27 -25
  108. edsl/surveys/Survey.py +791 -270
  109. edsl/surveys/SurveyCSS.py +8 -20
  110. edsl/surveys/{SurveyFlowVisualization.py → SurveyFlowVisualizationMixin.py} +9 -11
  111. edsl/surveys/__init__.py +2 -4
  112. edsl/surveys/descriptors.py +2 -6
  113. edsl/surveys/instructions/ChangeInstruction.py +2 -1
  114. edsl/surveys/instructions/Instruction.py +13 -4
  115. edsl/surveys/instructions/InstructionCollection.py +6 -11
  116. edsl/templates/error_reporting/interview_details.html +1 -1
  117. edsl/templates/error_reporting/report.html +1 -1
  118. edsl/tools/plotting.py +1 -1
  119. edsl/utilities/utilities.py +23 -35
  120. {edsl-0.1.39.dist-info → edsl-0.1.39.dev1.dist-info}/METADATA +10 -12
  121. edsl-0.1.39.dev1.dist-info/RECORD +277 -0
  122. {edsl-0.1.39.dist-info → edsl-0.1.39.dev1.dist-info}/WHEEL +1 -1
  123. edsl/agents/QuestionInstructionPromptBuilder.py +0 -128
  124. edsl/agents/QuestionTemplateReplacementsBuilder.py +0 -137
  125. edsl/agents/question_option_processor.py +0 -172
  126. edsl/coop/CoopFunctionsMixin.py +0 -15
  127. edsl/coop/ExpectedParrotKeyHandler.py +0 -125
  128. edsl/exceptions/inference_services.py +0 -5
  129. edsl/inference_services/AvailableModelCacheHandler.py +0 -184
  130. edsl/inference_services/AvailableModelFetcher.py +0 -215
  131. edsl/inference_services/ServiceAvailability.py +0 -135
  132. edsl/inference_services/data_structures.py +0 -134
  133. edsl/jobs/AnswerQuestionFunctionConstructor.py +0 -223
  134. edsl/jobs/FetchInvigilator.py +0 -47
  135. edsl/jobs/InterviewTaskManager.py +0 -98
  136. edsl/jobs/InterviewsConstructor.py +0 -50
  137. edsl/jobs/JobsComponentConstructor.py +0 -189
  138. edsl/jobs/JobsRemoteInferenceLogger.py +0 -239
  139. edsl/jobs/RequestTokenEstimator.py +0 -30
  140. edsl/jobs/async_interview_runner.py +0 -138
  141. edsl/jobs/buckets/TokenBucketAPI.py +0 -211
  142. edsl/jobs/buckets/TokenBucketClient.py +0 -191
  143. edsl/jobs/check_survey_scenario_compatibility.py +0 -85
  144. edsl/jobs/data_structures.py +0 -120
  145. edsl/jobs/decorators.py +0 -35
  146. edsl/jobs/jobs_status_enums.py +0 -9
  147. edsl/jobs/loggers/HTMLTableJobLogger.py +0 -304
  148. edsl/jobs/results_exceptions_handler.py +0 -98
  149. edsl/language_models/ComputeCost.py +0 -63
  150. edsl/language_models/PriceManager.py +0 -127
  151. edsl/language_models/RawResponseHandler.py +0 -106
  152. edsl/language_models/ServiceDataSources.py +0 -0
  153. edsl/language_models/key_management/KeyLookup.py +0 -63
  154. edsl/language_models/key_management/KeyLookupBuilder.py +0 -273
  155. edsl/language_models/key_management/KeyLookupCollection.py +0 -38
  156. edsl/language_models/key_management/__init__.py +0 -0
  157. edsl/language_models/key_management/models.py +0 -131
  158. edsl/language_models/model.py +0 -256
  159. edsl/notebooks/NotebookToLaTeX.py +0 -142
  160. edsl/questions/ExceptionExplainer.py +0 -77
  161. edsl/questions/HTMLQuestion.py +0 -103
  162. edsl/questions/QuestionMatrix.py +0 -265
  163. edsl/questions/data_structures.py +0 -20
  164. edsl/questions/loop_processor.py +0 -149
  165. edsl/questions/response_validator_factory.py +0 -34
  166. edsl/questions/templates/matrix/__init__.py +0 -1
  167. edsl/questions/templates/matrix/answering_instructions.jinja +0 -5
  168. edsl/questions/templates/matrix/question_presentation.jinja +0 -20
  169. edsl/results/MarkdownToDocx.py +0 -122
  170. edsl/results/MarkdownToPDF.py +0 -111
  171. edsl/results/TextEditor.py +0 -50
  172. edsl/results/file_exports.py +0 -252
  173. edsl/results/smart_objects.py +0 -96
  174. edsl/results/table_data_class.py +0 -12
  175. edsl/results/table_renderers.py +0 -118
  176. edsl/scenarios/ConstructDownloadLink.py +0 -109
  177. edsl/scenarios/DocumentChunker.py +0 -102
  178. edsl/scenarios/DocxScenario.py +0 -16
  179. edsl/scenarios/PdfExtractor.py +0 -40
  180. edsl/scenarios/directory_scanner.py +0 -96
  181. edsl/scenarios/file_methods.py +0 -85
  182. edsl/scenarios/handlers/__init__.py +0 -13
  183. edsl/scenarios/handlers/csv.py +0 -49
  184. edsl/scenarios/handlers/docx.py +0 -76
  185. edsl/scenarios/handlers/html.py +0 -37
  186. edsl/scenarios/handlers/json.py +0 -111
  187. edsl/scenarios/handlers/latex.py +0 -5
  188. edsl/scenarios/handlers/md.py +0 -51
  189. edsl/scenarios/handlers/pdf.py +0 -68
  190. edsl/scenarios/handlers/png.py +0 -39
  191. edsl/scenarios/handlers/pptx.py +0 -105
  192. edsl/scenarios/handlers/py.py +0 -294
  193. edsl/scenarios/handlers/sql.py +0 -313
  194. edsl/scenarios/handlers/sqlite.py +0 -149
  195. edsl/scenarios/handlers/txt.py +0 -33
  196. edsl/scenarios/scenario_selector.py +0 -156
  197. edsl/surveys/ConstructDAG.py +0 -92
  198. edsl/surveys/EditSurvey.py +0 -221
  199. edsl/surveys/InstructionHandler.py +0 -100
  200. edsl/surveys/MemoryManagement.py +0 -72
  201. edsl/surveys/RuleManager.py +0 -172
  202. edsl/surveys/Simulator.py +0 -75
  203. edsl/surveys/SurveyToApp.py +0 -141
  204. edsl/utilities/PrettyList.py +0 -56
  205. edsl/utilities/is_notebook.py +0 -18
  206. edsl/utilities/is_valid_variable_name.py +0 -11
  207. edsl/utilities/remove_edsl_version.py +0 -24
  208. edsl-0.1.39.dist-info/RECORD +0 -358
  209. /edsl/questions/{register_questions_meta.py → RegisterQuestionsMeta.py} +0 -0
  210. /edsl/results/{results_fetch_mixin.py → ResultsFetchMixin.py} +0 -0
  211. /edsl/results/{results_tools_mixin.py → ResultsToolsMixin.py} +0 -0
  212. {edsl-0.1.39.dist-info → edsl-0.1.39.dev1.dist-info}/LICENSE +0 -0
@@ -1,54 +1,54 @@
1
- # from .agents import (
2
- # # AgentAttributeLookupCallbackError,
3
- # AgentCombinationError,
4
- # # AgentLacksLLMError,
5
- # # AgentRespondedWithBadJSONError,
6
- # )
7
- # from .configuration import (
8
- # InvalidEnvironmentVariableError,
9
- # MissingEnvironmentVariableError,
10
- # )
11
- # from .data import (
12
- # DatabaseConnectionError,
13
- # DatabaseCRUDError,
14
- # DatabaseIntegrityError,
15
- # )
1
+ from .agents import (
2
+ # AgentAttributeLookupCallbackError,
3
+ AgentCombinationError,
4
+ # AgentLacksLLMError,
5
+ # AgentRespondedWithBadJSONError,
6
+ )
7
+ from .configuration import (
8
+ InvalidEnvironmentVariableError,
9
+ MissingEnvironmentVariableError,
10
+ )
11
+ from .data import (
12
+ DatabaseConnectionError,
13
+ DatabaseCRUDError,
14
+ DatabaseIntegrityError,
15
+ )
16
16
 
17
- # from .scenarios import (
18
- # ScenarioError,
19
- # )
17
+ from .scenarios import (
18
+ ScenarioError,
19
+ )
20
20
 
21
- # from .general import MissingAPIKeyError
21
+ from .general import MissingAPIKeyError
22
22
 
23
- # from .jobs import JobsRunError, InterviewErrorPriorTaskCanceled, InterviewTimeoutError
23
+ from .jobs import JobsRunError, InterviewErrorPriorTaskCanceled, InterviewTimeoutError
24
24
 
25
- # from .language_models import (
26
- # LanguageModelResponseNotJSONError,
27
- # LanguageModelMissingAttributeError,
28
- # LanguageModelAttributeTypeError,
29
- # LanguageModelDoNotAddError,
30
- # )
31
- # from .questions import (
32
- # QuestionAnswerValidationError,
33
- # QuestionAttributeMissing,
34
- # QuestionCreationValidationError,
35
- # QuestionResponseValidationError,
36
- # QuestionSerializationError,
37
- # QuestionScenarioRenderError,
38
- # )
39
- # from .results import (
40
- # ResultsBadMutationstringError,
41
- # ResultsColumnNotFoundError,
42
- # ResultsInvalidNameError,
43
- # ResultsMutateError,
44
- # )
45
- # from .surveys import (
46
- # SurveyCreationError,
47
- # SurveyHasNoRulesError,
48
- # SurveyRuleCannotEvaluateError,
49
- # SurveyRuleCollectionHasNoRulesAtNodeError,
50
- # SurveyRuleReferenceInRuleToUnknownQuestionError,
51
- # SurveyRuleRefersToFutureStateError,
52
- # SurveyRuleSendsYouBackwardsError,
53
- # SurveyRuleSkipLogicSyntaxError,
54
- # )
25
+ from .language_models import (
26
+ LanguageModelResponseNotJSONError,
27
+ LanguageModelMissingAttributeError,
28
+ LanguageModelAttributeTypeError,
29
+ LanguageModelDoNotAddError,
30
+ )
31
+ from .questions import (
32
+ QuestionAnswerValidationError,
33
+ QuestionAttributeMissing,
34
+ QuestionCreationValidationError,
35
+ QuestionResponseValidationError,
36
+ QuestionSerializationError,
37
+ QuestionScenarioRenderError,
38
+ )
39
+ from .results import (
40
+ ResultsBadMutationstringError,
41
+ ResultsColumnNotFoundError,
42
+ ResultsInvalidNameError,
43
+ ResultsMutateError,
44
+ )
45
+ from .surveys import (
46
+ SurveyCreationError,
47
+ SurveyHasNoRulesError,
48
+ SurveyRuleCannotEvaluateError,
49
+ SurveyRuleCollectionHasNoRulesAtNodeError,
50
+ SurveyRuleReferenceInRuleToUnknownQuestionError,
51
+ SurveyRuleRefersToFutureStateError,
52
+ SurveyRuleSendsYouBackwardsError,
53
+ SurveyRuleSkipLogicSyntaxError,
54
+ )
edsl/exceptions/agents.py CHANGED
@@ -1,18 +1,6 @@
1
1
  from edsl.exceptions.BaseException import BaseException
2
2
 
3
3
 
4
- # from edsl.utilities.utilities import is_notebook
5
-
6
- # from IPython.core.error import UsageError
7
-
8
- # class AgentListErrorAlternative(UsageError):
9
- # def __init__(self, message):
10
- # super().__init__(message)
11
-
12
- import sys
13
- from edsl.utilities.is_notebook import is_notebook
14
-
15
-
16
4
  class AgentListError(BaseException):
17
5
  relevant_doc = "https://docs.expectedparrot.com/en/latest/agents.html#agent-lists"
18
6
 
@@ -1,6 +1,6 @@
1
1
  from typing import Any, SupportsIndex
2
+ from jinja2 import Template
2
3
  import json
3
- from pydantic import ValidationError
4
4
 
5
5
 
6
6
  class QuestionErrors(Exception):
@@ -20,35 +20,17 @@ class QuestionAnswerValidationError(QuestionErrors):
20
20
  For example, if the question is a multiple choice question, the answer should be drawn from the list of options provided.
21
21
  """
22
22
 
23
- def __init__(
24
- self,
25
- message="Invalid answer.",
26
- pydantic_error: ValidationError = None,
27
- data: dict = None,
28
- model=None,
29
- ):
23
+ def __init__(self, message="Invalid answer.", data=None, model=None):
30
24
  self.message = message
31
- self.pydantic_error = pydantic_error
32
25
  self.data = data
33
26
  self.model = model
34
27
  super().__init__(self.message)
35
28
 
36
29
  def __str__(self):
37
- if isinstance(self.message, ValidationError):
38
- # If it's a ValidationError, just return the core error message
39
- return str(self.message)
40
- elif hasattr(self.message, "errors"):
41
- # Handle the case where it's already been converted to a string but has errors
42
- error_list = self.message.errors()
43
- if error_list:
44
- return str(error_list[0].get("msg", "Unknown error"))
45
- return str(self.message)
46
-
47
- # def __str__(self):
48
- # return f"""{repr(self)}
49
- # Data being validated: {self.data}
50
- # Pydnantic Model: {self.model}.
51
- # Reported error: {self.message}."""
30
+ return f"""{repr(self)}
31
+ Data being validated: {self.data}
32
+ Pydnantic Model: {self.model}.
33
+ Reported error: {self.message}."""
52
34
 
53
35
  def to_html_dict(self):
54
36
  return {
@@ -1,13 +1,6 @@
1
1
  import re
2
2
  import textwrap
3
3
 
4
- # from IPython.core.error import UsageError
5
-
6
-
7
- class AgentListError(Exception):
8
- def __init__(self, message):
9
- super().__init__(message)
10
-
11
4
 
12
5
  class ScenarioError(Exception):
13
6
  documentation = "https://docs.expectedparrot.com/en/latest/scenarios.html#module-edsl.scenarios.Scenario"
@@ -11,27 +11,21 @@ class AnthropicService(InferenceServiceABC):
11
11
 
12
12
  _inference_service_ = "anthropic"
13
13
  _env_key_name_ = "ANTHROPIC_API_KEY"
14
- key_sequence = ["content", 0, "text"]
14
+ key_sequence = ["content", 0, "text"] # ["content"][0]["text"]
15
15
  usage_sequence = ["usage"]
16
16
  input_token_name = "input_tokens"
17
17
  output_token_name = "output_tokens"
18
18
  model_exclude_list = []
19
19
 
20
- @classmethod
21
- def get_model_list(cls, api_key: str = None):
22
-
23
- import requests
24
-
25
- if api_key is None:
26
- api_key = os.environ.get("ANTHROPIC_API_KEY")
27
- headers = {"x-api-key": api_key, "anthropic-version": "2023-06-01"}
28
- response = requests.get("https://api.anthropic.com/v1/models", headers=headers)
29
- model_names = [m["id"] for m in response.json()["data"]]
30
- return model_names
31
-
32
20
  @classmethod
33
21
  def available(cls):
34
- return cls.get_model_list()
22
+ # TODO - replace with an API call
23
+ return [
24
+ "claude-3-5-sonnet-20240620",
25
+ "claude-3-opus-20240229",
26
+ "claude-3-sonnet-20240229",
27
+ "claude-3-haiku-20240307",
28
+ ]
35
29
 
36
30
  @classmethod
37
31
  def create_model(
@@ -62,42 +56,29 @@ class AnthropicService(InferenceServiceABC):
62
56
  "top_logprobs": 3,
63
57
  }
64
58
 
59
+ _tpm = cls.get_tpm(cls)
60
+ _rpm = cls.get_rpm(cls)
61
+
65
62
  async def async_execute_model_call(
66
63
  self,
67
64
  user_prompt: str,
68
65
  system_prompt: str = "",
69
66
  files_list: Optional[List["Files"]] = None,
70
67
  ) -> dict[str, Any]:
71
- """Calls the Anthropic API and returns the API response."""
68
+ """Calls the OpenAI API and returns the API response."""
72
69
 
73
- messages = [
74
- {
75
- "role": "user",
76
- "content": [{"type": "text", "text": user_prompt}],
77
- }
78
- ]
79
- if files_list:
80
- for file_entry in files_list:
81
- encoded_image = file_entry.base64_string
82
- messages[0]["content"].append(
83
- {
84
- "type": "image",
85
- "source": {
86
- "type": "base64",
87
- "media_type": file_entry.mime_type,
88
- "data": encoded_image,
89
- },
90
- }
91
- )
92
- # breakpoint()
93
- client = AsyncAnthropic(api_key=self.api_token)
70
+ api_key = os.environ.get("ANTHROPIC_API_KEY")
71
+ client = AsyncAnthropic(api_key=api_key)
94
72
 
95
73
  response = await client.messages.create(
96
74
  model=model_name,
97
75
  max_tokens=self.max_tokens,
98
76
  temperature=self.temperature,
99
- system=system_prompt, # note that the Anthropic API uses "system" parameter rather than put it in the message
100
- messages=messages,
77
+ system=system_prompt,
78
+ messages=[
79
+ # {"role": "system", "content": system_prompt},
80
+ {"role": "user", "content": user_prompt},
81
+ ],
101
82
  )
102
83
  return response.model_dump()
103
84
 
@@ -69,6 +69,8 @@ class AwsBedrockService(InferenceServiceABC):
69
69
  }
70
70
  input_token_name = cls.input_token_name
71
71
  output_token_name = cls.output_token_name
72
+ _rpm = cls.get_rpm(cls)
73
+ _tpm = cls.get_tpm(cls)
72
74
 
73
75
  async def async_execute_model_call(
74
76
  self,
@@ -118,6 +118,8 @@ class AzureAIService(InferenceServiceABC):
118
118
  "max_tokens": 512,
119
119
  "top_p": 0.9,
120
120
  }
121
+ _rpm = cls.get_rpm(cls)
122
+ _tpm = cls.get_tpm(cls)
121
123
 
122
124
  async def async_execute_model_call(
123
125
  self,
@@ -1,11 +1,11 @@
1
- # import os
1
+ import os
2
2
  from typing import Any, Dict, List, Optional
3
3
  import google
4
4
  import google.generativeai as genai
5
5
  from google.generativeai.types import GenerationConfig
6
6
  from google.api_core.exceptions import InvalidArgument
7
7
 
8
- # from edsl.exceptions.general import MissingAPIKeyError
8
+ from edsl.exceptions import MissingAPIKeyError
9
9
  from edsl.language_models.LanguageModel import LanguageModel
10
10
  from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
11
11
  from edsl.coop import Coop
@@ -39,18 +39,18 @@ class GoogleService(InferenceServiceABC):
39
39
 
40
40
  model_exclude_list = []
41
41
 
42
+ # @classmethod
43
+ # def available(cls) -> List[str]:
44
+ # return ["gemini-pro", "gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.0-pro"]
45
+
42
46
  @classmethod
43
- def get_model_list(cls):
47
+ def available(cls) -> List[str]:
44
48
  model_list = []
45
49
  for m in genai.list_models():
46
50
  if "generateContent" in m.supported_generation_methods:
47
51
  model_list.append(m.name.split("/")[-1])
48
52
  return model_list
49
53
 
50
- @classmethod
51
- def available(cls) -> List[str]:
52
- return cls.get_model_list()
53
-
54
54
  @classmethod
55
55
  def create_model(
56
56
  cls, model_name: str = "gemini-pro", model_class_name=None
@@ -66,6 +66,9 @@ class GoogleService(InferenceServiceABC):
66
66
  output_token_name = cls.output_token_name
67
67
  _inference_service_ = cls._inference_service_
68
68
 
69
+ _tpm = cls.get_tpm(cls)
70
+ _rpm = cls.get_rpm(cls)
71
+
69
72
  _parameters_ = {
70
73
  "temperature": 0.5,
71
74
  "topP": 1,
@@ -74,6 +77,7 @@ class GoogleService(InferenceServiceABC):
74
77
  "stopSequences": [],
75
78
  }
76
79
 
80
+ api_token = None
77
81
  model = None
78
82
 
79
83
  def __init__(self, *args, **kwargs):
@@ -98,6 +102,7 @@ class GoogleService(InferenceServiceABC):
98
102
 
99
103
  if files_list is None:
100
104
  files_list = []
105
+
101
106
  genai.configure(api_key=self.api_token)
102
107
  if (
103
108
  system_prompt is not None
@@ -1,4 +1,5 @@
1
1
  from abc import abstractmethod, ABC
2
+ import os
2
3
  import re
3
4
  from datetime import datetime, timedelta
4
5
  from edsl.config import CONFIG
@@ -7,32 +8,31 @@ from edsl.config import CONFIG
7
8
  class InferenceServiceABC(ABC):
8
9
  """
9
10
  Abstract class for inference services.
11
+ Anthropic: https://docs.anthropic.com/en/api/rate-limits
10
12
  """
11
13
 
12
14
  _coop_config_vars = None
13
15
 
16
+ default_levels = {
17
+ "google": {"tpm": 2_000_000, "rpm": 15},
18
+ "openai": {"tpm": 2_000_000, "rpm": 10_000},
19
+ "anthropic": {"tpm": 2_000_000, "rpm": 500},
20
+ }
21
+
14
22
  def __init_subclass__(cls):
15
23
  """
16
24
  Check that the subclass has the required attributes.
17
25
  - `key_sequence` attribute determines...
18
26
  - `model_exclude_list` attribute determines...
19
27
  """
20
- must_have_attributes = [
21
- "key_sequence",
22
- "model_exclude_list",
23
- "usage_sequence",
24
- "input_token_name",
25
- "output_token_name",
26
- ]
27
- for attr in must_have_attributes:
28
- if not hasattr(cls, attr):
29
- raise NotImplementedError(
30
- f"Class {cls.__name__} must have a '{attr}' attribute."
31
- )
32
-
33
- @property
34
- def service_name(self):
35
- return self._inference_service_
28
+ if not hasattr(cls, "key_sequence"):
29
+ raise NotImplementedError(
30
+ f"Class {cls.__name__} must have a 'key_sequence' attribute."
31
+ )
32
+ if not hasattr(cls, "model_exclude_list"):
33
+ raise NotImplementedError(
34
+ f"Class {cls.__name__} must have a 'model_exclude_list' attribute."
35
+ )
36
36
 
37
37
  @classmethod
38
38
  def _should_refresh_coop_config_vars(cls):
@@ -44,6 +44,44 @@ class InferenceServiceABC(ABC):
44
44
  return True
45
45
  return (datetime.now() - cls._last_config_fetch) > timedelta(hours=24)
46
46
 
47
+ @classmethod
48
+ def _get_limt(cls, limit_type: str) -> int:
49
+ key = f"EDSL_SERVICE_{limit_type.upper()}_{cls._inference_service_.upper()}"
50
+ if key in os.environ:
51
+ return int(os.getenv(key))
52
+
53
+ if cls._coop_config_vars is None or cls._should_refresh_coop_config_vars():
54
+ try:
55
+ from edsl import Coop
56
+
57
+ c = Coop()
58
+ cls._coop_config_vars = c.fetch_rate_limit_config_vars()
59
+ cls._last_config_fetch = datetime.now()
60
+ if key in cls._coop_config_vars:
61
+ return cls._coop_config_vars[key]
62
+ except Exception:
63
+ cls._coop_config_vars = None
64
+ else:
65
+ if key in cls._coop_config_vars:
66
+ return cls._coop_config_vars[key]
67
+
68
+ if cls._inference_service_ in cls.default_levels:
69
+ return int(cls.default_levels[cls._inference_service_][limit_type])
70
+
71
+ return int(CONFIG.get(f"EDSL_SERVICE_{limit_type.upper()}_BASELINE"))
72
+
73
+ def get_tpm(cls) -> int:
74
+ """
75
+ Returns the TPM for the service. If the service is not defined in the environment variables, it will return the baseline TPM.
76
+ """
77
+ return cls._get_limt(limit_type="tpm")
78
+
79
+ def get_rpm(cls):
80
+ """
81
+ Returns the RPM for the service. If the service is not defined in the environment variables, it will return the baseline RPM.
82
+ """
83
+ return cls._get_limt(limit_type="rpm")
84
+
47
85
  @abstractmethod
48
86
  def available() -> list[str]:
49
87
  """
@@ -75,6 +113,35 @@ class InferenceServiceABC(ABC):
75
113
 
76
114
 
77
115
  if __name__ == "__main__":
78
- import doctest
116
+ pass
117
+ # deep_infra_service = DeepInfraService("deep_infra", "DEEP_INFRA_API_KEY")
118
+ # deep_infra_service.available()
119
+ # m = deep_infra_service.create_model("microsoft/WizardLM-2-7B")
120
+ # response = m().hello()
121
+ # print(response)
122
+
123
+ # anthropic_service = AnthropicService("anthropic", "ANTHROPIC_API_KEY")
124
+ # anthropic_service.available()
125
+ # m = anthropic_service.create_model("claude-3-opus-20240229")
126
+ # response = m().hello()
127
+ # print(response)
128
+ # factory = OpenAIService("openai", "OPENAI_API")
129
+ # factory.available()
130
+ # m = factory.create_model("gpt-3.5-turbo")
131
+ # response = m().hello()
132
+
133
+ # from edsl import QuestionFreeText
134
+ # results = QuestionFreeText.example().by(m()).run()
135
+
136
+ # collection = InferenceServicesCollection([
137
+ # OpenAIService,
138
+ # AnthropicService,
139
+ # DeepInfraService
140
+ # ])
79
141
 
80
- doctest.testmod()
142
+ # available = collection.available()
143
+ # factory = collection.create_model_factory(*available[0])
144
+ # m = factory()
145
+ # from edsl import QuestionFreeText
146
+ # results = QuestionFreeText.example().by(m).run()
147
+ # print(results)