edsl 0.1.39.dev1__py3-none-any.whl → 0.1.39.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (194) hide show
  1. edsl/Base.py +169 -116
  2. edsl/__init__.py +14 -6
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +358 -146
  5. edsl/agents/AgentList.py +211 -73
  6. edsl/agents/Invigilator.py +88 -36
  7. edsl/agents/InvigilatorBase.py +59 -70
  8. edsl/agents/PromptConstructor.py +117 -219
  9. edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
  10. edsl/agents/QuestionOptionProcessor.py +172 -0
  11. edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
  12. edsl/agents/__init__.py +0 -1
  13. edsl/agents/prompt_helpers.py +3 -3
  14. edsl/config.py +22 -2
  15. edsl/conversation/car_buying.py +2 -1
  16. edsl/coop/CoopFunctionsMixin.py +15 -0
  17. edsl/coop/ExpectedParrotKeyHandler.py +125 -0
  18. edsl/coop/PriceFetcher.py +1 -1
  19. edsl/coop/coop.py +104 -42
  20. edsl/coop/utils.py +14 -14
  21. edsl/data/Cache.py +21 -14
  22. edsl/data/CacheEntry.py +12 -15
  23. edsl/data/CacheHandler.py +33 -12
  24. edsl/data/__init__.py +4 -3
  25. edsl/data_transfer_models.py +2 -1
  26. edsl/enums.py +20 -0
  27. edsl/exceptions/__init__.py +50 -50
  28. edsl/exceptions/agents.py +12 -0
  29. edsl/exceptions/inference_services.py +5 -0
  30. edsl/exceptions/questions.py +24 -6
  31. edsl/exceptions/scenarios.py +7 -0
  32. edsl/inference_services/AnthropicService.py +0 -3
  33. edsl/inference_services/AvailableModelCacheHandler.py +184 -0
  34. edsl/inference_services/AvailableModelFetcher.py +209 -0
  35. edsl/inference_services/AwsBedrock.py +0 -2
  36. edsl/inference_services/AzureAI.py +0 -2
  37. edsl/inference_services/GoogleService.py +2 -11
  38. edsl/inference_services/InferenceServiceABC.py +18 -85
  39. edsl/inference_services/InferenceServicesCollection.py +105 -80
  40. edsl/inference_services/MistralAIService.py +0 -3
  41. edsl/inference_services/OpenAIService.py +1 -4
  42. edsl/inference_services/PerplexityService.py +0 -3
  43. edsl/inference_services/ServiceAvailability.py +135 -0
  44. edsl/inference_services/TestService.py +11 -8
  45. edsl/inference_services/data_structures.py +62 -0
  46. edsl/jobs/AnswerQuestionFunctionConstructor.py +188 -0
  47. edsl/jobs/Answers.py +1 -14
  48. edsl/jobs/FetchInvigilator.py +40 -0
  49. edsl/jobs/InterviewTaskManager.py +98 -0
  50. edsl/jobs/InterviewsConstructor.py +48 -0
  51. edsl/jobs/Jobs.py +102 -243
  52. edsl/jobs/JobsChecks.py +35 -10
  53. edsl/jobs/JobsComponentConstructor.py +189 -0
  54. edsl/jobs/JobsPrompts.py +5 -3
  55. edsl/jobs/JobsRemoteInferenceHandler.py +128 -80
  56. edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
  57. edsl/jobs/RequestTokenEstimator.py +30 -0
  58. edsl/jobs/buckets/BucketCollection.py +44 -3
  59. edsl/jobs/buckets/TokenBucket.py +53 -21
  60. edsl/jobs/buckets/TokenBucketAPI.py +211 -0
  61. edsl/jobs/buckets/TokenBucketClient.py +191 -0
  62. edsl/jobs/decorators.py +35 -0
  63. edsl/jobs/interviews/Interview.py +77 -380
  64. edsl/jobs/jobs_status_enums.py +9 -0
  65. edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
  66. edsl/jobs/runners/JobsRunnerAsyncio.py +4 -49
  67. edsl/jobs/tasks/QuestionTaskCreator.py +21 -19
  68. edsl/jobs/tasks/TaskHistory.py +14 -15
  69. edsl/jobs/tasks/task_status_enum.py +0 -2
  70. edsl/language_models/ComputeCost.py +63 -0
  71. edsl/language_models/LanguageModel.py +137 -234
  72. edsl/language_models/ModelList.py +11 -13
  73. edsl/language_models/PriceManager.py +127 -0
  74. edsl/language_models/RawResponseHandler.py +106 -0
  75. edsl/language_models/ServiceDataSources.py +0 -0
  76. edsl/language_models/__init__.py +0 -1
  77. edsl/language_models/key_management/KeyLookup.py +63 -0
  78. edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
  79. edsl/language_models/key_management/KeyLookupCollection.py +38 -0
  80. edsl/language_models/key_management/__init__.py +0 -0
  81. edsl/language_models/key_management/models.py +131 -0
  82. edsl/language_models/registry.py +49 -59
  83. edsl/language_models/repair.py +2 -2
  84. edsl/language_models/utilities.py +5 -4
  85. edsl/notebooks/Notebook.py +19 -14
  86. edsl/notebooks/NotebookToLaTeX.py +142 -0
  87. edsl/prompts/Prompt.py +29 -39
  88. edsl/questions/AnswerValidatorMixin.py +47 -2
  89. edsl/questions/ExceptionExplainer.py +77 -0
  90. edsl/questions/HTMLQuestion.py +103 -0
  91. edsl/questions/LoopProcessor.py +149 -0
  92. edsl/questions/QuestionBase.py +37 -192
  93. edsl/questions/QuestionBaseGenMixin.py +52 -48
  94. edsl/questions/QuestionBasePromptsMixin.py +7 -3
  95. edsl/questions/QuestionCheckBox.py +1 -1
  96. edsl/questions/QuestionExtract.py +1 -1
  97. edsl/questions/QuestionFreeText.py +1 -2
  98. edsl/questions/QuestionList.py +3 -5
  99. edsl/questions/QuestionMatrix.py +265 -0
  100. edsl/questions/QuestionMultipleChoice.py +66 -22
  101. edsl/questions/QuestionNumerical.py +1 -3
  102. edsl/questions/QuestionRank.py +6 -16
  103. edsl/questions/ResponseValidatorABC.py +37 -11
  104. edsl/questions/ResponseValidatorFactory.py +28 -0
  105. edsl/questions/SimpleAskMixin.py +4 -3
  106. edsl/questions/__init__.py +1 -0
  107. edsl/questions/derived/QuestionLinearScale.py +6 -3
  108. edsl/questions/derived/QuestionTopK.py +1 -1
  109. edsl/questions/descriptors.py +17 -3
  110. edsl/questions/question_registry.py +1 -1
  111. edsl/questions/templates/matrix/__init__.py +1 -0
  112. edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
  113. edsl/questions/templates/matrix/question_presentation.jinja +20 -0
  114. edsl/results/CSSParameterizer.py +1 -1
  115. edsl/results/Dataset.py +170 -7
  116. edsl/results/DatasetExportMixin.py +224 -302
  117. edsl/results/DatasetTree.py +28 -8
  118. edsl/results/MarkdownToDocx.py +122 -0
  119. edsl/results/MarkdownToPDF.py +111 -0
  120. edsl/results/Result.py +192 -206
  121. edsl/results/Results.py +120 -113
  122. edsl/results/ResultsExportMixin.py +2 -0
  123. edsl/results/Selector.py +23 -13
  124. edsl/results/TableDisplay.py +98 -171
  125. edsl/results/TextEditor.py +50 -0
  126. edsl/results/__init__.py +1 -1
  127. edsl/results/smart_objects.py +96 -0
  128. edsl/results/table_data_class.py +12 -0
  129. edsl/results/table_renderers.py +118 -0
  130. edsl/scenarios/ConstructDownloadLink.py +109 -0
  131. edsl/scenarios/DirectoryScanner.py +96 -0
  132. edsl/scenarios/DocumentChunker.py +102 -0
  133. edsl/scenarios/DocxScenario.py +16 -0
  134. edsl/scenarios/FileStore.py +118 -239
  135. edsl/scenarios/PdfExtractor.py +40 -0
  136. edsl/scenarios/Scenario.py +90 -193
  137. edsl/scenarios/ScenarioHtmlMixin.py +4 -3
  138. edsl/scenarios/ScenarioJoin.py +10 -6
  139. edsl/scenarios/ScenarioList.py +383 -240
  140. edsl/scenarios/ScenarioListExportMixin.py +0 -7
  141. edsl/scenarios/ScenarioListPdfMixin.py +15 -37
  142. edsl/scenarios/ScenarioSelector.py +156 -0
  143. edsl/scenarios/__init__.py +1 -2
  144. edsl/scenarios/file_methods.py +85 -0
  145. edsl/scenarios/handlers/__init__.py +13 -0
  146. edsl/scenarios/handlers/csv.py +38 -0
  147. edsl/scenarios/handlers/docx.py +76 -0
  148. edsl/scenarios/handlers/html.py +37 -0
  149. edsl/scenarios/handlers/json.py +111 -0
  150. edsl/scenarios/handlers/latex.py +5 -0
  151. edsl/scenarios/handlers/md.py +51 -0
  152. edsl/scenarios/handlers/pdf.py +68 -0
  153. edsl/scenarios/handlers/png.py +39 -0
  154. edsl/scenarios/handlers/pptx.py +105 -0
  155. edsl/scenarios/handlers/py.py +294 -0
  156. edsl/scenarios/handlers/sql.py +313 -0
  157. edsl/scenarios/handlers/sqlite.py +149 -0
  158. edsl/scenarios/handlers/txt.py +33 -0
  159. edsl/study/ObjectEntry.py +1 -1
  160. edsl/study/SnapShot.py +1 -1
  161. edsl/study/Study.py +5 -12
  162. edsl/surveys/ConstructDAG.py +92 -0
  163. edsl/surveys/EditSurvey.py +221 -0
  164. edsl/surveys/InstructionHandler.py +100 -0
  165. edsl/surveys/MemoryManagement.py +72 -0
  166. edsl/surveys/Rule.py +5 -4
  167. edsl/surveys/RuleCollection.py +25 -27
  168. edsl/surveys/RuleManager.py +172 -0
  169. edsl/surveys/Simulator.py +75 -0
  170. edsl/surveys/Survey.py +199 -771
  171. edsl/surveys/SurveyCSS.py +20 -8
  172. edsl/surveys/{SurveyFlowVisualizationMixin.py → SurveyFlowVisualization.py} +11 -9
  173. edsl/surveys/SurveyToApp.py +141 -0
  174. edsl/surveys/__init__.py +4 -2
  175. edsl/surveys/descriptors.py +6 -2
  176. edsl/surveys/instructions/ChangeInstruction.py +1 -2
  177. edsl/surveys/instructions/Instruction.py +4 -13
  178. edsl/surveys/instructions/InstructionCollection.py +11 -6
  179. edsl/templates/error_reporting/interview_details.html +1 -1
  180. edsl/templates/error_reporting/report.html +1 -1
  181. edsl/tools/plotting.py +1 -1
  182. edsl/utilities/PrettyList.py +56 -0
  183. edsl/utilities/is_notebook.py +18 -0
  184. edsl/utilities/is_valid_variable_name.py +11 -0
  185. edsl/utilities/remove_edsl_version.py +24 -0
  186. edsl/utilities/utilities.py +35 -23
  187. {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev2.dist-info}/METADATA +12 -10
  188. edsl-0.1.39.dev2.dist-info/RECORD +352 -0
  189. edsl/language_models/KeyLookup.py +0 -30
  190. edsl/language_models/unused/ReplicateBase.py +0 -83
  191. edsl/results/ResultsDBMixin.py +0 -238
  192. edsl-0.1.39.dev1.dist-info/RECORD +0 -277
  193. {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev2.dist-info}/LICENSE +0 -0
  194. {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev2.dist-info}/WHEEL +0 -0
