edsl 0.1.32__py3-none-any.whl → 0.1.33__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. edsl/Base.py +9 -3
  2. edsl/TemplateLoader.py +24 -0
  3. edsl/__init__.py +8 -3
  4. edsl/__version__.py +1 -1
  5. edsl/agents/Agent.py +40 -8
  6. edsl/agents/AgentList.py +43 -0
  7. edsl/agents/Invigilator.py +135 -219
  8. edsl/agents/InvigilatorBase.py +148 -59
  9. edsl/agents/{PromptConstructionMixin.py → PromptConstructor.py} +138 -89
  10. edsl/agents/__init__.py +1 -0
  11. edsl/auto/AutoStudy.py +117 -0
  12. edsl/auto/StageBase.py +230 -0
  13. edsl/auto/StageGenerateSurvey.py +178 -0
  14. edsl/auto/StageLabelQuestions.py +125 -0
  15. edsl/auto/StagePersona.py +61 -0
  16. edsl/auto/StagePersonaDimensionValueRanges.py +88 -0
  17. edsl/auto/StagePersonaDimensionValues.py +74 -0
  18. edsl/auto/StagePersonaDimensions.py +69 -0
  19. edsl/auto/StageQuestions.py +73 -0
  20. edsl/auto/SurveyCreatorPipeline.py +21 -0
  21. edsl/auto/utilities.py +224 -0
  22. edsl/config.py +47 -56
  23. edsl/coop/PriceFetcher.py +58 -0
  24. edsl/coop/coop.py +50 -7
  25. edsl/data/Cache.py +35 -1
  26. edsl/data_transfer_models.py +73 -38
  27. edsl/enums.py +4 -0
  28. edsl/exceptions/language_models.py +25 -1
  29. edsl/exceptions/questions.py +62 -5
  30. edsl/exceptions/results.py +4 -0
  31. edsl/inference_services/AnthropicService.py +13 -11
  32. edsl/inference_services/AwsBedrock.py +19 -17
  33. edsl/inference_services/AzureAI.py +37 -20
  34. edsl/inference_services/GoogleService.py +16 -12
  35. edsl/inference_services/GroqService.py +2 -0
  36. edsl/inference_services/InferenceServiceABC.py +58 -3
  37. edsl/inference_services/MistralAIService.py +120 -0
  38. edsl/inference_services/OpenAIService.py +48 -54
  39. edsl/inference_services/TestService.py +80 -0
  40. edsl/inference_services/TogetherAIService.py +170 -0
  41. edsl/inference_services/models_available_cache.py +0 -6
  42. edsl/inference_services/registry.py +6 -0
  43. edsl/jobs/Answers.py +10 -12
  44. edsl/jobs/FailedQuestion.py +78 -0
  45. edsl/jobs/Jobs.py +37 -22
  46. edsl/jobs/buckets/BucketCollection.py +24 -15
  47. edsl/jobs/buckets/TokenBucket.py +93 -14
  48. edsl/jobs/interviews/Interview.py +366 -78
  49. edsl/jobs/interviews/{interview_exception_tracking.py → InterviewExceptionCollection.py} +14 -68
  50. edsl/jobs/interviews/InterviewExceptionEntry.py +85 -19
  51. edsl/jobs/runners/JobsRunnerAsyncio.py +146 -175
  52. edsl/jobs/runners/JobsRunnerStatus.py +331 -0
  53. edsl/jobs/tasks/QuestionTaskCreator.py +30 -23
  54. edsl/jobs/tasks/TaskHistory.py +148 -213
  55. edsl/language_models/LanguageModel.py +261 -156
  56. edsl/language_models/ModelList.py +2 -2
  57. edsl/language_models/RegisterLanguageModelsMeta.py +14 -29
  58. edsl/language_models/fake_openai_call.py +15 -0
  59. edsl/language_models/fake_openai_service.py +61 -0
  60. edsl/language_models/registry.py +23 -6
  61. edsl/language_models/repair.py +0 -19
  62. edsl/language_models/utilities.py +61 -0
  63. edsl/notebooks/Notebook.py +20 -2
  64. edsl/prompts/Prompt.py +52 -2
  65. edsl/questions/AnswerValidatorMixin.py +23 -26
  66. edsl/questions/QuestionBase.py +330 -249
  67. edsl/questions/QuestionBaseGenMixin.py +133 -0
  68. edsl/questions/QuestionBasePromptsMixin.py +266 -0
  69. edsl/questions/QuestionBudget.py +99 -41
  70. edsl/questions/QuestionCheckBox.py +227 -35
  71. edsl/questions/QuestionExtract.py +98 -27
  72. edsl/questions/QuestionFreeText.py +52 -29
  73. edsl/questions/QuestionFunctional.py +7 -0
  74. edsl/questions/QuestionList.py +141 -22
  75. edsl/questions/QuestionMultipleChoice.py +159 -65
  76. edsl/questions/QuestionNumerical.py +88 -46
  77. edsl/questions/QuestionRank.py +182 -24
  78. edsl/questions/Quick.py +41 -0
  79. edsl/questions/RegisterQuestionsMeta.py +31 -12
  80. edsl/questions/ResponseValidatorABC.py +170 -0
  81. edsl/questions/__init__.py +3 -4
  82. edsl/questions/decorators.py +21 -0
  83. edsl/questions/derived/QuestionLikertFive.py +10 -5
  84. edsl/questions/derived/QuestionLinearScale.py +15 -2
  85. edsl/questions/derived/QuestionTopK.py +10 -1
  86. edsl/questions/derived/QuestionYesNo.py +24 -3
  87. edsl/questions/descriptors.py +43 -7
  88. edsl/questions/prompt_templates/question_budget.jinja +13 -0
  89. edsl/questions/prompt_templates/question_checkbox.jinja +32 -0
  90. edsl/questions/prompt_templates/question_extract.jinja +11 -0
  91. edsl/questions/prompt_templates/question_free_text.jinja +3 -0
  92. edsl/questions/prompt_templates/question_linear_scale.jinja +11 -0
  93. edsl/questions/prompt_templates/question_list.jinja +17 -0
  94. edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -0
  95. edsl/questions/prompt_templates/question_numerical.jinja +37 -0
  96. edsl/questions/question_registry.py +6 -2
  97. edsl/questions/templates/__init__.py +0 -0
  98. edsl/questions/templates/budget/__init__.py +0 -0
  99. edsl/questions/templates/budget/answering_instructions.jinja +7 -0
  100. edsl/questions/templates/budget/question_presentation.jinja +7 -0
  101. edsl/questions/templates/checkbox/__init__.py +0 -0
  102. edsl/questions/templates/checkbox/answering_instructions.jinja +10 -0
  103. edsl/questions/templates/checkbox/question_presentation.jinja +22 -0
  104. edsl/questions/templates/extract/__init__.py +0 -0
  105. edsl/questions/templates/extract/answering_instructions.jinja +7 -0
  106. edsl/questions/templates/extract/question_presentation.jinja +1 -0
  107. edsl/questions/templates/free_text/__init__.py +0 -0
  108. edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
  109. edsl/questions/templates/free_text/question_presentation.jinja +1 -0
  110. edsl/questions/templates/likert_five/__init__.py +0 -0
  111. edsl/questions/templates/likert_five/answering_instructions.jinja +10 -0
  112. edsl/questions/templates/likert_five/question_presentation.jinja +12 -0
  113. edsl/questions/templates/linear_scale/__init__.py +0 -0
  114. edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -0
  115. edsl/questions/templates/linear_scale/question_presentation.jinja +5 -0
  116. edsl/questions/templates/list/__init__.py +0 -0
  117. edsl/questions/templates/list/answering_instructions.jinja +4 -0
  118. edsl/questions/templates/list/question_presentation.jinja +5 -0
  119. edsl/questions/templates/multiple_choice/__init__.py +0 -0
  120. edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -0
  121. edsl/questions/templates/multiple_choice/html.jinja +0 -0
  122. edsl/questions/templates/multiple_choice/question_presentation.jinja +12 -0
  123. edsl/questions/templates/numerical/__init__.py +0 -0
  124. edsl/questions/templates/numerical/answering_instructions.jinja +8 -0
  125. edsl/questions/templates/numerical/question_presentation.jinja +7 -0
  126. edsl/questions/templates/rank/__init__.py +0 -0
  127. edsl/questions/templates/rank/answering_instructions.jinja +11 -0
  128. edsl/questions/templates/rank/question_presentation.jinja +15 -0
  129. edsl/questions/templates/top_k/__init__.py +0 -0
  130. edsl/questions/templates/top_k/answering_instructions.jinja +8 -0
  131. edsl/questions/templates/top_k/question_presentation.jinja +22 -0
  132. edsl/questions/templates/yes_no/__init__.py +0 -0
  133. edsl/questions/templates/yes_no/answering_instructions.jinja +6 -0
  134. edsl/questions/templates/yes_no/question_presentation.jinja +12 -0
  135. edsl/results/Dataset.py +20 -0
  136. edsl/results/DatasetExportMixin.py +46 -48
  137. edsl/results/DatasetTree.py +145 -0
  138. edsl/results/Result.py +32 -5
  139. edsl/results/Results.py +135 -46
  140. edsl/results/ResultsDBMixin.py +3 -3
  141. edsl/results/Selector.py +118 -0
  142. edsl/results/tree_explore.py +115 -0
  143. edsl/scenarios/FileStore.py +71 -10
  144. edsl/scenarios/Scenario.py +96 -25
  145. edsl/scenarios/ScenarioImageMixin.py +2 -2
  146. edsl/scenarios/ScenarioList.py +361 -39
  147. edsl/scenarios/ScenarioListExportMixin.py +9 -0
  148. edsl/scenarios/ScenarioListPdfMixin.py +150 -4
  149. edsl/study/SnapShot.py +8 -1
  150. edsl/study/Study.py +32 -0
  151. edsl/surveys/Rule.py +10 -1
  152. edsl/surveys/RuleCollection.py +21 -5
  153. edsl/surveys/Survey.py +637 -311
  154. edsl/surveys/SurveyExportMixin.py +71 -9
  155. edsl/surveys/SurveyFlowVisualizationMixin.py +2 -1
  156. edsl/surveys/SurveyQualtricsImport.py +75 -4
  157. edsl/surveys/instructions/ChangeInstruction.py +47 -0
  158. edsl/surveys/instructions/Instruction.py +34 -0
  159. edsl/surveys/instructions/InstructionCollection.py +77 -0
  160. edsl/surveys/instructions/__init__.py +0 -0
  161. edsl/templates/error_reporting/base.html +24 -0
  162. edsl/templates/error_reporting/exceptions_by_model.html +35 -0
  163. edsl/templates/error_reporting/exceptions_by_question_name.html +17 -0
  164. edsl/templates/error_reporting/exceptions_by_type.html +17 -0
  165. edsl/templates/error_reporting/interview_details.html +116 -0
  166. edsl/templates/error_reporting/interviews.html +10 -0
  167. edsl/templates/error_reporting/overview.html +5 -0
  168. edsl/templates/error_reporting/performance_plot.html +2 -0
  169. edsl/templates/error_reporting/report.css +74 -0
  170. edsl/templates/error_reporting/report.html +118 -0
  171. edsl/templates/error_reporting/report.js +25 -0
  172. edsl/utilities/utilities.py +9 -1
  173. {edsl-0.1.32.dist-info → edsl-0.1.33.dist-info}/METADATA +5 -2
  174. edsl-0.1.33.dist-info/RECORD +295 -0
  175. edsl/jobs/interviews/InterviewTaskBuildingMixin.py +0 -286
  176. edsl/jobs/interviews/retry_management.py +0 -37
  177. edsl/jobs/runners/JobsRunnerStatusMixin.py +0 -333
  178. edsl/utilities/gcp_bucket/simple_example.py +0 -9
  179. edsl-0.1.32.dist-info/RECORD +0 -209
  180. {edsl-0.1.32.dist-info → edsl-0.1.33.dist-info}/LICENSE +0 -0
  181. {edsl-0.1.32.dist-info → edsl-0.1.33.dist-info}/WHEEL +0 -0
