edsl 0.1.31.dev3__py3-none-any.whl → 0.1.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. edsl/__version__.py +1 -1
  2. edsl/agents/Invigilator.py +7 -2
  3. edsl/agents/PromptConstructionMixin.py +35 -15
  4. edsl/config.py +15 -1
  5. edsl/conjure/Conjure.py +6 -0
  6. edsl/coop/coop.py +4 -0
  7. edsl/data/CacheHandler.py +3 -4
  8. edsl/enums.py +5 -0
  9. edsl/exceptions/general.py +10 -8
  10. edsl/inference_services/AwsBedrock.py +110 -0
  11. edsl/inference_services/AzureAI.py +197 -0
  12. edsl/inference_services/DeepInfraService.py +6 -91
  13. edsl/inference_services/GroqService.py +18 -0
  14. edsl/inference_services/InferenceServicesCollection.py +13 -8
  15. edsl/inference_services/OllamaService.py +18 -0
  16. edsl/inference_services/OpenAIService.py +68 -21
  17. edsl/inference_services/models_available_cache.py +31 -0
  18. edsl/inference_services/registry.py +14 -1
  19. edsl/jobs/Jobs.py +103 -21
  20. edsl/jobs/buckets/TokenBucket.py +12 -4
  21. edsl/jobs/interviews/Interview.py +31 -9
  22. edsl/jobs/interviews/InterviewExceptionEntry.py +101 -0
  23. edsl/jobs/interviews/InterviewTaskBuildingMixin.py +49 -33
  24. edsl/jobs/interviews/interview_exception_tracking.py +68 -10
  25. edsl/jobs/runners/JobsRunnerAsyncio.py +112 -81
  26. edsl/jobs/runners/JobsRunnerStatusData.py +0 -237
  27. edsl/jobs/runners/JobsRunnerStatusMixin.py +291 -35
  28. edsl/jobs/tasks/TaskCreators.py +8 -2
  29. edsl/jobs/tasks/TaskHistory.py +145 -1
  30. edsl/language_models/LanguageModel.py +62 -41
  31. edsl/language_models/registry.py +4 -0
  32. edsl/questions/QuestionBudget.py +0 -1
  33. edsl/questions/QuestionCheckBox.py +0 -1
  34. edsl/questions/QuestionExtract.py +0 -1
  35. edsl/questions/QuestionFreeText.py +2 -9
  36. edsl/questions/QuestionList.py +0 -1
  37. edsl/questions/QuestionMultipleChoice.py +1 -2
  38. edsl/questions/QuestionNumerical.py +0 -1
  39. edsl/questions/QuestionRank.py +0 -1
  40. edsl/results/DatasetExportMixin.py +33 -3
  41. edsl/scenarios/Scenario.py +14 -0
  42. edsl/scenarios/ScenarioList.py +216 -13
  43. edsl/scenarios/ScenarioListExportMixin.py +15 -4
  44. edsl/scenarios/ScenarioListPdfMixin.py +3 -0
  45. edsl/surveys/Rule.py +5 -2
  46. edsl/surveys/Survey.py +84 -1
  47. edsl/surveys/SurveyQualtricsImport.py +213 -0
  48. edsl/utilities/utilities.py +31 -0
  49. {edsl-0.1.31.dev3.dist-info → edsl-0.1.32.dist-info}/METADATA +5 -1
  50. {edsl-0.1.31.dev3.dist-info → edsl-0.1.32.dist-info}/RECORD +52 -46
  51. {edsl-0.1.31.dev3.dist-info → edsl-0.1.32.dist-info}/LICENSE +0 -0
  52. {edsl-0.1.31.dev3.dist-info → edsl-0.1.32.dist-info}/WHEEL +0 -0
edsl/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.1.31.dev3"
1
+ __version__ = "0.1.32"
@@ -18,7 +18,12 @@ class InvigilatorAI(PromptConstructorMixin, InvigilatorBase):
18
18
  """An invigilator that uses an AI model to answer questions."""
19
19
 
20
20
  async def async_answer_question(self) -> AgentResponseDict:
21
- """Answer a question using the AI model."""
21
+ """Answer a question using the AI model.
22
+
23
+ >>> i = InvigilatorAI.example()
24
+ >>> i.answer_question()
25
+ {'message': '{"answer": "SPAM!"}'}
26
+ """
22
27
  params = self.get_prompts() | {"iteration": self.iteration}
