edsl 0.1.39.dev1__py3-none-any.whl → 0.1.39.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +169 -116
- edsl/__init__.py +14 -6
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +358 -146
- edsl/agents/AgentList.py +211 -73
- edsl/agents/Invigilator.py +88 -36
- edsl/agents/InvigilatorBase.py +59 -70
- edsl/agents/PromptConstructor.py +117 -219
- edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
- edsl/agents/QuestionOptionProcessor.py +172 -0
- edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
- edsl/agents/__init__.py +0 -1
- edsl/agents/prompt_helpers.py +3 -3
- edsl/config.py +22 -2
- edsl/conversation/car_buying.py +2 -1
- edsl/coop/CoopFunctionsMixin.py +15 -0
- edsl/coop/ExpectedParrotKeyHandler.py +125 -0
- edsl/coop/PriceFetcher.py +1 -1
- edsl/coop/coop.py +104 -42
- edsl/coop/utils.py +14 -14
- edsl/data/Cache.py +21 -14
- edsl/data/CacheEntry.py +12 -15
- edsl/data/CacheHandler.py +33 -12
- edsl/data/__init__.py +4 -3
- edsl/data_transfer_models.py +2 -1
- edsl/enums.py +20 -0
- edsl/exceptions/__init__.py +50 -50
- edsl/exceptions/agents.py +12 -0
- edsl/exceptions/inference_services.py +5 -0
- edsl/exceptions/questions.py +24 -6
- edsl/exceptions/scenarios.py +7 -0
- edsl/inference_services/AnthropicService.py +0 -3
- edsl/inference_services/AvailableModelCacheHandler.py +184 -0
- edsl/inference_services/AvailableModelFetcher.py +209 -0
- edsl/inference_services/AwsBedrock.py +0 -2
- edsl/inference_services/AzureAI.py +0 -2
- edsl/inference_services/GoogleService.py +2 -11
- edsl/inference_services/InferenceServiceABC.py +18 -85
- edsl/inference_services/InferenceServicesCollection.py +105 -80
- edsl/inference_services/MistralAIService.py +0 -3
- edsl/inference_services/OpenAIService.py +1 -4
- edsl/inference_services/PerplexityService.py +0 -3
- edsl/inference_services/ServiceAvailability.py +135 -0
- edsl/inference_services/TestService.py +11 -8
- edsl/inference_services/data_structures.py +62 -0
- edsl/jobs/AnswerQuestionFunctionConstructor.py +188 -0
- edsl/jobs/Answers.py +1 -14
- edsl/jobs/FetchInvigilator.py +40 -0
- edsl/jobs/InterviewTaskManager.py +98 -0
- edsl/jobs/InterviewsConstructor.py +48 -0
- edsl/jobs/Jobs.py +102 -243
- edsl/jobs/JobsChecks.py +35 -10
- edsl/jobs/JobsComponentConstructor.py +189 -0
- edsl/jobs/JobsPrompts.py +5 -3
- edsl/jobs/JobsRemoteInferenceHandler.py +128 -80
- edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
- edsl/jobs/RequestTokenEstimator.py +30 -0
- edsl/jobs/buckets/BucketCollection.py +44 -3
- edsl/jobs/buckets/TokenBucket.py +53 -21
- edsl/jobs/buckets/TokenBucketAPI.py +211 -0
- edsl/jobs/buckets/TokenBucketClient.py +191 -0
- edsl/jobs/decorators.py +35 -0
- edsl/jobs/interviews/Interview.py +77 -380
- edsl/jobs/jobs_status_enums.py +9 -0
- edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +4 -49
- edsl/jobs/tasks/QuestionTaskCreator.py +21 -19
- edsl/jobs/tasks/TaskHistory.py +14 -15
- edsl/jobs/tasks/task_status_enum.py +0 -2
- edsl/language_models/ComputeCost.py +63 -0
- edsl/language_models/LanguageModel.py +137 -234
- edsl/language_models/ModelList.py +11 -13
- edsl/language_models/PriceManager.py +127 -0
- edsl/language_models/RawResponseHandler.py +106 -0
- edsl/language_models/ServiceDataSources.py +0 -0
- edsl/language_models/__init__.py +0 -1
- edsl/language_models/key_management/KeyLookup.py +63 -0
- edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
- edsl/language_models/key_management/KeyLookupCollection.py +38 -0
- edsl/language_models/key_management/__init__.py +0 -0
- edsl/language_models/key_management/models.py +131 -0
- edsl/language_models/registry.py +49 -59
- edsl/language_models/repair.py +2 -2
- edsl/language_models/utilities.py +5 -4
- edsl/notebooks/Notebook.py +19 -14
- edsl/notebooks/NotebookToLaTeX.py +142 -0
- edsl/prompts/Prompt.py +29 -39
- edsl/questions/AnswerValidatorMixin.py +47 -2
- edsl/questions/ExceptionExplainer.py +77 -0
- edsl/questions/HTMLQuestion.py +103 -0
- edsl/questions/LoopProcessor.py +149 -0
- edsl/questions/QuestionBase.py +37 -192
- edsl/questions/QuestionBaseGenMixin.py +52 -48
- edsl/questions/QuestionBasePromptsMixin.py +7 -3
- edsl/questions/QuestionCheckBox.py +1 -1
- edsl/questions/QuestionExtract.py +1 -1
- edsl/questions/QuestionFreeText.py +1 -2
- edsl/questions/QuestionList.py +3 -5
- edsl/questions/QuestionMatrix.py +265 -0
- edsl/questions/QuestionMultipleChoice.py +66 -22
- edsl/questions/QuestionNumerical.py +1 -3
- edsl/questions/QuestionRank.py +6 -16
- edsl/questions/ResponseValidatorABC.py +37 -11
- edsl/questions/ResponseValidatorFactory.py +28 -0
- edsl/questions/SimpleAskMixin.py +4 -3
- edsl/questions/__init__.py +1 -0
- edsl/questions/derived/QuestionLinearScale.py +6 -3
- edsl/questions/derived/QuestionTopK.py +1 -1
- edsl/questions/descriptors.py +17 -3
- edsl/questions/question_registry.py +1 -1
- edsl/questions/templates/matrix/__init__.py +1 -0
- edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
- edsl/questions/templates/matrix/question_presentation.jinja +20 -0
- edsl/results/CSSParameterizer.py +1 -1
- edsl/results/Dataset.py +170 -7
- edsl/results/DatasetExportMixin.py +224 -302
- edsl/results/DatasetTree.py +28 -8
- edsl/results/MarkdownToDocx.py +122 -0
- edsl/results/MarkdownToPDF.py +111 -0
- edsl/results/Result.py +192 -206
- edsl/results/Results.py +120 -113
- edsl/results/ResultsExportMixin.py +2 -0
- edsl/results/Selector.py +23 -13
- edsl/results/TableDisplay.py +98 -171
- edsl/results/TextEditor.py +50 -0
- edsl/results/__init__.py +1 -1
- edsl/results/smart_objects.py +96 -0
- edsl/results/table_data_class.py +12 -0
- edsl/results/table_renderers.py +118 -0
- edsl/scenarios/ConstructDownloadLink.py +109 -0
- edsl/scenarios/DirectoryScanner.py +96 -0
- edsl/scenarios/DocumentChunker.py +102 -0
- edsl/scenarios/DocxScenario.py +16 -0
- edsl/scenarios/FileStore.py +118 -239
- edsl/scenarios/PdfExtractor.py +40 -0
- edsl/scenarios/Scenario.py +90 -193
- edsl/scenarios/ScenarioHtmlMixin.py +4 -3
- edsl/scenarios/ScenarioJoin.py +10 -6
- edsl/scenarios/ScenarioList.py +383 -240
- edsl/scenarios/ScenarioListExportMixin.py +0 -7
- edsl/scenarios/ScenarioListPdfMixin.py +15 -37
- edsl/scenarios/ScenarioSelector.py +156 -0
- edsl/scenarios/__init__.py +1 -2
- edsl/scenarios/file_methods.py +85 -0
- edsl/scenarios/handlers/__init__.py +13 -0
- edsl/scenarios/handlers/csv.py +38 -0
- edsl/scenarios/handlers/docx.py +76 -0
- edsl/scenarios/handlers/html.py +37 -0
- edsl/scenarios/handlers/json.py +111 -0
- edsl/scenarios/handlers/latex.py +5 -0
- edsl/scenarios/handlers/md.py +51 -0
- edsl/scenarios/handlers/pdf.py +68 -0
- edsl/scenarios/handlers/png.py +39 -0
- edsl/scenarios/handlers/pptx.py +105 -0
- edsl/scenarios/handlers/py.py +294 -0
- edsl/scenarios/handlers/sql.py +313 -0
- edsl/scenarios/handlers/sqlite.py +149 -0
- edsl/scenarios/handlers/txt.py +33 -0
- edsl/study/ObjectEntry.py +1 -1
- edsl/study/SnapShot.py +1 -1
- edsl/study/Study.py +5 -12
- edsl/surveys/ConstructDAG.py +92 -0
- edsl/surveys/EditSurvey.py +221 -0
- edsl/surveys/InstructionHandler.py +100 -0
- edsl/surveys/MemoryManagement.py +72 -0
- edsl/surveys/Rule.py +5 -4
- edsl/surveys/RuleCollection.py +25 -27
- edsl/surveys/RuleManager.py +172 -0
- edsl/surveys/Simulator.py +75 -0
- edsl/surveys/Survey.py +199 -771
- edsl/surveys/SurveyCSS.py +20 -8
- edsl/surveys/{SurveyFlowVisualizationMixin.py → SurveyFlowVisualization.py} +11 -9
- edsl/surveys/SurveyToApp.py +141 -0
- edsl/surveys/__init__.py +4 -2
- edsl/surveys/descriptors.py +6 -2
- edsl/surveys/instructions/ChangeInstruction.py +1 -2
- edsl/surveys/instructions/Instruction.py +4 -13
- edsl/surveys/instructions/InstructionCollection.py +11 -6
- edsl/templates/error_reporting/interview_details.html +1 -1
- edsl/templates/error_reporting/report.html +1 -1
- edsl/tools/plotting.py +1 -1
- edsl/utilities/PrettyList.py +56 -0
- edsl/utilities/is_notebook.py +18 -0
- edsl/utilities/is_valid_variable_name.py +11 -0
- edsl/utilities/remove_edsl_version.py +24 -0
- edsl/utilities/utilities.py +35 -23
- {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev2.dist-info}/METADATA +12 -10
- edsl-0.1.39.dev2.dist-info/RECORD +352 -0
- edsl/language_models/KeyLookup.py +0 -30
- edsl/language_models/unused/ReplicateBase.py +0 -83
- edsl/results/ResultsDBMixin.py +0 -238
- edsl-0.1.39.dev1.dist-info/RECORD +0 -277
- {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev2.dist-info}/LICENSE +0 -0
- {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev2.dist-info}/WHEEL +0 -0
@@ -1,5 +1,4 @@
|
|
1
1
|
from abc import abstractmethod, ABC
|
2
|
-
import os
|
3
2
|
import re
|
4
3
|
from datetime import datetime, timedelta
|
5
4
|
from edsl.config import CONFIG
|
@@ -8,31 +7,32 @@ from edsl.config import CONFIG
|
|
8
7
|
class InferenceServiceABC(ABC):
|
9
8
|
"""
|
10
9
|
Abstract class for inference services.
|
11
|
-
Anthropic: https://docs.anthropic.com/en/api/rate-limits
|
12
10
|
"""
|
13
11
|
|
14
12
|
_coop_config_vars = None
|
15
13
|
|
16
|
-
default_levels = {
|
17
|
-
"google": {"tpm": 2_000_000, "rpm": 15},
|
18
|
-
"openai": {"tpm": 2_000_000, "rpm": 10_000},
|
19
|
-
"anthropic": {"tpm": 2_000_000, "rpm": 500},
|
20
|
-
}
|
21
|
-
|
22
14
|
def __init_subclass__(cls):
|
23
15
|
"""
|
24
16
|
Check that the subclass has the required attributes.
|
25
17
|
- `key_sequence` attribute determines...
|
26
18
|
- `model_exclude_list` attribute determines...
|
27
19
|
"""
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
20
|
+
must_have_attributes = [
|
21
|
+
"key_sequence",
|
22
|
+
"model_exclude_list",
|
23
|
+
"usage_sequence",
|
24
|
+
"input_token_name",
|
25
|
+
"output_token_name",
|
26
|
+
]
|
27
|
+
for attr in must_have_attributes:
|
28
|
+
if not hasattr(cls, attr):
|
29
|
+
raise NotImplementedError(
|
30
|
+
f"Class {cls.__name__} must have a '{attr}' attribute."
|
31
|
+
)
|
32
|
+
|
33
|
+
@property
|
34
|
+
def service_name(self):
|
35
|
+
return self._inference_service_
|
36
36
|
|
37
37
|
@classmethod
|
38
38
|
def _should_refresh_coop_config_vars(cls):
|
@@ -44,44 +44,6 @@ class InferenceServiceABC(ABC):
|
|
44
44
|
return True
|
45
45
|
return (datetime.now() - cls._last_config_fetch) > timedelta(hours=24)
|
46
46
|
|
47
|
-
@classmethod
|
48
|
-
def _get_limt(cls, limit_type: str) -> int:
|
49
|
-
key = f"EDSL_SERVICE_{limit_type.upper()}_{cls._inference_service_.upper()}"
|
50
|
-
if key in os.environ:
|
51
|
-
return int(os.getenv(key))
|
52
|
-
|
53
|
-
if cls._coop_config_vars is None or cls._should_refresh_coop_config_vars():
|
54
|
-
try:
|
55
|
-
from edsl import Coop
|
56
|
-
|
57
|
-
c = Coop()
|
58
|
-
cls._coop_config_vars = c.fetch_rate_limit_config_vars()
|
59
|
-
cls._last_config_fetch = datetime.now()
|
60
|
-
if key in cls._coop_config_vars:
|
61
|
-
return cls._coop_config_vars[key]
|
62
|
-
except Exception:
|
63
|
-
cls._coop_config_vars = None
|
64
|
-
else:
|
65
|
-
if key in cls._coop_config_vars:
|
66
|
-
return cls._coop_config_vars[key]
|
67
|
-
|
68
|
-
if cls._inference_service_ in cls.default_levels:
|
69
|
-
return int(cls.default_levels[cls._inference_service_][limit_type])
|
70
|
-
|
71
|
-
return int(CONFIG.get(f"EDSL_SERVICE_{limit_type.upper()}_BASELINE"))
|
72
|
-
|
73
|
-
def get_tpm(cls) -> int:
|
74
|
-
"""
|
75
|
-
Returns the TPM for the service. If the service is not defined in the environment variables, it will return the baseline TPM.
|
76
|
-
"""
|
77
|
-
return cls._get_limt(limit_type="tpm")
|
78
|
-
|
79
|
-
def get_rpm(cls):
|
80
|
-
"""
|
81
|
-
Returns the RPM for the service. If the service is not defined in the environment variables, it will return the baseline RPM.
|
82
|
-
"""
|
83
|
-
return cls._get_limt(limit_type="rpm")
|
84
|
-
|
85
47
|
@abstractmethod
|
86
48
|
def available() -> list[str]:
|
87
49
|
"""
|
@@ -113,35 +75,6 @@ class InferenceServiceABC(ABC):
|
|
113
75
|
|
114
76
|
|
115
77
|
if __name__ == "__main__":
|
116
|
-
|
117
|
-
# deep_infra_service = DeepInfraService("deep_infra", "DEEP_INFRA_API_KEY")
|
118
|
-
# deep_infra_service.available()
|
119
|
-
# m = deep_infra_service.create_model("microsoft/WizardLM-2-7B")
|
120
|
-
# response = m().hello()
|
121
|
-
# print(response)
|
122
|
-
|
123
|
-
# anthropic_service = AnthropicService("anthropic", "ANTHROPIC_API_KEY")
|
124
|
-
# anthropic_service.available()
|
125
|
-
# m = anthropic_service.create_model("claude-3-opus-20240229")
|
126
|
-
# response = m().hello()
|
127
|
-
# print(response)
|
128
|
-
# factory = OpenAIService("openai", "OPENAI_API")
|
129
|
-
# factory.available()
|
130
|
-
# m = factory.create_model("gpt-3.5-turbo")
|
131
|
-
# response = m().hello()
|
132
|
-
|
133
|
-
# from edsl import QuestionFreeText
|
134
|
-
# results = QuestionFreeText.example().by(m()).run()
|
135
|
-
|
136
|
-
# collection = InferenceServicesCollection([
|
137
|
-
# OpenAIService,
|
138
|
-
# AnthropicService,
|
139
|
-
# DeepInfraService
|
140
|
-
# ])
|
78
|
+
import doctest
|
141
79
|
|
142
|
-
|
143
|
-
# factory = collection.create_model_factory(*available[0])
|
144
|
-
# m = factory()
|
145
|
-
# from edsl import QuestionFreeText
|
146
|
-
# results = QuestionFreeText.example().by(m).run()
|
147
|
-
# print(results)
|
80
|
+
doctest.testmod()
|
@@ -1,97 +1,122 @@
|
|
1
|
+
from functools import lru_cache
|
2
|
+
from collections import defaultdict
|
3
|
+
from typing import Optional, Protocol, Dict, List, Tuple, TYPE_CHECKING, Literal
|
4
|
+
|
1
5
|
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
2
|
-
import
|
6
|
+
from edsl.inference_services.AvailableModelFetcher import AvailableModelFetcher
|
7
|
+
from edsl.exceptions.inference_services import InferenceServiceError
|
8
|
+
|
9
|
+
if TYPE_CHECKING:
|
10
|
+
from edsl.language_models.LanguageModel import LanguageModel
|
11
|
+
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
12
|
+
|
13
|
+
|
14
|
+
class ModelCreator(Protocol):
|
15
|
+
def create_model(self, model_name: str) -> "LanguageModel":
|
16
|
+
...
|
17
|
+
|
18
|
+
|
19
|
+
from edsl.enums import InferenceServiceLiteral
|
20
|
+
|
21
|
+
|
22
|
+
class ModelResolver:
|
23
|
+
def __init__(
|
24
|
+
self,
|
25
|
+
services: List[InferenceServiceLiteral],
|
26
|
+
models_to_services: Dict[InferenceServiceLiteral, InferenceServiceABC],
|
27
|
+
availability_fetcher: "AvailableModelFetcher",
|
28
|
+
):
|
29
|
+
"""
|
30
|
+
Class for determining which service to use for a given model.
|
31
|
+
"""
|
32
|
+
self.services = services
|
33
|
+
self._models_to_services = models_to_services
|
34
|
+
self.availability_fetcher = availability_fetcher
|
35
|
+
self._service_names_to_classes = {
|
36
|
+
service._inference_service_: service for service in services
|
37
|
+
}
|
38
|
+
|
39
|
+
def resolve_model(
|
40
|
+
self, model_name: str, service_name: Optional[InferenceServiceLiteral] = None
|
41
|
+
) -> InferenceServiceABC:
|
42
|
+
"""Returns an InferenceServiceABC object for the given model name.
|
43
|
+
|
44
|
+
:param model_name: The name of the model to resolve. E.g., 'gpt-4o'
|
45
|
+
:param service_name: The name of the service to use. E.g., 'openai'
|
46
|
+
:return: An InferenceServiceABC object
|
47
|
+
|
48
|
+
"""
|
49
|
+
if model_name == "test":
|
50
|
+
from edsl.inference_services.TestService import TestService
|
51
|
+
|
52
|
+
return TestService()
|
53
|
+
|
54
|
+
if service_name is not None:
|
55
|
+
service: InferenceServiceABC = self._service_names_to_classes.get(
|
56
|
+
service_name
|
57
|
+
)
|
58
|
+
if not service:
|
59
|
+
raise InferenceServiceError(f"Service {service_name} not found")
|
60
|
+
return service
|
61
|
+
|
62
|
+
if model_name in self._models_to_services: # maybe we've seen it before!
|
63
|
+
return self._models_to_services[model_name]
|
64
|
+
|
65
|
+
for service in self.services:
|
66
|
+
(
|
67
|
+
available_models,
|
68
|
+
service_name,
|
69
|
+
) = self.availability_fetcher.get_available_models_by_service(service)
|
70
|
+
if model_name in available_models:
|
71
|
+
self._models_to_services[model_name] = service
|
72
|
+
return service
|
73
|
+
|
74
|
+
raise InferenceServiceError(f"Model {model_name} not found in any services")
|
3
75
|
|
4
76
|
|
5
77
|
class InferenceServicesCollection:
|
6
|
-
added_models =
|
78
|
+
added_models = defaultdict(list) # Moved back to class level
|
7
79
|
|
8
|
-
def __init__(self, services:
|
80
|
+
def __init__(self, services: Optional[List[InferenceServiceABC]] = None):
|
9
81
|
self.services = services or []
|
82
|
+
self._models_to_services: Dict[str, InferenceServiceABC] = {}
|
83
|
+
|
84
|
+
self.availability_fetcher = AvailableModelFetcher(
|
85
|
+
self.services, self.added_models
|
86
|
+
)
|
87
|
+
self.resolver = ModelResolver(
|
88
|
+
self.services, self._models_to_services, self.availability_fetcher
|
89
|
+
)
|
10
90
|
|
11
91
|
@classmethod
|
12
|
-
def add_model(cls, service_name, model_name):
|
92
|
+
def add_model(cls, service_name: str, model_name: str) -> None:
|
13
93
|
if service_name not in cls.added_models:
|
14
|
-
cls.added_models[service_name]
|
15
|
-
cls.added_models[service_name].append(model_name)
|
16
|
-
|
17
|
-
@staticmethod
|
18
|
-
def _get_service_available(service, warn: bool = False) -> list[str]:
|
19
|
-
try:
|
20
|
-
service_models = service.available()
|
21
|
-
except Exception:
|
22
|
-
if warn:
|
23
|
-
warnings.warn(
|
24
|
-
f"""Error getting models for {service._inference_service_}.
|
25
|
-
Check that you have properly stored your Expected Parrot API key and activated remote inference, or stored your own API keys for the language models that you want to use.
|
26
|
-
See https://docs.expectedparrot.com/en/latest/api_keys.html for instructions on storing API keys.
|
27
|
-
Relying on Coop.""",
|
28
|
-
UserWarning,
|
29
|
-
)
|
30
|
-
|
31
|
-
# Use the list of models on Coop as a fallback
|
32
|
-
try:
|
33
|
-
from edsl import Coop
|
34
|
-
|
35
|
-
c = Coop()
|
36
|
-
models_from_coop = c.fetch_models()
|
37
|
-
service_models = models_from_coop.get(service._inference_service_, [])
|
38
|
-
|
39
|
-
# cache results
|
40
|
-
service._models_list_cache = service_models
|
41
|
-
|
42
|
-
# Finally, use the available models cache from the Python file
|
43
|
-
except Exception:
|
44
|
-
if warn:
|
45
|
-
warnings.warn(
|
46
|
-
f"""Error getting models for {service._inference_service_}.
|
47
|
-
Relying on EDSL cache.""",
|
48
|
-
UserWarning,
|
49
|
-
)
|
50
|
-
|
51
|
-
from edsl.inference_services.models_available_cache import (
|
52
|
-
models_available,
|
53
|
-
)
|
54
|
-
|
55
|
-
service_models = models_available.get(service._inference_service_, [])
|
56
|
-
|
57
|
-
# cache results
|
58
|
-
service._models_list_cache = service_models
|
59
|
-
|
60
|
-
return service_models
|
61
|
-
|
62
|
-
def available(self):
|
63
|
-
total_models = []
|
64
|
-
for service in self.services:
|
65
|
-
service_models = self._get_service_available(service)
|
66
|
-
for model in service_models:
|
67
|
-
total_models.append([model, service._inference_service_, -1])
|
94
|
+
cls.added_models[service_name].append(model_name)
|
68
95
|
|
69
|
-
|
70
|
-
|
96
|
+
def available(
|
97
|
+
self,
|
98
|
+
service: Optional[str] = None,
|
99
|
+
) -> List[Tuple[str, str, int]]:
|
100
|
+
return self.availability_fetcher.available(service)
|
71
101
|
|
72
|
-
|
73
|
-
|
74
|
-
model[2] = i
|
75
|
-
model = tuple(model)
|
76
|
-
return sorted_models
|
102
|
+
def reset_cache(self) -> None:
|
103
|
+
self.availability_fetcher.reset_cache()
|
77
104
|
|
78
|
-
|
79
|
-
|
105
|
+
@property
|
106
|
+
def num_cache_entries(self) -> int:
|
107
|
+
return self.availability_fetcher.num_cache_entries
|
80
108
|
|
81
|
-
def
|
82
|
-
|
109
|
+
def register(self, service: InferenceServiceABC) -> None:
|
110
|
+
self.services.append(service)
|
83
111
|
|
84
|
-
|
85
|
-
|
112
|
+
def create_model_factory(
|
113
|
+
self, model_name: str, service_name: Optional[InferenceServiceLiteral] = None
|
114
|
+
) -> "LanguageModel":
|
115
|
+
service = self.resolver.resolve_model(model_name, service_name)
|
116
|
+
return service.create_model(model_name)
|
86
117
|
|
87
|
-
if service_name:
|
88
|
-
for service in self.services:
|
89
|
-
if service_name == service._inference_service_:
|
90
|
-
return service.create_model(model_name)
|
91
118
|
|
92
|
-
|
93
|
-
|
94
|
-
if service_name is None or service_name == service._inference_service_:
|
95
|
-
return service.create_model(model_name)
|
119
|
+
if __name__ == "__main__":
|
120
|
+
import doctest
|
96
121
|
|
97
|
-
|
122
|
+
doctest.testmod()
|
@@ -5,7 +5,7 @@ import os
|
|
5
5
|
import openai
|
6
6
|
|
7
7
|
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
8
|
-
from edsl.language_models import LanguageModel
|
8
|
+
from edsl.language_models.LanguageModel import LanguageModel
|
9
9
|
from edsl.inference_services.rate_limits_cache import rate_limits
|
10
10
|
from edsl.utilities.utilities import fix_partial_correct_response
|
11
11
|
|
@@ -107,9 +107,6 @@ class OpenAIService(InferenceServiceABC):
|
|
107
107
|
input_token_name = cls.input_token_name
|
108
108
|
output_token_name = cls.output_token_name
|
109
109
|
|
110
|
-
_rpm = cls.get_rpm(cls)
|
111
|
-
_tpm = cls.get_tpm(cls)
|
112
|
-
|
113
110
|
_inference_service_ = cls._inference_service_
|
114
111
|
_model_ = model_name
|
115
112
|
_parameters_ = {
|
@@ -51,9 +51,6 @@ class PerplexityService(OpenAIService):
|
|
51
51
|
input_token_name = cls.input_token_name
|
52
52
|
output_token_name = cls.output_token_name
|
53
53
|
|
54
|
-
_rpm = cls.get_rpm(cls)
|
55
|
-
_tpm = cls.get_tpm(cls)
|
56
|
-
|
57
54
|
_inference_service_ = cls._inference_service_
|
58
55
|
_model_ = model_name
|
59
56
|
|
@@ -0,0 +1,135 @@
|
|
1
|
+
from enum import Enum
|
2
|
+
from typing import List, Optional, TYPE_CHECKING
|
3
|
+
from functools import partial
|
4
|
+
import warnings
|
5
|
+
|
6
|
+
from edsl.inference_services.data_structures import AvailableModels, ModelNamesList
|
7
|
+
|
8
|
+
if TYPE_CHECKING:
|
9
|
+
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
10
|
+
|
11
|
+
|
12
|
+
class ModelSource(Enum):
|
13
|
+
LOCAL = "local"
|
14
|
+
COOP = "coop"
|
15
|
+
CACHE = "cache"
|
16
|
+
|
17
|
+
|
18
|
+
class ServiceAvailability:
|
19
|
+
"""This class is responsible for fetching the available models from different sources."""
|
20
|
+
|
21
|
+
_coop_model_list = None
|
22
|
+
|
23
|
+
def __init__(self, source_order: Optional[List[ModelSource]] = None):
|
24
|
+
"""
|
25
|
+
Initialize with custom source order.
|
26
|
+
Default order is LOCAL -> COOP -> CACHE
|
27
|
+
"""
|
28
|
+
self.source_order = source_order or [
|
29
|
+
ModelSource.LOCAL,
|
30
|
+
ModelSource.COOP,
|
31
|
+
ModelSource.CACHE,
|
32
|
+
]
|
33
|
+
|
34
|
+
# Map sources to their fetch functions
|
35
|
+
self._source_fetchers = {
|
36
|
+
ModelSource.LOCAL: self._fetch_from_local_service,
|
37
|
+
ModelSource.COOP: self._fetch_from_coop,
|
38
|
+
ModelSource.CACHE: self._fetch_from_cache,
|
39
|
+
}
|
40
|
+
|
41
|
+
@classmethod
|
42
|
+
def models_from_coop(cls) -> AvailableModels:
|
43
|
+
if not cls._coop_model_list:
|
44
|
+
from edsl.coop.coop import Coop
|
45
|
+
|
46
|
+
c = Coop()
|
47
|
+
coop_model_list = c.fetch_models()
|
48
|
+
cls._coop_model_list = coop_model_list
|
49
|
+
return cls._coop_model_list
|
50
|
+
|
51
|
+
def get_service_available(
|
52
|
+
self, service: "InferenceServiceABC", warn: bool = False
|
53
|
+
) -> ModelNamesList:
|
54
|
+
"""
|
55
|
+
Try to fetch available models from sources in specified order.
|
56
|
+
Returns first successful result.
|
57
|
+
"""
|
58
|
+
last_error = None
|
59
|
+
|
60
|
+
for source in self.source_order:
|
61
|
+
try:
|
62
|
+
fetch_func = partial(self._source_fetchers[source], service)
|
63
|
+
result = fetch_func()
|
64
|
+
|
65
|
+
# Cache successful result
|
66
|
+
service._models_list_cache = result
|
67
|
+
return result
|
68
|
+
|
69
|
+
except Exception as e:
|
70
|
+
last_error = e
|
71
|
+
if warn:
|
72
|
+
self._warn_source_failed(service, source)
|
73
|
+
continue
|
74
|
+
|
75
|
+
# If we get here, all sources failed
|
76
|
+
raise RuntimeError(
|
77
|
+
f"All sources failed to fetch models. Last error: {last_error}"
|
78
|
+
)
|
79
|
+
|
80
|
+
@staticmethod
|
81
|
+
def _fetch_from_local_service(service: "InferenceServiceABC") -> ModelNamesList:
|
82
|
+
"""Attempt to fetch models directly from the service."""
|
83
|
+
return service.available()
|
84
|
+
|
85
|
+
@classmethod
|
86
|
+
def _fetch_from_coop(cls, service: "InferenceServiceABC") -> ModelNamesList:
|
87
|
+
"""Fetch models from Coop."""
|
88
|
+
models_from_coop = cls.models_from_coop()
|
89
|
+
return models_from_coop.get(service._inference_service_, [])
|
90
|
+
|
91
|
+
@staticmethod
|
92
|
+
def _fetch_from_cache(service: "InferenceServiceABC") -> ModelNamesList:
|
93
|
+
"""Fetch models from local cache."""
|
94
|
+
from edsl.inference_services.models_available_cache import models_available
|
95
|
+
|
96
|
+
return models_available.get(service._inference_service_, [])
|
97
|
+
|
98
|
+
def _warn_source_failed(self, service: "InferenceServiceABC", source: ModelSource):
|
99
|
+
"""Display appropriate warning message based on failed source."""
|
100
|
+
messages = {
|
101
|
+
ModelSource.LOCAL: f"""Error getting models for {service._inference_service_}.
|
102
|
+
Check that you have properly stored your Expected Parrot API key and activated remote inference,
|
103
|
+
or stored your own API keys for the language models that you want to use.
|
104
|
+
See https://docs.expectedparrot.com/en/latest/api_keys.html for instructions on storing API keys.
|
105
|
+
Trying next source.""",
|
106
|
+
ModelSource.COOP: f"Error getting models from Coop for {service._inference_service_}. Trying next source.",
|
107
|
+
ModelSource.CACHE: f"Error getting models from cache for {service._inference_service_}.",
|
108
|
+
}
|
109
|
+
warnings.warn(messages[source], UserWarning)
|
110
|
+
|
111
|
+
|
112
|
+
if __name__ == "__main__":
|
113
|
+
# sa = ServiceAvailability()
|
114
|
+
# models_from_coop = sa.models_from_coop()
|
115
|
+
# print(models_from_coop)
|
116
|
+
from edsl.inference_services.OpenAIService import OpenAIService
|
117
|
+
|
118
|
+
openai_models = ServiceAvailability._fetch_from_local_service(OpenAIService())
|
119
|
+
print(openai_models)
|
120
|
+
|
121
|
+
# Example usage:
|
122
|
+
"""
|
123
|
+
# Default order (LOCAL -> COOP -> CACHE)
|
124
|
+
availability = ServiceAvailability()
|
125
|
+
|
126
|
+
# Custom order (COOP -> LOCAL -> CACHE)
|
127
|
+
availability_coop_first = ServiceAvailability([
|
128
|
+
ModelSource.COOP,
|
129
|
+
ModelSource.LOCAL,
|
130
|
+
ModelSource.CACHE
|
131
|
+
])
|
132
|
+
|
133
|
+
# Get available models using custom order
|
134
|
+
models = availability_coop_first.get_service_available(service, warn=True)
|
135
|
+
"""
|
@@ -2,7 +2,7 @@ from typing import Any, List, Optional
|
|
2
2
|
import os
|
3
3
|
import asyncio
|
4
4
|
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
5
|
-
from edsl.language_models import LanguageModel
|
5
|
+
from edsl.language_models.LanguageModel import LanguageModel
|
6
6
|
from edsl.inference_services.rate_limits_cache import rate_limits
|
7
7
|
from edsl.utilities.utilities import fix_partial_correct_response
|
8
8
|
|
@@ -65,13 +65,7 @@ class TestService(InferenceServiceABC):
|
|
65
65
|
await asyncio.sleep(0.1)
|
66
66
|
# return {"message": """{"answer": "Hello, world"}"""}
|
67
67
|
|
68
|
-
|
69
|
-
return {
|
70
|
-
"message": [
|
71
|
-
{"text": self.func(user_prompt, system_prompt, files_list)}
|
72
|
-
],
|
73
|
-
"usage": {"prompt_tokens": 1, "completion_tokens": 1},
|
74
|
-
}
|
68
|
+
# breakpoint()
|
75
69
|
|
76
70
|
if hasattr(self, "throw_exception") and self.throw_exception:
|
77
71
|
if hasattr(self, "exception_probability"):
|
@@ -81,6 +75,15 @@ class TestService(InferenceServiceABC):
|
|
81
75
|
|
82
76
|
if random.random() < p:
|
83
77
|
raise Exception("This is a test error")
|
78
|
+
|
79
|
+
if hasattr(self, "func"):
|
80
|
+
return {
|
81
|
+
"message": [
|
82
|
+
{"text": self.func(user_prompt, system_prompt, files_list)}
|
83
|
+
],
|
84
|
+
"usage": {"prompt_tokens": 1, "completion_tokens": 1},
|
85
|
+
}
|
86
|
+
|
84
87
|
return {
|
85
88
|
"message": [{"text": f"{self._canned_response}"}],
|
86
89
|
"usage": {"prompt_tokens": 1, "completion_tokens": 1},
|
@@ -0,0 +1,62 @@
|
|
1
|
+
from collections import UserDict, defaultdict, UserList
|
2
|
+
from typing import Union
|
3
|
+
from edsl.enums import InferenceServiceLiteral
|
4
|
+
from dataclasses import dataclass
|
5
|
+
|
6
|
+
|
7
|
+
@dataclass
|
8
|
+
class LanguageModelInfo:
|
9
|
+
model_name: str
|
10
|
+
service_name: str
|
11
|
+
|
12
|
+
def __getitem__(self, key: int) -> str:
|
13
|
+
import warnings
|
14
|
+
|
15
|
+
warnings.warn(
|
16
|
+
"Accessing LanguageModelInfo via index is deprecated. "
|
17
|
+
"Please use .model_name, .service_name, or .index attributes instead.",
|
18
|
+
DeprecationWarning,
|
19
|
+
stacklevel=2,
|
20
|
+
)
|
21
|
+
|
22
|
+
if key == 0:
|
23
|
+
return self.model_name
|
24
|
+
elif key == 1:
|
25
|
+
return self.service_name
|
26
|
+
else:
|
27
|
+
raise IndexError("Index out of range")
|
28
|
+
|
29
|
+
|
30
|
+
class ModelNamesList(UserList):
|
31
|
+
pass
|
32
|
+
|
33
|
+
|
34
|
+
class AvailableModels(UserList):
|
35
|
+
def __init__(self, data: list) -> None:
|
36
|
+
super().__init__(data)
|
37
|
+
|
38
|
+
def __contains__(self, model_name: str) -> bool:
|
39
|
+
for model_entry in self:
|
40
|
+
if model_entry.model_name == model_name:
|
41
|
+
return True
|
42
|
+
return False
|
43
|
+
|
44
|
+
|
45
|
+
class ServiceToModelsMapping(UserDict):
|
46
|
+
def __init__(self, data: dict) -> None:
|
47
|
+
super().__init__(data)
|
48
|
+
|
49
|
+
@property
|
50
|
+
def service_names(self) -> list[str]:
|
51
|
+
return list(self.data.keys())
|
52
|
+
|
53
|
+
def _validate_service_names(self):
|
54
|
+
for service in self.service_names:
|
55
|
+
if service not in InferenceServiceLiteral:
|
56
|
+
raise ValueError(f"Invalid service name: {service}")
|
57
|
+
|
58
|
+
def model_to_services(self) -> dict:
|
59
|
+
self._model_to_service = defaultdict(list)
|
60
|
+
for service, models in self.data.items():
|
61
|
+
for model in models:
|
62
|
+
self._model_to_service[model].append(service)
|