edsl 0.1.31.dev4__py3-none-any.whl → 0.1.33__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. edsl/Base.py +9 -3
  2. edsl/TemplateLoader.py +24 -0
  3. edsl/__init__.py +8 -3
  4. edsl/__version__.py +1 -1
  5. edsl/agents/Agent.py +40 -8
  6. edsl/agents/AgentList.py +43 -0
  7. edsl/agents/Invigilator.py +136 -221
  8. edsl/agents/InvigilatorBase.py +148 -59
  9. edsl/agents/{PromptConstructionMixin.py → PromptConstructor.py} +154 -85
  10. edsl/agents/__init__.py +1 -0
  11. edsl/auto/AutoStudy.py +117 -0
  12. edsl/auto/StageBase.py +230 -0
  13. edsl/auto/StageGenerateSurvey.py +178 -0
  14. edsl/auto/StageLabelQuestions.py +125 -0
  15. edsl/auto/StagePersona.py +61 -0
  16. edsl/auto/StagePersonaDimensionValueRanges.py +88 -0
  17. edsl/auto/StagePersonaDimensionValues.py +74 -0
  18. edsl/auto/StagePersonaDimensions.py +69 -0
  19. edsl/auto/StageQuestions.py +73 -0
  20. edsl/auto/SurveyCreatorPipeline.py +21 -0
  21. edsl/auto/utilities.py +224 -0
  22. edsl/config.py +48 -47
  23. edsl/conjure/Conjure.py +6 -0
  24. edsl/coop/PriceFetcher.py +58 -0
  25. edsl/coop/coop.py +50 -7
  26. edsl/data/Cache.py +35 -1
  27. edsl/data/CacheHandler.py +3 -4
  28. edsl/data_transfer_models.py +73 -38
  29. edsl/enums.py +8 -0
  30. edsl/exceptions/general.py +10 -8
  31. edsl/exceptions/language_models.py +25 -1
  32. edsl/exceptions/questions.py +62 -5
  33. edsl/exceptions/results.py +4 -0
  34. edsl/inference_services/AnthropicService.py +13 -11
  35. edsl/inference_services/AwsBedrock.py +112 -0
  36. edsl/inference_services/AzureAI.py +214 -0
  37. edsl/inference_services/DeepInfraService.py +4 -3
  38. edsl/inference_services/GoogleService.py +16 -12
  39. edsl/inference_services/GroqService.py +5 -4
  40. edsl/inference_services/InferenceServiceABC.py +58 -3
  41. edsl/inference_services/InferenceServicesCollection.py +13 -8
  42. edsl/inference_services/MistralAIService.py +120 -0
  43. edsl/inference_services/OllamaService.py +18 -0
  44. edsl/inference_services/OpenAIService.py +55 -56
  45. edsl/inference_services/TestService.py +80 -0
  46. edsl/inference_services/TogetherAIService.py +170 -0
  47. edsl/inference_services/models_available_cache.py +25 -0
  48. edsl/inference_services/registry.py +19 -1
  49. edsl/jobs/Answers.py +10 -12
  50. edsl/jobs/FailedQuestion.py +78 -0
  51. edsl/jobs/Jobs.py +137 -41
  52. edsl/jobs/buckets/BucketCollection.py +24 -15
  53. edsl/jobs/buckets/TokenBucket.py +105 -18
  54. edsl/jobs/interviews/Interview.py +393 -83
  55. edsl/jobs/interviews/{interview_exception_tracking.py → InterviewExceptionCollection.py} +22 -18
  56. edsl/jobs/interviews/InterviewExceptionEntry.py +167 -0
  57. edsl/jobs/runners/JobsRunnerAsyncio.py +152 -160
  58. edsl/jobs/runners/JobsRunnerStatus.py +331 -0
  59. edsl/jobs/tasks/QuestionTaskCreator.py +30 -23
  60. edsl/jobs/tasks/TaskCreators.py +1 -1
  61. edsl/jobs/tasks/TaskHistory.py +205 -126
  62. edsl/language_models/LanguageModel.py +297 -177
  63. edsl/language_models/ModelList.py +2 -2
  64. edsl/language_models/RegisterLanguageModelsMeta.py +14 -29
  65. edsl/language_models/fake_openai_call.py +15 -0
  66. edsl/language_models/fake_openai_service.py +61 -0
  67. edsl/language_models/registry.py +25 -8
  68. edsl/language_models/repair.py +0 -19
  69. edsl/language_models/utilities.py +61 -0
  70. edsl/notebooks/Notebook.py +20 -2
  71. edsl/prompts/Prompt.py +52 -2
  72. edsl/questions/AnswerValidatorMixin.py +23 -26
  73. edsl/questions/QuestionBase.py +330 -249
  74. edsl/questions/QuestionBaseGenMixin.py +133 -0
  75. edsl/questions/QuestionBasePromptsMixin.py +266 -0
  76. edsl/questions/QuestionBudget.py +99 -42
  77. edsl/questions/QuestionCheckBox.py +227 -36
  78. edsl/questions/QuestionExtract.py +98 -28
  79. edsl/questions/QuestionFreeText.py +47 -31
  80. edsl/questions/QuestionFunctional.py +7 -0
  81. edsl/questions/QuestionList.py +141 -23
  82. edsl/questions/QuestionMultipleChoice.py +159 -66
  83. edsl/questions/QuestionNumerical.py +88 -47
  84. edsl/questions/QuestionRank.py +182 -25
  85. edsl/questions/Quick.py +41 -0
  86. edsl/questions/RegisterQuestionsMeta.py +31 -12
  87. edsl/questions/ResponseValidatorABC.py +170 -0
  88. edsl/questions/__init__.py +3 -4
  89. edsl/questions/decorators.py +21 -0
  90. edsl/questions/derived/QuestionLikertFive.py +10 -5
  91. edsl/questions/derived/QuestionLinearScale.py +15 -2
  92. edsl/questions/derived/QuestionTopK.py +10 -1
  93. edsl/questions/derived/QuestionYesNo.py +24 -3
  94. edsl/questions/descriptors.py +43 -7
  95. edsl/questions/prompt_templates/question_budget.jinja +13 -0
  96. edsl/questions/prompt_templates/question_checkbox.jinja +32 -0
  97. edsl/questions/prompt_templates/question_extract.jinja +11 -0
  98. edsl/questions/prompt_templates/question_free_text.jinja +3 -0
  99. edsl/questions/prompt_templates/question_linear_scale.jinja +11 -0
  100. edsl/questions/prompt_templates/question_list.jinja +17 -0
  101. edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -0
  102. edsl/questions/prompt_templates/question_numerical.jinja +37 -0
  103. edsl/questions/question_registry.py +6 -2
  104. edsl/questions/templates/__init__.py +0 -0
  105. edsl/questions/templates/budget/__init__.py +0 -0
  106. edsl/questions/templates/budget/answering_instructions.jinja +7 -0
  107. edsl/questions/templates/budget/question_presentation.jinja +7 -0
  108. edsl/questions/templates/checkbox/__init__.py +0 -0
  109. edsl/questions/templates/checkbox/answering_instructions.jinja +10 -0
  110. edsl/questions/templates/checkbox/question_presentation.jinja +22 -0
  111. edsl/questions/templates/extract/__init__.py +0 -0
  112. edsl/questions/templates/extract/answering_instructions.jinja +7 -0
  113. edsl/questions/templates/extract/question_presentation.jinja +1 -0
  114. edsl/questions/templates/free_text/__init__.py +0 -0
  115. edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
  116. edsl/questions/templates/free_text/question_presentation.jinja +1 -0
  117. edsl/questions/templates/likert_five/__init__.py +0 -0
  118. edsl/questions/templates/likert_five/answering_instructions.jinja +10 -0
  119. edsl/questions/templates/likert_five/question_presentation.jinja +12 -0
  120. edsl/questions/templates/linear_scale/__init__.py +0 -0
  121. edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -0
  122. edsl/questions/templates/linear_scale/question_presentation.jinja +5 -0
  123. edsl/questions/templates/list/__init__.py +0 -0
  124. edsl/questions/templates/list/answering_instructions.jinja +4 -0
  125. edsl/questions/templates/list/question_presentation.jinja +5 -0
  126. edsl/questions/templates/multiple_choice/__init__.py +0 -0
  127. edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -0
  128. edsl/questions/templates/multiple_choice/html.jinja +0 -0
  129. edsl/questions/templates/multiple_choice/question_presentation.jinja +12 -0
  130. edsl/questions/templates/numerical/__init__.py +0 -0
  131. edsl/questions/templates/numerical/answering_instructions.jinja +8 -0
  132. edsl/questions/templates/numerical/question_presentation.jinja +7 -0
  133. edsl/questions/templates/rank/__init__.py +0 -0
  134. edsl/questions/templates/rank/answering_instructions.jinja +11 -0
  135. edsl/questions/templates/rank/question_presentation.jinja +15 -0
  136. edsl/questions/templates/top_k/__init__.py +0 -0
  137. edsl/questions/templates/top_k/answering_instructions.jinja +8 -0
  138. edsl/questions/templates/top_k/question_presentation.jinja +22 -0
  139. edsl/questions/templates/yes_no/__init__.py +0 -0
  140. edsl/questions/templates/yes_no/answering_instructions.jinja +6 -0
  141. edsl/questions/templates/yes_no/question_presentation.jinja +12 -0
  142. edsl/results/Dataset.py +20 -0
  143. edsl/results/DatasetExportMixin.py +58 -30
  144. edsl/results/DatasetTree.py +145 -0
  145. edsl/results/Result.py +32 -5
  146. edsl/results/Results.py +135 -46
  147. edsl/results/ResultsDBMixin.py +3 -3
  148. edsl/results/Selector.py +118 -0
  149. edsl/results/tree_explore.py +115 -0
  150. edsl/scenarios/FileStore.py +71 -10
  151. edsl/scenarios/Scenario.py +109 -24
  152. edsl/scenarios/ScenarioImageMixin.py +2 -2
  153. edsl/scenarios/ScenarioList.py +546 -21
  154. edsl/scenarios/ScenarioListExportMixin.py +24 -4
  155. edsl/scenarios/ScenarioListPdfMixin.py +153 -4
  156. edsl/study/SnapShot.py +8 -1
  157. edsl/study/Study.py +32 -0
  158. edsl/surveys/Rule.py +15 -3
  159. edsl/surveys/RuleCollection.py +21 -5
  160. edsl/surveys/Survey.py +707 -298
  161. edsl/surveys/SurveyExportMixin.py +71 -9
  162. edsl/surveys/SurveyFlowVisualizationMixin.py +2 -1
  163. edsl/surveys/SurveyQualtricsImport.py +284 -0
  164. edsl/surveys/instructions/ChangeInstruction.py +47 -0
  165. edsl/surveys/instructions/Instruction.py +34 -0
  166. edsl/surveys/instructions/InstructionCollection.py +77 -0
  167. edsl/surveys/instructions/__init__.py +0 -0
  168. edsl/templates/error_reporting/base.html +24 -0
  169. edsl/templates/error_reporting/exceptions_by_model.html +35 -0
  170. edsl/templates/error_reporting/exceptions_by_question_name.html +17 -0
  171. edsl/templates/error_reporting/exceptions_by_type.html +17 -0
  172. edsl/templates/error_reporting/interview_details.html +116 -0
  173. edsl/templates/error_reporting/interviews.html +10 -0
  174. edsl/templates/error_reporting/overview.html +5 -0
  175. edsl/templates/error_reporting/performance_plot.html +2 -0
  176. edsl/templates/error_reporting/report.css +74 -0
  177. edsl/templates/error_reporting/report.html +118 -0
  178. edsl/templates/error_reporting/report.js +25 -0
  179. edsl/utilities/utilities.py +40 -1
  180. {edsl-0.1.31.dev4.dist-info → edsl-0.1.33.dist-info}/METADATA +8 -2
  181. edsl-0.1.33.dist-info/RECORD +295 -0
  182. edsl/jobs/interviews/InterviewTaskBuildingMixin.py +0 -271
  183. edsl/jobs/interviews/retry_management.py +0 -37
  184. edsl/jobs/runners/JobsRunnerStatusMixin.py +0 -303
  185. edsl/utilities/gcp_bucket/simple_example.py +0 -9
  186. edsl-0.1.31.dev4.dist-info/RECORD +0 -204
  187. {edsl-0.1.31.dev4.dist-info → edsl-0.1.33.dist-info}/LICENSE +0 -0
  188. {edsl-0.1.31.dev4.dist-info → edsl-0.1.33.dist-info}/WHEEL +0 -0
