edsl 0.1.31.dev4__py3-none-any.whl → 0.1.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. edsl/__version__.py +1 -1
  2. edsl/agents/Invigilator.py +3 -4
  3. edsl/agents/PromptConstructionMixin.py +35 -15
  4. edsl/config.py +11 -1
  5. edsl/conjure/Conjure.py +6 -0
  6. edsl/data/CacheHandler.py +3 -4
  7. edsl/enums.py +4 -0
  8. edsl/exceptions/general.py +10 -8
  9. edsl/inference_services/AwsBedrock.py +110 -0
  10. edsl/inference_services/AzureAI.py +197 -0
  11. edsl/inference_services/DeepInfraService.py +4 -3
  12. edsl/inference_services/GroqService.py +3 -4
  13. edsl/inference_services/InferenceServicesCollection.py +13 -8
  14. edsl/inference_services/OllamaService.py +18 -0
  15. edsl/inference_services/OpenAIService.py +23 -18
  16. edsl/inference_services/models_available_cache.py +31 -0
  17. edsl/inference_services/registry.py +13 -1
  18. edsl/jobs/Jobs.py +100 -19
  19. edsl/jobs/buckets/TokenBucket.py +12 -4
  20. edsl/jobs/interviews/Interview.py +31 -9
  21. edsl/jobs/interviews/InterviewExceptionEntry.py +101 -0
  22. edsl/jobs/interviews/InterviewTaskBuildingMixin.py +49 -34
  23. edsl/jobs/interviews/interview_exception_tracking.py +68 -10
  24. edsl/jobs/runners/JobsRunnerAsyncio.py +36 -15
  25. edsl/jobs/runners/JobsRunnerStatusMixin.py +81 -51
  26. edsl/jobs/tasks/TaskCreators.py +1 -1
  27. edsl/jobs/tasks/TaskHistory.py +145 -1
  28. edsl/language_models/LanguageModel.py +58 -43
  29. edsl/language_models/registry.py +2 -2
  30. edsl/questions/QuestionBudget.py +0 -1
  31. edsl/questions/QuestionCheckBox.py +0 -1
  32. edsl/questions/QuestionExtract.py +0 -1
  33. edsl/questions/QuestionFreeText.py +2 -9
  34. edsl/questions/QuestionList.py +0 -1
  35. edsl/questions/QuestionMultipleChoice.py +1 -2
  36. edsl/questions/QuestionNumerical.py +0 -1
  37. edsl/questions/QuestionRank.py +0 -1
  38. edsl/results/DatasetExportMixin.py +33 -3
  39. edsl/scenarios/Scenario.py +14 -0
  40. edsl/scenarios/ScenarioList.py +216 -13
  41. edsl/scenarios/ScenarioListExportMixin.py +15 -4
  42. edsl/scenarios/ScenarioListPdfMixin.py +3 -0
  43. edsl/surveys/Rule.py +5 -2
  44. edsl/surveys/Survey.py +84 -1
  45. edsl/surveys/SurveyQualtricsImport.py +213 -0
  46. edsl/utilities/utilities.py +31 -0
  47. {edsl-0.1.31.dev4.dist-info → edsl-0.1.32.dist-info}/METADATA +4 -1
  48. {edsl-0.1.31.dev4.dist-info → edsl-0.1.32.dist-info}/RECORD +50 -45
  49. {edsl-0.1.31.dev4.dist-info → edsl-0.1.32.dist-info}/LICENSE +0 -0
  50. {edsl-0.1.31.dev4.dist-info → edsl-0.1.32.dist-info}/WHEEL +0 -0
edsl/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.1.31.dev4"
1
+ __version__ = "0.1.32"
@@ -19,7 +19,7 @@ class InvigilatorAI(PromptConstructorMixin, InvigilatorBase):
19
19
 
20
20
  async def async_answer_question(self) -> AgentResponseDict:
