edsl 0.1.38.dev4__py3-none-any.whl → 0.1.39__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (212) hide show
  1. edsl/Base.py +197 -116
  2. edsl/__init__.py +15 -7
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +351 -147
  5. edsl/agents/AgentList.py +211 -73
  6. edsl/agents/Invigilator.py +101 -50
  7. edsl/agents/InvigilatorBase.py +62 -70
  8. edsl/agents/PromptConstructor.py +143 -225
  9. edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
  10. edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
  11. edsl/agents/__init__.py +0 -1
  12. edsl/agents/prompt_helpers.py +3 -3
  13. edsl/agents/question_option_processor.py +172 -0
  14. edsl/auto/AutoStudy.py +18 -5
  15. edsl/auto/StageBase.py +53 -40
  16. edsl/auto/StageQuestions.py +2 -1
  17. edsl/auto/utilities.py +0 -6
  18. edsl/config.py +22 -2
  19. edsl/conversation/car_buying.py +2 -1
  20. edsl/coop/CoopFunctionsMixin.py +15 -0
  21. edsl/coop/ExpectedParrotKeyHandler.py +125 -0
  22. edsl/coop/PriceFetcher.py +1 -1
  23. edsl/coop/coop.py +125 -47
  24. edsl/coop/utils.py +14 -14
  25. edsl/data/Cache.py +45 -27
  26. edsl/data/CacheEntry.py +12 -15
  27. edsl/data/CacheHandler.py +31 -12
  28. edsl/data/RemoteCacheSync.py +154 -46
  29. edsl/data/__init__.py +4 -3
  30. edsl/data_transfer_models.py +2 -1
  31. edsl/enums.py +27 -0
  32. edsl/exceptions/__init__.py +50 -50
  33. edsl/exceptions/agents.py +12 -0
  34. edsl/exceptions/inference_services.py +5 -0
  35. edsl/exceptions/questions.py +24 -6
  36. edsl/exceptions/scenarios.py +7 -0
  37. edsl/inference_services/AnthropicService.py +38 -19
  38. edsl/inference_services/AvailableModelCacheHandler.py +184 -0
  39. edsl/inference_services/AvailableModelFetcher.py +215 -0
  40. edsl/inference_services/AwsBedrock.py +0 -2
  41. edsl/inference_services/AzureAI.py +0 -2
  42. edsl/inference_services/GoogleService.py +7 -12
  43. edsl/inference_services/InferenceServiceABC.py +18 -85
  44. edsl/inference_services/InferenceServicesCollection.py +120 -79
  45. edsl/inference_services/MistralAIService.py +0 -3
  46. edsl/inference_services/OpenAIService.py +47 -35
  47. edsl/inference_services/PerplexityService.py +0 -3
  48. edsl/inference_services/ServiceAvailability.py +135 -0
  49. edsl/inference_services/TestService.py +11 -10
  50. edsl/inference_services/TogetherAIService.py +5 -3
  51. edsl/inference_services/data_structures.py +134 -0
  52. edsl/jobs/AnswerQuestionFunctionConstructor.py +223 -0
  53. edsl/jobs/Answers.py +1 -14
  54. edsl/jobs/FetchInvigilator.py +47 -0
  55. edsl/jobs/InterviewTaskManager.py +98 -0
  56. edsl/jobs/InterviewsConstructor.py +50 -0
  57. edsl/jobs/Jobs.py +356 -431
  58. edsl/jobs/JobsChecks.py +35 -10
  59. edsl/jobs/JobsComponentConstructor.py +189 -0
  60. edsl/jobs/JobsPrompts.py +6 -4
  61. edsl/jobs/JobsRemoteInferenceHandler.py +205 -133
  62. edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
  63. edsl/jobs/RequestTokenEstimator.py +30 -0
  64. edsl/jobs/async_interview_runner.py +138 -0
  65. edsl/jobs/buckets/BucketCollection.py +44 -3
  66. edsl/jobs/buckets/TokenBucket.py +53 -21
  67. edsl/jobs/buckets/TokenBucketAPI.py +211 -0
  68. edsl/jobs/buckets/TokenBucketClient.py +191 -0
  69. edsl/jobs/check_survey_scenario_compatibility.py +85 -0
  70. edsl/jobs/data_structures.py +120 -0
  71. edsl/jobs/decorators.py +35 -0
  72. edsl/jobs/interviews/Interview.py +143 -408
  73. edsl/jobs/jobs_status_enums.py +9 -0
  74. edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
  75. edsl/jobs/results_exceptions_handler.py +98 -0
  76. edsl/jobs/runners/JobsRunnerAsyncio.py +88 -403
  77. edsl/jobs/runners/JobsRunnerStatus.py +133 -165
  78. edsl/jobs/tasks/QuestionTaskCreator.py +21 -19
  79. edsl/jobs/tasks/TaskHistory.py +38 -18
  80. edsl/jobs/tasks/task_status_enum.py +0 -2
  81. edsl/language_models/ComputeCost.py +63 -0
  82. edsl/language_models/LanguageModel.py +194 -236
  83. edsl/language_models/ModelList.py +28 -19
  84. edsl/language_models/PriceManager.py +127 -0
  85. edsl/language_models/RawResponseHandler.py +106 -0
  86. edsl/language_models/ServiceDataSources.py +0 -0
  87. edsl/language_models/__init__.py +1 -2
  88. edsl/language_models/key_management/KeyLookup.py +63 -0
  89. edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
  90. edsl/language_models/key_management/KeyLookupCollection.py +38 -0
  91. edsl/language_models/key_management/__init__.py +0 -0
  92. edsl/language_models/key_management/models.py +131 -0
  93. edsl/language_models/model.py +256 -0
  94. edsl/language_models/repair.py +2 -2
  95. edsl/language_models/utilities.py +5 -4
  96. edsl/notebooks/Notebook.py +19 -14
  97. edsl/notebooks/NotebookToLaTeX.py +142 -0
  98. edsl/prompts/Prompt.py +29 -39
  99. edsl/questions/ExceptionExplainer.py +77 -0
  100. edsl/questions/HTMLQuestion.py +103 -0
  101. edsl/questions/QuestionBase.py +68 -214
  102. edsl/questions/QuestionBasePromptsMixin.py +7 -3
  103. edsl/questions/QuestionBudget.py +1 -1
  104. edsl/questions/QuestionCheckBox.py +3 -3
  105. edsl/questions/QuestionExtract.py +5 -7
  106. edsl/questions/QuestionFreeText.py +2 -3
  107. edsl/questions/QuestionList.py +10 -18
  108. edsl/questions/QuestionMatrix.py +265 -0
  109. edsl/questions/QuestionMultipleChoice.py +67 -23
  110. edsl/questions/QuestionNumerical.py +2 -4
  111. edsl/questions/QuestionRank.py +7 -17
  112. edsl/questions/SimpleAskMixin.py +4 -3
  113. edsl/questions/__init__.py +2 -1
  114. edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +47 -2
  115. edsl/questions/data_structures.py +20 -0
  116. edsl/questions/derived/QuestionLinearScale.py +6 -3
  117. edsl/questions/derived/QuestionTopK.py +1 -1
  118. edsl/questions/descriptors.py +17 -3
  119. edsl/questions/loop_processor.py +149 -0
  120. edsl/questions/{QuestionBaseGenMixin.py → question_base_gen_mixin.py} +57 -50
  121. edsl/questions/question_registry.py +1 -1
  122. edsl/questions/{ResponseValidatorABC.py → response_validator_abc.py} +40 -26
  123. edsl/questions/response_validator_factory.py +34 -0
  124. edsl/questions/templates/matrix/__init__.py +1 -0
  125. edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
  126. edsl/questions/templates/matrix/question_presentation.jinja +20 -0
  127. edsl/results/CSSParameterizer.py +1 -1
  128. edsl/results/Dataset.py +170 -7
  129. edsl/results/DatasetExportMixin.py +168 -305
  130. edsl/results/DatasetTree.py +28 -8
  131. edsl/results/MarkdownToDocx.py +122 -0
  132. edsl/results/MarkdownToPDF.py +111 -0
  133. edsl/results/Result.py +298 -206
  134. edsl/results/Results.py +149 -131
  135. edsl/results/ResultsExportMixin.py +2 -0
  136. edsl/results/TableDisplay.py +98 -171
  137. edsl/results/TextEditor.py +50 -0
  138. edsl/results/__init__.py +1 -1
  139. edsl/results/file_exports.py +252 -0
  140. edsl/results/{Selector.py → results_selector.py} +23 -13
  141. edsl/results/smart_objects.py +96 -0
  142. edsl/results/table_data_class.py +12 -0
  143. edsl/results/table_renderers.py +118 -0
  144. edsl/scenarios/ConstructDownloadLink.py +109 -0
  145. edsl/scenarios/DocumentChunker.py +102 -0
  146. edsl/scenarios/DocxScenario.py +16 -0
  147. edsl/scenarios/FileStore.py +150 -239
  148. edsl/scenarios/PdfExtractor.py +40 -0
  149. edsl/scenarios/Scenario.py +90 -193
  150. edsl/scenarios/ScenarioHtmlMixin.py +4 -3
  151. edsl/scenarios/ScenarioList.py +415 -244
  152. edsl/scenarios/ScenarioListExportMixin.py +0 -7
  153. edsl/scenarios/ScenarioListPdfMixin.py +15 -37
  154. edsl/scenarios/__init__.py +1 -2
  155. edsl/scenarios/directory_scanner.py +96 -0
  156. edsl/scenarios/file_methods.py +85 -0
  157. edsl/scenarios/handlers/__init__.py +13 -0
  158. edsl/scenarios/handlers/csv.py +49 -0
  159. edsl/scenarios/handlers/docx.py +76 -0
  160. edsl/scenarios/handlers/html.py +37 -0
  161. edsl/scenarios/handlers/json.py +111 -0
  162. edsl/scenarios/handlers/latex.py +5 -0
  163. edsl/scenarios/handlers/md.py +51 -0
  164. edsl/scenarios/handlers/pdf.py +68 -0
  165. edsl/scenarios/handlers/png.py +39 -0
  166. edsl/scenarios/handlers/pptx.py +105 -0
  167. edsl/scenarios/handlers/py.py +294 -0
  168. edsl/scenarios/handlers/sql.py +313 -0
  169. edsl/scenarios/handlers/sqlite.py +149 -0
  170. edsl/scenarios/handlers/txt.py +33 -0
  171. edsl/scenarios/{ScenarioJoin.py → scenario_join.py} +10 -6
  172. edsl/scenarios/scenario_selector.py +156 -0
  173. edsl/study/ObjectEntry.py +1 -1
  174. edsl/study/SnapShot.py +1 -1
  175. edsl/study/Study.py +5 -12
  176. edsl/surveys/ConstructDAG.py +92 -0
  177. edsl/surveys/EditSurvey.py +221 -0
  178. edsl/surveys/InstructionHandler.py +100 -0
  179. edsl/surveys/MemoryManagement.py +72 -0
  180. edsl/surveys/Rule.py +5 -4
  181. edsl/surveys/RuleCollection.py +25 -27
  182. edsl/surveys/RuleManager.py +172 -0
  183. edsl/surveys/Simulator.py +75 -0
  184. edsl/surveys/Survey.py +270 -791
  185. edsl/surveys/SurveyCSS.py +20 -8
  186. edsl/surveys/{SurveyFlowVisualizationMixin.py → SurveyFlowVisualization.py} +11 -9
  187. edsl/surveys/SurveyToApp.py +141 -0
  188. edsl/surveys/__init__.py +4 -2
  189. edsl/surveys/descriptors.py +6 -2
  190. edsl/surveys/instructions/ChangeInstruction.py +1 -2
  191. edsl/surveys/instructions/Instruction.py +4 -13
  192. edsl/surveys/instructions/InstructionCollection.py +11 -6
  193. edsl/templates/error_reporting/interview_details.html +1 -1
  194. edsl/templates/error_reporting/report.html +1 -1
  195. edsl/tools/plotting.py +1 -1
  196. edsl/utilities/PrettyList.py +56 -0
  197. edsl/utilities/is_notebook.py +18 -0
  198. edsl/utilities/is_valid_variable_name.py +11 -0
  199. edsl/utilities/remove_edsl_version.py +24 -0
  200. edsl/utilities/utilities.py +35 -23
  201. {edsl-0.1.38.dev4.dist-info → edsl-0.1.39.dist-info}/METADATA +12 -10
  202. edsl-0.1.39.dist-info/RECORD +358 -0
  203. {edsl-0.1.38.dev4.dist-info → edsl-0.1.39.dist-info}/WHEEL +1 -1
  204. edsl/language_models/KeyLookup.py +0 -30
  205. edsl/language_models/registry.py +0 -190
  206. edsl/language_models/unused/ReplicateBase.py +0 -83
  207. edsl/results/ResultsDBMixin.py +0 -238
  208. edsl-0.1.38.dev4.dist-info/RECORD +0 -277
  209. /edsl/questions/{RegisterQuestionsMeta.py → register_questions_meta.py} +0 -0
  210. /edsl/results/{ResultsFetchMixin.py → results_fetch_mixin.py} +0 -0
  211. /edsl/results/{ResultsToolsMixin.py → results_tools_mixin.py} +0 -0
  212. {edsl-0.1.38.dev4.dist-info → edsl-0.1.39.dist-info}/LICENSE +0 -0