@@ -1,5 +1,4 @@
1
1
  from abc import abstractmethod, ABC
2
- import os
3
2
  import re
4
3
  from datetime import datetime, timedelta
5
4
  from edsl.config import CONFIG
@@ -8,31 +7,32 @@ from edsl.config import CONFIG
8
7
  class InferenceServiceABC(ABC):
9
8
  """
10
9
  Abstract class for inference services.
11
- Anthropic: https://docs.anthropic.com/en/api/rate-limits
12
10
  """
13
11
 
14
12
  _coop_config_vars = None
15
13
 
16
- default_levels = {
17
- "google": {"tpm": 2_000_000, "rpm": 15},
18
- "openai": {"tpm": 2_000_000, "rpm": 10_000},
19
- "anthropic": {"tpm": 2_000_000, "rpm": 500},
20
- }
21
-
22
14
  def __init_subclass__(cls):
23
15
  """
24
16
  Check that the subclass has the required attributes.
25
17
  - `key_sequence` attribute determines...
26
18
  - `model_exclude_list` attribute determines...
27
19
  """
28
- if not hasattr(cls, "key_sequence"):
29
- raise NotImplementedError(
30
- f"Class {cls.__name__} must have a 'key_sequence' attribute."
31
- )
32
- if not hasattr(cls, "model_exclude_list"):
33
- raise NotImplementedError(
34
- f"Class {cls.__name__} must have a 'model_exclude_list' attribute."
35
- )
20
+ must_have_attributes = [
21
+ "key_sequence",
22
+ "model_exclude_list",
23
+ "usage_sequence",
24
+ "input_token_name",
25
+ "output_token_name",
26
+ ]
27
+ for attr in must_have_attributes:
28
+ if not hasattr(cls, attr):
29
+ raise NotImplementedError(
30
+ f"Class {cls.__name__} must have a '{attr}' attribute."
31
+ )
32
+
33
+ @property
34
+ def service_name(self):
35
+ return self._inference_service_
36
36
 