@@ -11,6 +11,11 @@ class AnthropicService(InferenceServiceABC):
11
11
 
12
12
  _inference_service_ = "anthropic"
13
13
  _env_key_name_ = "ANTHROPIC_API_KEY"
14
+ key_sequence = ["content", 0, "text"] # ["content"][0]["text"]
15
+ usage_sequence = ["usage"]
16
+ input_token_name = "input_tokens"
17
+ output_token_name = "output_tokens"
18
+ model_exclude_list = []
14
19
 
15
20
  @classmethod
16
21
  def available(cls):
@@ -34,6 +39,11 @@ class AnthropicService(InferenceServiceABC):
34
39
  Child class of LanguageModel for interacting with OpenAI models
35
40
  """
36
41
 
42
+ key_sequence = cls.key_sequence
43
+ usage_sequence = cls.usage_sequence
44
+ input_token_name = cls.input_token_name
45
+ output_token_name = cls.output_token_name
46
+
37
47
  _inference_service_ = cls._inference_service_
38
48
  _model_ = model_name
39
49
  _parameters_ = {
@@ -46,6 +56,9 @@ class AnthropicService(InferenceServiceABC):
46
56
  "top_logprobs": 3,
47
57
  }
48
58
 
59
+ _tpm = cls.get_tpm(cls)
60
+ _rpm = cls.get_rpm(cls)
61
+
49
62
  async def async_execute_model_call(
50
63
  self, user_prompt: str, system_prompt: str = ""
51
64
  ) -> dict[str, Any]:
@@ -66,17 +79,6 @@ class AnthropicService(InferenceServiceABC):
66
79
  )
67
80
  return response.model_dump()
68
81
 
69
- @staticmethod
70
- def parse_response(raw_response: dict[str, Any]) -> str:
71
- """Parses the API response and returns the response text."""
72
- response = raw_response["content"][0]["text"]
73
- pattern = r"^```json(?:\\n|\n)(.+?)(?:\\n|\n)```$"
74
- match = re.match(pattern, response, re.DOTALL)
75
- if match:
76
- return match.group(1)
77
- else:
78
- return response
79
-
80
82
  LLM.__name__ = model_class_name
81
83
 
82
84
  return LLM
@@ -16,6 +16,18 @@ class AwsBedrockService(InferenceServiceABC):
16
16
  _env_key_name_ = (
17
17
  "AWS_ACCESS_KEY_ID" # or any other environment key for AWS credentials
18
18
  )
19
+ key_sequence = ["output", "message", "content", 0, "text"]
20
+ input_token_name = "inputTokens"
21
+ output_token_name = "outputTokens"
22
+ usage_sequence = ["usage"]
23
+ model_exclude_list = [
24
+ "ai21.j2-grande-instruct",
25
+ "ai21.j2-jumbo-instruct",
26
+ "ai21.j2-mid",
27
+ "ai21.j2-mid-v1",
28
+ "ai21.j2-ultra",
29
+ "ai21.j2-ultra-v1",
30
+ ]
19
31
 
20
32
  @classmethod
21
33
  def available(cls):
@@ -28,7 +40,7 @@ class AwsBedrockService(InferenceServiceABC):
28
40
  else:
29
41
  all_models_ids = cls._models_list_cache
30
42
 
31
- return all_models_ids
43
+ return [m for m in all_models_ids if m not in cls.model_exclude_list]
32
44
 
33
45
  @classmethod
34
46
  def create_model(
@@ -42,6 +54,8 @@ class AwsBedrockService(InferenceServiceABC):
42
54
  Child class of LanguageModel for interacting with AWS Bedrock models.
43
55
  """