21
21
  """Answer a question using the AI model.
22
-
22
+
23
23
  >>> i = InvigilatorAI.example()
24
24
  >>> i.answer_question()
25
25
  {'message': '{"answer": "SPAM!"}'}
@@ -34,7 +34,7 @@ class InvigilatorAI(PromptConstructorMixin, InvigilatorBase):
34
34
  "raw_model_response": raw_response["raw_model_response"],
35
35
  }
36
36
  response = self._format_raw_response(**data)
37
- #breakpoint()
37
+ # breakpoint()
38
38
  return AgentResponseDict(**response)
39
39
 
40
40
  async def async_get_response(
@@ -44,8 +44,7 @@ class InvigilatorAI(PromptConstructorMixin, InvigilatorBase):
44
44
  iteration: int = 0,
45
45
  encoded_image=None,
46
46
  ) -> dict:
47
- """Call the LLM and gets a response. Used in the `answer_question` method.
48
- """
47
+ """Call the LLM and gets a response. Used in the `answer_question` method."""
49
48
  try:
50
49
  params = {
51
50
  "user_prompt": user_prompt.text,
@@ -214,6 +214,17 @@ class PromptConstructorMixin:
214
214
 
215
215
  return self._agent_persona_prompt
216
216
 
217
+ def prior_answers_dict(self) -> dict:
218
+ d = self.survey.question_names_to_questions()
219
+ for question, answer in self.current_answers.items():
220
+ if question in d:
221
+ d[question].answer = answer
222
+ else:
223
+ # adds a comment to the question
224
+ if (new_question := question.split("_comment")[0]) in d:
225
+ d[new_question].comment = answer
226
+ return d
227
+
217
228
  @property
218
229
  def question_instructions_prompt(self) -> Prompt:
219
230
  """
@@ -266,29 +277,38 @@ class PromptConstructorMixin:
266
277
  question_prompt = self.question.get_instructions(model=self.model.model)
267
278
 
268
279
  # TODO: Try to populate the answers in the question object if they are available
269
- d = self.survey.question_names_to_questions()
270
- for question, answer in self.current_answers.items():
271
- if question in d:
272
- d[question].answer = answer
273
- else:
274
- # adds a comment to the question
275
- if (new_question := question.split("_comment")[0]) in d:
276
- d[new_question].comment = answer
280
+ # d = self.survey.question_names_to_questions()
281
+ # for question, answer in self.current_answers.items():
282
+ # if question in d:
283
+ # d[question].answer = answer
284
+ # else:
285
+ # # adds a comment to the question
286
+ # if (new_question := question.split("_comment")[0]) in d:
287
+ # d[new_question].comment = answer
277
288
 
278
289
  question_data = self.question.data.copy()
279
290
 
280
- # check to see if the questio_options is actuall a string
291
+ # check to see if the question_options is actually a string
292
+ # This is used when the user is using the question_options as a variable from a sceario
281
293
  if "question_options" in question_data:
282
294
  if isinstance(self.question.data["question_options"], str):
283
295
  from jinja2 import Environment, meta
296
+
284
297
  env = Environment()
285
- parsed_content = env.parse(self.question.data['question_options'])
286
- question_option_key = list(meta.find_undeclared_variables(parsed_content))[0]
287
- question_data["question_options"] = self.scenario.get(question_option_key)
298
+ parsed_content = env.parse(self.question.data["question_options"])
299
+ question_option_key = list(
300
+ meta.find_undeclared_variables(parsed_content)
301
+ )[0]
302
+ question_data["question_options"] = self.scenario.get(
303
+ question_option_key
304
+ )
288
305
 
289
- #breakpoint()
306
+ # breakpoint()
290
307
  rendered_instructions = question_prompt.render(
291
- question_data | self.scenario | d | {"agent": self.agent}
308
+ question_data
309
+ | self.scenario
310
+ | self.prior_answers_dict()
311
+ | {"agent": self.agent}
292
312
  )
293
313
 
294
314
  undefined_template_variables = (
@@ -321,7 +341,7 @@ class PromptConstructorMixin:
321
341
  if self.memory_plan is not None:
322
342
  memory_prompt += self.create_memory_prompt(
323
343
  self.question.question_name
324
- ).render(self.scenario)
344
+ ).render(self.scenario | self.prior_answers_dict())
325
345
  self._prior_question_memory_prompt = memory_prompt