37
37
  @classmethod
38
38
  def _should_refresh_coop_config_vars(cls):
@@ -44,44 +44,6 @@ class InferenceServiceABC(ABC):
44
44
  return True
45
45
  return (datetime.now() - cls._last_config_fetch) > timedelta(hours=24)
46
46
 
47
- @classmethod
48
- def _get_limt(cls, limit_type: str) -> int:
49
- key = f"EDSL_SERVICE_{limit_type.upper()}_{cls._inference_service_.upper()}"
50
- if key in os.environ:
51
- return int(os.getenv(key))
52
-
53
- if cls._coop_config_vars is None or cls._should_refresh_coop_config_vars():
54
- try:
55
- from edsl import Coop
56
-
57
- c = Coop()
58
- cls._coop_config_vars = c.fetch_rate_limit_config_vars()
59
- cls._last_config_fetch = datetime.now()
60
- if key in cls._coop_config_vars:
61
- return cls._coop_config_vars[key]
62
- except Exception:
63
- cls._coop_config_vars = None
64
- else:
65
- if key in cls._coop_config_vars:
66
- return cls._coop_config_vars[key]
67
-
68
- if cls._inference_service_ in cls.default_levels:
69
- return int(cls.default_levels[cls._inference_service_][limit_type])
70
-
71
- return int(CONFIG.get(f"EDSL_SERVICE_{limit_type.upper()}_BASELINE"))
72
-
73
- def get_tpm(cls) -> int:
74
- """
75
- Returns the TPM for the service. If the service is not defined in the environment variables, it will return the baseline TPM.
76
- """
77
- return cls._get_limt(limit_type="tpm")
78
-
79
- def get_rpm(cls):
80
- """
81
- Returns the RPM for the service. If the service is not defined in the environment variables, it will return the baseline RPM.
82
- """
83
- return cls._get_limt(limit_type="rpm")
84
-
85
47
  @abstractmethod
