edsl 0.1.39.dev1__py3-none-any.whl → 0.1.39.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +169 -116
- edsl/__init__.py +14 -6
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +358 -146
- edsl/agents/AgentList.py +211 -73
- edsl/agents/Invigilator.py +88 -36
- edsl/agents/InvigilatorBase.py +59 -70
- edsl/agents/PromptConstructor.py +117 -219
- edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
- edsl/agents/QuestionOptionProcessor.py +172 -0
- edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
- edsl/agents/__init__.py +0 -1
- edsl/agents/prompt_helpers.py +3 -3
- edsl/config.py +22 -2
- edsl/conversation/car_buying.py +2 -1
- edsl/coop/CoopFunctionsMixin.py +15 -0
- edsl/coop/ExpectedParrotKeyHandler.py +125 -0
- edsl/coop/PriceFetcher.py +1 -1
- edsl/coop/coop.py +104 -42
- edsl/coop/utils.py +14 -14
- edsl/data/Cache.py +21 -14
- edsl/data/CacheEntry.py +12 -15
- edsl/data/CacheHandler.py +33 -12
- edsl/data/__init__.py +4 -3
- edsl/data_transfer_models.py +2 -1
- edsl/enums.py +20 -0
- edsl/exceptions/__init__.py +50 -50
- edsl/exceptions/agents.py +12 -0
- edsl/exceptions/inference_services.py +5 -0
- edsl/exceptions/questions.py +24 -6
- edsl/exceptions/scenarios.py +7 -0
- edsl/inference_services/AnthropicService.py +0 -3
- edsl/inference_services/AvailableModelCacheHandler.py +184 -0
- edsl/inference_services/AvailableModelFetcher.py +209 -0
- edsl/inference_services/AwsBedrock.py +0 -2
- edsl/inference_services/AzureAI.py +0 -2
- edsl/inference_services/GoogleService.py +2 -11
- edsl/inference_services/InferenceServiceABC.py +18 -85
- edsl/inference_services/InferenceServicesCollection.py +105 -80
- edsl/inference_services/MistralAIService.py +0 -3
- edsl/inference_services/OpenAIService.py +1 -4
- edsl/inference_services/PerplexityService.py +0 -3
- edsl/inference_services/ServiceAvailability.py +135 -0
- edsl/inference_services/TestService.py +11 -8
- edsl/inference_services/data_structures.py +62 -0
- edsl/jobs/AnswerQuestionFunctionConstructor.py +188 -0
- edsl/jobs/Answers.py +1 -14
- edsl/jobs/FetchInvigilator.py +40 -0
- edsl/jobs/InterviewTaskManager.py +98 -0
- edsl/jobs/InterviewsConstructor.py +48 -0
- edsl/jobs/Jobs.py +102 -243
- edsl/jobs/JobsChecks.py +35 -10
- edsl/jobs/JobsComponentConstructor.py +189 -0
- edsl/jobs/JobsPrompts.py +5 -3
- edsl/jobs/JobsRemoteInferenceHandler.py +128 -80
- edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
- edsl/jobs/RequestTokenEstimator.py +30 -0
- edsl/jobs/buckets/BucketCollection.py +44 -3
- edsl/jobs/buckets/TokenBucket.py +53 -21
- edsl/jobs/buckets/TokenBucketAPI.py +211 -0
- edsl/jobs/buckets/TokenBucketClient.py +191 -0
- edsl/jobs/decorators.py +35 -0
- edsl/jobs/interviews/Interview.py +77 -380
- edsl/jobs/jobs_status_enums.py +9 -0
- edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +4 -49
- edsl/jobs/tasks/QuestionTaskCreator.py +21 -19
- edsl/jobs/tasks/TaskHistory.py +14 -15
- edsl/jobs/tasks/task_status_enum.py +0 -2
- edsl/language_models/ComputeCost.py +63 -0
- edsl/language_models/LanguageModel.py +137 -234
- edsl/language_models/ModelList.py +11 -13
- edsl/language_models/PriceManager.py +127 -0
- edsl/language_models/RawResponseHandler.py +106 -0
- edsl/language_models/ServiceDataSources.py +0 -0
- edsl/language_models/__init__.py +0 -1
- edsl/language_models/key_management/KeyLookup.py +63 -0
- edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
- edsl/language_models/key_management/KeyLookupCollection.py +38 -0
- edsl/language_models/key_management/__init__.py +0 -0
- edsl/language_models/key_management/models.py +131 -0
- edsl/language_models/registry.py +49 -59
- edsl/language_models/repair.py +2 -2
- edsl/language_models/utilities.py +5 -4
- edsl/notebooks/Notebook.py +19 -14
- edsl/notebooks/NotebookToLaTeX.py +142 -0
- edsl/prompts/Prompt.py +29 -39
- edsl/questions/AnswerValidatorMixin.py +47 -2
- edsl/questions/ExceptionExplainer.py +77 -0
- edsl/questions/HTMLQuestion.py +103 -0
- edsl/questions/LoopProcessor.py +149 -0
- edsl/questions/QuestionBase.py +37 -192
- edsl/questions/QuestionBaseGenMixin.py +52 -48
- edsl/questions/QuestionBasePromptsMixin.py +7 -3
- edsl/questions/QuestionCheckBox.py +1 -1
- edsl/questions/QuestionExtract.py +1 -1
- edsl/questions/QuestionFreeText.py +1 -2
- edsl/questions/QuestionList.py +3 -5
- edsl/questions/QuestionMatrix.py +265 -0
- edsl/questions/QuestionMultipleChoice.py +66 -22
- edsl/questions/QuestionNumerical.py +1 -3
- edsl/questions/QuestionRank.py +6 -16
- edsl/questions/ResponseValidatorABC.py +37 -11
- edsl/questions/ResponseValidatorFactory.py +28 -0
- edsl/questions/SimpleAskMixin.py +4 -3
- edsl/questions/__init__.py +1 -0
- edsl/questions/derived/QuestionLinearScale.py +6 -3
- edsl/questions/derived/QuestionTopK.py +1 -1
- edsl/questions/descriptors.py +17 -3
- edsl/questions/question_registry.py +1 -1
- edsl/questions/templates/matrix/__init__.py +1 -0
- edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
- edsl/questions/templates/matrix/question_presentation.jinja +20 -0
- edsl/results/CSSParameterizer.py +1 -1
- edsl/results/Dataset.py +170 -7
- edsl/results/DatasetExportMixin.py +224 -302
- edsl/results/DatasetTree.py +28 -8
- edsl/results/MarkdownToDocx.py +122 -0
- edsl/results/MarkdownToPDF.py +111 -0
- edsl/results/Result.py +192 -206
- edsl/results/Results.py +120 -113
- edsl/results/ResultsExportMixin.py +2 -0
- edsl/results/Selector.py +23 -13
- edsl/results/TableDisplay.py +98 -171
- edsl/results/TextEditor.py +50 -0
- edsl/results/__init__.py +1 -1
- edsl/results/smart_objects.py +96 -0
- edsl/results/table_data_class.py +12 -0
- edsl/results/table_renderers.py +118 -0
- edsl/scenarios/ConstructDownloadLink.py +109 -0
- edsl/scenarios/DirectoryScanner.py +96 -0
- edsl/scenarios/DocumentChunker.py +102 -0
- edsl/scenarios/DocxScenario.py +16 -0
- edsl/scenarios/FileStore.py +118 -239
- edsl/scenarios/PdfExtractor.py +40 -0
- edsl/scenarios/Scenario.py +90 -193
- edsl/scenarios/ScenarioHtmlMixin.py +4 -3
- edsl/scenarios/ScenarioJoin.py +10 -6
- edsl/scenarios/ScenarioList.py +383 -240
- edsl/scenarios/ScenarioListExportMixin.py +0 -7
- edsl/scenarios/ScenarioListPdfMixin.py +15 -37
- edsl/scenarios/ScenarioSelector.py +156 -0
- edsl/scenarios/__init__.py +1 -2
- edsl/scenarios/file_methods.py +85 -0
- edsl/scenarios/handlers/__init__.py +13 -0
- edsl/scenarios/handlers/csv.py +38 -0
- edsl/scenarios/handlers/docx.py +76 -0
- edsl/scenarios/handlers/html.py +37 -0
- edsl/scenarios/handlers/json.py +111 -0
- edsl/scenarios/handlers/latex.py +5 -0
- edsl/scenarios/handlers/md.py +51 -0
- edsl/scenarios/handlers/pdf.py +68 -0
- edsl/scenarios/handlers/png.py +39 -0
- edsl/scenarios/handlers/pptx.py +105 -0
- edsl/scenarios/handlers/py.py +294 -0
- edsl/scenarios/handlers/sql.py +313 -0
- edsl/scenarios/handlers/sqlite.py +149 -0
- edsl/scenarios/handlers/txt.py +33 -0
- edsl/study/ObjectEntry.py +1 -1
- edsl/study/SnapShot.py +1 -1
- edsl/study/Study.py +5 -12
- edsl/surveys/ConstructDAG.py +92 -0
- edsl/surveys/EditSurvey.py +221 -0
- edsl/surveys/InstructionHandler.py +100 -0
- edsl/surveys/MemoryManagement.py +72 -0
- edsl/surveys/Rule.py +5 -4
- edsl/surveys/RuleCollection.py +25 -27
- edsl/surveys/RuleManager.py +172 -0
- edsl/surveys/Simulator.py +75 -0
- edsl/surveys/Survey.py +199 -771
- edsl/surveys/SurveyCSS.py +20 -8
- edsl/surveys/{SurveyFlowVisualizationMixin.py → SurveyFlowVisualization.py} +11 -9
- edsl/surveys/SurveyToApp.py +141 -0
- edsl/surveys/__init__.py +4 -2
- edsl/surveys/descriptors.py +6 -2
- edsl/surveys/instructions/ChangeInstruction.py +1 -2
- edsl/surveys/instructions/Instruction.py +4 -13
- edsl/surveys/instructions/InstructionCollection.py +11 -6
- edsl/templates/error_reporting/interview_details.html +1 -1
- edsl/templates/error_reporting/report.html +1 -1
- edsl/tools/plotting.py +1 -1
- edsl/utilities/PrettyList.py +56 -0
- edsl/utilities/is_notebook.py +18 -0
- edsl/utilities/is_valid_variable_name.py +11 -0
- edsl/utilities/remove_edsl_version.py +24 -0
- edsl/utilities/utilities.py +35 -23
- {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev2.dist-info}/METADATA +12 -10
- edsl-0.1.39.dev2.dist-info/RECORD +352 -0
- edsl/language_models/KeyLookup.py +0 -30
- edsl/language_models/unused/ReplicateBase.py +0 -83
- edsl/results/ResultsDBMixin.py +0 -238
- edsl-0.1.39.dev1.dist-info/RECORD +0 -277
- {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev2.dist-info}/LICENSE +0 -0
- {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev2.dist-info}/WHEEL +0 -0
edsl/exceptions/__init__.py
CHANGED
@@ -1,54 +1,54 @@
|
|
1
|
-
from .agents import (
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
)
|
7
|
-
from .configuration import (
|
8
|
-
|
9
|
-
|
10
|
-
)
|
11
|
-
from .data import (
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
)
|
1
|
+
# from .agents import (
|
2
|
+
# # AgentAttributeLookupCallbackError,
|
3
|
+
# AgentCombinationError,
|
4
|
+
# # AgentLacksLLMError,
|
5
|
+
# # AgentRespondedWithBadJSONError,
|
6
|
+
# )
|
7
|
+
# from .configuration import (
|
8
|
+
# InvalidEnvironmentVariableError,
|
9
|
+
# MissingEnvironmentVariableError,
|
10
|
+
# )
|
11
|
+
# from .data import (
|
12
|
+
# DatabaseConnectionError,
|
13
|
+
# DatabaseCRUDError,
|
14
|
+
# DatabaseIntegrityError,
|
15
|
+
# )
|
16
16
|
|
17
|
-
from .scenarios import (
|
18
|
-
|
19
|
-
)
|
17
|
+
# from .scenarios import (
|
18
|
+
# ScenarioError,
|
19
|
+
# )
|
20
20
|
|
21
|
-
from .general import MissingAPIKeyError
|
21
|
+
# from .general import MissingAPIKeyError
|
22
22
|
|
23
|
-
from .jobs import JobsRunError, InterviewErrorPriorTaskCanceled, InterviewTimeoutError
|
23
|
+
# from .jobs import JobsRunError, InterviewErrorPriorTaskCanceled, InterviewTimeoutError
|
24
24
|
|
25
|
-
from .language_models import (
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
)
|
31
|
-
from .questions import (
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
)
|
39
|
-
from .results import (
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
)
|
45
|
-
from .surveys import (
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
)
|
25
|
+
# from .language_models import (
|
26
|
+
# LanguageModelResponseNotJSONError,
|
27
|
+
# LanguageModelMissingAttributeError,
|
28
|
+
# LanguageModelAttributeTypeError,
|
29
|
+
# LanguageModelDoNotAddError,
|
30
|
+
# )
|
31
|
+
# from .questions import (
|
32
|
+
# QuestionAnswerValidationError,
|
33
|
+
# QuestionAttributeMissing,
|
34
|
+
# QuestionCreationValidationError,
|
35
|
+
# QuestionResponseValidationError,
|
36
|
+
# QuestionSerializationError,
|
37
|
+
# QuestionScenarioRenderError,
|
38
|
+
# )
|
39
|
+
# from .results import (
|
40
|
+
# ResultsBadMutationstringError,
|
41
|
+
# ResultsColumnNotFoundError,
|
42
|
+
# ResultsInvalidNameError,
|
43
|
+
# ResultsMutateError,
|
44
|
+
# )
|
45
|
+
# from .surveys import (
|
46
|
+
# SurveyCreationError,
|
47
|
+
# SurveyHasNoRulesError,
|
48
|
+
# SurveyRuleCannotEvaluateError,
|
49
|
+
# SurveyRuleCollectionHasNoRulesAtNodeError,
|
50
|
+
# SurveyRuleReferenceInRuleToUnknownQuestionError,
|
51
|
+
# SurveyRuleRefersToFutureStateError,
|
52
|
+
# SurveyRuleSendsYouBackwardsError,
|
53
|
+
# SurveyRuleSkipLogicSyntaxError,
|
54
|
+
# )
|
edsl/exceptions/agents.py
CHANGED
@@ -1,6 +1,18 @@
|
|
1
1
|
from edsl.exceptions.BaseException import BaseException
|
2
2
|
|
3
3
|
|
4
|
+
# from edsl.utilities.utilities import is_notebook
|
5
|
+
|
6
|
+
# from IPython.core.error import UsageError
|
7
|
+
|
8
|
+
# class AgentListErrorAlternative(UsageError):
|
9
|
+
# def __init__(self, message):
|
10
|
+
# super().__init__(message)
|
11
|
+
|
12
|
+
import sys
|
13
|
+
from edsl.utilities.is_notebook import is_notebook
|
14
|
+
|
15
|
+
|
4
16
|
class AgentListError(BaseException):
|
5
17
|
relevant_doc = "https://docs.expectedparrot.com/en/latest/agents.html#agent-lists"
|
6
18
|
|
edsl/exceptions/questions.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
from typing import Any, SupportsIndex
|
2
|
-
from jinja2 import Template
|
3
2
|
import json
|
3
|
+
from pydantic import ValidationError
|
4
4
|
|
5
5
|
|
6
6
|
class QuestionErrors(Exception):
|
@@ -20,17 +20,35 @@ class QuestionAnswerValidationError(QuestionErrors):
|
|
20
20
|
For example, if the question is a multiple choice question, the answer should be drawn from the list of options provided.
|
21
21
|
"""
|
22
22
|
|
23
|
-
def __init__(
|
23
|
+
def __init__(
|
24
|
+
self,
|
25
|
+
message="Invalid answer.",
|
26
|
+
pydantic_error: ValidationError = None,
|
27
|
+
data: dict = None,
|
28
|
+
model=None,
|
29
|
+
):
|
24
30
|
self.message = message
|
31
|
+
self.pydantic_error = pydantic_error
|
25
32
|
self.data = data
|
26
33
|
self.model = model
|
27
34
|
super().__init__(self.message)
|
28
35
|
|
29
36
|
def __str__(self):
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
37
|
+
if isinstance(self.message, ValidationError):
|
38
|
+
# If it's a ValidationError, just return the core error message
|
39
|
+
return str(self.message)
|
40
|
+
elif hasattr(self.message, "errors"):
|
41
|
+
# Handle the case where it's already been converted to a string but has errors
|
42
|
+
error_list = self.message.errors()
|
43
|
+
if error_list:
|
44
|
+
return str(error_list[0].get("msg", "Unknown error"))
|
45
|
+
return str(self.message)
|
46
|
+
|
47
|
+
# def __str__(self):
|
48
|
+
# return f"""{repr(self)}
|
49
|
+
# Data being validated: {self.data}
|
50
|
+
# Pydnantic Model: {self.model}.
|
51
|
+
# Reported error: {self.message}."""
|
34
52
|
|
35
53
|
def to_html_dict(self):
|
36
54
|
return {
|
edsl/exceptions/scenarios.py
CHANGED
@@ -1,6 +1,13 @@
|
|
1
1
|
import re
|
2
2
|
import textwrap
|
3
3
|
|
4
|
+
# from IPython.core.error import UsageError
|
5
|
+
|
6
|
+
|
7
|
+
class AgentListError(Exception):
|
8
|
+
def __init__(self, message):
|
9
|
+
super().__init__(message)
|
10
|
+
|
4
11
|
|
5
12
|
class ScenarioError(Exception):
|
6
13
|
documentation = "https://docs.expectedparrot.com/en/latest/scenarios.html#module-edsl.scenarios.Scenario"
|
@@ -0,0 +1,184 @@
|
|
1
|
+
from typing import List, Optional, get_args, Union
|
2
|
+
from pathlib import Path
|
3
|
+
import sqlite3
|
4
|
+
from datetime import datetime
|
5
|
+
import tempfile
|
6
|
+
from platformdirs import user_cache_dir
|
7
|
+
from dataclasses import dataclass
|
8
|
+
import os
|
9
|
+
|
10
|
+
from edsl.inference_services.data_structures import LanguageModelInfo, AvailableModels
|
11
|
+
from edsl.enums import InferenceServiceLiteral
|
12
|
+
|
13
|
+
|
14
|
+
class AvailableModelCacheHandler:
|
15
|
+
MAX_ROWS = 1000
|
16
|
+
CACHE_VALIDITY_HOURS = 48
|
17
|
+
|
18
|
+
def __init__(
|
19
|
+
self,
|
20
|
+
cache_validity_hours: int = 48,
|
21
|
+
verbose: bool = False,
|
22
|
+
testing_db_name: str = None,
|
23
|
+
):
|
24
|
+
self.cache_validity_hours = cache_validity_hours
|
25
|
+
self.verbose = verbose
|
26
|
+
|
27
|
+
if testing_db_name:
|
28
|
+
self.cache_dir = Path(tempfile.mkdtemp())
|
29
|
+
self.db_path = self.cache_dir / testing_db_name
|
30
|
+
else:
|
31
|
+
self.cache_dir = Path(user_cache_dir("edsl", "model_availability"))
|
32
|
+
self.db_path = self.cache_dir / "available_models.db"
|
33
|
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
34
|
+
|
35
|
+
if os.path.exists(self.db_path):
|
36
|
+
if self.verbose:
|
37
|
+
print(f"Using existing cache DB: {self.db_path}")
|
38
|
+
else:
|
39
|
+
self._initialize_db()
|
40
|
+
|
41
|
+
@property
|
42
|
+
def path_to_db(self):
|
43
|
+
return self.db_path
|
44
|
+
|
45
|
+
def _initialize_db(self):
|
46
|
+
"""Initialize the SQLite database with the required schema."""
|
47
|
+
with sqlite3.connect(self.db_path) as conn:
|
48
|
+
cursor = conn.cursor()
|
49
|
+
# Drop the old table if it exists (for migration)
|
50
|
+
cursor.execute("DROP TABLE IF EXISTS model_cache")
|
51
|
+
cursor.execute(
|
52
|
+
"""
|
53
|
+
CREATE TABLE IF NOT EXISTS model_cache (
|
54
|
+
timestamp DATETIME NOT NULL,
|
55
|
+
model_name TEXT NOT NULL,
|
56
|
+
service_name TEXT NOT NULL,
|
57
|
+
UNIQUE(model_name, service_name)
|
58
|
+
)
|
59
|
+
"""
|
60
|
+
)
|
61
|
+
conn.commit()
|
62
|
+
|
63
|
+
def _prune_old_entries(self, conn: sqlite3.Connection):
|
64
|
+
"""Delete oldest entries when MAX_ROWS is exceeded."""
|
65
|
+
cursor = conn.cursor()
|
66
|
+
cursor.execute("SELECT COUNT(*) FROM model_cache")
|
67
|
+
count = cursor.fetchone()[0]
|
68
|
+
|
69
|
+
if count > self.MAX_ROWS:
|
70
|
+
cursor.execute(
|
71
|
+
"""
|
72
|
+
DELETE FROM model_cache
|
73
|
+
WHERE rowid IN (
|
74
|
+
SELECT rowid
|
75
|
+
FROM model_cache
|
76
|
+
ORDER BY timestamp ASC
|
77
|
+
LIMIT ?
|
78
|
+
)
|
79
|
+
""",
|
80
|
+
(count - self.MAX_ROWS,),
|
81
|
+
)
|
82
|
+
conn.commit()
|
83
|
+
|
84
|
+
@classmethod
|
85
|
+
def example_models(cls) -> List[LanguageModelInfo]:
|
86
|
+
return [
|
87
|
+
LanguageModelInfo(
|
88
|
+
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", "deep_infra"
|
89
|
+
),
|
90
|
+
LanguageModelInfo("openai/gpt-4", "openai"),
|
91
|
+
]
|
92
|
+
|
93
|
+
def add_models_to_cache(self, models_data: List[LanguageModelInfo]):
|
94
|
+
"""Add new models to the cache, updating timestamps for existing entries."""
|
95
|
+
current_time = datetime.now()
|
96
|
+
|
97
|
+
with sqlite3.connect(self.db_path) as conn:
|
98
|
+
cursor = conn.cursor()
|
99
|
+
for model in models_data:
|
100
|
+
cursor.execute(
|
101
|
+
"""
|
102
|
+
INSERT INTO model_cache (timestamp, model_name, service_name)
|
103
|
+
VALUES (?, ?, ?)
|
104
|
+
ON CONFLICT(model_name, service_name)
|
105
|
+
DO UPDATE SET timestamp = excluded.timestamp
|
106
|
+
""",
|
107
|
+
(current_time, model.model_name, model.service_name),
|
108
|
+
)
|
109
|
+
|
110
|
+
# self._prune_old_entries(conn)
|
111
|
+
conn.commit()
|
112
|
+
|
113
|
+
def reset_cache(self):
|
114
|
+
"""Clear all entries from the cache."""
|
115
|
+
with sqlite3.connect(self.db_path) as conn:
|
116
|
+
cursor = conn.cursor()
|
117
|
+
cursor.execute("DELETE FROM model_cache")
|
118
|
+
conn.commit()
|
119
|
+
|
120
|
+
@property
|
121
|
+
def num_cache_entries(self):
|
122
|
+
"""Return the number of entries in the cache."""
|
123
|
+
with sqlite3.connect(self.db_path) as conn:
|
124
|
+
cursor = conn.cursor()
|
125
|
+
cursor.execute("SELECT COUNT(*) FROM model_cache")
|
126
|
+
count = cursor.fetchone()[0]
|
127
|
+
return count
|
128
|
+
|
129
|
+
def models(
|
130
|
+
self,
|
131
|
+
service: Optional[InferenceServiceLiteral],
|
132
|
+
) -> Union[None, AvailableModels]:
|
133
|
+
"""Return the available models within the cache validity period."""
|
134
|
+
# if service is not None:
|
135
|
+
# assert service in get_args(InferenceServiceLiteral)
|
136
|
+
|
137
|
+
with sqlite3.connect(self.db_path) as conn:
|
138
|
+
cursor = conn.cursor()
|
139
|
+
valid_time = datetime.now().timestamp() - (self.cache_validity_hours * 3600)
|
140
|
+
|
141
|
+
if self.verbose:
|
142
|
+
print(f"Fetching all with timestamp greater than {valid_time}")
|
143
|
+
|
144
|
+
cursor.execute(
|
145
|
+
"""
|
146
|
+
SELECT DISTINCT model_name, service_name
|
147
|
+
FROM model_cache
|
148
|
+
WHERE timestamp > ?
|
149
|
+
ORDER BY timestamp DESC
|
150
|
+
""",
|
151
|
+
(valid_time,),
|
152
|
+
)
|
153
|
+
|
154
|
+
results = cursor.fetchall()
|
155
|
+
if not results:
|
156
|
+
if self.verbose:
|
157
|
+
print("No results found in cache DB.")
|
158
|
+
return None
|
159
|
+
|
160
|
+
matching_models = [
|
161
|
+
LanguageModelInfo(model_name=row[0], service_name=row[1])
|
162
|
+
for row in results
|
163
|
+
]
|
164
|
+
|
165
|
+
if self.verbose:
|
166
|
+
print(f"Found {len(matching_models)} models in cache DB.")
|
167
|
+
if service:
|
168
|
+
matching_models = [
|
169
|
+
model for model in matching_models if model.service_name == service
|
170
|
+
]
|
171
|
+
|
172
|
+
return AvailableModels(matching_models)
|
173
|
+
|
174
|
+
|
175
|
+
if __name__ == "__main__":
|
176
|
+
import doctest
|
177
|
+
|
178
|
+
doctest.testmod()
|
179
|
+
# cache_handler = AvailableModelCacheHandler(verbose=True)
|
180
|
+
# models_data = cache_handler.example_models()
|
181
|
+
# cache_handler.add_models_to_cache(models_data)
|
182
|
+
# print(cache_handler.models())
|
183
|
+
# cache_handler.clear_cache()
|
184
|
+
# print(cache_handler.models())
|
@@ -0,0 +1,209 @@
|
|
1
|
+
from typing import Any, List, Tuple, Optional, Dict, TYPE_CHECKING, Union, Generator
|
2
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
3
|
+
from collections import UserList
|
4
|
+
|
5
|
+
from edsl.inference_services.ServiceAvailability import ServiceAvailability
|
6
|
+
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
7
|
+
from edsl.inference_services.data_structures import ModelNamesList
|
8
|
+
from edsl.enums import InferenceServiceLiteral
|
9
|
+
|
10
|
+
from edsl.inference_services.data_structures import LanguageModelInfo
|
11
|
+
from edsl.inference_services.AvailableModelCacheHandler import (
|
12
|
+
AvailableModelCacheHandler,
|
13
|
+
)
|
14
|
+
|
15
|
+
|
16
|
+
from edsl.inference_services.data_structures import AvailableModels
|
17
|
+
|
18
|
+
|
19
|
+
class AvailableModelFetcher:
|
20
|
+
"""Fetches available models from the various services with JSON caching."""
|
21
|
+
|
22
|
+
service_availability = ServiceAvailability()
|
23
|
+
CACHE_VALIDITY_HOURS = 48 # Cache validity period in hours
|
24
|
+
|
25
|
+
def __init__(
|
26
|
+
self,
|
27
|
+
services: List["InferenceServiceABC"],
|
28
|
+
added_models: Dict[str, List[str]],
|
29
|
+
verbose: bool = False,
|
30
|
+
use_cache: bool = True,
|
31
|
+
):
|
32
|
+
self.services = services
|
33
|
+
self.added_models = added_models
|
34
|
+
self._service_map = {
|
35
|
+
service._inference_service_: service for service in services
|
36
|
+
}
|
37
|
+
self.verbose = verbose
|
38
|
+
if use_cache:
|
39
|
+
self.cache_handler = AvailableModelCacheHandler()
|
40
|
+
else:
|
41
|
+
self.cache_handler = None
|
42
|
+
|
43
|
+
@property
|
44
|
+
def num_cache_entries(self):
|
45
|
+
return self.cache_handler.num_cache_entries
|
46
|
+
|
47
|
+
@property
|
48
|
+
def path_to_db(self):
|
49
|
+
return self.cache_handler.path_to_db
|
50
|
+
|
51
|
+
def reset_cache(self):
|
52
|
+
if self.cache_handler:
|
53
|
+
self.cache_handler.reset_cache()
|
54
|
+
|
55
|
+
def available(
|
56
|
+
self,
|
57
|
+
service: Optional[InferenceServiceABC] = None,
|
58
|
+
force_refresh: bool = False,
|
59
|
+
) -> List[LanguageModelInfo]:
|
60
|
+
"""
|
61
|
+
Get available models from all services, using cached data when available.
|
62
|
+
|
63
|
+
:param service: Optional[InferenceServiceABC] - If specified, only fetch models for this service.
|
64
|
+
|
65
|
+
>>> from edsl.inference_services.OpenAIService import OpenAIService
|
66
|
+
>>> af = AvailableModelFetcher([OpenAIService()], {})
|
67
|
+
>>> af.available(service="openai")
|
68
|
+
[LanguageModelInfo(model_name='...', service_name='openai'), ...]
|
69
|
+
|
70
|
+
Returns a list of [model, service_name, index] entries.
|
71
|
+
"""
|
72
|
+
|
73
|
+
if service: # they passed a specific service
|
74
|
+
matching_models, _ = self.get_available_models_by_service(
|
75
|
+
service=service, force_refresh=force_refresh
|
76
|
+
)
|
77
|
+
return matching_models
|
78
|
+
|
79
|
+
# Nope, we need to fetch them all
|
80
|
+
all_models = self._get_all_models()
|
81
|
+
|
82
|
+
# if self.cache_handler:
|
83
|
+
# self.cache_handler.add_models_to_cache(all_models)
|
84
|
+
|
85
|
+
return all_models
|
86
|
+
|
87
|
+
def get_available_models_by_service(
|
88
|
+
self,
|
89
|
+
service: Union["InferenceServiceABC", InferenceServiceLiteral],
|
90
|
+
force_refresh: bool = False,
|
91
|
+
) -> Tuple[AvailableModels, InferenceServiceLiteral]:
|
92
|
+
"""Get models for a single service.
|
93
|
+
|
94
|
+
:param service: InferenceServiceABC - e.g., OpenAIService or "openai"
|
95
|
+
:return: Tuple[List[LanguageModelInfo], InferenceServiceLiteral]
|
96
|
+
"""
|
97
|
+
if isinstance(service, str):
|
98
|
+
service = self._fetch_service_by_service_name(service)
|
99
|
+
|
100
|
+
if not force_refresh:
|
101
|
+
models_from_cache = self.cache_handler.models(
|
102
|
+
service=service._inference_service_
|
103
|
+
)
|
104
|
+
if self.verbose:
|
105
|
+
print(
|
106
|
+
"Searching cache for models with service name:",
|
107
|
+
service._inference_service_,
|
108
|
+
)
|
109
|
+
print("Got models from cache:", models_from_cache)
|
110
|
+
else:
|
111
|
+
models_from_cache = None
|
112
|
+
|
113
|
+
if models_from_cache:
|
114
|
+
# print(f"Models from cache for {service}: {models_from_cache}")
|
115
|
+
# print(hasattr(models_from_cache[0], "service_name"))
|
116
|
+
return models_from_cache, service._inference_service_
|
117
|
+
else:
|
118
|
+
return self.get_available_models_by_service_fresh(service)
|
119
|
+
|
120
|
+
def get_available_models_by_service_fresh(
|
121
|
+
self, service: Union["InferenceServiceABC", InferenceServiceLiteral]
|
122
|
+
) -> Tuple[AvailableModels, InferenceServiceLiteral]:
|
123
|
+
"""Get models for a single service. This method always fetches fresh data.
|
124
|
+
|
125
|
+
:param service: InferenceServiceABC - e.g., OpenAIService or "openai"
|
126
|
+
:return: Tuple[List[LanguageModelInfo], InferenceServiceLiteral]
|
127
|
+
"""
|
128
|
+
if isinstance(service, str):
|
129
|
+
service = self._fetch_service_by_service_name(service)
|
130
|
+
|
131
|
+
service_models: ModelNamesList = (
|
132
|
+
self.service_availability.get_service_available(service, warn=False)
|
133
|
+
)
|
134
|
+
service_name = service._inference_service_
|
135
|
+
|
136
|
+
models_list = AvailableModels(
|
137
|
+
[
|
138
|
+
LanguageModelInfo(
|
139
|
+
model_name=model_name,
|
140
|
+
service_name=service_name,
|
141
|
+
)
|
142
|
+
for model_name in service_models
|
143
|
+
]
|
144
|
+
)
|
145
|
+
self.cache_handler.add_models_to_cache(models_list) # update the cache
|
146
|
+
return models_list, service_name
|
147
|
+
|
148
|
+
def _fetch_service_by_service_name(
|
149
|
+
self, service_name: InferenceServiceLiteral
|
150
|
+
) -> "InferenceServiceABC":
|
151
|
+
"""The service name is the _inference_service_ attribute of the service."""
|
152
|
+
if service_name in self._service_map:
|
153
|
+
return self._service_map[service_name]
|
154
|
+
raise ValueError(f"Service {service_name} not found")
|
155
|
+
|
156
|
+
def _get_all_models(self, force_refresh=False) -> List[LanguageModelInfo]:
|
157
|
+
all_models = []
|
158
|
+
with ThreadPoolExecutor(max_workers=min(len(self.services), 10)) as executor:
|
159
|
+
future_to_service = {
|
160
|
+
executor.submit(
|
161
|
+
self.get_available_models_by_service, service, force_refresh
|
162
|
+
): service
|
163
|
+
for service in self.services
|
164
|
+
}
|
165
|
+
|
166
|
+
for future in as_completed(future_to_service):
|
167
|
+
try:
|
168
|
+
models, service_name = future.result()
|
169
|
+
all_models.extend(models)
|
170
|
+
|
171
|
+
# Add any additional models for this service
|
172
|
+
for model in self.added_models.get(service_name, []):
|
173
|
+
all_models.append(
|
174
|
+
LanguageModelInfo(
|
175
|
+
model_name=model, service_name=service_name
|
176
|
+
)
|
177
|
+
)
|
178
|
+
|
179
|
+
except Exception as exc:
|
180
|
+
print(f"Service query failed: {exc}")
|
181
|
+
continue
|
182
|
+
|
183
|
+
return AvailableModels(all_models)
|
184
|
+
|
185
|
+
|
186
|
+
def main():
|
187
|
+
from edsl.inference_services.OpenAIService import OpenAIService
|
188
|
+
|
189
|
+
af = AvailableModelFetcher([OpenAIService()], {}, verbose=True)
|
190
|
+
# print(af.available(service="openai"))
|
191
|
+
all_models = AvailableModelFetcher([OpenAIService()], {})._get_all_models(
|
192
|
+
force_refresh=True
|
193
|
+
)
|
194
|
+
print(all_models)
|
195
|
+
|
196
|
+
|
197
|
+
if __name__ == "__main__":
|
198
|
+
import doctest
|
199
|
+
|
200
|
+
doctest.testmod(optionflags=doctest.ELLIPSIS)
|
201
|
+
# main()
|
202
|
+
|
203
|
+
# from edsl.inference_services.OpenAIService import OpenAIService
|
204
|
+
|
205
|
+
# af = AvailableModelFetcher([OpenAIService()], {}, verbose=True)
|
206
|
+
# # print(af.available(service="openai"))
|
207
|
+
|
208
|
+
# all_models = AvailableModelFetcher([OpenAIService()], {})._get_all_models()
|
209
|
+
# print(all_models)
|
@@ -1,11 +1,11 @@
|
|
1
|
-
import os
|
1
|
+
# import os
|
2
2
|
from typing import Any, Dict, List, Optional
|
3
3
|
import google
|
4
4
|
import google.generativeai as genai
|
5
5
|
from google.generativeai.types import GenerationConfig
|
6
6
|
from google.api_core.exceptions import InvalidArgument
|
7
7
|
|
8
|
-
from edsl.exceptions import MissingAPIKeyError
|
8
|
+
# from edsl.exceptions.general import MissingAPIKeyError
|
9
9
|
from edsl.language_models.LanguageModel import LanguageModel
|
10
10
|
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
11
11
|
from edsl.coop import Coop
|
@@ -39,10 +39,6 @@ class GoogleService(InferenceServiceABC):
|
|
39
39
|
|
40
40
|
model_exclude_list = []
|
41
41
|
|
42
|
-
# @classmethod
|
43
|
-
# def available(cls) -> List[str]:
|
44
|
-
# return ["gemini-pro", "gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.0-pro"]
|
45
|
-
|
46
42
|
@classmethod
|
47
43
|
def available(cls) -> List[str]:
|
48
44
|
model_list = []
|
@@ -66,9 +62,6 @@ class GoogleService(InferenceServiceABC):
|
|
66
62
|
output_token_name = cls.output_token_name
|
67
63
|
_inference_service_ = cls._inference_service_
|
68
64
|
|
69
|
-
_tpm = cls.get_tpm(cls)
|
70
|
-
_rpm = cls.get_rpm(cls)
|
71
|
-
|
72
65
|
_parameters_ = {
|
73
66
|
"temperature": 0.5,
|
74
67
|
"topP": 1,
|
@@ -77,7 +70,6 @@ class GoogleService(InferenceServiceABC):
|
|
77
70
|
"stopSequences": [],
|
78
71
|
}
|
79
72
|
|
80
|
-
api_token = None
|
81
73
|
model = None
|
82
74
|
|
83
75
|
def __init__(self, *args, **kwargs):
|
@@ -102,7 +94,6 @@ class GoogleService(InferenceServiceABC):
|
|
102
94
|
|
103
95
|
if files_list is None:
|
104
96
|
files_list = []
|
105
|
-
|
106
97
|
genai.configure(api_key=self.api_token)
|
107
98
|
if (
|
108
99
|
system_prompt is not None
|