edsl 0.1.33__py3-none-any.whl → 0.1.33.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. edsl/Base.py +3 -9
  2. edsl/__init__.py +3 -8
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +8 -40
  5. edsl/agents/AgentList.py +0 -43
  6. edsl/agents/Invigilator.py +219 -135
  7. edsl/agents/InvigilatorBase.py +59 -148
  8. edsl/agents/{PromptConstructor.py → PromptConstructionMixin.py} +89 -138
  9. edsl/agents/__init__.py +0 -1
  10. edsl/config.py +56 -47
  11. edsl/coop/coop.py +7 -50
  12. edsl/data/Cache.py +1 -35
  13. edsl/data_transfer_models.py +38 -73
  14. edsl/enums.py +0 -4
  15. edsl/exceptions/language_models.py +1 -25
  16. edsl/exceptions/questions.py +5 -62
  17. edsl/exceptions/results.py +0 -4
  18. edsl/inference_services/AnthropicService.py +11 -13
  19. edsl/inference_services/AwsBedrock.py +17 -19
  20. edsl/inference_services/AzureAI.py +20 -37
  21. edsl/inference_services/GoogleService.py +12 -16
  22. edsl/inference_services/GroqService.py +0 -2
  23. edsl/inference_services/InferenceServiceABC.py +3 -58
  24. edsl/inference_services/OpenAIService.py +54 -48
  25. edsl/inference_services/models_available_cache.py +6 -0
  26. edsl/inference_services/registry.py +0 -6
  27. edsl/jobs/Answers.py +12 -10
  28. edsl/jobs/Jobs.py +21 -36
  29. edsl/jobs/buckets/BucketCollection.py +15 -24
  30. edsl/jobs/buckets/TokenBucket.py +14 -93
  31. edsl/jobs/interviews/Interview.py +78 -366
  32. edsl/jobs/interviews/InterviewExceptionEntry.py +19 -85
  33. edsl/jobs/interviews/InterviewTaskBuildingMixin.py +286 -0
  34. edsl/jobs/interviews/{InterviewExceptionCollection.py → interview_exception_tracking.py} +68 -14
  35. edsl/jobs/interviews/retry_management.py +37 -0
  36. edsl/jobs/runners/JobsRunnerAsyncio.py +175 -146
  37. edsl/jobs/runners/JobsRunnerStatusMixin.py +333 -0
  38. edsl/jobs/tasks/QuestionTaskCreator.py +23 -30
  39. edsl/jobs/tasks/TaskHistory.py +213 -148
  40. edsl/language_models/LanguageModel.py +156 -261
  41. edsl/language_models/ModelList.py +2 -2
  42. edsl/language_models/RegisterLanguageModelsMeta.py +29 -14
  43. edsl/language_models/registry.py +6 -23
  44. edsl/language_models/repair.py +19 -0
  45. edsl/prompts/Prompt.py +2 -52
  46. edsl/questions/AnswerValidatorMixin.py +26 -23
  47. edsl/questions/QuestionBase.py +249 -329
  48. edsl/questions/QuestionBudget.py +41 -99
  49. edsl/questions/QuestionCheckBox.py +35 -227
  50. edsl/questions/QuestionExtract.py +27 -98
  51. edsl/questions/QuestionFreeText.py +29 -52
  52. edsl/questions/QuestionFunctional.py +0 -7
  53. edsl/questions/QuestionList.py +22 -141
  54. edsl/questions/QuestionMultipleChoice.py +65 -159
  55. edsl/questions/QuestionNumerical.py +46 -88
  56. edsl/questions/QuestionRank.py +24 -182
  57. edsl/questions/RegisterQuestionsMeta.py +12 -31
  58. edsl/questions/__init__.py +4 -3
  59. edsl/questions/derived/QuestionLikertFive.py +5 -10
  60. edsl/questions/derived/QuestionLinearScale.py +2 -15
  61. edsl/questions/derived/QuestionTopK.py +1 -10
  62. edsl/questions/derived/QuestionYesNo.py +3 -24
  63. edsl/questions/descriptors.py +7 -43
  64. edsl/questions/question_registry.py +2 -6
  65. edsl/results/Dataset.py +0 -20
  66. edsl/results/DatasetExportMixin.py +48 -46
  67. edsl/results/Result.py +5 -32
  68. edsl/results/Results.py +46 -135
  69. edsl/results/ResultsDBMixin.py +3 -3
  70. edsl/scenarios/FileStore.py +10 -71
  71. edsl/scenarios/Scenario.py +25 -96
  72. edsl/scenarios/ScenarioImageMixin.py +2 -2
  73. edsl/scenarios/ScenarioList.py +39 -361
  74. edsl/scenarios/ScenarioListExportMixin.py +0 -9
  75. edsl/scenarios/ScenarioListPdfMixin.py +4 -150
  76. edsl/study/SnapShot.py +1 -8
  77. edsl/study/Study.py +0 -32
  78. edsl/surveys/Rule.py +1 -10
  79. edsl/surveys/RuleCollection.py +5 -21
  80. edsl/surveys/Survey.py +310 -636
  81. edsl/surveys/SurveyExportMixin.py +9 -71
  82. edsl/surveys/SurveyFlowVisualizationMixin.py +1 -2
  83. edsl/surveys/SurveyQualtricsImport.py +4 -75
  84. edsl/utilities/gcp_bucket/simple_example.py +9 -0
  85. edsl/utilities/utilities.py +1 -9
  86. {edsl-0.1.33.dist-info → edsl-0.1.33.dev1.dist-info}/METADATA +2 -5
  87. edsl-0.1.33.dev1.dist-info/RECORD +209 -0
  88. edsl/TemplateLoader.py +0 -24
  89. edsl/auto/AutoStudy.py +0 -117
  90. edsl/auto/StageBase.py +0 -230
  91. edsl/auto/StageGenerateSurvey.py +0 -178
  92. edsl/auto/StageLabelQuestions.py +0 -125
  93. edsl/auto/StagePersona.py +0 -61
  94. edsl/auto/StagePersonaDimensionValueRanges.py +0 -88
  95. edsl/auto/StagePersonaDimensionValues.py +0 -74
  96. edsl/auto/StagePersonaDimensions.py +0 -69
  97. edsl/auto/StageQuestions.py +0 -73
  98. edsl/auto/SurveyCreatorPipeline.py +0 -21
  99. edsl/auto/utilities.py +0 -224
  100. edsl/coop/PriceFetcher.py +0 -58
  101. edsl/inference_services/MistralAIService.py +0 -120
  102. edsl/inference_services/TestService.py +0 -80
  103. edsl/inference_services/TogetherAIService.py +0 -170
  104. edsl/jobs/FailedQuestion.py +0 -78
  105. edsl/jobs/runners/JobsRunnerStatus.py +0 -331
  106. edsl/language_models/fake_openai_call.py +0 -15
  107. edsl/language_models/fake_openai_service.py +0 -61
  108. edsl/language_models/utilities.py +0 -61
  109. edsl/questions/QuestionBaseGenMixin.py +0 -133
  110. edsl/questions/QuestionBasePromptsMixin.py +0 -266
  111. edsl/questions/Quick.py +0 -41
  112. edsl/questions/ResponseValidatorABC.py +0 -170
  113. edsl/questions/decorators.py +0 -21
  114. edsl/questions/prompt_templates/question_budget.jinja +0 -13
  115. edsl/questions/prompt_templates/question_checkbox.jinja +0 -32
  116. edsl/questions/prompt_templates/question_extract.jinja +0 -11
  117. edsl/questions/prompt_templates/question_free_text.jinja +0 -3
  118. edsl/questions/prompt_templates/question_linear_scale.jinja +0 -11
  119. edsl/questions/prompt_templates/question_list.jinja +0 -17
  120. edsl/questions/prompt_templates/question_multiple_choice.jinja +0 -33
  121. edsl/questions/prompt_templates/question_numerical.jinja +0 -37
  122. edsl/questions/templates/__init__.py +0 -0
  123. edsl/questions/templates/budget/__init__.py +0 -0
  124. edsl/questions/templates/budget/answering_instructions.jinja +0 -7
  125. edsl/questions/templates/budget/question_presentation.jinja +0 -7
  126. edsl/questions/templates/checkbox/__init__.py +0 -0
  127. edsl/questions/templates/checkbox/answering_instructions.jinja +0 -10
  128. edsl/questions/templates/checkbox/question_presentation.jinja +0 -22
  129. edsl/questions/templates/extract/__init__.py +0 -0
  130. edsl/questions/templates/extract/answering_instructions.jinja +0 -7
  131. edsl/questions/templates/extract/question_presentation.jinja +0 -1
  132. edsl/questions/templates/free_text/__init__.py +0 -0
  133. edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
  134. edsl/questions/templates/free_text/question_presentation.jinja +0 -1
  135. edsl/questions/templates/likert_five/__init__.py +0 -0
  136. edsl/questions/templates/likert_five/answering_instructions.jinja +0 -10
  137. edsl/questions/templates/likert_five/question_presentation.jinja +0 -12
  138. edsl/questions/templates/linear_scale/__init__.py +0 -0
  139. edsl/questions/templates/linear_scale/answering_instructions.jinja +0 -5
  140. edsl/questions/templates/linear_scale/question_presentation.jinja +0 -5
  141. edsl/questions/templates/list/__init__.py +0 -0
  142. edsl/questions/templates/list/answering_instructions.jinja +0 -4
  143. edsl/questions/templates/list/question_presentation.jinja +0 -5
  144. edsl/questions/templates/multiple_choice/__init__.py +0 -0
  145. edsl/questions/templates/multiple_choice/answering_instructions.jinja +0 -9
  146. edsl/questions/templates/multiple_choice/html.jinja +0 -0
  147. edsl/questions/templates/multiple_choice/question_presentation.jinja +0 -12
  148. edsl/questions/templates/numerical/__init__.py +0 -0
  149. edsl/questions/templates/numerical/answering_instructions.jinja +0 -8
  150. edsl/questions/templates/numerical/question_presentation.jinja +0 -7
  151. edsl/questions/templates/rank/__init__.py +0 -0
  152. edsl/questions/templates/rank/answering_instructions.jinja +0 -11
  153. edsl/questions/templates/rank/question_presentation.jinja +0 -15
  154. edsl/questions/templates/top_k/__init__.py +0 -0
  155. edsl/questions/templates/top_k/answering_instructions.jinja +0 -8
  156. edsl/questions/templates/top_k/question_presentation.jinja +0 -22
  157. edsl/questions/templates/yes_no/__init__.py +0 -0
  158. edsl/questions/templates/yes_no/answering_instructions.jinja +0 -6
  159. edsl/questions/templates/yes_no/question_presentation.jinja +0 -12
  160. edsl/results/DatasetTree.py +0 -145
  161. edsl/results/Selector.py +0 -118
  162. edsl/results/tree_explore.py +0 -115
  163. edsl/surveys/instructions/ChangeInstruction.py +0 -47
  164. edsl/surveys/instructions/Instruction.py +0 -34
  165. edsl/surveys/instructions/InstructionCollection.py +0 -77
  166. edsl/surveys/instructions/__init__.py +0 -0
  167. edsl/templates/error_reporting/base.html +0 -24
  168. edsl/templates/error_reporting/exceptions_by_model.html +0 -35
  169. edsl/templates/error_reporting/exceptions_by_question_name.html +0 -17
  170. edsl/templates/error_reporting/exceptions_by_type.html +0 -17
  171. edsl/templates/error_reporting/interview_details.html +0 -116
  172. edsl/templates/error_reporting/interviews.html +0 -10
  173. edsl/templates/error_reporting/overview.html +0 -5
  174. edsl/templates/error_reporting/performance_plot.html +0 -2
  175. edsl/templates/error_reporting/report.css +0 -74
  176. edsl/templates/error_reporting/report.html +0 -118
  177. edsl/templates/error_reporting/report.js +0 -25
  178. edsl-0.1.33.dist-info/RECORD +0 -295
  179. {edsl-0.1.33.dist-info → edsl-0.1.33.dev1.dist-info}/LICENSE +0 -0
  180. {edsl-0.1.33.dist-info → edsl-0.1.33.dev1.dist-info}/WHEEL +0 -0