326
346
  return self._prior_question_memory_prompt
327
347
 
edsl/config.py CHANGED
@@ -69,6 +69,15 @@ CONFIG_MAP = {
69
69
  # "default": None,
70
70
  # "info": "This env var holds your GROQ API key (https://console.groq.com/login).",
71
71
  # },
72
+ # "AWS_ACCESS_KEY_ID" :
73
+ # "default": None,
74
+ # "info": "This env var holds your AWS access key ID.",
75
+ # "AWS_SECRET_ACCESS_KEY:
76
+ # "default": None,
77
+ # "info": "This env var holds your AWS secret access key.",
78
+ # "AZURE_ENDPOINT_URL_AND_KEY":
79
+ # "default": None,
80
+ # "info": "This env var holds your Azure endpoint URL and key (URL:key). You can have several comma-separated URL-key pairs (URL1:key1,URL2:key2).",
72
81
  }
73
82
 
74
83
 
@@ -113,8 +122,9 @@ class Config:
113
122
  """
114
123
  # for each env var in the CONFIG_MAP
115
124
  for env_var, config in CONFIG_MAP.items():
125
+ # we've set it already in _set_run_mode
116
126
  if env_var == "EDSL_RUN_MODE":
117
- continue # we've set it already in _set_run_mode
127
+ continue
118
128
  value = os.getenv(env_var)
119
129
  default_value = config.get("default")
120
130
  # if the env var is set, set it as a CONFIG attribute
edsl/conjure/Conjure.py CHANGED
@@ -35,6 +35,12 @@ class Conjure:
35
35
  # The __init__ method in Conjure won't be called because __new__ returns a different class instance.
36
36
  pass
37
37
 
38
+ @classmethod
39
+ def example(cls):
40
+ from edsl.conjure.InputData import InputDataABC
41
+
42
+ return InputDataABC.example()
43
+
38
44
 
39
45
  if __name__ == "__main__":
40
46
  pass
edsl/data/CacheHandler.py CHANGED
@@ -41,7 +41,7 @@ class CacheHandler:
41
41
  old_data = self.from_old_sqlite_cache()
42
42
  self.cache.add_from_dict(old_data)
43
43
 
44
- def create_cache_directory(self) -> None:
44
+ def create_cache_directory(self, notify=False) -> None:
45
45
  """
46
46
  Create the cache directory if one is required and it does not exist.
47
47
  """
@@ -49,9 +49,8 @@ class CacheHandler:
49
49
  dir_path = os.path.dirname(path)
50
50
  if dir_path and not os.path.exists(dir_path):
51
51
  os.makedirs(dir_path)
52
- import warnings
53
-
54
- warnings.warn(f"Created cache directory: {dir_path}")
52
+ if notify:
53
+ print(f"Created cache directory: {dir_path}")
55
54
 
56
55
  def gen_cache(self) -> Cache:
57
56
  """
edsl/enums.py CHANGED
@@ -60,6 +60,9 @@ class InferenceServiceType(EnumWithChecks):
60
60
  TEST = "test"
61
61
  ANTHROPIC = "anthropic"
62
62
  GROQ = "groq"
63
+ AZURE = "azure"
64
+ OLLAMA = "ollama"
65
+
63
66
 