44
56
 
57
+ key_sequence = cls.key_sequence
58
+ usage_sequence = cls.usage_sequence
45
59
  _inference_service_ = cls._inference_service_
46
60
  _model_ = model_name
47
61
  _parameters_ = {
@@ -49,6 +63,10 @@ class AwsBedrockService(InferenceServiceABC):
49
63
  "max_tokens": 512,
50
64
  "top_p": 0.9,
51
65
  }
66
+ input_token_name = cls.input_token_name
67
+ output_token_name = cls.output_token_name
68
+ _rpm = cls.get_rpm(cls)
69
+ _tpm = cls.get_tpm(cls)
52
70
 
53
71
  async def async_execute_model_call(
54
72
  self, user_prompt: str, system_prompt: str = ""
@@ -89,22 +107,6 @@ class AwsBedrockService(InferenceServiceABC):
89
107
  print(e)
90
108
  return {"error": str(e)}
91
109
 
92
- @staticmethod
93
- def parse_response(raw_response: dict[str, Any]) -> str:
94
- """Parses the API response and returns the response text."""
95
- if "output" in raw_response and "message" in raw_response["output"]:
96
- response = raw_response["output"]["message"]["content"][0]["text"]
97
- pattern = r"^```json(?:\\n|\n)(.+?)(?:\\n|\n)```$"
98
- match = re.match(pattern, response, re.DOTALL)
99
- if match:
100
- return match.group(1)
101
- else:
102
- out = fix_partial_correct_response(response)
103
- if "error" not in out:
104
- response = out["extracted_json"]
105
- return response
106
- return "Error parsing response"
107
-
108
110
  LLM.__name__ = model_class_name
109
111
 
110
112
  return LLM
@@ -25,11 +25,22 @@ def json_handle_none(value: Any) -> Any:
25
25
  class AzureAIService(InferenceServiceABC):
26
26
  """Azure AI service class."""
27
27
 
28
+ # key_sequence = ["content", 0, "text"] # ["content"][0]["text"]
29
+ key_sequence = ["choices", 0, "message", "content"]
30
+ usage_sequence = ["usage"]
31
+ input_token_name = "prompt_tokens"
32
+ output_token_name = "completion_tokens"
33
+
28
34
  _inference_service_ = "azure"
29
35
  _env_key_name_ = (
30
36
  "AZURE_ENDPOINT_URL_AND_KEY" # Environment variable for Azure API key
31
37
  )
32
38
  _model_id_to_endpoint_and_key = {}
39
+ model_exclude_list = [
40
+ "Cohere-command-r-plus-xncmg",
41
+ "Mistral-Nemo-klfsi",
42
+ "Mistral-large-2407-ojfld",
43
+ ]
33
44
 
34
45
  @classmethod
35
46
  def available(cls):
@@ -82,7 +93,7 @@ class AzureAIService(InferenceServiceABC):
82
93
 
83
94
  except Exception as e:
84
95
  raise e
85
- return out
96
+ return [m for m in out if m not in cls.model_exclude_list]
86
97
 
87
98
  @classmethod
88
99
  def create_model(
@@ -96,6 +107,10 @@ class AzureAIService(InferenceServiceABC):
96
107
  Child class of LanguageModel for interacting with Azure OpenAI models.
97
108
  """
98
109
 
110
+ key_sequence = cls.key_sequence
111
+ usage_sequence = cls.usage_sequence
112
+ input_token_name = cls.input_token_name
113
+ output_token_name = cls.output_token_name
99
114
  _inference_service_ = cls._inference_service_
100
115
  _model_ = model_name
101
116
  _parameters_ = {
@@ -103,6 +118,8 @@ class AzureAIService(InferenceServiceABC):
103
118
  "max_tokens": 512,
104
119
  "top_p": 0.9,
105
120
  }
121
+ _rpm = cls.get_rpm(cls)
122
+ _tpm = cls.get_tpm(cls)
106
123
 
107
124
  async def async_execute_model_call(
108
125
  self, user_prompt: str, system_prompt: str = ""
@@ -172,25 +189,25 @@ class AzureAIService(InferenceServiceABC):
172
189
  )
173
190
  return response.model_dump()
174
191
 
175
- @staticmethod
176
- def parse_response(raw_response: dict[str, Any]) -> str:
177
- """Parses the API response and returns the response text."""
178
- if (
179
- raw_response
180
- and "choices" in raw_response
181
- and raw_response["choices"]
182
- ):
183
- response = raw_response["choices"][0]["message"]["content"]
184
- pattern = r"^```json(?:\\n|\n)(.+?)(?:\\n|\n)```$"
185
- match = re.match(pattern, response, re.DOTALL)
186
- if match:
187
- return match.group(1)
188
- else:
189
- out = fix_partial_correct_response(response)
190
- if "error" not in out:
191
- response = out["extracted_json"]
192
- return response
193
- return "Error parsing response"
192
+ # @staticmethod
193
+ # def parse_response(raw_response: dict[str, Any]) -> str:
194
+ # """Parses the API response and returns the response text."""
195
+ # if (
196
+ # raw_response
197
+ # and "choices" in raw_response
198
+ # and raw_response["choices"]
199
+ # ):
200
+ # response = raw_response["choices"][0]["message"]["content"]
201
+ # pattern = r"^```json(?:\\n|\n)(.+?)(?:\\n|\n)```$"
202
+ # match = re.match(pattern, response, re.DOTALL)
203
+ # if match:
204
+ # return match.group(1)
205
+ # else:
206
+ # out = fix_partial_correct_response(response)
207
+ # if "error" not in out:
208
+ # response = out["extracted_json"]
209
+ # return response
210
+ # return "Error parsing response"
194
211
 
195
212
  LLM.__name__ = model_class_name
196
213
 
@@ -10,10 +10,16 @@ from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
10
10
 
11
11
  class GoogleService(InferenceServiceABC):
12
12
  _inference_service_ = "google"
13
+ key_sequence = ["candidates", 0, "content", "parts", 0, "text"]
14
+ usage_sequence = ["usageMetadata"]
15
+ input_token_name = "promptTokenCount"
16
+ output_token_name = "candidatesTokenCount"
17
+
18
+ model_exclude_list = []
13
19
 
14
20
  @classmethod
15
21
  def available(cls):
16
- return ["gemini-pro"]
22
+ return ["gemini-pro", "gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.0-pro"]
17
23
 
18
24
  @classmethod
19
25
  def create_model(
@@ -24,7 +30,15 @@ class GoogleService(InferenceServiceABC):
24
30
 
25
31
  class LLM(LanguageModel):
26
32
  _model_ = model_name
33
+ key_sequence = cls.key_sequence
34
+ usage_sequence = cls.usage_sequence
35
+ input_token_name = cls.input_token_name
36
+ output_token_name = cls.output_token_name
27
37
  _inference_service_ = cls._inference_service_
38
+
39
+ _tpm = cls.get_tpm(cls)
40
+ _rpm = cls.get_rpm(cls)
41
+
28
42
  _parameters_ = {
29
43
  "temperature": 0.5,
30
44
  "topP": 1,
@@ -50,7 +64,7 @@ class GoogleService(InferenceServiceABC):
50
64
  "stopSequences": self.stopSequences,
51
65
  },
52
66
  }
53
-
67
+ # print(combined_prompt)
54
68
  async with aiohttp.ClientSession() as session:
55
69
  async with session.post(
56
70
  url, headers=headers, data=json.dumps(data)
@@ -58,16 +72,6 @@ class GoogleService(InferenceServiceABC):
58
72
  raw_response_text = await response.text()
59
73
  return json.loads(raw_response_text)
60
74
 
61
- def parse_response(self, raw_response: dict[str, Any]) -> str:
62
- data = raw_response
63
- try:
64
- return data["candidates"][0]["content"]["parts"][0]["text"]
65
- except KeyError as e:
66
- print(
67
- f"The data return was {data}, which was missing the key 'candidates'"
68
- )
69
- raise e
70
-
71
75
  LLM.__name__ = model_name
72
76
 
73
77
  return LLM
@@ -13,6 +13,8 @@ class GroqService(OpenAIService):
13
13
  _sync_client_ = groq.Groq
14
14
  _async_client_ = groq.AsyncGroq
15
15
 
16
+ model_exclude_list = ["whisper-large-v3", "distil-whisper-large-v3-en"]
17
+
16
18
  # _base_url_ = "https://api.deepinfra.com/v1/openai"
17
19
  _base_url_ = None
18
20
  _models_list_cache: List[str] = []
@@ -1,22 +1,77 @@
1
1
  from abc import abstractmethod, ABC
2
- from typing import Any
2
+ import os
3
3
  import re
4
+ from edsl.config import CONFIG
4
5
 
5
6
 
6
7
  class InferenceServiceABC(ABC):
7
- """Abstract class for inference services."""
8
+ """
9
+ Abstract class for inference services.
10
+ Anthropic: https://docs.anthropic.com/en/api/rate-limits
11
+ """
12
+
13
+ default_levels = {
14
+ "google": {"tpm": 2_000_000, "rpm": 15},
15
+ "openai": {"tpm": 2_000_000, "rpm": 10_000},
16
+ "anthropic": {"tpm": 2_000_000, "rpm": 500},
17
+ }
18
+
19
+ def __init_subclass__(cls):
20
+ """
21
+ Check that the subclass has the required attributes.
22
+ - `key_sequence` attribute determines...
23
+ - `model_exclude_list` attribute determines...
24
+ """
25
+ if not hasattr(cls, "key_sequence"):
26
+ raise NotImplementedError(
27
+ f"Class {cls.__name__} must have a 'key_sequence' attribute."
28
+ )
29
+ if not hasattr(cls, "model_exclude_list"):
30
+ raise NotImplementedError(
31
+ f"Class {cls.__name__} must have a 'model_exclude_list' attribute."
32
+ )
33
+
34
+ @classmethod
35
+ def _get_limt(cls, limit_type: str) -> int:
36
+ key = f"EDSL_SERVICE_{limit_type.upper()}_{cls._inference_service_.upper()}"
37
+ if key in os.environ:
38
+ return int(os.getenv(key))
39
+
40
+ if cls._inference_service_ in cls.default_levels:
41
+ return int(cls.default_levels[cls._inference_service_][limit_type])
42
+
43
+ return int(CONFIG.get(f"EDSL_SERVICE_{limit_type.upper()}_BASELINE"))
44
+
45
+ def get_tpm(cls) -> int:
46
+ """
47
+ Returns the TPM for the service. If the service is not defined in the environment variables, it will return the baseline TPM.
48
+ """
49
+ return cls._get_limt(limit_type="tpm")
50
+
51
+ def get_rpm(cls):
52
+ """
53
+ Returns the RPM for the service. If the service is not defined in the environment variables, it will return the baseline RPM.
54
+ """
55
+ return cls._get_limt(limit_type="rpm")
8
56
 
9
57
  @abstractmethod
10
58
  def available() -> list[str]:
59
+ """
60
+ Returns a list of available models for the service.
61
+ """
11
62
  pass
12
63
 
13
64
  @abstractmethod
14
65
  def create_model():
66
+ """
67
+ Returns a LanguageModel object.
68
+ """
15
69
  pass
16
70
 
17
71
  @staticmethod
18
72
  def to_class_name(s):
19
- """Convert a string to a valid class name.
73
+ """
74
+ Converts a string to a valid class name.
20
75
 
21
76
  >>> InferenceServiceABC.to_class_name("hello world")
22
77
  'HelloWorld'
@@ -0,0 +1,120 @@
1
+ import os
2
+ from typing import Any, List
3
+ from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
4
+ from edsl.language_models.LanguageModel import LanguageModel
5
+ import asyncio
6
+ from mistralai import Mistral
7
+
8
+ from edsl.exceptions.language_models import LanguageModelBadResponseError
9
+
10
+
11
+ class MistralAIService(InferenceServiceABC):
12
+ """Mistral AI service class."""
13
+
14
+ key_sequence = ["choices", 0, "message", "content"]
15
+ usage_sequence = ["usage"]
16
+
17
+ _inference_service_ = "mistral"
18
+ _env_key_name_ = "MISTRAL_API_KEY" # Environment variable for Mistral API key
19
+ input_token_name = "prompt_tokens"
20
+ output_token_name = "completion_tokens"
21
+
22
+ _sync_client_instance = None
23
+ _async_client_instance = None
24
+
25
+ _sync_client = Mistral
26
+ _async_client = Mistral
27
+
28
+ _models_list_cache: List[str] = []
29
+ model_exclude_list = []
30
+
31
+ def __init_subclass__(cls, **kwargs):
32
+ super().__init_subclass__(**kwargs)
33
+ # so subclasses have to create their own instances of the clients
34
+ cls._sync_client_instance = None
35
+ cls._async_client_instance = None
36
+
37
+ @classmethod
38
+ def sync_client(cls):
39
+ if cls._sync_client_instance is None:
40
+ cls._sync_client_instance = cls._sync_client(
41
+ api_key=os.getenv(cls._env_key_name_)
42
+ )
43
+ return cls._sync_client_instance
44
+
45
+ @classmethod
46
+ def async_client(cls):
47
+ if cls._async_client_instance is None:
48
+ cls._async_client_instance = cls._async_client(
49
+ api_key=os.getenv(cls._env_key_name_)
50
+ )
51
+ return cls._async_client_instance
52
+
53
+ @classmethod
54
+ def available(cls) -> list[str]:
55
+ if not cls._models_list_cache:
56
+ cls._models_list_cache = [
57
+ m.id for m in cls.sync_client().models.list().data
58
+ ]
59
+
60
+ return cls._models_list_cache
61
+
62
+ @classmethod
63
+ def create_model(
64
+ cls, model_name: str = "mistral", model_class_name=None
65
+ ) -> LanguageModel:
66
+ if model_class_name is None:
67
+ model_class_name = cls.to_class_name(model_name)
68
+
69
+ class LLM(LanguageModel):
70
+ """
71
+ Child class of LanguageModel for interacting with Mistral models.
72
+ """
73
+
74
+ key_sequence = cls.key_sequence
75
+ usage_sequence = cls.usage_sequence
76
+
77
+ input_token_name = cls.input_token_name
78
+ output_token_name = cls.output_token_name
79
+
80
+ _inference_service_ = cls._inference_service_
81
+ _model_ = model_name
82
+ _parameters_ = {
83
+ "temperature": 0.5,
84
+ "max_tokens": 512,
85
+ "top_p": 0.9,
86
+ }
87
+
88
+ _tpm = cls.get_tpm(cls)
89
+ _rpm = cls.get_rpm(cls)
90
+
91
+ def sync_client(self):
92
+ return cls.sync_client()
93
+
94
+ def async_client(self):
95
+ return cls.async_client()
96
+
97
+ async def async_execute_model_call(
98
+ self, user_prompt: str, system_prompt: str = ""
99
+ ) -> dict[str, Any]:
100
+ """Calls the Mistral API and returns the API response."""
101
+ s = self.async_client()
102
+
103
+ try:
104
+ res = await s.chat.complete_async(
105
+ model=model_name,
106
+ messages=[
107
+ {
108
+ "content": user_prompt,
109
+ "role": "user",
110
+ },
111
+ ],
112
+ )
113
+ except Exception as e:
114
+ raise LanguageModelBadResponseError(f"Error with Mistral API: {e}")
115
+
116
+ return res.model_dump()
117
+
118
+ LLM.__name__ = model_class_name
119
+
120
+ return LLM
@@ -1,8 +1,7 @@
1
- from typing import Any, List
2
- import re
1
+ from __future__ import annotations
2
+ from typing import Any, List, Optional
3
3
  import os
4
4
 
5
- # from openai import AsyncOpenAI
6
5
  import openai
7
6
 
8
7
  from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
@@ -10,6 +9,8 @@ from edsl.language_models import LanguageModel
10
9
  from edsl.inference_services.rate_limits_cache import rate_limits
11
10
  from edsl.utilities.utilities import fix_partial_correct_response
12
11
 
12
+ from edsl.config import CONFIG
13
+
13
14
 
14
15
  class OpenAIService(InferenceServiceABC):
15
16
  """OpenAI service class."""
@@ -21,19 +22,36 @@ class OpenAIService(InferenceServiceABC):
21
22
  _sync_client_ = openai.OpenAI
22
23
  _async_client_ = openai.AsyncOpenAI
23
24
 
25
+ _sync_client_instance = None
26
+ _async_client_instance = None
27
+
28
+ key_sequence = ["choices", 0, "message", "content"]
29
+ usage_sequence = ["usage"]
30
+ input_token_name = "prompt_tokens"
31
+ output_token_name = "completion_tokens"
32
+
33
+ def __init_subclass__(cls, **kwargs):
34
+ super().__init_subclass__(**kwargs)
35
+ # so subclasses have to create their own instances of the clients
36
+ cls._sync_client_instance = None
37
+ cls._async_client_instance = None
38
+
24
39
  @classmethod
25
40
  def sync_client(cls):
26
- return cls._sync_client_(
27
- api_key=os.getenv(cls._env_key_name_), base_url=cls._base_url_
28
- )
41
+ if cls._sync_client_instance is None:
42
+ cls._sync_client_instance = cls._sync_client_(
43
+ api_key=os.getenv(cls._env_key_name_), base_url=cls._base_url_
44
+ )
45
+ return cls._sync_client_instance
29
46
 
30
47
  @classmethod
31
48
  def async_client(cls):
32
- return cls._async_client_(
33
- api_key=os.getenv(cls._env_key_name_), base_url=cls._base_url_
34
- )
49
+ if cls._async_client_instance is None:
50
+ cls._async_client_instance = cls._async_client_(
51
+ api_key=os.getenv(cls._env_key_name_), base_url=cls._base_url_
52
+ )
53
+ return cls._async_client_instance
35
54
 
36
- # TODO: Make this a coop call
37
55
  model_exclude_list = [
38
56
  "whisper-1",
39
57
  "davinci-002",
@@ -48,6 +66,8 @@ class OpenAIService(InferenceServiceABC):
48
66
  "text-embedding-3-small",
49
67
  "text-embedding-ada-002",
50
68
  "ft:davinci-002:mit-horton-lab::8OfuHgoo",
69
+ "gpt-3.5-turbo-instruct-0914",
70
+ "gpt-3.5-turbo-instruct",
51
71
  ]
52
72
  _models_list_cache: List[str] = []
53
73
 
@@ -61,11 +81,8 @@ class OpenAIService(InferenceServiceABC):
61
81
 
62
82
  @classmethod
63
83
  def available(cls) -> List[str]:
64
- # from openai import OpenAI
65
-
66
84
  if not cls._models_list_cache:
67
85
  try:
68
- # client = OpenAI(api_key = os.getenv(cls._env_key_name_), base_url = cls._base_url_)
69
86
  cls._models_list_cache = [
70
87
  m.id
71
88
  for m in cls.get_model_list()
@@ -73,15 +90,6 @@ class OpenAIService(InferenceServiceABC):
73
90
  ]
74
91
  except Exception as e:
75
92
  raise
76
- # print(
77
- # f"""Error retrieving models: {e}.
78
- # See instructions about storing your API keys: https://docs.expectedparrot.com/en/latest/api_keys.html"""
79
- # )
80
- # cls._models_list_cache = [
81
- # "gpt-3.5-turbo",
82
- # "gpt-4-1106-preview",
83
- # "gpt-4",
84
- # ] # Fallback list
85
93
  return cls._models_list_cache
86
94
 
87
95
  @classmethod
@@ -94,6 +102,14 @@ class OpenAIService(InferenceServiceABC):
94
102
  Child class of LanguageModel for interacting with OpenAI models
95
103
  """
96
104
 
105
+ key_sequence = cls.key_sequence
106
+ usage_sequence = cls.usage_sequence
107
+ input_token_name = cls.input_token_name
108
+ output_token_name = cls.output_token_name
109
+
110
+ _rpm = cls.get_rpm(cls)
111
+ _tpm = cls.get_tpm(cls)
112
+
97
113
  _inference_service_ = cls._inference_service_
98
114
  _model_ = model_name
99
115
  _parameters_ = {
@@ -114,15 +130,9 @@ class OpenAIService(InferenceServiceABC):
114
130
 
115
131
  @classmethod
116
132
  def available(cls) -> list[str]:
117
- # import openai
118
- # client = openai.OpenAI(api_key = os.getenv(cls._env_key_name_), base_url = cls._base_url_)
119
- # return client.models.list()
120
133
  return cls.sync_client().models.list()
121
134
 
122
135
  def get_headers(self) -> dict[str, Any]:
123
- # from openai import OpenAI
124
-
125
- # client = OpenAI(api_key = os.getenv(cls._env_key_name_), base_url = cls._base_url_)
126
136
  client = self.sync_client()
127
137
  response = client.chat.completions.with_raw_response.create(
128
138
  messages=[
@@ -159,6 +169,9 @@ class OpenAIService(InferenceServiceABC):
159
169
  user_prompt: str,
160
170
  system_prompt: str = "",
161
171
  encoded_image=None,
172
+ invigilator: Optional[
173
+ "InvigilatorAI"
174
+ ] = None, # TBD - can eventually be used for function-calling
162
175
  ) -> dict[str, Any]:
163
176
  """Calls the OpenAI API and returns the API response."""
164
177
  if encoded_image:
@@ -173,17 +186,16 @@ class OpenAIService(InferenceServiceABC):
173
186
  )
174
187
  else:
175
188
  content = user_prompt
176
- # self.client = AsyncOpenAI(
177
- # api_key = os.getenv(cls._env_key_name_),
178
- # base_url = cls._base_url_
179
- # )
180
189
  client = self.async_client()
190
+ messages = [
191
+ {"role": "system", "content": system_prompt},
192
+ {"role": "user", "content": content},
193
+ ]
194
+ if system_prompt == "" and self.omit_system_prompt_if_empty:
195
+ messages = messages[1:]
181
196
  params = {
182
197
  "model": self.model,
183
- "messages": [
184
- {"role": "system", "content": system_prompt},
185
- {"role": "user", "content": content},
186
- ],
198
+ "messages": messages,
187
199
  "temperature": self.temperature,
188
200
  "max_tokens": self.max_tokens,
189
201
  "top_p": self.top_p,
@@ -195,24 +207,6 @@ class OpenAIService(InferenceServiceABC):
195
207
  response = await client.chat.completions.create(**params)
196
208
  return response.model_dump()
197
209
 
198
- @staticmethod
199
- def parse_response(raw_response: dict[str, Any]) -> str:
200
- """Parses the API response and returns the response text."""
201
- try:
202
- response = raw_response["choices"][0]["message"]["content"]
203
- except KeyError:
204
- print("Tried to parse response but failed:")
205
- print(raw_response)
206
- pattern = r"^```json(?:\\n|\n)(.+?)(?:\\n|\n)```$"
207
- match = re.match(pattern, response, re.DOTALL)
208
- if match:
209
- return match.group(1)
210
- else:
211
- out = fix_partial_correct_response(response)
212
- if "error" not in out:
213
- response = out["extracted_json"]
214
- return response
215
-
216
210
  LLM.__name__ = "LanguageModel"
217
211
 
218
212
  return LLM