@@ -16,18 +16,6 @@ class AwsBedrockService(InferenceServiceABC):
16
16
  _env_key_name_ = (
17
17
  "AWS_ACCESS_KEY_ID" # or any other environment key for AWS credentials
18
18
  )
19
- key_sequence = ["output", "message", "content", 0, "text"]
20
- input_token_name = "inputTokens"
21
- output_token_name = "outputTokens"
22
- usage_sequence = ["usage"]
23
- model_exclude_list = [
24
- "ai21.j2-grande-instruct",
25
- "ai21.j2-jumbo-instruct",
26
- "ai21.j2-mid",
27
- "ai21.j2-mid-v1",
28
- "ai21.j2-ultra",
29
- "ai21.j2-ultra-v1",
30
- ]
31
19
 
32
20
  @classmethod
33
21
  def available(cls):
@@ -40,7 +28,7 @@ class AwsBedrockService(InferenceServiceABC):
40
28
  else:
41
29
  all_models_ids = cls._models_list_cache
42
30
 
43
- return [m for m in all_models_ids if m not in cls.model_exclude_list]
31
+ return all_models_ids
44
32
 
45
33
  @classmethod
46
34
  def create_model(
@@ -54,8 +42,6 @@ class AwsBedrockService(InferenceServiceABC):
54
42
  Child class of LanguageModel for interacting with AWS Bedrock models.
55
43
  """
56
44
 
57
- key_sequence = cls.key_sequence
58
- usage_sequence = cls.usage_sequence
59
45
  _inference_service_ = cls._inference_service_
60
46
  _model_ = model_name
61
47
  _parameters_ = {
@@ -63,10 +49,6 @@ class AwsBedrockService(InferenceServiceABC):
63
49
  "max_tokens": 512,
64
50
  "top_p": 0.9,
65
51
  }
66
- input_token_name = cls.input_token_name
67
- output_token_name = cls.output_token_name
68
- _rpm = cls.get_rpm(cls)
69
- _tpm = cls.get_tpm(cls)
70
52
 
71
53
  async def async_execute_model_call(
72
54
  self, user_prompt: str, system_prompt: str = ""
@@ -107,6 +89,22 @@ class AwsBedrockService(InferenceServiceABC):
107
89
  print(e)
108
90
  return {"error": str(e)}
109
91
 
92
+ @staticmethod
93
+ def parse_response(raw_response: dict[str, Any]) -> str:
94
+ """Parses the API response and returns the response text."""
95
+ if "output" in raw_response and "message" in raw_response["output"]:
96
+ response = raw_response["output"]["message"]["content"][0]["text"]
97
+ pattern = r"^```json(?:\\n|\n)(.+?)(?:\\n|\n)```$"
98
+ match = re.match(pattern, response, re.DOTALL)
99
+ if match:
100
+ return match.group(1)
101
+ else:
102
+ out = fix_partial_correct_response(response)
103
+ if "error" not in out:
104
+ response = out["extracted_json"]
105
+ return response
106
+ return "Error parsing response"
107
+
110
108
  LLM.__name__ = model_class_name
111
109
 
112
110
  return LLM
@@ -25,22 +25,11 @@ def json_handle_none(value: Any) -> Any:
25
25
  class AzureAIService(InferenceServiceABC):
26
26
  """Azure AI service class."""
27
27
 
28
- # key_sequence = ["content", 0, "text"] # ["content"][0]["text"]
29
- key_sequence = ["choices", 0, "message", "content"]
30
- usage_sequence = ["usage"]
31
- input_token_name = "prompt_tokens"
32
- output_token_name = "completion_tokens"
33
-
34
28
  _inference_service_ = "azure"
35
29
  _env_key_name_ = (
36
30
  "AZURE_ENDPOINT_URL_AND_KEY" # Environment variable for Azure API key
37
31
  )
38
32
  _model_id_to_endpoint_and_key = {}
39
- model_exclude_list = [
40
- "Cohere-command-r-plus-xncmg",
41
- "Mistral-Nemo-klfsi",
42
- "Mistral-large-2407-ojfld",
43
- ]
44
33
 
45
34
  @classmethod
46
35
  def available(cls):
@@ -93,7 +82,7 @@ class AzureAIService(InferenceServiceABC):
93
82
 
94
83
  except Exception as e:
95
84
  raise e
96
- return [m for m in out if m not in cls.model_exclude_list]
85
+ return out
97
86
 
98
87
  @classmethod
99
88
  def create_model(
@@ -107,10 +96,6 @@ class AzureAIService(InferenceServiceABC):
107
96
  Child class of LanguageModel for interacting with Azure OpenAI models.
108
97
  """
109
98
 
110
- key_sequence = cls.key_sequence
111
- usage_sequence = cls.usage_sequence
112
- input_token_name = cls.input_token_name
113
- output_token_name = cls.output_token_name
114
99
  _inference_service_ = cls._inference_service_
115
100
  _model_ = model_name
116
101
  _parameters_ = {
@@ -118,8 +103,6 @@ class AzureAIService(InferenceServiceABC):
118
103
  "max_tokens": 512,
119
104
  "top_p": 0.9,
120
105
  }
121
- _rpm = cls.get_rpm(cls)
122
- _tpm = cls.get_tpm(cls)
123
106
 
124
107
  async def async_execute_model_call(
125
108
  self, user_prompt: str, system_prompt: str = ""
@@ -189,25 +172,25 @@ class AzureAIService(InferenceServiceABC):
189
172
  )
190
173
  return response.model_dump()
191
174
 
192
- # @staticmethod
193
- # def parse_response(raw_response: dict[str, Any]) -> str:
194
- # """Parses the API response and returns the response text."""
195
- # if (
196
- # raw_response
197
- # and "choices" in raw_response
198
- # and raw_response["choices"]
199
- # ):
200
- # response = raw_response["choices"][0]["message"]["content"]
201
- # pattern = r"^```json(?:\\n|\n)(.+?)(?:\\n|\n)```$"
202
- # match = re.match(pattern, response, re.DOTALL)
203
- # if match:
204
- # return match.group(1)
205
- # else:
206
- # out = fix_partial_correct_response(response)
207
- # if "error" not in out:
208
- # response = out["extracted_json"]
209
- # return response
210
- # return "Error parsing response"
175
+ @staticmethod
176
+ def parse_response(raw_response: dict[str, Any]) -> str:
177
+ """Parses the API response and returns the response text."""
178
+ if (
179
+ raw_response
180
+ and "choices" in raw_response
181
+ and raw_response["choices"]
182
+ ):
183
+ response = raw_response["choices"][0]["message"]["content"]
184
+ pattern = r"^```json(?:\\n|\n)(.+?)(?:\\n|\n)```$"
185
+ match = re.match(pattern, response, re.DOTALL)
186
+ if match:
187
+ return match.group(1)
188
+ else:
189
+ out = fix_partial_correct_response(response)
190
+ if "error" not in out:
191
+ response = out["extracted_json"]
192
+ return response
193
+ return "Error parsing response"
211
194
 
212
195
  LLM.__name__ = model_class_name
213
196
 
@@ -10,16 +10,10 @@ from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
10
10
 
11
11
  class GoogleService(InferenceServiceABC):
12
12
  _inference_service_ = "google"
13
- key_sequence = ["candidates", 0, "content", "parts", 0, "text"]
14
- usage_sequence = ["usageMetadata"]
15
- input_token_name = "promptTokenCount"
16
- output_token_name = "candidatesTokenCount"
17
-
18
- model_exclude_list = []
19
13
 
20
14
  @classmethod
21
15
  def available(cls):
22
- return ["gemini-pro", "gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.0-pro"]
16
+ return ["gemini-pro"]
23
17
 
24
18
  @classmethod
25
19
  def create_model(
@@ -30,15 +24,7 @@ class GoogleService(InferenceServiceABC):
30
24
 
31
25
  class LLM(LanguageModel):
32
26
  _model_ = model_name
33
- key_sequence = cls.key_sequence
34
- usage_sequence = cls.usage_sequence
35
- input_token_name = cls.input_token_name
36
- output_token_name = cls.output_token_name
37
27
  _inference_service_ = cls._inference_service_
38
-
39
- _tpm = cls.get_tpm(cls)
40
- _rpm = cls.get_rpm(cls)
41
-
42
28
  _parameters_ = {
43
29
  "temperature": 0.5,
44
30
  "topP": 1,
@@ -64,7 +50,7 @@ class GoogleService(InferenceServiceABC):
64
50
  "stopSequences": self.stopSequences,
65
51
  },
66
52
  }
67
- # print(combined_prompt)
53
+
68
54
  async with aiohttp.ClientSession() as session:
69
55
  async with session.post(
70
56
  url, headers=headers, data=json.dumps(data)
@@ -72,6 +58,16 @@ class GoogleService(InferenceServiceABC):
72
58
  raw_response_text = await response.text()
73
59
  return json.loads(raw_response_text)
74
60
 
61
+ def parse_response(self, raw_response: dict[str, Any]) -> str:
62
+ data = raw_response
63
+ try:
64
+ return data["candidates"][0]["content"]["parts"][0]["text"]
65
+ except KeyError as e:
66
+ print(
67
+ f"The data return was {data}, which was missing the key 'candidates'"
68
+ )
69
+ raise e
70
+
75
71
  LLM.__name__ = model_name
76
72
 
77
73
  return LLM
@@ -13,8 +13,6 @@ class GroqService(OpenAIService):
13
13
  _sync_client_ = groq.Groq
14
14
  _async_client_ = groq.AsyncGroq
15
15
 
16
- model_exclude_list = ["whisper-large-v3", "distil-whisper-large-v3-en"]
17
-
18
16
  # _base_url_ = "https://api.deepinfra.com/v1/openai"
19
17
  _base_url_ = None
20
18
  _models_list_cache: List[str] = []
@@ -1,77 +1,22 @@
1
1
  from abc import abstractmethod, ABC
2
- import os
2
+ from typing import Any
3
3
  import re
4
- from edsl.config import CONFIG
5
4
 
6
5
 
7
6
  class InferenceServiceABC(ABC):
8
- """
9
- Abstract class for inference services.
10
- Anthropic: https://docs.anthropic.com/en/api/rate-limits
11
- """
12
-
13
- default_levels = {
14
- "google": {"tpm": 2_000_000, "rpm": 15},
15
- "openai": {"tpm": 2_000_000, "rpm": 10_000},
16
- "anthropic": {"tpm": 2_000_000, "rpm": 500},
17
- }
18
-
19
- def __init_subclass__(cls):
20
- """
21
- Check that the subclass has the required attributes.
22
- - `key_sequence` attribute determines...
23
- - `model_exclude_list` attribute determines...
24
- """
25
- if not hasattr(cls, "key_sequence"):
26
- raise NotImplementedError(
27
- f"Class {cls.__name__} must have a 'key_sequence' attribute."
28
- )
29
- if not hasattr(cls, "model_exclude_list"):
30
- raise NotImplementedError(
31
- f"Class {cls.__name__} must have a 'model_exclude_list' attribute."
32
- )
33
-
34
- @classmethod
35
- def _get_limt(cls, limit_type: str) -> int:
36
- key = f"EDSL_SERVICE_{limit_type.upper()}_{cls._inference_service_.upper()}"
37
- if key in os.environ:
38
- return int(os.getenv(key))
39
-
40
- if cls._inference_service_ in cls.default_levels:
41
- return int(cls.default_levels[cls._inference_service_][limit_type])
42
-
43
- return int(CONFIG.get(f"EDSL_SERVICE_{limit_type.upper()}_BASELINE"))
44
-
45
- def get_tpm(cls) -> int:
46
- """
47
- Returns the TPM for the service. If the service is not defined in the environment variables, it will return the baseline TPM.
48
- """
49
- return cls._get_limt(limit_type="tpm")
50
-
51
- def get_rpm(cls):
52
- """
53
- Returns the RPM for the service. If the service is not defined in the environment variables, it will return the baseline RPM.
54
- """
55
- return cls._get_limt(limit_type="rpm")
7
+ """Abstract class for inference services."""
56
8
 
57
9
  @abstractmethod
58
10
  def available() -> list[str]:
59
- """
60
- Returns a list of available models for the service.
61
- """
62
11
  pass
63
12
 
64
13
  @abstractmethod
65
14
  def create_model():
66
- """
67
- Returns a LanguageModel object.
68
- """
69
15
  pass
70
16
 
71
17
  @staticmethod
72
18
  def to_class_name(s):
73
- """
74
- Converts a string to a valid class name.
19
+ """Convert a string to a valid class name.
75
20
 
76
21
  >>> InferenceServiceABC.to_class_name("hello world")
77
22
  'HelloWorld'
@@ -1,7 +1,8 @@
1
- from __future__ import annotations
2
- from typing import Any, List, Optional
1
+ from typing import Any, List
2
+ import re
3
3
  import os
4
4
 
5
+ # from openai import AsyncOpenAI
5
6
  import openai
6
7
 
7
8
  from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
@@ -9,8 +10,6 @@ from edsl.language_models import LanguageModel
9
10
  from edsl.inference_services.rate_limits_cache import rate_limits
10
11
  from edsl.utilities.utilities import fix_partial_correct_response
11
12
 
12
- from edsl.config import CONFIG
13
-
14
13
 
15
14
  class OpenAIService(InferenceServiceABC):
16
15
  """OpenAI service class."""
@@ -22,36 +21,19 @@ class OpenAIService(InferenceServiceABC):
22
21
  _sync_client_ = openai.OpenAI
23
22
  _async_client_ = openai.AsyncOpenAI
24
23
 
25
- _sync_client_instance = None
26
- _async_client_instance = None
27
-
28
- key_sequence = ["choices", 0, "message", "content"]
29
- usage_sequence = ["usage"]
30
- input_token_name = "prompt_tokens"
31
- output_token_name = "completion_tokens"
32
-
33
- def __init_subclass__(cls, **kwargs):
34
- super().__init_subclass__(**kwargs)
35
- # so subclasses have to create their own instances of the clients
36
- cls._sync_client_instance = None
37
- cls._async_client_instance = None
38
-
39
24
  @classmethod
40
25
  def sync_client(cls):
41
- if cls._sync_client_instance is None:
42
- cls._sync_client_instance = cls._sync_client_(
43
- api_key=os.getenv(cls._env_key_name_), base_url=cls._base_url_
44
- )
45
- return cls._sync_client_instance
26
+ return cls._sync_client_(
27
+ api_key=os.getenv(cls._env_key_name_), base_url=cls._base_url_
28
+ )
46
29
 
47
30
  @classmethod
48
31
  def async_client(cls):
49
- if cls._async_client_instance is None:
50
- cls._async_client_instance = cls._async_client_(
51
- api_key=os.getenv(cls._env_key_name_), base_url=cls._base_url_
52
- )
53
- return cls._async_client_instance
32
+ return cls._async_client_(
33
+ api_key=os.getenv(cls._env_key_name_), base_url=cls._base_url_
34
+ )
54
35
 
36
+ # TODO: Make this a coop call
55
37
  model_exclude_list = [
56
38
  "whisper-1",
57
39
  "davinci-002",
@@ -66,8 +48,6 @@ class OpenAIService(InferenceServiceABC):
66
48
  "text-embedding-3-small",
67
49
  "text-embedding-ada-002",
68
50
  "ft:davinci-002:mit-horton-lab::8OfuHgoo",
69
- "gpt-3.5-turbo-instruct-0914",
70
- "gpt-3.5-turbo-instruct",
71
51
  ]
72
52
  _models_list_cache: List[str] = []
73
53
 
@@ -81,8 +61,11 @@ class OpenAIService(InferenceServiceABC):
81
61
 
82
62
  @classmethod
83
63
  def available(cls) -> List[str]:
64
+ # from openai import OpenAI
65
+
84
66
  if not cls._models_list_cache:
85
67
  try:
68
+ # client = OpenAI(api_key = os.getenv(cls._env_key_name_), base_url = cls._base_url_)
86
69
  cls._models_list_cache = [
87
70
  m.id
88
71
  for m in cls.get_model_list()
@@ -90,6 +73,15 @@ class OpenAIService(InferenceServiceABC):
90
73
  ]
91
74
  except Exception as e:
92
75
  raise
76
+ # print(
77
+ # f"""Error retrieving models: {e}.
78
+ # See instructions about storing your API keys: https://docs.expectedparrot.com/en/latest/api_keys.html"""
79
+ # )
80
+ # cls._models_list_cache = [
81
+ # "gpt-3.5-turbo",
82
+ # "gpt-4-1106-preview",
83
+ # "gpt-4",
84
+ # ] # Fallback list
93
85
  return cls._models_list_cache
94
86
 
95
87
  @classmethod
@@ -102,14 +94,6 @@ class OpenAIService(InferenceServiceABC):
102
94
  Child class of LanguageModel for interacting with OpenAI models
103
95
  """
104
96
 
105
- key_sequence = cls.key_sequence
106
- usage_sequence = cls.usage_sequence
107
- input_token_name = cls.input_token_name
108
- output_token_name = cls.output_token_name
109
-
110
- _rpm = cls.get_rpm(cls)
111
- _tpm = cls.get_tpm(cls)
112
-
113
97
  _inference_service_ = cls._inference_service_
114
98
  _model_ = model_name
115
99
  _parameters_ = {
@@ -130,9 +114,15 @@ class OpenAIService(InferenceServiceABC):
130
114
 
131
115
  @classmethod
132
116
  def available(cls) -> list[str]:
117
+ # import openai
118
+ # client = openai.OpenAI(api_key = os.getenv(cls._env_key_name_), base_url = cls._base_url_)
119
+ # return client.models.list()
133
120
  return cls.sync_client().models.list()
134
121
 
135
122
  def get_headers(self) -> dict[str, Any]:
123
+ # from openai import OpenAI
124
+
125
+ # client = OpenAI(api_key = os.getenv(cls._env_key_name_), base_url = cls._base_url_)
136
126
  client = self.sync_client()
137
127
  response = client.chat.completions.with_raw_response.create(
138
128
  messages=[
@@ -169,9 +159,6 @@ class OpenAIService(InferenceServiceABC):
169
159
  user_prompt: str,
170
160
  system_prompt: str = "",
171
161
  encoded_image=None,
172
- invigilator: Optional[
173
- "InvigilatorAI"
174
- ] = None, # TBD - can eventually be used for function-calling
175
162
  ) -> dict[str, Any]:
176
163
  """Calls the OpenAI API and returns the API response."""
177
164
  if encoded_image:
@@ -186,16 +173,17 @@ class OpenAIService(InferenceServiceABC):
186
173
  )
187
174
  else:
188
175
  content = user_prompt
176
+ # self.client = AsyncOpenAI(
177
+ # api_key = os.getenv(cls._env_key_name_),
178
+ # base_url = cls._base_url_
179
+ # )
189
180
  client = self.async_client()
190
- messages = [
191
- {"role": "system", "content": system_prompt},
192
- {"role": "user", "content": content},
193
- ]
194
- if system_prompt == "" and self.omit_system_prompt_if_empty:
195
- messages = messages[1:]
196
181
  params = {
197
182
  "model": self.model,
198
- "messages": messages,
183
+ "messages": [
184
+ {"role": "system", "content": system_prompt},
185
+ {"role": "user", "content": content},
186
+ ],
199
187
  "temperature": self.temperature,
200
188
  "max_tokens": self.max_tokens,
201
189
  "top_p": self.top_p,
@@ -207,6 +195,24 @@ class OpenAIService(InferenceServiceABC):
207
195
  response = await client.chat.completions.create(**params)
208
196
  return response.model_dump()
209
197
 
198
+ @staticmethod
199
+ def parse_response(raw_response: dict[str, Any]) -> str:
200
+ """Parses the API response and returns the response text."""
201
+ try:
202
+ response = raw_response["choices"][0]["message"]["content"]
203
+ except KeyError:
204
+ print("Tried to parse response but failed:")
205
+ print(raw_response)
206
+ pattern = r"^```json(?:\\n|\n)(.+?)(?:\\n|\n)```$"
207
+ match = re.match(pattern, response, re.DOTALL)
208
+ if match:
209
+ return match.group(1)
210
+ else:
211
+ out = fix_partial_correct_response(response)
212
+ if "error" not in out:
213
+ response = out["extracted_json"]
214
+ return response
215
+
210
216
  LLM.__name__ = "LanguageModel"
211
217
 
212
218
  return LLM
@@ -70,6 +70,12 @@ models_available = {
70
70
  "amazon.titan-tg1-large",
71
71
  "amazon.titan-text-lite-v1",
72
72
  "amazon.titan-text-express-v1",
73
+ "ai21.j2-grande-instruct",
74
+ "ai21.j2-jumbo-instruct",
75
+ "ai21.j2-mid",
76
+ "ai21.j2-mid-v1",
77
+ "ai21.j2-ultra",
78
+ "ai21.j2-ultra-v1",
73
79
  "anthropic.claude-instant-v1",
74
80
  "anthropic.claude-v2:1",
75
81
  "anthropic.claude-v2",
@@ -10,9 +10,6 @@ from edsl.inference_services.GroqService import GroqService
10
10
  from edsl.inference_services.AwsBedrock import AwsBedrockService
11
11
  from edsl.inference_services.AzureAI import AzureAIService
12
12
  from edsl.inference_services.OllamaService import OllamaService
13
- from edsl.inference_services.TestService import TestService
14
- from edsl.inference_services.MistralAIService import MistralAIService
15
- from edsl.inference_services.TogetherAIService import TogetherAIService
16
13
 
17
14
  default = InferenceServicesCollection(
18
15
  [
@@ -24,8 +21,5 @@ default = InferenceServicesCollection(
24
21
  AwsBedrockService,
25
22
  AzureAIService,
26
23
  OllamaService,
27
- TestService,
28
- MistralAIService,
29
- TogetherAIService,
30
24
  ]
31
25
  )
edsl/jobs/Answers.py CHANGED
@@ -2,22 +2,24 @@
2
2
 
3
3
  from collections import UserDict
4
4
  from rich.table import Table
5
- from edsl.data_transfer_models import EDSLResultObjectInput
6
5
 
7
6
 
8
7
  class Answers(UserDict):
9
8
  """Helper class to hold the answers to a survey."""
10
9
 
11
- def add_answer(
12
- self, response: EDSLResultObjectInput, question: "QuestionBase"
13
- ) -> None:
14
- """Add a response to the answers dictionary."""
15
- answer = response.answer
16
- comment = response.comment
17
- generated_tokens = response.generated_tokens
10
+ def add_answer(self, response, question) -> None:
11
+ """Add a response to the answers dictionary.
12
+
13
+ >>> from edsl import QuestionFreeText
14
+ >>> q = QuestionFreeText.example()
15
+ >>> answers = Answers()
16
+ >>> answers.add_answer({"answer": "yes"}, q)
17
+ >>> answers[q.question_name]
18
+ 'yes'
19
+ """
20
+ answer = response.get("answer")
21
+ comment = response.pop("comment", None)
18
22
  # record the answer
19
- if generated_tokens:
20
- self[question.question_name + "_generated_tokens"] = generated_tokens
21
23
  self[question.question_name] = answer
22
24
  if comment:
23
25
  self[question.question_name + "_comment"] = comment