23
28
  raw_response = await self.async_get_response(**params)
24
29
  data = {
@@ -29,6 +34,7 @@ class InvigilatorAI(PromptConstructorMixin, InvigilatorBase):
29
34
  "raw_model_response": raw_response["raw_model_response"],
30
35
  }
31
36
  response = self._format_raw_response(**data)
37
+ # breakpoint()
32
38
  return AgentResponseDict(**response)
33
39
 
34
40
  async def async_get_response(
@@ -97,7 +103,6 @@ class InvigilatorAI(PromptConstructorMixin, InvigilatorBase):
97
103
  answer = question._translate_answer_code_to_answer(
98
104
  response["answer"], combined_dict
99
105
  )
100
- # breakpoint()
101
106
  data = {
102
107
  "answer": answer,
103
108
  "comment": response.get(
@@ -214,6 +214,17 @@ class PromptConstructorMixin:
214
214
 
215
215
  return self._agent_persona_prompt
216
216
 
217
+ def prior_answers_dict(self) -> dict:
218
+ d = self.survey.question_names_to_questions()
219
+ for question, answer in self.current_answers.items():
220
+ if question in d:
221
+ d[question].answer = answer
222
+ else:
223
+ # adds a comment to the question
224
+ if (new_question := question.split("_comment")[0]) in d:
225
+ d[new_question].comment = answer
226
+ return d
227
+
217
228
  @property
218
229
  def question_instructions_prompt(self) -> Prompt:
219
230
  """
@@ -266,29 +277,38 @@ class PromptConstructorMixin:
266
277
  question_prompt = self.question.get_instructions(model=self.model.model)
267
278
 
268
279
  # TODO: Try to populate the answers in the question object if they are available
269
- d = self.survey.question_names_to_questions()
270
- for question, answer in self.current_answers.items():
271
- if question in d:
272
- d[question].answer = answer
273
- else:
274
- # adds a comment to the question
275
- if (new_question := question.split("_comment")[0]) in d:
276
- d[new_question].comment = answer
280
+ # d = self.survey.question_names_to_questions()
281
+ # for question, answer in self.current_answers.items():
282
+ # if question in d:
283
+ # d[question].answer = answer
284
+ # else:
285
+ # # adds a comment to the question
286
+ # if (new_question := question.split("_comment")[0]) in d:
287
+ # d[new_question].comment = answer
277
288
 
278
289
  question_data = self.question.data.copy()
279
290
 
280
- # check to see if the questio_options is actuall a string
291
+ # check to see if the question_options is actually a string
292
+ # This is used when the user is using the question_options as a variable from a sceario
281
293
  if "question_options" in question_data:
282
294
  if isinstance(self.question.data["question_options"], str):
283
295
  from jinja2 import Environment, meta
296
+
284
297
  env = Environment()
285
- parsed_content = env.parse(self.question.data['question_options'])
286
- question_option_key = list(meta.find_undeclared_variables(parsed_content))[0]
287
- question_data["question_options"] = self.scenario.get(question_option_key)
298
+ parsed_content = env.parse(self.question.data["question_options"])
299
+ question_option_key = list(
300
+ meta.find_undeclared_variables(parsed_content)
301
+ )[0]
302
+ question_data["question_options"] = self.scenario.get(
303
+ question_option_key
304
+ )
288
305
 
289
- #breakpoint()
306
+ # breakpoint()
290
307
  rendered_instructions = question_prompt.render(
291
- question_data | self.scenario | d | {"agent": self.agent}
308
+ question_data
309
+ | self.scenario
310
+ | self.prior_answers_dict()
311
+ | {"agent": self.agent}
292
312
  )
293
313
 
294
314
  undefined_template_variables = (
@@ -321,7 +341,7 @@ class PromptConstructorMixin:
321
341
  if self.memory_plan is not None:
322
342
  memory_prompt += self.create_memory_prompt(
323
343
  self.question.question_name
324
- ).render(self.scenario)
344
+ ).render(self.scenario | self.prior_answers_dict())
325
345
  self._prior_question_memory_prompt = memory_prompt
326
346
  return self._prior_question_memory_prompt
327
347
 
edsl/config.py CHANGED
@@ -65,6 +65,19 @@ CONFIG_MAP = {
65
65
  # "default": None,
66
66
  # "info": "This env var holds your Anthropic API key (https://www.anthropic.com/).",
67
67
  # },
68
+ # "GROQ_API_KEY": {
69
+ # "default": None,
70
+ # "info": "This env var holds your GROQ API key (https://console.groq.com/login).",
71
+ # },
72
+ # "AWS_ACCESS_KEY_ID" :
73
+ # "default": None,
74
+ # "info": "This env var holds your AWS access key ID.",
75
+ # "AWS_SECRET_ACCESS_KEY:
76
+ # "default": None,
77
+ # "info": "This env var holds your AWS secret access key.",
78
+ # "AZURE_ENDPOINT_URL_AND_KEY":
79
+ # "default": None,
80
+ # "info": "This env var holds your Azure endpoint URL and key (URL:key). You can have several comma-separated URL-key pairs (URL1:key1,URL2:key2).",
68
81
  }
69
82
 
70
83
 
@@ -109,8 +122,9 @@ class Config:
109
122
  """
110
123
  # for each env var in the CONFIG_MAP
111
124
  for env_var, config in CONFIG_MAP.items():
125
+ # we've set it already in _set_run_mode
112
126
  if env_var == "EDSL_RUN_MODE":
113
- continue # we've set it already in _set_run_mode
127
+ continue
114
128
  value = os.getenv(env_var)
115
129
  default_value = config.get("default")
116
130
  # if the env var is set, set it as a CONFIG attribute
edsl/conjure/Conjure.py CHANGED
@@ -35,6 +35,12 @@ class Conjure:
35
35
  # The __init__ method in Conjure won't be called because __new__ returns a different class instance.
36
36
  pass
37
37
 
38
+ @classmethod
39
+ def example(cls):
40
+ from edsl.conjure.InputData import InputDataABC
41
+
42
+ return InputDataABC.example()
43
+
38
44
 
39
45
  if __name__ == "__main__":
40
46
  pass
edsl/coop/coop.py CHANGED
@@ -465,6 +465,7 @@ class Coop:
465
465
  description: Optional[str] = None,
466
466
  status: RemoteJobStatus = "queued",
467
467
  visibility: Optional[VisibilityType] = "unlisted",
468
+ iterations: Optional[int] = 1,
468
469
  ) -> dict:
469
470
  """
470
471
  Send a remote inference job to the server.
@@ -473,6 +474,7 @@ class Coop:
473
474
  :param optional description: A description for this entry in the remote cache.
474
475
  :param status: The status of the job. Should be 'queued', unless you are debugging.
475
476
  :param visibility: The visibility of the cache entry.
477
+ :param iterations: The number of times to run each interview.
476
478
 
477
479
  >>> job = Jobs.example()
478
480
  >>> coop.remote_inference_create(job=job, description="My job")
@@ -488,6 +490,7 @@ class Coop:
488
490
  ),
489
491
  "description": description,
490
492
  "status": status,
493
+ "iterations": iterations,
491
494
  "visibility": visibility,
492
495
  "version": self._edsl_version,
493
496
  },
@@ -498,6 +501,7 @@ class Coop:
498
501
  "uuid": response_json.get("jobs_uuid"),
499
502
  "description": response_json.get("description"),
500
503
  "status": response_json.get("status"),
504
+ "iterations": response_json.get("iterations"),
501
505
  "visibility": response_json.get("visibility"),
502
506
  "version": self._edsl_version,
503
507
  }
edsl/data/CacheHandler.py CHANGED
@@ -41,7 +41,7 @@ class CacheHandler:
41
41
  old_data = self.from_old_sqlite_cache()
42
42
  self.cache.add_from_dict(old_data)
43
43
 
44
- def create_cache_directory(self) -> None:
44
+ def create_cache_directory(self, notify=False) -> None:
45
45
  """
46
46
  Create the cache directory if one is required and it does not exist.
47
47
  """
@@ -49,9 +49,8 @@ class CacheHandler:
49
49
  dir_path = os.path.dirname(path)
50
50
  if dir_path and not os.path.exists(dir_path):
51
51
  os.makedirs(dir_path)
52
- import warnings
53
-
54
- warnings.warn(f"Created cache directory: {dir_path}")
52
+ if notify:
53
+ print(f"Created cache directory: {dir_path}")
55
54
 
56
55
  def gen_cache(self) -> Cache:
57
56
  """
edsl/enums.py CHANGED
@@ -59,6 +59,9 @@ class InferenceServiceType(EnumWithChecks):
59
59
  GOOGLE = "google"
60
60
  TEST = "test"
61
61
  ANTHROPIC = "anthropic"
62
+ GROQ = "groq"
63
+ AZURE = "azure"
64
+ OLLAMA = "ollama"
62
65
 
63
66
 
64
67
  service_to_api_keyname = {
@@ -69,6 +72,8 @@ service_to_api_keyname = {
69
72
  InferenceServiceType.GOOGLE.value: "GOOGLE_API_KEY",
70
73
  InferenceServiceType.TEST.value: "TBD",
71
74
  InferenceServiceType.ANTHROPIC.value: "ANTHROPIC_API_KEY",
75
+ InferenceServiceType.GROQ.value: "GROQ_API_KEY",
76
+ InferenceServiceType.BEDROCK.value: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"],
72
77
  }
73
78
 
74
79
 
@@ -21,12 +21,14 @@ class GeneralErrors(Exception):
21
21
 
22
22
 
23
23
  class MissingAPIKeyError(GeneralErrors):
24
- def __init__(self, model_name, inference_service):
25
- full_message = dedent(
26
- f"""
27
- An API Key for model `{model_name}` is missing from the .env file.
28
- This key is associated with the inference service `{inference_service}`.
29
- Please see https://docs.expectedparrot.com/en/latest/api_keys.html for more information.
30
- """
31
- )
24
+ def __init__(self, full_message=None, model_name=None, inference_service=None):
25
+ if model_name and inference_service:
26
+ full_message = dedent(
27
+ f"""
28
+ An API Key for model `{model_name}` is missing from the .env file.
29
+ This key is associated with the inference service `{inference_service}`.
30
+ Please see https://docs.expectedparrot.com/en/latest/api_keys.html for more information.
31
+ """
32
+ )
33
+
32
34
  super().__init__(full_message)
@@ -0,0 +1,110 @@
1
+ import os
2
+ from typing import Any
3
+ import re
4
+ import boto3
5
+ from botocore.exceptions import ClientError
6
+ from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
7
+ from edsl.language_models.LanguageModel import LanguageModel
8
+ import json
9
+ from edsl.utilities.utilities import fix_partial_correct_response
10
+
11
+
12
+ class AwsBedrockService(InferenceServiceABC):
13
+ """AWS Bedrock service class."""
14
+
15
+ _inference_service_ = "bedrock"
16
+ _env_key_name_ = (
17
+ "AWS_ACCESS_KEY_ID" # or any other environment key for AWS credentials
18
+ )
19
+
20
+ @classmethod
21
+ def available(cls):
22
+ """Fetch available models from AWS Bedrock."""
23
+ if not cls._models_list_cache:
24
+ client = boto3.client("bedrock", region_name="us-west-2")
25
+ all_models_ids = [
26
+ x["modelId"] for x in client.list_foundation_models()["modelSummaries"]
27
+ ]
28
+ else:
29
+ all_models_ids = cls._models_list_cache
30
+
31
+ return all_models_ids
32
+
33
+ @classmethod
34
+ def create_model(
35
+ cls, model_name: str = "amazon.titan-tg1-large", model_class_name=None
36
+ ) -> LanguageModel:
37
+ if model_class_name is None:
38
+ model_class_name = cls.to_class_name(model_name)
39
+
40
+ class LLM(LanguageModel):
41
+ """
42
+ Child class of LanguageModel for interacting with AWS Bedrock models.
43
+ """
44
+
45
+ _inference_service_ = cls._inference_service_
46
+ _model_ = model_name
47
+ _parameters_ = {
48
+ "temperature": 0.5,
49
+ "max_tokens": 512,
50
+ "top_p": 0.9,
51
+ }
52
+
53
+ async def async_execute_model_call(
54
+ self, user_prompt: str, system_prompt: str = ""
55
+ ) -> dict[str, Any]:
56
+ """Calls the AWS Bedrock API and returns the API response."""
57
+
58
+ api_token = (
59
+ self.api_token
60
+ ) # call to check the if env variables are set.
61
+
62
+ client = boto3.client("bedrock-runtime", region_name="us-west-2")
63
+
64
+ conversation = [
65
+ {
66
+ "role": "user",
67
+ "content": [{"text": user_prompt}],
68
+ }
69
+ ]
70
+ system = [
71
+ {
72
+ "text": system_prompt,
73
+ }
74
+ ]
75
+ try:
76
+ response = client.converse(
77
+ modelId=self._model_,
78
+ messages=conversation,
79
+ inferenceConfig={
80
+ "maxTokens": self.max_tokens,
81
+ "temperature": self.temperature,
82
+ "topP": self.top_p,
83
+ },
84
+ # system=system,
85
+ additionalModelRequestFields={},
86
+ )
87
+ return response
88
+ except (ClientError, Exception) as e:
89
+ print(e)
90
+ return {"error": str(e)}
91
+
92
+ @staticmethod
93
+ def parse_response(raw_response: dict[str, Any]) -> str:
94
+ """Parses the API response and returns the response text."""
95
+ if "output" in raw_response and "message" in raw_response["output"]:
96
+ response = raw_response["output"]["message"]["content"][0]["text"]
97
+ pattern = r"^```json(?:\\n|\n)(.+?)(?:\\n|\n)```$"
98
+ match = re.match(pattern, response, re.DOTALL)
99
+ if match:
100
+ return match.group(1)
101
+ else:
102
+ out = fix_partial_correct_response(response)
103
+ if "error" not in out:
104
+ response = out["extracted_json"]
105
+ return response
106
+ return "Error parsing response"
107
+
108
+ LLM.__name__ = model_class_name
109
+
110
+ return LLM
@@ -0,0 +1,197 @@
1
+ import os
2
+ from typing import Any
3
+ import re
4
+ from openai import AsyncAzureOpenAI
5
+ from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
6
+ from edsl.language_models.LanguageModel import LanguageModel
7
+
8
+ from azure.ai.inference.aio import ChatCompletionsClient
9
+ from azure.core.credentials import AzureKeyCredential
10
+ from azure.ai.inference.models import SystemMessage, UserMessage
11
+ import asyncio
12
+ import json
13
+ from edsl.utilities.utilities import fix_partial_correct_response
14
+
15
+
16
+ def json_handle_none(value: Any) -> Any:
17
+ """
18
+ Handle None values during JSON serialization.
19
+ - Return "null" if the value is None. Otherwise, don't return anything.
20
+ """
21
+ if value is None:
22
+ return "null"
23
+
24
+
25
+ class AzureAIService(InferenceServiceABC):
26
+ """Azure AI service class."""
27
+
28
+ _inference_service_ = "azure"
29
+ _env_key_name_ = (
30
+ "AZURE_ENDPOINT_URL_AND_KEY" # Environment variable for Azure API key
31
+ )
32
+ _model_id_to_endpoint_and_key = {}
33
+
34
+ @classmethod
35
+ def available(cls):
36
+ out = []
37
+ azure_endpoints = os.getenv("AZURE_ENDPOINT_URL_AND_KEY", None)
38
+ if not azure_endpoints:
39
+ raise EnvironmentError(f"AZURE_ENDPOINT_URL_AND_KEY is not defined")
40
+ azure_endpoints = azure_endpoints.split(",")
41
+ for data in azure_endpoints:
42
+ try:
43
+ # data has this format for non openai models https://model_id.azure_endpoint:azure_key
44
+ _, endpoint, azure_endpoint_key = data.split(":")
45
+ if "openai" not in endpoint:
46
+ model_id = endpoint.split(".")[0].replace("/", "")
47
+ out.append(model_id)
48
+ cls._model_id_to_endpoint_and_key[model_id] = {
49
+ "endpoint": f"https:{endpoint}",
50
+ "azure_endpoint_key": azure_endpoint_key,
51
+ }
52
+ else:
53
+ # data has this format for openai models ,https://azure_project_id.openai.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2023-03-15-preview:azure_key
54
+ if "/deployments/" in endpoint:
55
+ start_idx = endpoint.index("/deployments/") + len(
56
+ "/deployments/"
57
+ )
58
+ end_idx = (
59
+ endpoint.index("/", start_idx)
60
+ if "/" in endpoint[start_idx:]
61
+ else len(endpoint)
62
+ )
63
+ model_id = endpoint[start_idx:end_idx]
64
+ api_version_value = None
65
+ if "api-version=" in endpoint:
66
+ start_idx = endpoint.index("api-version=") + len(
67
+ "api-version="
68
+ )
69
+ end_idx = (
70
+ endpoint.index("&", start_idx)
71
+ if "&" in endpoint[start_idx:]
72
+ else len(endpoint)
73
+ )
74
+ api_version_value = endpoint[start_idx:end_idx]
75
+
76
+ cls._model_id_to_endpoint_and_key[f"azure:{model_id}"] = {
77
+ "endpoint": f"https:{endpoint}",
78
+ "azure_endpoint_key": azure_endpoint_key,
79
+ "api_version": api_version_value,
80
+ }
81
+ out.append(f"azure:{model_id}")
82
+
83
+ except Exception as e:
84
+ raise e
85
+ return out
86
+
87
+ @classmethod
88
+ def create_model(
89
+ cls, model_name: str = "azureai", model_class_name=None
90
+ ) -> LanguageModel:
91
+ if model_class_name is None:
92
+ model_class_name = cls.to_class_name(model_name)
93
+
94
+ class LLM(LanguageModel):
95
+ """
96
+ Child class of LanguageModel for interacting with Azure OpenAI models.
97
+ """
98
+
99
+ _inference_service_ = cls._inference_service_
100
+ _model_ = model_name
101
+ _parameters_ = {
102
+ "temperature": 0.5,
103
+ "max_tokens": 512,
104
+ "top_p": 0.9,
105
+ }
106
+
107
+ async def async_execute_model_call(
108
+ self, user_prompt: str, system_prompt: str = ""
109
+ ) -> dict[str, Any]:
110
+ """Calls the Azure OpenAI API and returns the API response."""
111
+
112
+ try:
113
+ api_key = cls._model_id_to_endpoint_and_key[model_name][
114
+ "azure_endpoint_key"
115
+ ]
116
+ except:
117
+ api_key = None
118
+
119
+ if not api_key:
120
+ raise EnvironmentError(
121
+ f"AZURE_ENDPOINT_URL_AND_KEY doesn't have the endpoint:key pair for your model: {model_name}"
122
+ )
123
+
124
+ try:
125
+ endpoint = cls._model_id_to_endpoint_and_key[model_name]["endpoint"]
126
+ except:
127
+ endpoint = None
128
+
129
+ if not endpoint:
130
+ raise EnvironmentError(
131
+ f"AZURE_ENDPOINT_URL_AND_KEY doesn't have the endpoint:key pair for your model: {model_name}"
132
+ )
133
+
134
+ if "openai" not in endpoint:
135
+ client = ChatCompletionsClient(
136
+ endpoint=endpoint,
137
+ credential=AzureKeyCredential(api_key),
138
+ temperature=self.temperature,
139
+ top_p=self.top_p,
140
+ max_tokens=self.max_tokens,
141
+ )
142
+ try:
143
+ response = await client.complete(
144
+ messages=[
145
+ SystemMessage(content=system_prompt),
146
+ UserMessage(content=user_prompt),
147
+ ],
148
+ # model_extras={"safe_mode": True},
149
+ )
150
+ await client.close()
151
+ return response.as_dict()
152
+ except Exception as e:
153
+ await client.close()
154
+ return {"error": str(e)}
155
+ else:
156
+ api_version = cls._model_id_to_endpoint_and_key[model_name][
157
+ "api_version"
158
+ ]
159
+ client = AsyncAzureOpenAI(
160
+ azure_endpoint=endpoint,
161
+ api_version=api_version,
162
+ api_key=api_key,
163
+ )
164
+ response = await client.chat.completions.create(
165
+ model=model_name,
166
+ messages=[
167
+ {
168
+ "role": "user",
169
+ "content": user_prompt, # Your question can go here
170
+ },
171
+ ],
172
+ )
173
+ return response.model_dump()
174
+
175
+ @staticmethod
176
+ def parse_response(raw_response: dict[str, Any]) -> str:
177
+ """Parses the API response and returns the response text."""
178
+ if (
179
+ raw_response
180
+ and "choices" in raw_response
181
+ and raw_response["choices"]
182
+ ):
183
+ response = raw_response["choices"][0]["message"]["content"]
184
+ pattern = r"^```json(?:\\n|\n)(.+?)(?:\\n|\n)```$"
185
+ match = re.match(pattern, response, re.DOTALL)
186
+ if match:
187
+ return match.group(1)
188
+ else:
189
+ out = fix_partial_correct_response(response)
190
+ if "error" not in out:
191
+ response = out["extracted_json"]
192
+ return response
193
+ return "Error parsing response"
194
+
195
+ LLM.__name__ = model_class_name
196
+
197
+ return LLM
@@ -2,102 +2,17 @@ import aiohttp
2
2
  import json
3
3
  import requests
4
4
  from typing import Any, List
5
- from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
5
+
6
+ # from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
6
7
  from edsl.language_models import LanguageModel
7
8
 
9
+ from edsl.inference_services.OpenAIService import OpenAIService
10
+
8
11
 
9
- class DeepInfraService(InferenceServiceABC):
12
+ class DeepInfraService(OpenAIService):
10
13
  """DeepInfra service class."""
11
14
 
12
15
  _inference_service_ = "deep_infra"
13
16
  _env_key_name_ = "DEEP_INFRA_API_KEY"
14
-
17
+ _base_url_ = "https://api.deepinfra.com/v1/openai"
15
18
  _models_list_cache: List[str] = []
16
-
17
- @classmethod
18
- def available(cls):
19
- text_models = cls.full_details_available()
20
- return [m["model_name"] for m in text_models]
21
-
22
- @classmethod
23
- def full_details_available(cls, verbose=False):
24
- if not cls._models_list_cache:
25
- url = "https://api.deepinfra.com/models/list"
26
- response = requests.get(url)
27
- if response.status_code == 200:
28
- text_generation_models = [
29
- r for r in response.json() if r["type"] == "text-generation"
30
- ]
31
- cls._models_list_cache = text_generation_models
32
-
33
- from rich import print_json
34
- import json
35
-
36
- if verbose:
37
- print_json(json.dumps(text_generation_models))
38
- return text_generation_models
39
- else:
40
- return f"Failed to fetch data: Status code {response.status_code}"
41
- else:
42
- return cls._models_list_cache
43
-
44
- @classmethod
45
- def create_model(cls, model_name: str, model_class_name=None) -> LanguageModel:
46
- base_url = "https://api.deepinfra.com/v1/inference/"
47
- if model_class_name is None:
48
- model_class_name = cls.to_class_name(model_name)
49
- url = f"{base_url}{model_name}"
50
-
51
- class LLM(LanguageModel):
52
- _inference_service_ = cls._inference_service_
53
- _model_ = model_name
54
- _parameters_ = {
55
- "temperature": 0.7,
56
- "top_p": 0.2,
57
- "top_k": 0.1,
58
- "max_new_tokens": 512,
59
- "stopSequences": [],
60
- }
61
-
62
- async def async_execute_model_call(
63
- self, user_prompt: str, system_prompt: str = ""
64
- ) -> dict[str, Any]:
65
- self.url = url
66
- headers = {
67
- "Content-Type": "application/json",
68
- "Authorization": f"bearer {self.api_token}",
69
- }
70
- # don't mess w/ the newlines
71
- data = {
72
- "input": f"""
73
- [INST]<<SYS>>
74
- {system_prompt}
75
- <<SYS>>{user_prompt}[/INST]
76
- """,
77
- "stream": False,
78
- "temperature": self.temperature,
79
- "top_p": self.top_p,
80
- "top_k": self.top_k,
81
- "max_new_tokens": self.max_new_tokens,
82
- }
83
- async with aiohttp.ClientSession() as session:
84
- async with session.post(
85
- self.url, headers=headers, data=json.dumps(data)
86
- ) as response:
87
- raw_response_text = await response.text()
88
- return json.loads(raw_response_text)
89
-
90
- def parse_response(self, raw_response: dict[str, Any]) -> str:
91
- if "results" not in raw_response:
92
- raise Exception(
93
- f"Deep Infra response does not contain 'results' key: {raw_response}"
94
- )
95
- if "generated_text" not in raw_response["results"][0]:
96
- raise Exception(
97
- f"Deep Infra response does not contain 'generate_text' key: {raw_response['results'][0]}"
98
- )
99
- return raw_response["results"][0]["generated_text"]
100
-
101
- LLM.__name__ = model_class_name
102
-
103
- return LLM