@@ -1,16 +1,73 @@
1
+ from typing import Any, SupportsIndex
2
+ from jinja2 import Template
3
+ import json
4
+
5
+
1
6
  class QuestionErrors(Exception):
2
- pass
7
+ """
8
+ Base exception class for question-related errors.
9
+ """
3
10
 
11
+ def __init__(self, message="An error occurred with the question"):
12
+ self.message = message
13
+ super().__init__(self.message)
4
14
 
5
- class QuestionCreationValidationError(QuestionErrors):
6
- pass
7
15
 
16
+ class QuestionAnswerValidationError(QuestionErrors):
17
+ documentation = "https://docs.expectedparrot.com/en/latest/exceptions.html"
18
+
19
+ explanation = """This when the answer coming from the Language Model does not conform to the expectation for that question type.
20
+ For example, if the question is a multiple choice question, the answer should be drawn from the list of options provided.
21
+ """
22
+
23
+ def __init__(self, message="Invalid answer.", data=None, model=None):
24
+ self.message = message
25
+ self.data = data
26
+ self.model = model
27
+ super().__init__(self.message)
28
+
29
+ def __str__(self):
30
+ return f"""{repr(self)}
31
+ Data being validated: {self.data}
32
+ Pydnantic Model: {self.model}.
33
+ Reported error: {self.message}."""
34
+
35
+ def to_html_dict(self):
36
+ return {
37
+ "error_type": ("Name of the exception", "p", "/p", self.__class__.__name__),
38
+ "explaination": ("Explanation", "p", "/p", self.explanation),
39
+ "edsl answer": (
40
+ "What model returned",
41
+ "pre",
42
+ "/pre",
43
+ json.dumps(self.data, indent=2),
44
+ ),
45
+ "validating_model": (
46
+ "Pydantic model for answers",
47
+ "pre",
48
+ "/pre",
49
+ json.dumps(self.model.model_json_schema(), indent=2),
50
+ ),
51
+ "error_message": (
52
+ "Error message Pydantic returned",
53
+ "p",
54
+ "/p",
55
+ self.message,
56
+ ),
57
+ "documentation_url": (
58
+ "URL to EDSL docs",
59
+ f"a href='{self.documentation}'",
60
+ "/a",
61
+ self.documentation,
62
+ ),
63
+ }
8
64
 