86
48
  def available() -> list[str]:
87
49
  """
@@ -113,35 +75,6 @@ class InferenceServiceABC(ABC):
113
75
 
114
76
 
115
77
  if __name__ == "__main__":
116
- pass
117
- # deep_infra_service = DeepInfraService("deep_infra", "DEEP_INFRA_API_KEY")
118
- # deep_infra_service.available()
119
- # m = deep_infra_service.create_model("microsoft/WizardLM-2-7B")
120
- # response = m().hello()
121
- # print(response)
122
-
123
- # anthropic_service = AnthropicService("anthropic", "ANTHROPIC_API_KEY")
124
- # anthropic_service.available()
125
- # m = anthropic_service.create_model("claude-3-opus-20240229")
126
- # response = m().hello()
127
- # print(response)
128
- # factory = OpenAIService("openai", "OPENAI_API")
129
- # factory.available()
130
- # m = factory.create_model("gpt-3.5-turbo")
131
- # response = m().hello()
132
-
133
- # from edsl import QuestionFreeText
134
- # results = QuestionFreeText.example().by(m()).run()
135
-
136
- # collection = InferenceServicesCollection([
137
- # OpenAIService,
138
- # AnthropicService,
139
- # DeepInfraService
140
- # ])
78
+ import doctest
141
79
 
142
- # available = collection.available()
143
- # factory = collection.create_model_factory(*available[0])
144
- # m = factory()
145
- # from edsl import QuestionFreeText
146
- # results = QuestionFreeText.example().by(m).run()
147
- # print(results)
80
+ doctest.testmod()
@@ -1,97 +1,122 @@
1
+ from functools import lru_cache
2
+ from collections import defaultdict
3
+ from typing import Optional, Protocol, Dict, List, Tuple, TYPE_CHECKING, Literal
4
+
1
5
  from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
2
- import warnings
6
+ from edsl.inference_services.AvailableModelFetcher import AvailableModelFetcher
7
+ from edsl.exceptions.inference_services import InferenceServiceError
8
+
9
+ if TYPE_CHECKING:
10
+ from edsl.language_models.LanguageModel import LanguageModel
11
+ from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
12
+
13
+
14
+ class ModelCreator(Protocol):
15
+ def create_model(self, model_name: str) -> "LanguageModel":
16
+ ...
17
+
18
+
19
+ from edsl.enums import InferenceServiceLiteral
20
+
21
+
22
+ class ModelResolver:
23
+ def __init__(
24
+ self,
25
+ services: List[InferenceServiceLiteral],
26
+ models_to_services: Dict[InferenceServiceLiteral, InferenceServiceABC],
27
+ availability_fetcher: "AvailableModelFetcher",
28
+ ):
29
+ """
30
+ Class for determining which service to use for a given model.
31
+ """
32
+ self.services = services
33
+ self._models_to_services = models_to_services
34
+ self.availability_fetcher = availability_fetcher
35
+ self._service_names_to_classes = {
36
+ service._inference_service_: service for service in services
37
+ }
38
+
39
+ def resolve_model(
40
+ self, model_name: str, service_name: Optional[InferenceServiceLiteral] = None
41
+ ) -> InferenceServiceABC:
42
+ """Returns an InferenceServiceABC object for the given model name.
43
+
44
+ :param model_name: The name of the model to resolve. E.g., 'gpt-4o'
45
+ :param service_name: The name of the service to use. E.g., 'openai'
46
+ :return: An InferenceServiceABC object
47
+
48
+ """
49
+ if model_name == "test":
50
+ from edsl.inference_services.TestService import TestService
51
+
52
+ return TestService()
53
+
54
+ if service_name is not None:
55
+ service: InferenceServiceABC = self._service_names_to_classes.get(
56
+ service_name
57
+ )
58
+ if not service:
59
+ raise InferenceServiceError(f"Service {service_name} not found")
60
+ return service
61
+
62
+ if model_name in self._models_to_services: # maybe we've seen it before!
63
+ return self._models_to_services[model_name]
64
+
65
+ for service in self.services:
66
+ (
67
+ available_models,
68
+ service_name,
69
+ ) = self.availability_fetcher.get_available_models_by_service(service)
70
+ if model_name in available_models:
71
+ self._models_to_services[model_name] = service
72
+ return service
73
+
74
+ raise InferenceServiceError(f"Model {model_name} not found in any services")
3
75
 
4
76
 
5
77
  class InferenceServicesCollection:
6
- added_models = {}
78
+ added_models = defaultdict(list) # Moved back to class level
7
79
 
8
- def __init__(self, services: list[InferenceServiceABC] = None):
80
+ def __init__(self, services: Optional[List[InferenceServiceABC]] = None):
9
81
  self.services = services or []
82
+ self._models_to_services: Dict[str, InferenceServiceABC] = {}
83
+
84
+ self.availability_fetcher = AvailableModelFetcher(
85
+ self.services, self.added_models
86
+ )
87
+ self.resolver = ModelResolver(
88
+ self.services, self._models_to_services, self.availability_fetcher
89
+ )
10
90
 
11
91
  @classmethod
12
- def add_model(cls, service_name, model_name):
92
+ def add_model(cls, service_name: str, model_name: str) -> None:
13
93
  if service_name not in cls.added_models:
14
- cls.added_models[service_name] = []
15
- cls.added_models[service_name].append(model_name)
16
-
17
- @staticmethod
18
- def _get_service_available(service, warn: bool = False) -> list[str]:
19
- try:
20
- service_models = service.available()
21
- except Exception:
22
- if warn:
23
- warnings.warn(
24
- f"""Error getting models for {service._inference_service_}.
25
- Check that you have properly stored your Expected Parrot API key and activated remote inference, or stored your own API keys for the language models that you want to use.
26
- See https://docs.expectedparrot.com/en/latest/api_keys.html for instructions on storing API keys.
27
- Relying on Coop.""",
28
- UserWarning,
29
- )
30
-
31
- # Use the list of models on Coop as a fallback
32
- try:
33
- from edsl import Coop
34
-
35
- c = Coop()
36
- models_from_coop = c.fetch_models()
37
- service_models = models_from_coop.get(service._inference_service_, [])
38
-
39
- # cache results
40
- service._models_list_cache = service_models
41
-
42
- # Finally, use the available models cache from the Python file
43
- except Exception:
44
- if warn:
45
- warnings.warn(
46
- f"""Error getting models for {service._inference_service_}.
47
- Relying on EDSL cache.""",
48
- UserWarning,
49
- )
50
-
51
- from edsl.inference_services.models_available_cache import (
52
- models_available,
53
- )
54
-
55
- service_models = models_available.get(service._inference_service_, [])
56
-
57
- # cache results
58
- service._models_list_cache = service_models
59
-
60
- return service_models
61
-
62
- def available(self):
63
- total_models = []
64
- for service in self.services:
65
- service_models = self._get_service_available(service)
66
- for model in service_models:
67
- total_models.append([model, service._inference_service_, -1])
94
+ cls.added_models[service_name].append(model_name)
68
95
 
69
- for model in self.added_models.get(service._inference_service_, []):
70
- total_models.append([model, service._inference_service_, -1])
96
+ def available(
97
+ self,
98
+ service: Optional[str] = None,
99
+ ) -> List[Tuple[str, str, int]]:
100
+ return self.availability_fetcher.available(service)
71
101
 
72
- sorted_models = sorted(total_models)
73
- for i, model in enumerate(sorted_models):
74
- model[2] = i
75
- model = tuple(model)
76
- return sorted_models
102
+ def reset_cache(self) -> None:
103
+ self.availability_fetcher.reset_cache()
77
104
 
78
- def register(self, service):
79
- self.services.append(service)
105
+ @property
106
+ def num_cache_entries(self) -> int:
107
+ return self.availability_fetcher.num_cache_entries
80
108
 
81
- def create_model_factory(self, model_name: str, service_name=None, index=None):
82
- from edsl.inference_services.TestService import TestService
109
+ def register(self, service: InferenceServiceABC) -> None:
110
+ self.services.append(service)
83
111
 
84
- if model_name == "test":
85
- return TestService.create_model(model_name)
112
+ def create_model_factory(
113
+ self, model_name: str, service_name: Optional[InferenceServiceLiteral] = None
114
+ ) -> "LanguageModel":
115
+ service = self.resolver.resolve_model(model_name, service_name)
116
+ return service.create_model(model_name)
86
117
 
87
- if service_name:
88
- for service in self.services:
89
- if service_name == service._inference_service_:
90
- return service.create_model(model_name)
91
118
 
92
- for service in self.services:
93
- if model_name in self._get_service_available(service):
94
- if service_name is None or service_name == service._inference_service_:
95
- return service.create_model(model_name)
119
+ if __name__ == "__main__":
120
+ import doctest
96
121
 
97
- raise Exception(f"Model {model_name} not found in any of the services")
122
+ doctest.testmod()
@@ -85,9 +85,6 @@ class MistralAIService(InferenceServiceABC):
85
85
  "top_p": 0.9,
86
86
  }
87
87
 
88
- _tpm = cls.get_tpm(cls)
89
- _rpm = cls.get_rpm(cls)
90
-
91
88
  def sync_client(self):
92
89
  return cls.sync_client()
93
90
 
@@ -5,7 +5,7 @@ import os
5
5
  import openai
6
6
 
7
7
  from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
8
- from edsl.language_models import LanguageModel
8
+ from edsl.language_models.LanguageModel import LanguageModel
9
9
  from edsl.inference_services.rate_limits_cache import rate_limits
10
10
  from edsl.utilities.utilities import fix_partial_correct_response
11
11
 
@@ -107,9 +107,6 @@ class OpenAIService(InferenceServiceABC):
107
107
  input_token_name = cls.input_token_name
108
108
  output_token_name = cls.output_token_name
109
109
 
110
- _rpm = cls.get_rpm(cls)
111
- _tpm = cls.get_tpm(cls)
112
-
113
110
  _inference_service_ = cls._inference_service_
114
111
  _model_ = model_name
115
112
  _parameters_ = {
@@ -51,9 +51,6 @@ class PerplexityService(OpenAIService):
51
51
  input_token_name = cls.input_token_name
52
52
  output_token_name = cls.output_token_name
53
53
 
54
- _rpm = cls.get_rpm(cls)
55
- _tpm = cls.get_tpm(cls)
56
-
57
54
  _inference_service_ = cls._inference_service_
58
55
  _model_ = model_name
59
56
 
@@ -0,0 +1,135 @@
1
+ from enum import Enum
2
+ from typing import List, Optional, TYPE_CHECKING
3
+ from functools import partial
4
+ import warnings
5
+
6
+ from edsl.inference_services.data_structures import AvailableModels, ModelNamesList
7
+
8
+ if TYPE_CHECKING:
9
+ from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
10
+
11
+
12
+ class ModelSource(Enum):
13
+ LOCAL = "local"
14
+ COOP = "coop"
15
+ CACHE = "cache"
16
+
17
+
18
+ class ServiceAvailability:
19
+ """This class is responsible for fetching the available models from different sources."""
20
+
21
+ _coop_model_list = None
22
+
23
+ def __init__(self, source_order: Optional[List[ModelSource]] = None):
24
+ """
25
+ Initialize with custom source order.
26
+ Default order is LOCAL -> COOP -> CACHE
27
+ """
28
+ self.source_order = source_order or [
29
+ ModelSource.LOCAL,
30
+ ModelSource.COOP,
31
+ ModelSource.CACHE,
32
+ ]
33
+
34
+ # Map sources to their fetch functions
35
+ self._source_fetchers = {
36
+ ModelSource.LOCAL: self._fetch_from_local_service,
37
+ ModelSource.COOP: self._fetch_from_coop,
38
+ ModelSource.CACHE: self._fetch_from_cache,
39
+ }
40
+
41
+ @classmethod
42
+ def models_from_coop(cls) -> AvailableModels:
43
+ if not cls._coop_model_list:
44
+ from edsl.coop.coop import Coop
45
+
46
+ c = Coop()
47
+ coop_model_list = c.fetch_models()
48
+ cls._coop_model_list = coop_model_list
49
+ return cls._coop_model_list
50
+
51
+ def get_service_available(
52
+ self, service: "InferenceServiceABC", warn: bool = False
53
+ ) -> ModelNamesList:
54
+ """
55
+ Try to fetch available models from sources in specified order.
56
+ Returns first successful result.
57
+ """
58
+ last_error = None
59
+
60
+ for source in self.source_order:
61
+ try:
62
+ fetch_func = partial(self._source_fetchers[source], service)
63
+ result = fetch_func()
64
+
65
+ # Cache successful result
66
+ service._models_list_cache = result
67
+ return result
68
+
69
+ except Exception as e:
70
+ last_error = e
71
+ if warn:
72
+ self._warn_source_failed(service, source)
73
+ continue
74
+
75
+ # If we get here, all sources failed
76
+ raise RuntimeError(
77
+ f"All sources failed to fetch models. Last error: {last_error}"
78
+ )
79
+
80
+ @staticmethod
81
+ def _fetch_from_local_service(service: "InferenceServiceABC") -> ModelNamesList:
82
+ """Attempt to fetch models directly from the service."""
83
+ return service.available()
84
+
85
+ @classmethod
86
+ def _fetch_from_coop(cls, service: "InferenceServiceABC") -> ModelNamesList:
87
+ """Fetch models from Coop."""
88
+ models_from_coop = cls.models_from_coop()
89
+ return models_from_coop.get(service._inference_service_, [])
90
+
91
+ @staticmethod
92
+ def _fetch_from_cache(service: "InferenceServiceABC") -> ModelNamesList:
93
+ """Fetch models from local cache."""
94
+ from edsl.inference_services.models_available_cache import models_available
95
+
96
+ return models_available.get(service._inference_service_, [])
97
+
98
+ def _warn_source_failed(self, service: "InferenceServiceABC", source: ModelSource):
99
+ """Display appropriate warning message based on failed source."""
100
+ messages = {
101
+ ModelSource.LOCAL: f"""Error getting models for {service._inference_service_}.
102
+ Check that you have properly stored your Expected Parrot API key and activated remote inference,
103
+ or stored your own API keys for the language models that you want to use.
104
+ See https://docs.expectedparrot.com/en/latest/api_keys.html for instructions on storing API keys.
105
+ Trying next source.""",
106
+ ModelSource.COOP: f"Error getting models from Coop for {service._inference_service_}. Trying next source.",
107
+ ModelSource.CACHE: f"Error getting models from cache for {service._inference_service_}.",
108
+ }
109
+ warnings.warn(messages[source], UserWarning)
110
+
111
+
112
+ if __name__ == "__main__":
113
+ # sa = ServiceAvailability()
114
+ # models_from_coop = sa.models_from_coop()
115
+ # print(models_from_coop)
116
+ from edsl.inference_services.OpenAIService import OpenAIService
117
+
118
+ openai_models = ServiceAvailability._fetch_from_local_service(OpenAIService())
119
+ print(openai_models)
120
+
121
+ # Example usage:
122
+ """
123
+ # Default order (LOCAL -> COOP -> CACHE)
124
+ availability = ServiceAvailability()
125
+
126
+ # Custom order (COOP -> LOCAL -> CACHE)
127
+ availability_coop_first = ServiceAvailability([
128
+ ModelSource.COOP,
129
+ ModelSource.LOCAL,
130
+ ModelSource.CACHE
131
+ ])
132
+
133
+ # Get available models using custom order
134
+ models = availability_coop_first.get_service_available(service, warn=True)
135
+ """
@@ -2,7 +2,7 @@ from typing import Any, List, Optional
2
2
  import os
3
3
  import asyncio
4
4
  from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
5
- from edsl.language_models import LanguageModel
5
+ from edsl.language_models.LanguageModel import LanguageModel
6
6
  from edsl.inference_services.rate_limits_cache import rate_limits
7
7
  from edsl.utilities.utilities import fix_partial_correct_response
8
8
 
@@ -65,13 +65,7 @@ class TestService(InferenceServiceABC):
65
65
  await asyncio.sleep(0.1)
66
66
  # return {"message": """{"answer": "Hello, world"}"""}
67
67
 
68
- if hasattr(self, "func"):
69
- return {
70
- "message": [
71
- {"text": self.func(user_prompt, system_prompt, files_list)}
72
- ],
73
- "usage": {"prompt_tokens": 1, "completion_tokens": 1},
74
- }
68
+ # breakpoint()
75
69
 
76
70
  if hasattr(self, "throw_exception") and self.throw_exception:
77
71
  if hasattr(self, "exception_probability"):
@@ -81,6 +75,15 @@ class TestService(InferenceServiceABC):
81
75
 
82
76
  if random.random() < p:
83
77
  raise Exception("This is a test error")
78
+
79
+ if hasattr(self, "func"):
80
+ return {
81
+ "message": [
82
+ {"text": self.func(user_prompt, system_prompt, files_list)}
83
+ ],
84
+ "usage": {"prompt_tokens": 1, "completion_tokens": 1},
85
+ }
86
+
84
87
  return {
85
88
  "message": [{"text": f"{self._canned_response}"}],
86
89
  "usage": {"prompt_tokens": 1, "completion_tokens": 1},
@@ -0,0 +1,62 @@
1
+ from collections import UserDict, defaultdict, UserList
2
+ from typing import Union
3
+ from edsl.enums import InferenceServiceLiteral
4
+ from dataclasses import dataclass
5
+
6
+
7
+ @dataclass
8
+ class LanguageModelInfo:
9
+ model_name: str
10
+ service_name: str
11
+
12
+ def __getitem__(self, key: int) -> str:
13
+ import warnings
14
+
15
+ warnings.warn(
16
+ "Accessing LanguageModelInfo via index is deprecated. "
17
+ "Please use .model_name, .service_name, or .index attributes instead.",
18
+ DeprecationWarning,
19
+ stacklevel=2,
20
+ )
21
+
22
+ if key == 0:
23
+ return self.model_name
24
+ elif key == 1:
25
+ return self.service_name
26
+ else:
27
+ raise IndexError("Index out of range")
28
+
29
+
30
+ class ModelNamesList(UserList):
31
+ pass
32
+
33
+
34
+ class AvailableModels(UserList):
35
+ def __init__(self, data: list) -> None:
36
+ super().__init__(data)
37
+
38
+ def __contains__(self, model_name: str) -> bool:
39
+ for model_entry in self:
40
+ if model_entry.model_name == model_name:
41
+ return True
42
+ return False
43
+
44
+
45
+ class ServiceToModelsMapping(UserDict):
46
+ def __init__(self, data: dict) -> None:
47
+ super().__init__(data)
48
+
49
+ @property
50
+ def service_names(self) -> list[str]:
51
+ return list(self.data.keys())
52
+
53
+ def _validate_service_names(self):
54
+ for service in self.service_names:
55
+ if service not in InferenceServiceLiteral:
56
+ raise ValueError(f"Invalid service name: {service}")
57
+
58
+ def model_to_services(self) -> dict:
59
+ self._model_to_service = defaultdict(list)
60
+ for service, models in self.data.items():
61
+ for model in models:
62
+ self._model_to_service[model].append(service)