64
67
  service_to_api_keyname = {
65
68
  InferenceServiceType.BEDROCK.value: "TBD",
@@ -70,6 +73,7 @@ service_to_api_keyname = {
70
73
  InferenceServiceType.TEST.value: "TBD",
71
74
  InferenceServiceType.ANTHROPIC.value: "ANTHROPIC_API_KEY",
72
75
  InferenceServiceType.GROQ.value: "GROQ_API_KEY",
76
+ InferenceServiceType.BEDROCK.value: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"],
73
77
  }
74
78
 
75
79
 
@@ -21,12 +21,14 @@ class GeneralErrors(Exception):
21
21
 
22
22
 
23
23
  class MissingAPIKeyError(GeneralErrors):
24
- def __init__(self, model_name, inference_service):
25
- full_message = dedent(
26
- f"""
27
- An API Key for model `{model_name}` is missing from the .env file.
28
- This key is associated with the inference service `{inference_service}`.
29
- Please see https://docs.expectedparrot.com/en/latest/api_keys.html for more information.
30
- """
31
- )
24
+ def __init__(self, full_message=None, model_name=None, inference_service=None):
25
+ if model_name and inference_service:
26
+ full_message = dedent(
27
+ f"""
28
+ An API Key for model `{model_name}` is missing from the .env file.
29
+ This key is associated with the inference service `{inference_service}`.
30
+ Please see https://docs.expectedparrot.com/en/latest/api_keys.html for more information.
31
+ """
32
+ )
33
+
32
34
  super().__init__(full_message)
@@ -0,0 +1,110 @@
1
+ import os
2
+ from typing import Any
3
+ import re
4
+ import boto3
5
+ from botocore.exceptions import ClientError
6
+ from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
7
+ from edsl.language_models.LanguageModel import LanguageModel
8
+ import json
9
+ from edsl.utilities.utilities import fix_partial_correct_response
10
+
11
+
12
+ class AwsBedrockService(InferenceServiceABC):
13
+ """AWS Bedrock service class."""
14
+
15
+ _inference_service_ = "bedrock"
16
+ _env_key_name_ = (
17
+ "AWS_ACCESS_KEY_ID" # or any other environment key for AWS credentials
18
+ )
19
+
20
+ @classmethod
21
+ def available(cls):
22
+ """Fetch available models from AWS Bedrock."""
23
+ if not cls._models_list_cache:
24
+ client = boto3.client("bedrock", region_name="us-west-2")
25
+ all_models_ids = [
26
+ x["modelId"] for x in client.list_foundation_models()["modelSummaries"]
27
+ ]
28
+ else:
29
+ all_models_ids = cls._models_list_cache
30
+
31
+ return all_models_ids
32
+
33
+ @classmethod
34
+ def create_model(
35
+ cls, model_name: str = "amazon.titan-tg1-large", model_class_name=None
36
+ ) -> LanguageModel:
37
+ if model_class_name is None:
38
+ model_class_name = cls.to_class_name(model_name)
39
+
40
+ class LLM(LanguageModel):
41
+ """
42
+ Child class of LanguageModel for interacting with AWS Bedrock models.
43
+ """
44
+
45
+ _inference_service_ = cls._inference_service_
46
+ _model_ = model_name
47
+ _parameters_ = {
48
+ "temperature": 0.5,
49
+ "max_tokens": 512,
50
+ "top_p": 0.9,
51
+ }
52
+
53
+ async def async_execute_model_call(
54
+ self, user_prompt: str, system_prompt: str = ""
55
+ ) -> dict[str, Any]:
56
+ """Calls the AWS Bedrock API and returns the API response."""
57
+
58
+ api_token = (
59
+ self.api_token
60
+ ) # call to check the if env variables are set.
61
+
62
+ client = boto3.client("bedrock-runtime", region_name="us-west-2")
63
+
64
+ conversation = [
65
+ {
66
+ "role": "user",
67
+ "content": [{"text": user_prompt}],
68
+ }
69
+ ]
70
+ system = [
71
+ {
72
+ "text": system_prompt,
73
+ }
74
+ ]
75
+ try:
76
+ response = client.converse(
77
+ modelId=self._model_,
78
+ messages=conversation,
79
+ inferenceConfig={
80
+ "maxTokens": self.max_tokens,
81
+ "temperature": self.temperature,
82
+ "topP": self.top_p,
83
+ },
84
+ # system=system,
85
+ additionalModelRequestFields={},
86
+ )
87
+ return response
88
+ except (ClientError, Exception) as e:
89
+ print(e)
90
+ return {"error": str(e)}
91
+
92
+ @staticmethod
93
+ def parse_response(raw_response: dict[str, Any]) -> str:
94
+ """Parses the API response and returns the response text."""
95
+ if "output" in raw_response and "message" in raw_response["output"]:
96
+ response = raw_response["output"]["message"]["content"][0]["text"]
97
+ pattern = r"^```json(?:\\n|\n)(.+?)(?:\\n|\n)```$"
98
+ match = re.match(pattern, response, re.DOTALL)
99
+ if match:
100
+ return match.group(1)
101
+ else:
102
+ out = fix_partial_correct_response(response)
103
+ if "error" not in out:
104
+ response = out["extracted_json"]
105
+ return response
106
+ return "Error parsing response"
107
+
108
+ LLM.__name__ = model_class_name
109
+
110
+ return LLM
@@ -0,0 +1,197 @@
1
+ import os
2
+ from typing import Any
3
+ import re
4
+ from openai import AsyncAzureOpenAI
5
+ from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
6
+ from edsl.language_models.LanguageModel import LanguageModel
7
+
8
+ from azure.ai.inference.aio import ChatCompletionsClient
9
+ from azure.core.credentials import AzureKeyCredential
10
+ from azure.ai.inference.models import SystemMessage, UserMessage
11
+ import asyncio
12
+ import json
13
+ from edsl.utilities.utilities import fix_partial_correct_response
14
+
15
+
16
+ def json_handle_none(value: Any) -> Any:
17
+ """
18
+ Handle None values during JSON serialization.
19
+ - Return "null" if the value is None. Otherwise, don't return anything.
20
+ """
21
+ if value is None:
22
+ return "null"
23
+
24
+
25
+ class AzureAIService(InferenceServiceABC):
26
+ """Azure AI service class."""
27
+
28
+ _inference_service_ = "azure"
29
+ _env_key_name_ = (
30
+ "AZURE_ENDPOINT_URL_AND_KEY" # Environment variable for Azure API key
31
+ )
32
+ _model_id_to_endpoint_and_key = {}
33
+
34
+ @classmethod
35
+ def available(cls):
36
+ out = []
37
+ azure_endpoints = os.getenv("AZURE_ENDPOINT_URL_AND_KEY", None)
38
+ if not azure_endpoints:
39
+ raise EnvironmentError(f"AZURE_ENDPOINT_URL_AND_KEY is not defined")
40
+ azure_endpoints = azure_endpoints.split(",")
41
+ for data in azure_endpoints:
42
+ try:
43
+ # data has this format for non openai models https://model_id.azure_endpoint:azure_key
44
+ _, endpoint, azure_endpoint_key = data.split(":")
45
+ if "openai" not in endpoint:
46
+ model_id = endpoint.split(".")[0].replace("/", "")
47
+ out.append(model_id)
48
+ cls._model_id_to_endpoint_and_key[model_id] = {
49
+ "endpoint": f"https:{endpoint}",
50
+ "azure_endpoint_key": azure_endpoint_key,
51
+ }
52
+ else:
53
+ # data has this format for openai models ,https://azure_project_id.openai.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2023-03-15-preview:azure_key
54
+ if "/deployments/" in endpoint:
55
+ start_idx = endpoint.index("/deployments/") + len(
56
+ "/deployments/"
57
+ )
58
+ end_idx = (
59
+ endpoint.index("/", start_idx)
60
+ if "/" in endpoint[start_idx:]
61
+ else len(endpoint)
62
+ )
63
+ model_id = endpoint[start_idx:end_idx]
64
+ api_version_value = None
65
+ if "api-version=" in endpoint:
66
+ start_idx = endpoint.index("api-version=") + len(
67
+ "api-version="
68
+ )
69
+ end_idx = (
70
+ endpoint.index("&", start_idx)
71
+ if "&" in endpoint[start_idx:]
72
+ else len(endpoint)
73
+ )
74
+ api_version_value = endpoint[start_idx:end_idx]
75
+
76
+ cls._model_id_to_endpoint_and_key[f"azure:{model_id}"] = {
77
+ "endpoint": f"https:{endpoint}",
78
+ "azure_endpoint_key": azure_endpoint_key,
79
+ "api_version": api_version_value,
80
+ }
81
+ out.append(f"azure:{model_id}")
82
+
83
+ except Exception as e:
84
+ raise e
85
+ return out
86
+
87
+ @classmethod
88
+ def create_model(
89
+ cls, model_name: str = "azureai", model_class_name=None
90
+ ) -> LanguageModel:
91
+ if model_class_name is None:
92
+ model_class_name = cls.to_class_name(model_name)
93
+
94
+ class LLM(LanguageModel):
95
+ """
96
+ Child class of LanguageModel for interacting with Azure OpenAI models.
97
+ """
98
+
99
+ _inference_service_ = cls._inference_service_
100
+ _model_ = model_name
101
+ _parameters_ = {
102
+ "temperature": 0.5,
103
+ "max_tokens": 512,
104
+ "top_p": 0.9,
105
+ }
106
+
107
+ async def async_execute_model_call(
108
+ self, user_prompt: str, system_prompt: str = ""
109
+ ) -> dict[str, Any]:
110
+ """Calls the Azure OpenAI API and returns the API response."""
111
+
112
+ try:
113
+ api_key = cls._model_id_to_endpoint_and_key[model_name][
114
+ "azure_endpoint_key"
115
+ ]
116
+ except:
117
+ api_key = None
118
+
119
+ if not api_key:
120
+ raise EnvironmentError(
121
+ f"AZURE_ENDPOINT_URL_AND_KEY doesn't have the endpoint:key pair for your model: {model_name}"
122
+ )
123
+
124
+ try:
125
+ endpoint = cls._model_id_to_endpoint_and_key[model_name]["endpoint"]
126
+ except:
127
+ endpoint = None
128
+
129
+ if not endpoint:
130
+ raise EnvironmentError(
131
+ f"AZURE_ENDPOINT_URL_AND_KEY doesn't have the endpoint:key pair for your model: {model_name}"
132
+ )
133
+
134
+ if "openai" not in endpoint:
135
+ client = ChatCompletionsClient(
136
+ endpoint=endpoint,
137
+ credential=AzureKeyCredential(api_key),
138
+ temperature=self.temperature,
139
+ top_p=self.top_p,
140
+ max_tokens=self.max_tokens,
141
+ )
142
+ try:
143
+ response = await client.complete(
144
+ messages=[
145
+ SystemMessage(content=system_prompt),
146
+ UserMessage(content=user_prompt),
147
+ ],
148
+ # model_extras={"safe_mode": True},
149
+ )
150
+ await client.close()
151
+ return response.as_dict()
152
+ except Exception as e:
153
+ await client.close()
154
+ return {"error": str(e)}
155
+ else:
156
+ api_version = cls._model_id_to_endpoint_and_key[model_name][
157
+ "api_version"
158
+ ]
159
+ client = AsyncAzureOpenAI(
160
+ azure_endpoint=endpoint,
161
+ api_version=api_version,
162
+ api_key=api_key,
163
+ )
164
+ response = await client.chat.completions.create(
165
+ model=model_name,
166
+ messages=[
167
+ {
168
+ "role": "user",
169
+ "content": user_prompt, # Your question can go here
170
+ },
171
+ ],
172
+ )
173
+ return response.model_dump()
174
+
175
+ @staticmethod
176
+ def parse_response(raw_response: dict[str, Any]) -> str:
177
+ """Parses the API response and returns the response text."""
178
+ if (
179
+ raw_response
180
+ and "choices" in raw_response
181
+ and raw_response["choices"]
182
+ ):
183
+ response = raw_response["choices"][0]["message"]["content"]
184
+ pattern = r"^```json(?:\\n|\n)(.+?)(?:\\n|\n)```$"
185
+ match = re.match(pattern, response, re.DOTALL)
186
+ if match:
187
+ return match.group(1)
188
+ else:
189
+ out = fix_partial_correct_response(response)
190
+ if "error" not in out:
191
+ response = out["extracted_json"]
192
+ return response
193
+ return "Error parsing response"
194
+
195
+ LLM.__name__ = model_class_name
196
+
197
+ return LLM
@@ -2,16 +2,17 @@ import aiohttp
2
2
  import json
3
3
  import requests
4
4
  from typing import Any, List
5
- #from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
5
+
6
+ # from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
6
7
  from edsl.language_models import LanguageModel
7
8
 
8
9
  from edsl.inference_services.OpenAIService import OpenAIService
9
10
 
11
+
10
12
  class DeepInfraService(OpenAIService):
11
13
  """DeepInfra service class."""
12
14
 
13
15
  _inference_service_ = "deep_infra"
14
16
  _env_key_name_ = "DEEP_INFRA_API_KEY"
15
- _base_url_ = "https://api.deepinfra.com/v1/openai"
17
+ _base_url_ = "https://api.deepinfra.com/v1/openai"
16
18
  _models_list_cache: List[str] = []
17
-
@@ -10,10 +10,9 @@ class GroqService(OpenAIService):
10
10
  _inference_service_ = "groq"
11
11
  _env_key_name_ = "GROQ_API_KEY"
12
12
 
13
- _sync_client_ = groq.Groq
14
- _async_client_ = groq.AsyncGroq
13
+ _sync_client_ = groq.Groq
14
+ _async_client_ = groq.AsyncGroq
15
15
 
16
- #_base_url_ = "https://api.deepinfra.com/v1/openai"
16
+ # _base_url_ = "https://api.deepinfra.com/v1/openai"
17
17
  _base_url_ = None
18
18
  _models_list_cache: List[str] = []
19
-
@@ -15,18 +15,19 @@ class InferenceServicesCollection:
15
15
  cls.added_models[service_name].append(model_name)
16
16
 
17
17
  @staticmethod
18
- def _get_service_available(service) -> list[str]:
18
+ def _get_service_available(service, warn: bool = False) -> list[str]:
19
19
  from_api = True
20
20
  try:
21
21
  service_models = service.available()
22
22
  except Exception as e:
23
- warnings.warn(
24
- f"""Error getting models for {service._inference_service_}.
25
- Check that you have properly stored your Expected Parrot API key and activated remote inference, or stored your own API keys for the language models that you want to use.
26
- See https://docs.expectedparrot.com/en/latest/api_keys.html for instructions on storing API keys.
27
- Relying on cache.""",
28
- UserWarning,
29
- )
23
+ if warn:
24
+ warnings.warn(
25
+ f"""Error getting models for {service._inference_service_}.
26
+ Check that you have properly stored your Expected Parrot API key and activated remote inference, or stored your own API keys for the language models that you want to use.
27
+ See https://docs.expectedparrot.com/en/latest/api_keys.html for instructions on storing API keys.
28
+ Relying on cache.""",
29
+ UserWarning,
30
+ )
30
31
  from edsl.inference_services.models_available_cache import models_available
31
32
 
32
33
  service_models = models_available.get(service._inference_service_, [])
@@ -60,4 +61,8 @@ class InferenceServicesCollection:
60
61
  if service_name is None or service_name == service._inference_service_:
61
62
  return service.create_model(model_name)
62
63
 
64
+ # if model_name == "test":
65
+ # from edsl.language_models import LanguageModel
66
+ # return LanguageModel(test = True)
67
+
63
68
  raise Exception(f"Model {model_name} not found in any of the services")
@@ -0,0 +1,18 @@
1
+ import aiohttp
2
+ import json
3
+ import requests
4
+ from typing import Any, List
5
+
6
+ # from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
7
+ from edsl.language_models import LanguageModel
8
+
9
+ from edsl.inference_services.OpenAIService import OpenAIService
10
+
11
+
12
+ class OllamaService(OpenAIService):
13
+ """DeepInfra service class."""
14
+
15
+ _inference_service_ = "ollama"
16
+ _env_key_name_ = "DEEP_INFRA_API_KEY"
17
+ _base_url_ = "http://localhost:11434/v1"
18
+ _models_list_cache: List[str] = []