9
- class QuestionResponseValidationError(QuestionErrors):
65
+
66
+ class QuestionCreationValidationError(QuestionErrors):
10
67
  pass
11
68
 
12
69
 
13
- class QuestionAnswerValidationError(QuestionErrors):
70
+ class QuestionResponseValidationError(QuestionErrors):
14
71
  pass
15
72
 
16
73
 
@@ -2,6 +2,10 @@ class ResultsErrors(Exception):
2
2
  pass
3
3
 
4
4
 
5
+ class ResultsDeserializationError(ResultsErrors):
6
+ pass
7
+
8
+
5
9
  class ResultsBadMutationstringError(ResultsErrors):
6
10
  pass
7
11
 
@@ -11,6 +11,11 @@ class AnthropicService(InferenceServiceABC):
11
11
 
12
12
  _inference_service_ = "anthropic"
13
13
  _env_key_name_ = "ANTHROPIC_API_KEY"
14
+ key_sequence = ["content", 0, "text"] # ["content"][0]["text"]
15
+ usage_sequence = ["usage"]
16
+ input_token_name = "input_tokens"
17
+ output_token_name = "output_tokens"
18
+ model_exclude_list = []
14
19
 
15
20
  @classmethod
16
21
  def available(cls):
@@ -34,6 +39,11 @@ class AnthropicService(InferenceServiceABC):
34
39
  Child class of LanguageModel for interacting with OpenAI models