@@ -1,5 +1,4 @@
1
1
  from abc import abstractmethod, ABC
2
- import os
3
2
  import re
4
3
  from datetime import datetime, timedelta
5
4
  from edsl.config import CONFIG
@@ -8,31 +7,32 @@ from edsl.config import CONFIG
8
7
  class InferenceServiceABC(ABC):
9
8
  """
10
9
  Abstract class for inference services.
11
- Anthropic: https://docs.anthropic.com/en/api/rate-limits
12
10
  """
13
11
 
14
12
  _coop_config_vars = None
15
13
 
16
- default_levels = {
17
- "google": {"tpm": 2_000_000, "rpm": 15},
18
- "openai": {"tpm": 2_000_000, "rpm": 10_000},
19
- "anthropic": {"tpm": 2_000_000, "rpm": 500},
20
- }
21
-
22
14
  def __init_subclass__(cls):
23
15
  """
24
16
  Check that the subclass has the required attributes.
25
17
  - `key_sequence` attribute determines...
26
18
  - `model_exclude_list` attribute determines...
27
19
  """
28
- if not hasattr(cls, "key_sequence"):
29
- raise NotImplementedError(
30
- f"Class {cls.__name__} must have a 'key_sequence' attribute."
31
- )
32
- if not hasattr(cls, "model_exclude_list"):
33
- raise NotImplementedError(
34
- f"Class {cls.__name__} must have a 'model_exclude_list' attribute."
35
- )
20
+ must_have_attributes = [
21
+ "key_sequence",
22
+ "model_exclude_list",
23
+ "usage_sequence",
24
+ "input_token_name",
25
+ "output_token_name",
26
+ ]
27
+ for attr in must_have_attributes:
28
+ if not hasattr(cls, attr):
29
+ raise NotImplementedError(
30
+ f"Class {cls.__name__} must have a '{attr}' attribute."
31
+ )
32
+
33
+ @property
34
+ def service_name(self):
35
+ return self._inference_service_
36
36
 
