edsl 0.1.33.dev2__py3-none-any.whl → 0.1.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. edsl/Base.py +24 -14
  2. edsl/__init__.py +1 -0
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +6 -6
  5. edsl/agents/Invigilator.py +28 -6
  6. edsl/agents/InvigilatorBase.py +8 -27
  7. edsl/agents/{PromptConstructionMixin.py → PromptConstructor.py} +150 -182
  8. edsl/agents/prompt_helpers.py +129 -0
  9. edsl/config.py +26 -34
  10. edsl/coop/coop.py +14 -4
  11. edsl/data_transfer_models.py +26 -73
  12. edsl/enums.py +2 -0
  13. edsl/inference_services/AnthropicService.py +5 -2
  14. edsl/inference_services/AwsBedrock.py +5 -2
  15. edsl/inference_services/AzureAI.py +5 -2
  16. edsl/inference_services/GoogleService.py +108 -33
  17. edsl/inference_services/InferenceServiceABC.py +44 -13
  18. edsl/inference_services/MistralAIService.py +5 -2
  19. edsl/inference_services/OpenAIService.py +10 -6
  20. edsl/inference_services/TestService.py +34 -16
  21. edsl/inference_services/TogetherAIService.py +170 -0
  22. edsl/inference_services/registry.py +2 -0
  23. edsl/jobs/Jobs.py +109 -18
  24. edsl/jobs/buckets/BucketCollection.py +24 -15
  25. edsl/jobs/buckets/TokenBucket.py +64 -10
  26. edsl/jobs/interviews/Interview.py +130 -49
  27. edsl/jobs/interviews/{interview_exception_tracking.py → InterviewExceptionCollection.py} +16 -0
  28. edsl/jobs/interviews/InterviewExceptionEntry.py +2 -0
  29. edsl/jobs/runners/JobsRunnerAsyncio.py +119 -173
  30. edsl/jobs/runners/JobsRunnerStatus.py +332 -0
  31. edsl/jobs/tasks/QuestionTaskCreator.py +1 -13
  32. edsl/jobs/tasks/TaskHistory.py +17 -0
  33. edsl/language_models/LanguageModel.py +36 -38
  34. edsl/language_models/registry.py +13 -9
  35. edsl/language_models/utilities.py +5 -2
  36. edsl/questions/QuestionBase.py +74 -16
  37. edsl/questions/QuestionBaseGenMixin.py +28 -0
  38. edsl/questions/QuestionBudget.py +93 -41
  39. edsl/questions/QuestionCheckBox.py +1 -1
  40. edsl/questions/QuestionFreeText.py +6 -0
  41. edsl/questions/QuestionMultipleChoice.py +13 -24
  42. edsl/questions/QuestionNumerical.py +5 -4
  43. edsl/questions/Quick.py +41 -0
  44. edsl/questions/ResponseValidatorABC.py +11 -6
  45. edsl/questions/derived/QuestionLinearScale.py +4 -1
  46. edsl/questions/derived/QuestionTopK.py +4 -1
  47. edsl/questions/derived/QuestionYesNo.py +8 -2
  48. edsl/questions/descriptors.py +12 -11
  49. edsl/questions/templates/budget/__init__.py +0 -0
  50. edsl/questions/templates/budget/answering_instructions.jinja +7 -0
  51. edsl/questions/templates/budget/question_presentation.jinja +7 -0
  52. edsl/questions/templates/extract/__init__.py +0 -0
  53. edsl/questions/templates/numerical/answering_instructions.jinja +0 -1
  54. edsl/questions/templates/rank/__init__.py +0 -0
  55. edsl/questions/templates/yes_no/answering_instructions.jinja +2 -2
  56. edsl/results/DatasetExportMixin.py +5 -1
  57. edsl/results/Result.py +1 -1
  58. edsl/results/Results.py +4 -1
  59. edsl/scenarios/FileStore.py +178 -34
  60. edsl/scenarios/Scenario.py +76 -37
  61. edsl/scenarios/ScenarioList.py +19 -2
  62. edsl/scenarios/ScenarioListPdfMixin.py +150 -4
  63. edsl/study/Study.py +32 -0
  64. edsl/surveys/DAG.py +62 -0
  65. edsl/surveys/MemoryPlan.py +26 -0
  66. edsl/surveys/Rule.py +34 -1
  67. edsl/surveys/RuleCollection.py +55 -5
  68. edsl/surveys/Survey.py +189 -10
  69. edsl/surveys/base.py +4 -0
  70. edsl/templates/error_reporting/interview_details.html +6 -1
  71. edsl/utilities/utilities.py +9 -1
  72. {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/METADATA +3 -1
  73. {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/RECORD +75 -69
  74. edsl/jobs/interviews/retry_management.py +0 -39
  75. edsl/jobs/runners/JobsRunnerStatusMixin.py +0 -333
  76. edsl/scenarios/ScenarioImageMixin.py +0 -100
  77. {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/LICENSE +0 -0
  78. {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/WHEEL +0 -0
edsl/config.py CHANGED
@@ -1,73 +1,65 @@
1
1
  """This module provides a Config class that loads environment variables from a .env file and sets them as class attributes."""
2
2
 
3
3
  import os
4
+ from dotenv import load_dotenv, find_dotenv
4
5
  from edsl.exceptions import (
5
6
  InvalidEnvironmentVariableError,
6
7
  MissingEnvironmentVariableError,
7
8
  )
8
- from dotenv import load_dotenv, find_dotenv
9
9
 
10
10
  # valid values for EDSL_RUN_MODE
11
- EDSL_RUN_MODES = ["development", "development-testrun", "production"]
11
+ EDSL_RUN_MODES = [
12
+ "development",
13
+ "development-testrun",
14
+ "production",
15
+ ]
12
16
 
13
17
  # `default` is used to impute values only in "production" mode
14
18
  # `info` gives a brief description of the env var
15
19
  CONFIG_MAP = {
16
20
  "EDSL_RUN_MODE": {
17
21
  "default": "production",
18
- "info": "This env var determines the run mode of the application.",
19
- },
20
- "EDSL_DATABASE_PATH": {
21
- "default": f"sqlite:///{os.path.join(os.getcwd(), '.edsl_cache/data.db')}",
22
- "info": "This env var determines the path to the cache file.",
23
- },
24
- "EDSL_LOGGING_PATH": {
25
- "default": f"{os.path.join(os.getcwd(), 'interview.log')}",
26
- "info": "This env var determines the path to the log file.",
22
+ "info": "This config var determines the run mode of the application.",
27
23
  },
28
24
  "EDSL_API_TIMEOUT": {
29
25
  "default": "60",
30
- "info": "This env var determines the maximum number of seconds to wait for an API call to return.",
26
+ "info": "This config var determines the maximum number of seconds to wait for an API call to return.",
31
27
  },
32
28
  "EDSL_BACKOFF_START_SEC": {
33
29
  "default": "1",
34
- "info": "This env var determines the number of seconds to wait before retrying a failed API call.",
30
+ "info": "This config var determines the number of seconds to wait before retrying a failed API call.",
35
31
  },
36
- "EDSL_MAX_BACKOFF_SEC": {
32
+ "EDSL_BACKOFF_MAX_SEC": {
37
33
  "default": "60",
38
- "info": "This env var determines the maximum number of seconds to wait before retrying a failed API call.",
34
+ "info": "This config var determines the maximum number of seconds to wait before retrying a failed API call.",
39
35
  },
40
- "EDSL_MAX_ATTEMPTS": {
41
- "default": "5",
42
- "info": "This env var determines the maximum number of times to retry a failed API call.",
36
+ "EDSL_DATABASE_PATH": {
37
+ "default": f"sqlite:///{os.path.join(os.getcwd(), '.edsl_cache/data.db')}",
38
+ "info": "This config var determines the path to the cache file.",
43
39
  },
44
40
  "EDSL_DEFAULT_MODEL": {
45
41
  "default": "gpt-4o",
46
- "info": "This env var holds the default model name.",
42
+ "info": "This config var holds the default model that will be used if a model is not explicitly passed.",
47
43
  },
48
- "EDSL_SERVICE_TPM_BASELINE": {
49
- "default": "2000000",
50
- "info": "This env var holds the maximum number of tokens per minute for all models. Model-specific values such as EDSL_SERVICE_TPM_OPENAI will override this.",
44
+ "EDSL_FETCH_TOKEN_PRICES": {
45
+ "default": "True",
46
+ "info": "This config var determines whether to fetch prices for tokens used in remote inference",
47
+ },
48
+ "EDSL_MAX_ATTEMPTS": {
49
+ "default": "5",
50
+ "info": "This config var determines the maximum number of times to retry a failed API call.",
51
51
  },
52
52
  "EDSL_SERVICE_RPM_BASELINE": {
53
53
  "default": "100",
54
- "info": "This env var holds the maximum number of requests per minute for OpenAI. Model-specific values such as EDSL_SERVICE_RPM_OPENAI will override this.",
54
+ "info": "This config var holds the maximum number of requests per minute. Model-specific values provided in env vars such as EDSL_SERVICE_RPM_OPENAI will override this. value for the corresponding model",
55
55
  },
56
- "EDSL_SERVICE_TPM_OPENAI": {
56
+ "EDSL_SERVICE_TPM_BASELINE": {
57
57
  "default": "2000000",
58
- "info": "This env var holds the maximum number of tokens per minute for OpenAI.",
59
- },
60
- "EDSL_SERVICE_RPM_OPENAI": {
61
- "default": "100",
62
- "info": "This env var holds the maximum number of requests per minute for OpenAI.",
63
- },
64
- "EDSL_FETCH_TOKEN_PRICES": {
65
- "default": "True",
66
- "info": "Whether to fetch the prices for tokens",
58
+ "info": "This config var holds the maximum number of tokens per minute for all models. Model-specific values provided in env vars such as EDSL_SERVICE_TPM_OPENAI will override this value for the corresponding model.",
67
59
  },
68
60
  "EXPECTED_PARROT_URL": {
69
61
  "default": "https://www.expectedparrot.com",
70
- "info": "This env var holds the URL of the Expected Parrot API.",
62
+ "info": "This config var holds the URL of the Expected Parrot API.",
71
63
  },
72
64
  }
73
65
 
edsl/coop/coop.py CHANGED
@@ -59,8 +59,16 @@ class Coop:
59
59
  Send a request to the server and return the response.
60
60
  """
61
61
  url = f"{self.url}/{uri}"
62
+ method = method.upper()
63
+ if payload is None:
64
+ timeout = 20
65
+ elif (
66
+ method.upper() == "POST"
67
+ and "json_string" in payload
68
+ and payload.get("json_string") is not None
69
+ ):
70
+ timeout = max(20, (len(payload.get("json_string", "")) // (1024 * 1024)))
62
71
  try:
63
- method = method.upper()
64
72
  if method in ["GET", "DELETE"]:
65
73
  response = requests.request(
66
74
  method, url, params=params, headers=self.headers, timeout=timeout
@@ -77,7 +85,7 @@ class Coop:
77
85
  else:
78
86
  raise Exception(f"Invalid {method=}.")
79
87
  except requests.ConnectionError:
80
- raise requests.ConnectionError("Could not connect to the server.")
88
+ raise requests.ConnectionError(f"Could not connect to the server at {url}.")
81
89
 
82
90
  return response
83
91
 
@@ -87,6 +95,7 @@ class Coop:
87
95
  """
88
96
  if response.status_code >= 400:
89
97
  message = response.json().get("detail")
98
+ # print(response.text)
90
99
  if "Authorization" in message:
91
100
  print(message)
92
101
  message = "Please provide an Expected Parrot API key."
@@ -794,8 +803,9 @@ def main():
794
803
  ##############
795
804
  job = Jobs.example()
796
805
  coop.remote_inference_cost(job)
797
- results = coop.remote_inference_create(job)
798
- coop.remote_inference_get(results.get("uuid"))
806
+ job_coop_object = coop.remote_inference_create(job)
807
+ job_coop_results = coop.remote_inference_get(job_coop_object.get("uuid"))
808
+ coop.get(uuid=job_coop_results.get("results_uuid"))
799
809
 
800
810
  ##############
801
811
  # E. Errors
@@ -1,4 +1,6 @@
1
1
  from typing import NamedTuple, Dict, List, Optional, Any
2
+ from dataclasses import dataclass, fields
3
+ import reprlib
2
4
 
3
5
 
4
6
  class ModelInputs(NamedTuple):
@@ -45,76 +47,27 @@ class EDSLResultObjectInput(NamedTuple):
45
47
  cost: Optional[float] = None
46
48
 
47
49
 
48
- # from collections import UserDict
49
-
50
-
51
- # class AgentResponseDict(UserDict):
52
- # """A dictionary to store the response of the agent to a question."""
53
-
54
- # def __init__(
55
- # self,
56
- # *,
57
- # question_name,
58
- # answer,
59
- # prompts,
60
- # generated_tokens: str,
61
- # usage=None,
62
- # comment=None,
63
- # cached_response=None,
64
- # raw_model_response=None,
65
- # simple_model_raw_response=None,
66
- # cache_used=None,
67
- # cache_key=None,
68
- # ):
69
- # """Initialize the AgentResponseDict object."""
70
- # usage = usage or {"prompt_tokens": 0, "completion_tokens": 0}
71
- # if generated_tokens is None:
72
- # raise ValueError("generated_tokens must be provided")
73
- # self.data = {
74
- # "answer": answer,
75
- # "comment": comment,
76
- # "question_name": question_name,
77
- # "prompts": prompts,
78
- # "usage": usage,
79
- # "cached_response": cached_response,
80
- # "raw_model_response": raw_model_response,
81
- # "simple_model_raw_response": simple_model_raw_response,
82
- # "cache_used": cache_used,
83
- # "cache_key": cache_key,
84
- # "generated_tokens": generated_tokens,
85
- # }
86
-
87
- # @property
88
- # def data(self):
89
- # return self._data
90
-
91
- # @data.setter
92
- # def data(self, value):
93
- # self._data = value
94
-
95
- # def __getitem__(self, key):
96
- # return self.data.get(key, None)
97
-
98
- # def __setitem__(self, key, value):
99
- # self.data[key] = value
100
-
101
- # def __delitem__(self, key):
102
- # del self.data[key]
103
-
104
- # def __iter__(self):
105
- # return iter(self.data)
106
-
107
- # def __len__(self):
108
- # return len(self.data)
109
-
110
- # def keys(self):
111
- # return self.data.keys()
112
-
113
- # def values(self):
114
- # return self.data.values()
115
-
116
- # def items(self):
117
- # return self.data.items()
118
-
119
- # def is_this_same_model(self):
120
- # return True
50
+ @dataclass
51
+ class ImageInfo:
52
+ file_path: str
53
+ file_name: str
54
+ image_format: str
55
+ file_size: int
56
+ encoded_image: str
57
+
58
+ def __repr__(self):
59
+ reprlib_instance = reprlib.Repr()
60
+ reprlib_instance.maxstring = 30 # Limit the string length for the encoded image
61
+
62
+ # Get all fields except encoded_image
63
+ field_reprs = [
64
+ f"{f.name}={getattr(self, f.name)!r}"
65
+ for f in fields(self)
66
+ if f.name != "encoded_image"
67
+ ]
68
+
69
+ # Add the reprlib-restricted encoded_image field
70
+ field_reprs.append(f"encoded_image={reprlib_instance.repr(self.encoded_image)}")
71
+
72
+ # Join everything to create the repr
73
+ return f"{self.__class__.__name__}({', '.join(field_reprs)})"
edsl/enums.py CHANGED
@@ -63,6 +63,7 @@ class InferenceServiceType(EnumWithChecks):
63
63
  AZURE = "azure"
64
64
  OLLAMA = "ollama"
65
65
  MISTRAL = "mistral"
66
+ TOGETHER = "together"
66
67
 
67
68
 
68
69
  service_to_api_keyname = {
@@ -76,6 +77,7 @@ service_to_api_keyname = {
76
77
  InferenceServiceType.GROQ.value: "GROQ_API_KEY",
77
78
  InferenceServiceType.BEDROCK.value: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"],
78
79
  InferenceServiceType.MISTRAL.value: "MISTRAL_API_KEY",
80
+ InferenceServiceType.TOGETHER.value: "TOGETHER_API_KEY",
79
81
  }
80
82
 
81
83
 
@@ -1,5 +1,5 @@
1
1
  import os
2
- from typing import Any
2
+ from typing import Any, Optional, List
3
3
  import re
4
4
  from anthropic import AsyncAnthropic
5
5
  from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
@@ -60,7 +60,10 @@ class AnthropicService(InferenceServiceABC):
60
60
  _rpm = cls.get_rpm(cls)
61
61
 
62
62
  async def async_execute_model_call(
63
- self, user_prompt: str, system_prompt: str = ""
63
+ self,
64
+ user_prompt: str,
65
+ system_prompt: str = "",
66
+ files_list: Optional[List["Files"]] = None,
64
67
  ) -> dict[str, Any]:
65
68
  """Calls the OpenAI API and returns the API response."""
66
69
 
@@ -1,5 +1,5 @@
1
1
  import os
2
- from typing import Any
2
+ from typing import Any, List, Optional
3
3
  import re
4
4
  import boto3
5
5
  from botocore.exceptions import ClientError
@@ -69,7 +69,10 @@ class AwsBedrockService(InferenceServiceABC):
69
69
  _tpm = cls.get_tpm(cls)
70
70
 
71
71
  async def async_execute_model_call(
72
- self, user_prompt: str, system_prompt: str = ""
72
+ self,
73
+ user_prompt: str,
74
+ system_prompt: str = "",
75
+ files_list: Optional[List["FileStore"]] = None,
73
76
  ) -> dict[str, Any]:
74
77
  """Calls the AWS Bedrock API and returns the API response."""
75
78
 
@@ -1,5 +1,5 @@
1
1
  import os
2
- from typing import Any
2
+ from typing import Any, Optional, List
3
3
  import re
4
4
  from openai import AsyncAzureOpenAI
5
5
  from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
@@ -122,7 +122,10 @@ class AzureAIService(InferenceServiceABC):
122
122
  _tpm = cls.get_tpm(cls)
123
123
 
124
124
  async def async_execute_model_call(
125
- self, user_prompt: str, system_prompt: str = ""
125
+ self,
126
+ user_prompt: str,
127
+ system_prompt: str = "",
128
+ files_list: Optional[List["FileStore"]] = None,
126
129
  ) -> dict[str, Any]:
127
130
  """Calls the Azure OpenAI API and returns the API response."""
128
131
 
@@ -1,25 +1,54 @@
1
1
  import os
2
- import aiohttp
3
- import json
4
- from typing import Any
2
+ from typing import Any, Dict, List, Optional
3
+ import google
4
+ import google.generativeai as genai
5
+ from google.generativeai.types import GenerationConfig
6
+ from google.api_core.exceptions import InvalidArgument
7
+
5
8
  from edsl.exceptions import MissingAPIKeyError
6
9
  from edsl.language_models.LanguageModel import LanguageModel
7
-
8
10
  from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
9
11
 
12
+ safety_settings = [
13
+ {
14
+ "category": "HARM_CATEGORY_HARASSMENT",
15
+ "threshold": "BLOCK_NONE",
16
+ },
17
+ {
18
+ "category": "HARM_CATEGORY_HATE_SPEECH",
19
+ "threshold": "BLOCK_NONE",
20
+ },
21
+ {
22
+ "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
23
+ "threshold": "BLOCK_NONE",
24
+ },
25
+ {
26
+ "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
27
+ "threshold": "BLOCK_NONE",
28
+ },
29
+ ]
30
+
10
31
 
11
32
  class GoogleService(InferenceServiceABC):
12
33
  _inference_service_ = "google"
13
34
  key_sequence = ["candidates", 0, "content", "parts", 0, "text"]
14
- usage_sequence = ["usageMetadata"]
15
- input_token_name = "promptTokenCount"
16
- output_token_name = "candidatesTokenCount"
35
+ usage_sequence = ["usage_metadata"]
36
+ input_token_name = "prompt_token_count"
37
+ output_token_name = "candidates_token_count"
17
38
 
18
39
  model_exclude_list = []
19
40
 
41
+ # @classmethod
42
+ # def available(cls) -> List[str]:
43
+ # return ["gemini-pro", "gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.0-pro"]
44
+
20
45
  @classmethod
21
- def available(cls):
22
- return ["gemini-pro", "gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.0-pro"]
46
+ def available(cls) -> List[str]:
47
+ model_list = []
48
+ for m in genai.list_models():
49
+ if "generateContent" in m.supported_generation_methods:
50
+ model_list.append(m.name.split("/")[-1])
51
+ return model_list
23
52
 
24
53
  @classmethod
25
54
  def create_model(
@@ -47,33 +76,79 @@ class GoogleService(InferenceServiceABC):
47
76
  "stopSequences": [],
48
77
  }
49
78
 
79
+ api_token = None
80
+ model = None
81
+
82
+ @classmethod
83
+ def initialize(cls):
84
+ if cls.api_token is None:
85
+ cls.api_token = os.getenv("GOOGLE_API_KEY")
86
+ if not cls.api_token:
87
+ raise MissingAPIKeyError(
88
+ "GOOGLE_API_KEY environment variable is not set"
89
+ )
90
+ genai.configure(api_key=cls.api_token)
91
+ cls.generative_model = genai.GenerativeModel(
92
+ cls._model_, safety_settings=safety_settings
93
+ )
94
+
95
+ def __init__(self, *args, **kwargs):
96
+ super().__init__(*args, **kwargs)
97
+ self.initialize()
98
+
99
+ def get_generation_config(self) -> GenerationConfig:
100
+ return GenerationConfig(
101
+ temperature=self.temperature,
102
+ top_p=self.topP,
103
+ top_k=self.topK,
104
+ max_output_tokens=self.maxOutputTokens,
105
+ stop_sequences=self.stopSequences,
106
+ )
107
+
50
108
  async def async_execute_model_call(
51
- self, user_prompt: str, system_prompt: str = ""
52
- ) -> dict[str, Any]:
53
- # self.api_token = os.getenv("GOOGLE_API_KEY")
54
- combined_prompt = user_prompt + system_prompt
55
- url = f"https://generativelanguage.googleapis.com/v1beta/models/{model_name}:generateContent?key={self.api_token}"
56
- headers = {"Content-Type": "application/json"}
57
- data = {
58
- "contents": [{"parts": [{"text": combined_prompt}]}],
59
- "generationConfig": {
60
- "temperature": self.temperature,
61
- "topK": self.topK,
62
- "topP": self.topP,
63
- "maxOutputTokens": self.maxOutputTokens,
64
- "stopSequences": self.stopSequences,
65
- },
66
- }
67
- print(combined_prompt)
68
- async with aiohttp.ClientSession() as session:
69
- async with session.post(
70
- url, headers=headers, data=json.dumps(data)
71
- ) as response:
72
- raw_response_text = await response.text()
73
- return json.loads(raw_response_text)
109
+ self,
110
+ user_prompt: str,
111
+ system_prompt: str = "",
112
+ files_list: Optional["Files"] = None,
113
+ ) -> Dict[str, Any]:
114
+ generation_config = self.get_generation_config()
74
115
 
75
- LLM.__name__ = model_name
116
+ if files_list is None:
117
+ files_list = []
118
+
119
+ if (
120
+ system_prompt is not None
121
+ and system_prompt != ""
122
+ and self._model_ != "gemini-pro"
123
+ ):
124
+ try:
125
+ self.generative_model = genai.GenerativeModel(
126
+ self._model_,
127
+ safety_settings=safety_settings,
128
+ system_instruction=system_prompt,
129
+ )
130
+ except InvalidArgument as e:
131
+ print(
132
+ f"This model, {self._model_}, does not support system_instruction"
133
+ )
134
+ print("Will add system_prompt to user_prompt")
135
+ user_prompt = f"{system_prompt}\n{user_prompt}"
76
136
 
137
+ combined_prompt = [user_prompt]
138
+ for file in files_list:
139
+ if "google" not in file.external_locations:
140
+ _ = file.upload_google()
141
+ gen_ai_file = google.generativeai.types.file_types.File(
142
+ file.external_locations["google"]
143
+ )
144
+ combined_prompt.append(gen_ai_file)
145
+
146
+ response = await self.generative_model.generate_content_async(
147
+ combined_prompt, generation_config=generation_config
148
+ )
149
+ return response.to_dict()
150
+
151
+ LLM.__name__ = model_name
77
152
  return LLM
78
153
 
79
154
 
@@ -1,14 +1,27 @@
1
1
  from abc import abstractmethod, ABC
2
- from typing import Any
2
+ import os
3
3
  import re
4
4
  from edsl.config import CONFIG
5
5
 
6
6
 
7
7
  class InferenceServiceABC(ABC):
8
- """Abstract class for inference services."""
8
+ """
9
+ Abstract class for inference services.
10
+ Anthropic: https://docs.anthropic.com/en/api/rate-limits
11
+ """
12
+
13
+ default_levels = {
14
+ "google": {"tpm": 2_000_000, "rpm": 15},
15
+ "openai": {"tpm": 2_000_000, "rpm": 10_000},
16
+ "anthropic": {"tpm": 2_000_000, "rpm": 500},
17
+ }
9
18
 
10
- # check if child class has cls attribute "key_sequence"
11
19
  def __init_subclass__(cls):
20
+ """
21
+ Check that the subclass has the required attributes.
22
+ - `key_sequence` attribute determines...
23
+ - `model_exclude_list` attribute determines...
24
+ """
12
25
  if not hasattr(cls, "key_sequence"):
13
26
  raise NotImplementedError(
14
27
  f"Class {cls.__name__} must have a 'key_sequence' attribute."
@@ -18,29 +31,47 @@ class InferenceServiceABC(ABC):
18
31
  f"Class {cls.__name__} must have a 'model_exclude_list' attribute."
19
32
  )
20
33
 
21
- def get_tpm(cls):
22
- key = f"EDSL_SERVICE_TPM_{cls._inference_service_.upper()}"
23
- if key not in CONFIG:
24
- key = "EDSL_SERVICE_TPM_BASELINE"
25
- return int(CONFIG.get(key))
34
+ @classmethod
35
+ def _get_limt(cls, limit_type: str) -> int:
36
+ key = f"EDSL_SERVICE_{limit_type.upper()}_{cls._inference_service_.upper()}"
37
+ if key in os.environ:
38
+ return int(os.getenv(key))
39
+
40
+ if cls._inference_service_ in cls.default_levels:
41
+ return int(cls.default_levels[cls._inference_service_][limit_type])
42
+
43
+ return int(CONFIG.get(f"EDSL_SERVICE_{limit_type.upper()}_BASELINE"))
44
+
45
+ def get_tpm(cls) -> int:
46
+ """
47
+ Returns the TPM for the service. If the service is not defined in the environment variables, it will return the baseline TPM.
48
+ """
49
+ return cls._get_limt(limit_type="tpm")
26
50
 
27
51
  def get_rpm(cls):
28
- key = f"EDSL_SERVICE_RPM_{cls._inference_service_.upper()}"
29
- if key not in CONFIG:
30
- key = "EDSL_SERVICE_RPM_BASELINE"
31
- return int(CONFIG.get(key))
52
+ """
53
+ Returns the RPM for the service. If the service is not defined in the environment variables, it will return the baseline RPM.
54
+ """
55
+ return cls._get_limt(limit_type="rpm")
32
56
 
33
57
  @abstractmethod
34
58
  def available() -> list[str]:
59
+ """
60
+ Returns a list of available models for the service.
61
+ """
35
62
  pass
36
63
 
37
64
  @abstractmethod
38
65
  def create_model():
66
+ """
67
+ Returns a LanguageModel object.
68
+ """
39
69
  pass
40
70
 
41
71
  @staticmethod
42
72
  def to_class_name(s):
43
- """Convert a string to a valid class name.
73
+ """
74
+ Converts a string to a valid class name.
44
75
 
45
76
  >>> InferenceServiceABC.to_class_name("hello world")
46
77
  'HelloWorld'
@@ -1,5 +1,5 @@
1
1
  import os
2
- from typing import Any, List
2
+ from typing import Any, List, Optional
3
3
  from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
4
4
  from edsl.language_models.LanguageModel import LanguageModel
5
5
  import asyncio
@@ -95,7 +95,10 @@ class MistralAIService(InferenceServiceABC):
95
95
  return cls.async_client()
96
96
 
97
97
  async def async_execute_model_call(
98
- self, user_prompt: str, system_prompt: str = ""
98
+ self,
99
+ user_prompt: str,
100
+ system_prompt: str = "",
101
+ files_list: Optional[List["FileStore"]] = None,
99
102
  ) -> dict[str, Any]:
100
103
  """Calls the Mistral API and returns the API response."""
101
104
  s = self.async_client()
@@ -168,13 +168,14 @@ class OpenAIService(InferenceServiceABC):
168
168
  self,
169
169
  user_prompt: str,
170
170
  system_prompt: str = "",
171
- encoded_image=None,
171
+ files_list: Optional[List["Files"]] = None,
172
172
  invigilator: Optional[
173
173
  "InvigilatorAI"
174
174
  ] = None, # TBD - can eventually be used for function-calling
175
175
  ) -> dict[str, Any]:
176
176
  """Calls the OpenAI API and returns the API response."""
177
- if encoded_image:
177
+ if files_list:
178
+ encoded_image = files_list[0].base64_string
178
179
  content = [{"type": "text", "text": user_prompt}]
179
180
  content.append(
180
181
  {
@@ -187,12 +188,15 @@ class OpenAIService(InferenceServiceABC):
187
188
  else:
188
189
  content = user_prompt
189
190
  client = self.async_client()
191
+ messages = [
192
+ {"role": "system", "content": system_prompt},
193
+ {"role": "user", "content": content},
194
+ ]
195
+ if system_prompt == "" and self.omit_system_prompt_if_empty:
196
+ messages = messages[1:]
190
197
  params = {
191
198
  "model": self.model,
192
- "messages": [
193
- {"role": "system", "content": system_prompt},
194
- {"role": "user", "content": content},
195
- ],
199
+ "messages": messages,
196
200
  "temperature": self.temperature,
197
201
  "max_tokens": self.max_tokens,
198
202
  "top_p": self.top_p,