35
40
  """
36
41
 
42
+ key_sequence = cls.key_sequence
43
+ usage_sequence = cls.usage_sequence
44
+ input_token_name = cls.input_token_name
45
+ output_token_name = cls.output_token_name
46
+
37
47
  _inference_service_ = cls._inference_service_
38
48
  _model_ = model_name
39
49
  _parameters_ = {
@@ -46,6 +56,9 @@ class AnthropicService(InferenceServiceABC):
46
56
  "top_logprobs": 3,
47
57
  }
48
58
 
59
+ _tpm = cls.get_tpm(cls)
60
+ _rpm = cls.get_rpm(cls)
61
+
49
62
  async def async_execute_model_call(
50
63
  self, user_prompt: str, system_prompt: str = ""
51
64
  ) -> dict[str, Any]:
@@ -66,17 +79,6 @@ class AnthropicService(InferenceServiceABC):
66
79
  )
67
80
  return response.model_dump()
68
81
 
69
- @staticmethod
70
- def parse_response(raw_response: dict[str, Any]) -> str:
71
- """Parses the API response and returns the response text."""
72
- response = raw_response["content"][0]["text"]
73
- pattern = r"^```json(?:\\n|\n)(.+?)(?:\\n|\n)```$"
74
- match = re.match(pattern, response, re.DOTALL)
75
- if match:
76
- return match.group(1)
77
- else:
78
- return response
79
-
80
82
  LLM.__name__ = model_class_name
81
83
 
82
84
  return LLM
@@ -0,0 +1,112 @@
1
+ import os
2
+ from typing import Any
3
+ import re
4
+ import boto3
5
+ from botocore.exceptions import ClientError
6
+ from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
7
+ from edsl.language_models.LanguageModel import LanguageModel
8
+ import json
9
+ from edsl.utilities.utilities import fix_partial_correct_response
10
+
11
+
12
+ class AwsBedrockService(InferenceServiceABC):
13
+ """AWS Bedrock service class."""
14
+
15
+ _inference_service_ = "bedrock"
16
+ _env_key_name_ = (
17
+ "AWS_ACCESS_KEY_ID" # or any other environment key for AWS credentials
18
+ )
19
+ key_sequence = ["output", "message", "content", 0, "text"]
20
+ input_token_name = "inputTokens"
21
+ output_token_name = "outputTokens"
22
+ usage_sequence = ["usage"]
23
+ model_exclude_list = [
24
+ "ai21.j2-grande-instruct",
25
+ "ai21.j2-jumbo-instruct",
26
+ "ai21.j2-mid",
27
+ "ai21.j2-mid-v1",
28
+ "ai21.j2-ultra",
29
+ "ai21.j2-ultra-v1",
30
+ ]
31
+
32
+ @classmethod
33
+ def available(cls):
34
+ """Fetch available models from AWS Bedrock."""
35
+ if not cls._models_list_cache:
36
+ client = boto3.client("bedrock", region_name="us-west-2")
37
+ all_models_ids = [
38
+ x["modelId"] for x in client.list_foundation_models()["modelSummaries"]
39
+ ]
40
+ else:
41
+ all_models_ids = cls._models_list_cache
42
+
43
+ return [m for m in all_models_ids if m not in cls.model_exclude_list]
44
+
45
+ @classmethod
46
+ def create_model(
47
+ cls, model_name: str = "amazon.titan-tg1-large", model_class_name=None
48
+ ) -> LanguageModel:
49
+ if model_class_name is None:
50
+ model_class_name = cls.to_class_name(model_name)
51
+
52
+ class LLM(LanguageModel):
53
+ """
54
+ Child class of LanguageModel for interacting with AWS Bedrock models.
55
+ """
56
+
57
+ key_sequence = cls.key_sequence
58
+ usage_sequence = cls.usage_sequence
59
+ _inference_service_ = cls._inference_service_
60
+ _model_ = model_name
61
+ _parameters_ = {
62
+ "temperature": 0.5,
63
+ "max_tokens": 512,
64
+ "top_p": 0.9,
65
+ }
66
+ input_token_name = cls.input_token_name
67
+ output_token_name = cls.output_token_name
68
+ _rpm = cls.get_rpm(cls)
69
+ _tpm = cls.get_tpm(cls)
70
+
71
+ async def async_execute_model_call(
72
+ self, user_prompt: str, system_prompt: str = ""
73
+ ) -> dict[str, Any]:
74
+ """Calls the AWS Bedrock API and returns the API response."""
75
+
76
+ api_token = (
77
+ self.api_token
78
+ ) # call to check the if env variables are set.
79
+
80
+ client = boto3.client("bedrock-runtime", region_name="us-west-2")
81
+
82
+ conversation = [
83
+ {
84
+ "role": "user",
85
+ "content": [{"text": user_prompt}],
86
+ }
87
+ ]
88
+ system = [
89
+ {
90
+ "text": system_prompt,
91
+ }
92
+ ]
93
+ try:
94
+ response = client.converse(
95
+ modelId=self._model_,
96
+ messages=conversation,
97
+ inferenceConfig={
98
+ "maxTokens": self.max_tokens,
99
+ "temperature": self.temperature,
100
+ "topP": self.top_p,
101
+ },
102
+ # system=system,
103
+ additionalModelRequestFields={},
104
+ )
105
+ return response
106
+ except (ClientError, Exception) as e:
107
+ print(e)
108
+ return {"error": str(e)}
109
+
110
+ LLM.__name__ = model_class_name
111
+
112
+ return LLM
@@ -0,0 +1,214 @@
1
+ import os
2
+ from typing import Any
3
+ import re
4
+ from openai import AsyncAzureOpenAI
5
+ from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
6
+ from edsl.language_models.LanguageModel import LanguageModel
7
+
8
+ from azure.ai.inference.aio import ChatCompletionsClient
9
+ from azure.core.credentials import AzureKeyCredential
10
+ from azure.ai.inference.models import SystemMessage, UserMessage
11
+ import asyncio
12
+ import json
13
+ from edsl.utilities.utilities import fix_partial_correct_response
14
+
15
+
16
+ def json_handle_none(value: Any) -> Any:
17
+ """
18
+ Handle None values during JSON serialization.
19
+ - Return "null" if the value is None. Otherwise, don't return anything.
20
+ """
21
+ if value is None:
22
+ return "null"
23
+
24
+
25
+ class AzureAIService(InferenceServiceABC):
26
+ """Azure AI service class."""
27
+
28
+ # key_sequence = ["content", 0, "text"] # ["content"][0]["text"]
29
+ key_sequence = ["choices", 0, "message", "content"]
30
+ usage_sequence = ["usage"]
31
+ input_token_name = "prompt_tokens"
32
+ output_token_name = "completion_tokens"
33
+
34
+ _inference_service_ = "azure"
35
+ _env_key_name_ = (
36
+ "AZURE_ENDPOINT_URL_AND_KEY" # Environment variable for Azure API key
37
+ )
38
+ _model_id_to_endpoint_and_key = {}
39
+ model_exclude_list = [
40
+ "Cohere-command-r-plus-xncmg",
41
+ "Mistral-Nemo-klfsi",
42
+ "Mistral-large-2407-ojfld",
43
+ ]
44
+
45
+ @classmethod
46
+ def available(cls):
47
+ out = []
48
+ azure_endpoints = os.getenv("AZURE_ENDPOINT_URL_AND_KEY", None)
49
+ if not azure_endpoints:
50
+ raise EnvironmentError(f"AZURE_ENDPOINT_URL_AND_KEY is not defined")
51
+ azure_endpoints = azure_endpoints.split(",")
52
+ for data in azure_endpoints:
53
+ try:
54
+ # data has this format for non openai models https://model_id.azure_endpoint:azure_key
55
+ _, endpoint, azure_endpoint_key = data.split(":")
56
+ if "openai" not in endpoint:
57
+ model_id = endpoint.split(".")[0].replace("/", "")
58
+ out.append(model_id)
59
+ cls._model_id_to_endpoint_and_key[model_id] = {
60
+ "endpoint": f"https:{endpoint}",
61
+ "azure_endpoint_key": azure_endpoint_key,
62
+ }
63
+ else:
64
+ # data has this format for openai models ,https://azure_project_id.openai.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2023-03-15-preview:azure_key
65
+ if "/deployments/" in endpoint:
66
+ start_idx = endpoint.index("/deployments/") + len(
67
+ "/deployments/"
68
+ )
69
+ end_idx = (
70
+ endpoint.index("/", start_idx)
71
+ if "/" in endpoint[start_idx:]
72
+ else len(endpoint)
73
+ )
74
+ model_id = endpoint[start_idx:end_idx]
75
+ api_version_value = None
76
+ if "api-version=" in endpoint:
77
+ start_idx = endpoint.index("api-version=") + len(
78
+ "api-version="
79
+ )
80
+ end_idx = (
81
+ endpoint.index("&", start_idx)
82
+ if "&" in endpoint[start_idx:]
83
+ else len(endpoint)
84
+ )
85
+ api_version_value = endpoint[start_idx:end_idx]
86
+
87
+ cls._model_id_to_endpoint_and_key[f"azure:{model_id}"] = {
88
+ "endpoint": f"https:{endpoint}",
89
+ "azure_endpoint_key": azure_endpoint_key,
90
+ "api_version": api_version_value,
91
+ }
92
+ out.append(f"azure:{model_id}")
93
+
94
+ except Exception as e:
95
+ raise e
96
+ return [m for m in out if m not in cls.model_exclude_list]
97
+
98
+ @classmethod
99
+ def create_model(
100
+ cls, model_name: str = "azureai", model_class_name=None
101
+ ) -> LanguageModel:
102
+ if model_class_name is None:
103
+ model_class_name = cls.to_class_name(model_name)
104
+
105
+ class LLM(LanguageModel):
106
+ """
107
+ Child class of LanguageModel for interacting with Azure OpenAI models.
108
+ """
109
+
110
+ key_sequence = cls.key_sequence
111
+ usage_sequence = cls.usage_sequence
112
+ input_token_name = cls.input_token_name
113
+ output_token_name = cls.output_token_name
114
+ _inference_service_ = cls._inference_service_
115
+ _model_ = model_name
116
+ _parameters_ = {
117
+ "temperature": 0.5,
118
+ "max_tokens": 512,
119
+ "top_p": 0.9,
120
+ }
121
+ _rpm = cls.get_rpm(cls)
122
+ _tpm = cls.get_tpm(cls)
123
+
124
+ async def async_execute_model_call(
125
+ self, user_prompt: str, system_prompt: str = ""
126
+ ) -> dict[str, Any]:
127
+ """Calls the Azure OpenAI API and returns the API response."""
128
+
129
+ try:
130
+ api_key = cls._model_id_to_endpoint_and_key[model_name][
131
+ "azure_endpoint_key"
132
+ ]
133
+ except:
134
+ api_key = None
135
+
136
+ if not api_key:
137
+ raise EnvironmentError(
138
+ f"AZURE_ENDPOINT_URL_AND_KEY doesn't have the endpoint:key pair for your model: {model_name}"
139
+ )
140
+
141
+ try:
142
+ endpoint = cls._model_id_to_endpoint_and_key[model_name]["endpoint"]
143
+ except:
144
+ endpoint = None
145
+
146
+ if not endpoint:
147
+ raise EnvironmentError(
148
+ f"AZURE_ENDPOINT_URL_AND_KEY doesn't have the endpoint:key pair for your model: {model_name}"
149
+ )
150
+
151
+ if "openai" not in endpoint:
152
+ client = ChatCompletionsClient(
153
+ endpoint=endpoint,
154
+ credential=AzureKeyCredential(api_key),
155
+ temperature=self.temperature,
156
+ top_p=self.top_p,
157
+ max_tokens=self.max_tokens,
158
+ )
159
+ try:
160
+ response = await client.complete(
161
+ messages=[
162
+ SystemMessage(content=system_prompt),
163
+ UserMessage(content=user_prompt),
164
+ ],
165
+ # model_extras={"safe_mode": True},
166
+ )
167
+ await client.close()
168
+ return response.as_dict()
169
+ except Exception as e:
170
+ await client.close()
171
+ return {"error": str(e)}
172
+ else:
173
+ api_version = cls._model_id_to_endpoint_and_key[model_name][
174
+ "api_version"
175
+ ]
176
+ client = AsyncAzureOpenAI(
177
+ azure_endpoint=endpoint,
178
+ api_version=api_version,
179
+ api_key=api_key,
180
+ )
181
+ response = await client.chat.completions.create(
182
+ model=model_name,
183
+ messages=[
184
+ {
185
+ "role": "user",
186
+ "content": user_prompt, # Your question can go here
187
+ },
188
+ ],
189
+ )
190
+ return response.model_dump()
191
+
192
+ # @staticmethod
193
+ # def parse_response(raw_response: dict[str, Any]) -> str:
194
+ # """Parses the API response and returns the response text."""
195
+ # if (
196
+ # raw_response
197
+ # and "choices" in raw_response
198
+ # and raw_response["choices"]
199
+ # ):
200
+ # response = raw_response["choices"][0]["message"]["content"]
201
+ # pattern = r"^```json(?:\\n|\n)(.+?)(?:\\n|\n)```$"
202
+ # match = re.match(pattern, response, re.DOTALL)
203
+ # if match:
204
+ # return match.group(1)
205
+ # else:
206
+ # out = fix_partial_correct_response(response)
207
+ # if "error" not in out:
208
+ # response = out["extracted_json"]
209
+ # return response
210
+ # return "Error parsing response"
211
+
212
+ LLM.__name__ = model_class_name
213
+
214
+ return LLM
@@ -2,16 +2,17 @@ import aiohttp
2
2
  import json
3
3
  import requests
4
4
  from typing import Any, List
5
- #from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
5
+
6
+ # from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
6
7
  from edsl.language_models import LanguageModel
7
8
 
8
9
  from edsl.inference_services.OpenAIService import OpenAIService
9
10
 
11
+
10
12
  class DeepInfraService(OpenAIService):
11
13
  """DeepInfra service class."""
12
14
 
13
15
  _inference_service_ = "deep_infra"
14
16
  _env_key_name_ = "DEEP_INFRA_API_KEY"
15
- _base_url_ = "https://api.deepinfra.com/v1/openai"
17
+ _base_url_ = "https://api.deepinfra.com/v1/openai"
16
18
  _models_list_cache: List[str] = []
17
-
@@ -10,10 +10,16 @@ from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
10
10
 
11
11
  class GoogleService(InferenceServiceABC):
12
12
  _inference_service_ = "google"
13
+ key_sequence = ["candidates", 0, "content", "parts", 0, "text"]
14
+ usage_sequence = ["usageMetadata"]
15
+ input_token_name = "promptTokenCount"
16
+ output_token_name = "candidatesTokenCount"
17
+
18
+ model_exclude_list = []
13
19
 
14
20
  @classmethod
15
21
  def available(cls):
16
- return ["gemini-pro"]
22
+ return ["gemini-pro", "gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.0-pro"]
17
23
 
18
24
  @classmethod
19
25
  def create_model(
@@ -24,7 +30,15 @@ class GoogleService(InferenceServiceABC):
24
30
 
25
31
  class LLM(LanguageModel):
26
32
  _model_ = model_name
33
+ key_sequence = cls.key_sequence
34
+ usage_sequence = cls.usage_sequence
35
+ input_token_name = cls.input_token_name
36
+ output_token_name = cls.output_token_name
27
37
  _inference_service_ = cls._inference_service_
38
+
39
+ _tpm = cls.get_tpm(cls)
40
+ _rpm = cls.get_rpm(cls)
41
+
28
42
  _parameters_ = {
29
43
  "temperature": 0.5,
30
44
  "topP": 1,
@@ -50,7 +64,7 @@ class GoogleService(InferenceServiceABC):
50
64
  "stopSequences": self.stopSequences,
51
65
  },
52
66
  }
53
-
67
+ # print(combined_prompt)
54
68
  async with aiohttp.ClientSession() as session:
55
69
  async with session.post(
56
70
  url, headers=headers, data=json.dumps(data)
@@ -58,16 +72,6 @@ class GoogleService(InferenceServiceABC):
58
72
  raw_response_text = await response.text()
59
73
  return json.loads(raw_response_text)
60
74
 
61
- def parse_response(self, raw_response: dict[str, Any]) -> str:
62
- data = raw_response
63
- try:
64
- return data["candidates"][0]["content"]["parts"][0]["text"]
65
- except KeyError as e:
66
- print(
67
- f"The data return was {data}, which was missing the key 'candidates'"
68
- )
69
- raise e
70
-
71
75
  LLM.__name__ = model_name
72
76
 
73
77
  return LLM
@@ -10,10 +10,11 @@ class GroqService(OpenAIService):
10
10
  _inference_service_ = "groq"
11
11
  _env_key_name_ = "GROQ_API_KEY"
12
12
 
13
- _sync_client_ = groq.Groq
14
- _async_client_ = groq.AsyncGroq
13
+ _sync_client_ = groq.Groq
14
+ _async_client_ = groq.AsyncGroq
15
15
 
16
- #_base_url_ = "https://api.deepinfra.com/v1/openai"
16
+ model_exclude_list = ["whisper-large-v3", "distil-whisper-large-v3-en"]
17
+
18
+ # _base_url_ = "https://api.deepinfra.com/v1/openai"
17
19
  _base_url_ = None
18
20
  _models_list_cache: List[str] = []
19
-
@@ -1,22 +1,77 @@
1
1
  from abc import abstractmethod, ABC
2
- from typing import Any
2
+ import os
3
3
  import re
4
+ from edsl.config import CONFIG
4
5
 
5
6
 
6
7
  class InferenceServiceABC(ABC):
7
- """Abstract class for inference services."""
8
+ """
9
+ Abstract class for inference services.
10
+ Anthropic: https://docs.anthropic.com/en/api/rate-limits
11
+ """
12
+
13
+ default_levels = {
14
+ "google": {"tpm": 2_000_000, "rpm": 15},
15
+ "openai": {"tpm": 2_000_000, "rpm": 10_000},
16
+ "anthropic": {"tpm": 2_000_000, "rpm": 500},
17
+ }
18
+
19
+ def __init_subclass__(cls):
20
+ """
21
+ Check that the subclass has the required attributes.
22
+ - `key_sequence` attribute determines...
23
+ - `model_exclude_list` attribute determines...
24
+ """
25
+ if not hasattr(cls, "key_sequence"):
26
+ raise NotImplementedError(
27
+ f"Class {cls.__name__} must have a 'key_sequence' attribute."
28
+ )
29
+ if not hasattr(cls, "model_exclude_list"):
30
+ raise NotImplementedError(
31
+ f"Class {cls.__name__} must have a 'model_exclude_list' attribute."
32
+ )
33
+
34
+ @classmethod
35
+ def _get_limt(cls, limit_type: str) -> int:
36
+ key = f"EDSL_SERVICE_{limit_type.upper()}_{cls._inference_service_.upper()}"
37
+ if key in os.environ:
38
+ return int(os.getenv(key))
39
+
40
+ if cls._inference_service_ in cls.default_levels:
41
+ return int(cls.default_levels[cls._inference_service_][limit_type])
42
+
43
+ return int(CONFIG.get(f"EDSL_SERVICE_{limit_type.upper()}_BASELINE"))
44
+
45
+ def get_tpm(cls) -> int:
46
+ """
47
+ Returns the TPM for the service. If the service is not defined in the environment variables, it will return the baseline TPM.
48
+ """
49
+ return cls._get_limt(limit_type="tpm")
50
+
51
+ def get_rpm(cls):
52
+ """
53
+ Returns the RPM for the service. If the service is not defined in the environment variables, it will return the baseline RPM.
54
+ """
55
+ return cls._get_limt(limit_type="rpm")
8
56
 
9
57
  @abstractmethod
10
58
  def available() -> list[str]:
59
+ """
60
+ Returns a list of available models for the service.
61
+ """
11
62
  pass
12
63
 
13
64
  @abstractmethod
14
65
  def create_model():
66
+ """
67
+ Returns a LanguageModel object.
68
+ """
15
69
  pass
16
70
 
17
71
  @staticmethod
18
72
  def to_class_name(s):
19
- """Convert a string to a valid class name.
73
+ """
74
+ Converts a string to a valid class name.
20
75
 
21
76
  >>> InferenceServiceABC.to_class_name("hello world")
22
77
  'HelloWorld'
@@ -15,18 +15,19 @@ class InferenceServicesCollection:
15
15
  cls.added_models[service_name].append(model_name)
16
16
 
17
17
  @staticmethod
18
- def _get_service_available(service) -> list[str]:
18
+ def _get_service_available(service, warn: bool = False) -> list[str]:
19
19
  from_api = True
20
20
  try:
21
21
  service_models = service.available()
22
22
  except Exception as e:
23
- warnings.warn(
24
- f"""Error getting models for {service._inference_service_}.
25
- Check that you have properly stored your Expected Parrot API key and activated remote inference, or stored your own API keys for the language models that you want to use.
26
- See https://docs.expectedparrot.com/en/latest/api_keys.html for instructions on storing API keys.
27
- Relying on cache.""",
28
- UserWarning,
29
- )
23
+ if warn:
24
+ warnings.warn(
25
+ f"""Error getting models for {service._inference_service_}.
26
+ Check that you have properly stored your Expected Parrot API key and activated remote inference, or stored your own API keys for the language models that you want to use.
27
+ See https://docs.expectedparrot.com/en/latest/api_keys.html for instructions on storing API keys.
28
+ Relying on cache.""",
29
+ UserWarning,
30
+ )
30
31
  from edsl.inference_services.models_available_cache import models_available
31
32
 
32
33
  service_models = models_available.get(service._inference_service_, [])
@@ -60,4 +61,8 @@ class InferenceServicesCollection:
60
61
  if service_name is None or service_name == service._inference_service_:
61
62
  return service.create_model(model_name)
62
63
 
64
+ # if model_name == "test":
65
+ # from edsl.language_models import LanguageModel
66
+ # return LanguageModel(test = True)
67
+
63
68
  raise Exception(f"Model {model_name} not found in any of the services")