37
37
  @classmethod
38
38
  def _should_refresh_coop_config_vars(cls):
@@ -44,44 +44,6 @@ class InferenceServiceABC(ABC):
44
44
  return True
45
45
  return (datetime.now() - cls._last_config_fetch) > timedelta(hours=24)
46
46
 
47
- @classmethod
48
- def _get_limt(cls, limit_type: str) -> int:
49
- key = f"EDSL_SERVICE_{limit_type.upper()}_{cls._inference_service_.upper()}"
50
- if key in os.environ:
51
- return int(os.getenv(key))
52
-
53
- if cls._coop_config_vars is None or cls._should_refresh_coop_config_vars():
54
- try:
55
- from edsl import Coop
56
-
57
- c = Coop()
58
- cls._coop_config_vars = c.fetch_rate_limit_config_vars()
59
- cls._last_config_fetch = datetime.now()
60
- if key in cls._coop_config_vars:
61
- return cls._coop_config_vars[key]
62
- except Exception:
63
- cls._coop_config_vars = None
64
- else:
65
- if key in cls._coop_config_vars:
66
- return cls._coop_config_vars[key]
67
-
68
- if cls._inference_service_ in cls.default_levels:
69
- return int(cls.default_levels[cls._inference_service_][limit_type])
70
-
71
- return int(CONFIG.get(f"EDSL_SERVICE_{limit_type.upper()}_BASELINE"))
72
-
73
- def get_tpm(cls) -> int:
74
- """
75
- Returns the TPM for the service. If the service is not defined in the environment variables, it will return the baseline TPM.
76
- """
77
- return cls._get_limt(limit_type="tpm")
78
-
79
- def get_rpm(cls):
80
- """
81
- Returns the RPM for the service. If the service is not defined in the environment variables, it will return the baseline RPM.
82
- """
83
- return cls._get_limt(limit_type="rpm")
84
-
85
47
  @abstractmethod
86
48
  def available() -> list[str]:
87
49
  """
@@ -113,35 +75,6 @@ class InferenceServiceABC(ABC):
113
75
 
114
76
 
115
77
  if __name__ == "__main__":
116
- pass
117
- # deep_infra_service = DeepInfraService("deep_infra", "DEEP_INFRA_API_KEY")
118
- # deep_infra_service.available()
119
- # m = deep_infra_service.create_model("microsoft/WizardLM-2-7B")
120
- # response = m().hello()
121
- # print(response)
122
-
123
- # anthropic_service = AnthropicService("anthropic", "ANTHROPIC_API_KEY")
124
- # anthropic_service.available()
125
- # m = anthropic_service.create_model("claude-3-opus-20240229")
126
- # response = m().hello()
127
- # print(response)
128
- # factory = OpenAIService("openai", "OPENAI_API")
129
- # factory.available()
130
- # m = factory.create_model("gpt-3.5-turbo")
131
- # response = m().hello()
132
-
133
- # from edsl import QuestionFreeText
134
- # results = QuestionFreeText.example().by(m()).run()
135
-
136
- # collection = InferenceServicesCollection([
137
- # OpenAIService,
138
- # AnthropicService,
139
- # DeepInfraService
140
- # ])
78
+ import doctest
141
79
 
142
- # available = collection.available()
143
- # factory = collection.create_model_factory(*available[0])
144
- # m = factory()
145
- # from edsl import QuestionFreeText
146
- # results = QuestionFreeText.example().by(m).run()
147
- # print(results)
80
+ doctest.testmod()
@@ -1,97 +1,138 @@
1
+ from functools import lru_cache
2
+ from collections import defaultdict
3
+ from typing import Optional, Protocol, Dict, List, Tuple, TYPE_CHECKING, Literal
4
+
1
5
  from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
2
- import warnings
6
+ from edsl.inference_services.AvailableModelFetcher import AvailableModelFetcher
7
+ from edsl.exceptions.inference_services import InferenceServiceError
8
+
9
+ if TYPE_CHECKING:
10
+ from edsl.language_models.LanguageModel import LanguageModel
11
+ from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
12
+
13
+
14
+ class ModelCreator(Protocol):
15
+ def create_model(self, model_name: str) -> "LanguageModel":
16
+ ...
17
+
18
+
19
+ from edsl.enums import InferenceServiceLiteral
20
+
21
+
22
+ class ModelResolver:
23
+ def __init__(
24
+ self,
25
+ services: List[InferenceServiceLiteral],
26
+ models_to_services: Dict[InferenceServiceLiteral, InferenceServiceABC],
27
+ availability_fetcher: "AvailableModelFetcher",
28
+ ):
29
+ """
30
+ Class for determining which service to use for a given model.
31
+ """
32
+ self.services = services
33
+ self._models_to_services = models_to_services
34
+ self.availability_fetcher = availability_fetcher
35
+ self._service_names_to_classes = {
36
+ service._inference_service_: service for service in services
37
+ }
38
+
39
+ def resolve_model(
40
+ self, model_name: str, service_name: Optional[InferenceServiceLiteral] = None
41
+ ) -> InferenceServiceABC:
42
+ """Returns an InferenceServiceABC object for the given model name.
43
+
44
+ :param model_name: The name of the model to resolve. E.g., 'gpt-4o'
45
+ :param service_name: The name of the service to use. E.g., 'openai'
46
+ :return: An InferenceServiceABC object
47
+
48
+ """
49
+ if model_name == "test":
50
+ from edsl.inference_services.TestService import TestService
51
+
52
+ return TestService()
53
+
54
+ if service_name is not None:
55
+ service: InferenceServiceABC = self._service_names_to_classes.get(
56
+ service_name
57
+ )
58
+ if not service:
59
+ raise InferenceServiceError(f"Service {service_name} not found")
60
+ return service
61
+
62
+ if model_name in self._models_to_services: # maybe we've seen it before!
63
+ return self._models_to_services[model_name]
64
+
65
+ for service in self.services:
66
+ (
67
+ available_models,
68
+ service_name,
69
+ ) = self.availability_fetcher.get_available_models_by_service(service)
70
+ if model_name in available_models:
71
+ self._models_to_services[model_name] = service
72
+ return service
73
+
74
+ raise InferenceServiceError(
75
+ f"""Model {model_name} not found in any services.
76
+ If you know the service that has this model, use the service_name parameter directly.
77
+ E.g., Model("gpt-4o", service_name="openai")
78
+ """
79
+ )
3
80
 
4
81
 
5
82
  class InferenceServicesCollection:
6
- added_models = {}
83
+ added_models = defaultdict(list) # Moved back to class level
7
84
 
8
- def __init__(self, services: list[InferenceServiceABC] = None):
85
+ def __init__(self, services: Optional[List[InferenceServiceABC]] = None):
9
86
  self.services = services or []
87
+ self._models_to_services: Dict[str, InferenceServiceABC] = {}
88
+
89
+ self.availability_fetcher = AvailableModelFetcher(
90
+ self.services, self.added_models
91
+ )
92
+ self.resolver = ModelResolver(
93
+ self.services, self._models_to_services, self.availability_fetcher
94
+ )
10
95
 
11
96
  @classmethod
12
- def add_model(cls, service_name, model_name):
97
+ def add_model(cls, service_name: str, model_name: str) -> None:
13
98
  if service_name not in cls.added_models:
14
- cls.added_models[service_name] = []
15
- cls.added_models[service_name].append(model_name)
16
-
17
- @staticmethod
18
- def _get_service_available(service, warn: bool = False) -> list[str]:
19
- try:
20
- service_models = service.available()
21
- except Exception:
22
- if warn:
23
- warnings.warn(
24
- f"""Error getting models for {service._inference_service_}.
25
- Check that you have properly stored your Expected Parrot API key and activated remote inference, or stored your own API keys for the language models that you want to use.
26
- See https://docs.expectedparrot.com/en/latest/api_keys.html for instructions on storing API keys.
27
- Relying on Coop.""",
28
- UserWarning,
29
- )
30
-
31
- # Use the list of models on Coop as a fallback
32
- try:
33
- from edsl import Coop
34
-
35
- c = Coop()
36
- models_from_coop = c.fetch_models()
37
- service_models = models_from_coop.get(service._inference_service_, [])
38
-
39
- # cache results
40
- service._models_list_cache = service_models
41
-
42
- # Finally, use the available models cache from the Python file
43
- except Exception:
44
- if warn:
45
- warnings.warn(
46
- f"""Error getting models for {service._inference_service_}.
47
- Relying on EDSL cache.""",
48
- UserWarning,
49
- )
50
-
51
- from edsl.inference_services.models_available_cache import (
52
- models_available,
53
- )
54
-
55
- service_models = models_available.get(service._inference_service_, [])
56
-
57
- # cache results
58
- service._models_list_cache = service_models
59
-
60
- return service_models
61
-
62
- def available(self):
63
- total_models = []
64
- for service in self.services:
65
- service_models = self._get_service_available(service)
66
- for model in service_models:
67
- total_models.append([model, service._inference_service_, -1])
99
+ cls.added_models[service_name].append(model_name)
68
100
 
69
- for model in self.added_models.get(service._inference_service_, []):
70
- total_models.append([model, service._inference_service_, -1])
101
+ def service_names_to_classes(self) -> Dict[str, InferenceServiceABC]:
102
+ return {service._inference_service_: service for service in self.services}
71
103
 
72
- sorted_models = sorted(total_models)
73
- for i, model in enumerate(sorted_models):
74
- model[2] = i
75
- model = tuple(model)
76
- return sorted_models
104
+ def available(
105
+ self,
106
+ service: Optional[str] = None,
107
+ ) -> List[Tuple[str, str, int]]:
108
+ return self.availability_fetcher.available(service)
77
109
 
78
- def register(self, service):
110
+ def reset_cache(self) -> None:
111
+ self.availability_fetcher.reset_cache()
112
+
113
+ @property
114
+ def num_cache_entries(self) -> int:
115
+ return self.availability_fetcher.num_cache_entries
116
+
117
+ def register(self, service: InferenceServiceABC) -> None:
79
118
  self.services.append(service)
80
119
 
81
- def create_model_factory(self, model_name: str, service_name=None, index=None):
82
- from edsl.inference_services.TestService import TestService
120
+ def create_model_factory(
121
+ self, model_name: str, service_name: Optional[InferenceServiceLiteral] = None
122
+ ) -> "LanguageModel":
83
123
 
84
- if model_name == "test":
85
- return TestService.create_model(model_name)
124
+ if service_name is None: # we try to find the right service
125
+ service = self.resolver.resolve_model(model_name, service_name)
126
+ else: # if they passed a service, we'll use that
127
+ service = self.service_names_to_classes().get(service_name)
86
128
 
87
- if service_name:
88
- for service in self.services:
89
- if service_name == service._inference_service_:
90
- return service.create_model(model_name)
129
+ if not service: # but if we can't find it, we'll raise an error
130
+ raise InferenceServiceError(f"Service {service_name} not found")
91
131
 
92
- for service in self.services:
93
- if model_name in self._get_service_available(service):
94
- if service_name is None or service_name == service._inference_service_:
95
- return service.create_model(model_name)
132
+ return service.create_model(model_name)
133
+
134
+
135
+ if __name__ == "__main__":
136
+ import doctest
96
137
 
97
- raise Exception(f"Model {model_name} not found in any of the services")
138
+ doctest.testmod()
@@ -85,9 +85,6 @@ class MistralAIService(InferenceServiceABC):
85
85
  "top_p": 0.9,
86
86
  }
87
87
 
88
- _tpm = cls.get_tpm(cls)
89
- _rpm = cls.get_rpm(cls)
90
-
91
88
  def sync_client(self):
92
89
  return cls.sync_client()
93
90
 
@@ -1,16 +1,19 @@
1
1
  from __future__ import annotations
2
- from typing import Any, List, Optional
2
+ from typing import Any, List, Optional, Dict, NewType
3
3
  import os
4
4
 
5
+
5
6
  import openai
6
7
 
7
8
  from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
8
- from edsl.language_models import LanguageModel
9
+ from edsl.language_models.LanguageModel import LanguageModel
9
10
  from edsl.inference_services.rate_limits_cache import rate_limits
10
11
  from edsl.utilities.utilities import fix_partial_correct_response
11
12
 
12
13
  from edsl.config import CONFIG
13
14
 
15
+ APIToken = NewType("APIToken", str)
16
+
14
17
 
15
18
  class OpenAIService(InferenceServiceABC):
16
19
  """OpenAI service class."""
@@ -22,35 +25,43 @@ class OpenAIService(InferenceServiceABC):
22
25
  _sync_client_ = openai.OpenAI
23
26
  _async_client_ = openai.AsyncOpenAI
24
27
 
25
- _sync_client_instance = None
26
- _async_client_instance = None
28
+ _sync_client_instances: Dict[APIToken, openai.OpenAI] = {}
29
+ _async_client_instances: Dict[APIToken, openai.AsyncOpenAI] = {}
27
30
 
28
31
  key_sequence = ["choices", 0, "message", "content"]
29
32
  usage_sequence = ["usage"]
30
33
  input_token_name = "prompt_tokens"
31
34
  output_token_name = "completion_tokens"
32
35
 
36
+ available_models_url = "https://platform.openai.com/docs/models/gp"
37
+
33
38
  def __init_subclass__(cls, **kwargs):
34
39
  super().__init_subclass__(**kwargs)
35
- # so subclasses have to create their own instances of the clients
36
- cls._sync_client_instance = None
37
- cls._async_client_instance = None
40
+ # so subclasses that use the OpenAI api key have to create their own instances of the clients
41
+ cls._sync_client_instances = {}
42
+ cls._async_client_instances = {}
38
43
 
39
44
  @classmethod
40
- def sync_client(cls):
41
- if cls._sync_client_instance is None:
42
- cls._sync_client_instance = cls._sync_client_(
43
- api_key=os.getenv(cls._env_key_name_), base_url=cls._base_url_
45
+ def sync_client(cls, api_key):
46
+ if api_key not in cls._sync_client_instances:
47
+ client = cls._sync_client_(
48
+ api_key=api_key,
49
+ base_url=cls._base_url_,
44
50
  )
45
- return cls._sync_client_instance
51
+ cls._sync_client_instances[api_key] = client
52
+ client = cls._sync_client_instances[api_key]
53
+ return client
46
54
 
47
55
  @classmethod
48
- def async_client(cls):
49
- if cls._async_client_instance is None:
50
- cls._async_client_instance = cls._async_client_(
51
- api_key=os.getenv(cls._env_key_name_), base_url=cls._base_url_
56
+ def async_client(cls, api_key):
57
+ if api_key not in cls._async_client_instances:
58
+ client = cls._async_client_(
59
+ api_key=api_key,
60
+ base_url=cls._base_url_,
52
61
  )
53
- return cls._async_client_instance
62
+ cls._async_client_instances[api_key] = client
63
+ client = cls._async_client_instances[api_key]
64
+ return client
54
65
 
55
66
  model_exclude_list = [
56
67
  "whisper-1",
@@ -72,20 +83,24 @@ class OpenAIService(InferenceServiceABC):
72
83
  _models_list_cache: List[str] = []
73
84
 
74
85
  @classmethod
75
- def get_model_list(cls):
76
- raw_list = cls.sync_client().models.list()
86
+ def get_model_list(cls, api_key=None):
87
+ if api_key is None:
88
+ api_key = os.getenv(cls._env_key_name_)
89
+ raw_list = cls.sync_client(api_key).models.list()
77
90
  if hasattr(raw_list, "data"):
78
91
  return raw_list.data
79
92
  else:
80
93
  return raw_list
81
94
 
82
95
  @classmethod
83
- def available(cls) -> List[str]:
96
+ def available(cls, api_token=None) -> List[str]:
97
+ if api_token is None:
98
+ api_token = os.getenv(cls._env_key_name_)
84
99
  if not cls._models_list_cache:
85
100
  try:
86
101
  cls._models_list_cache = [
87
102
  m.id
88
- for m in cls.get_model_list()
103
+ for m in cls.get_model_list(api_key=api_token)
89
104
  if m.id not in cls.model_exclude_list
90
105
  ]
91
106
  except Exception as e:
@@ -107,9 +122,6 @@ class OpenAIService(InferenceServiceABC):
107
122
  input_token_name = cls.input_token_name
108
123
  output_token_name = cls.output_token_name
109
124
 
110
- _rpm = cls.get_rpm(cls)
111
- _tpm = cls.get_tpm(cls)
112
-
113
125
  _inference_service_ = cls._inference_service_
114
126
  _model_ = model_name
115
127
  _parameters_ = {
@@ -123,10 +135,10 @@ class OpenAIService(InferenceServiceABC):
123
135
  }
124
136
 
125
137
  def sync_client(self):
126
- return cls.sync_client()
138
+ return cls.sync_client(api_key=self.api_token)
127
139
 
128
140
  def async_client(self):
129
- return cls.async_client()
141
+ return cls.async_client(api_key=self.api_token)
130
142
 
131
143
  @classmethod
132
144
  def available(cls) -> list[str]:
@@ -175,16 +187,16 @@ class OpenAIService(InferenceServiceABC):
175
187
  ) -> dict[str, Any]:
176
188
  """Calls the OpenAI API and returns the API response."""
177
189
  if files_list:
178
- encoded_image = files_list[0].base64_string
179
190
  content = [{"type": "text", "text": user_prompt}]
180
- content.append(
181
- {
182
- "type": "image_url",
183
- "image_url": {
184
- "url": f"data:image/jpeg;base64,{encoded_image}"
185
- },
186
- }
187
- )
191
+ for file_entry in files_list:
192
+ content.append(
193
+ {
194
+ "type": "image_url",
195
+ "image_url": {
196
+ "url": f"data:{file_entry.mime_type};base64,{file_entry.base64_string}"
197
+ },
198
+ }
199
+ )
188
200
  else:
189
201
  content = user_prompt
190
202
  client = self.async_client()
@@ -51,9 +51,6 @@ class PerplexityService(OpenAIService):
51
51
  input_token_name = cls.input_token_name
52
52
  output_token_name = cls.output_token_name
53
53
 
54
- _rpm = cls.get_rpm(cls)
55
- _tpm = cls.get_tpm(cls)
56
-
57
54
  _inference_service_ = cls._inference_service_
58
55
  _model_ = model_name
59
56
 
@@ -0,0 +1,135 @@
1
+ from enum import Enum
2
+ from typing import List, Optional, TYPE_CHECKING
3
+ from functools import partial
4
+ import warnings
5
+
6
+ from edsl.inference_services.data_structures import AvailableModels, ModelNamesList
7
+
8
+ if TYPE_CHECKING:
9
+ from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
10
+
11
+
12
+ class ModelSource(Enum):
13
+ LOCAL = "local"
14
+ COOP = "coop"
15
+ CACHE = "cache"
16
+
17
+
18
+ class ServiceAvailability:
19
+ """This class is responsible for fetching the available models from different sources."""
20
+
21
+ _coop_model_list = None
22
+
23
+ def __init__(self, source_order: Optional[List[ModelSource]] = None):
24
+ """
25
+ Initialize with custom source order.
26
+ Default order is LOCAL -> COOP -> CACHE
27
+ """
28
+ self.source_order = source_order or [
29
+ ModelSource.LOCAL,
30
+ ModelSource.COOP,
31
+ ModelSource.CACHE,
32
+ ]
33
+
34
+ # Map sources to their fetch functions
35
+ self._source_fetchers = {
36
+ ModelSource.LOCAL: self._fetch_from_local_service,
37
+ ModelSource.COOP: self._fetch_from_coop,
38
+ ModelSource.CACHE: self._fetch_from_cache,
39
+ }
40
+
41
+ @classmethod
42
+ def models_from_coop(cls) -> AvailableModels:
43
+ if not cls._coop_model_list:
44
+ from edsl.coop.coop import Coop
45
+
46
+ c = Coop()
47
+ coop_model_list = c.fetch_models()
48
+ cls._coop_model_list = coop_model_list
49
+ return cls._coop_model_list
50
+
51
+ def get_service_available(
52
+ self, service: "InferenceServiceABC", warn: bool = False
53
+ ) -> ModelNamesList:
54
+ """
55
+ Try to fetch available models from sources in specified order.
56
+ Returns first successful result.
57
+ """
58
+ last_error = None
59
+
60
+ for source in self.source_order:
61
+ try:
62
+ fetch_func = partial(self._source_fetchers[source], service)
63
+ result = fetch_func()
64
+
65
+ # Cache successful result
66
+ service._models_list_cache = result
67
+ return result
68
+
69
+ except Exception as e:
70
+ last_error = e
71
+ if warn:
72
+ self._warn_source_failed(service, source)
73
+ continue
74
+
75
+ # If we get here, all sources failed
76
+ raise RuntimeError(
77
+ f"All sources failed to fetch models. Last error: {last_error}"
78
+ )
79
+
80
+ @staticmethod
81
+ def _fetch_from_local_service(service: "InferenceServiceABC") -> ModelNamesList:
82
+ """Attempt to fetch models directly from the service."""
83
+ return service.available()
84
+
85
+ @classmethod
86
+ def _fetch_from_coop(cls, service: "InferenceServiceABC") -> ModelNamesList:
87
+ """Fetch models from Coop."""
88
+ models_from_coop = cls.models_from_coop()
89
+ return models_from_coop.get(service._inference_service_, [])
90
+
91
+ @staticmethod
92
+ def _fetch_from_cache(service: "InferenceServiceABC") -> ModelNamesList:
93
+ """Fetch models from local cache."""
94
+ from edsl.inference_services.models_available_cache import models_available
95
+
96
+ return models_available.get(service._inference_service_, [])
97
+
98
+ def _warn_source_failed(self, service: "InferenceServiceABC", source: ModelSource):
99
+ """Display appropriate warning message based on failed source."""
100
+ messages = {
101
+ ModelSource.LOCAL: f"""Error getting models for {service._inference_service_}.
102
+ Check that you have properly stored your Expected Parrot API key and activated remote inference,
103
+ or stored your own API keys for the language models that you want to use.
104
+ See https://docs.expectedparrot.com/en/latest/api_keys.html for instructions on storing API keys.
105
+ Trying next source.""",
106
+ ModelSource.COOP: f"Error getting models from Coop for {service._inference_service_}. Trying next source.",
107
+ ModelSource.CACHE: f"Error getting models from cache for {service._inference_service_}.",
108
+ }
109
+ warnings.warn(messages[source], UserWarning)
110
+
111
+
112
+ if __name__ == "__main__":
113
+ # sa = ServiceAvailability()
114
+ # models_from_coop = sa.models_from_coop()
115
+ # print(models_from_coop)
116
+ from edsl.inference_services.OpenAIService import OpenAIService
117
+
118
+ openai_models = ServiceAvailability._fetch_from_local_service(OpenAIService())
119
+ print(openai_models)
120
+
121
+ # Example usage:
122
+ """
123
+ # Default order (LOCAL -> COOP -> CACHE)
124
+ availability = ServiceAvailability()
125
+
126
+ # Custom order (COOP -> LOCAL -> CACHE)
127
+ availability_coop_first = ServiceAvailability([
128
+ ModelSource.COOP,
129
+ ModelSource.LOCAL,
130
+ ModelSource.CACHE
131
+ ])
132
+
133
+ # Get available models using custom order
134
+ models = availability_coop_first.get_service_available(service, warn=True)
135
+ """