edsl 0.1.39.dev2__py3-none-any.whl → 0.1.39.dev4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +28 -0
- edsl/__init__.py +1 -1
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +8 -16
- edsl/agents/Invigilator.py +13 -14
- edsl/agents/InvigilatorBase.py +4 -1
- edsl/agents/PromptConstructor.py +42 -22
- edsl/agents/QuestionInstructionPromptBuilder.py +1 -1
- edsl/auto/AutoStudy.py +18 -5
- edsl/auto/StageBase.py +53 -40
- edsl/auto/StageQuestions.py +2 -1
- edsl/auto/utilities.py +0 -6
- edsl/coop/coop.py +21 -5
- edsl/data/Cache.py +29 -18
- edsl/data/CacheHandler.py +0 -2
- edsl/data/RemoteCacheSync.py +154 -46
- edsl/data/hack.py +10 -0
- edsl/enums.py +7 -0
- edsl/inference_services/AnthropicService.py +38 -16
- edsl/inference_services/AvailableModelFetcher.py +7 -1
- edsl/inference_services/GoogleService.py +5 -1
- edsl/inference_services/InferenceServicesCollection.py +18 -2
- edsl/inference_services/OpenAIService.py +46 -31
- edsl/inference_services/TestService.py +1 -3
- edsl/inference_services/TogetherAIService.py +5 -3
- edsl/inference_services/data_structures.py +74 -2
- edsl/jobs/AnswerQuestionFunctionConstructor.py +148 -113
- edsl/jobs/FetchInvigilator.py +10 -3
- edsl/jobs/InterviewsConstructor.py +6 -4
- edsl/jobs/Jobs.py +299 -233
- edsl/jobs/JobsChecks.py +2 -2
- edsl/jobs/JobsPrompts.py +1 -1
- edsl/jobs/JobsRemoteInferenceHandler.py +160 -136
- edsl/jobs/async_interview_runner.py +138 -0
- edsl/jobs/check_survey_scenario_compatibility.py +85 -0
- edsl/jobs/data_structures.py +120 -0
- edsl/jobs/interviews/Interview.py +80 -42
- edsl/jobs/results_exceptions_handler.py +98 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +87 -357
- edsl/jobs/runners/JobsRunnerStatus.py +131 -164
- edsl/jobs/tasks/TaskHistory.py +24 -3
- edsl/language_models/LanguageModel.py +59 -4
- edsl/language_models/ModelList.py +19 -8
- edsl/language_models/__init__.py +1 -1
- edsl/language_models/model.py +256 -0
- edsl/language_models/repair.py +1 -1
- edsl/questions/QuestionBase.py +35 -26
- edsl/questions/QuestionBasePromptsMixin.py +1 -1
- edsl/questions/QuestionBudget.py +1 -1
- edsl/questions/QuestionCheckBox.py +2 -2
- edsl/questions/QuestionExtract.py +5 -7
- edsl/questions/QuestionFreeText.py +1 -1
- edsl/questions/QuestionList.py +9 -15
- edsl/questions/QuestionMatrix.py +1 -1
- edsl/questions/QuestionMultipleChoice.py +1 -1
- edsl/questions/QuestionNumerical.py +1 -1
- edsl/questions/QuestionRank.py +1 -1
- edsl/questions/SimpleAskMixin.py +1 -1
- edsl/questions/__init__.py +1 -1
- edsl/questions/data_structures.py +20 -0
- edsl/questions/{QuestionBaseGenMixin.py → question_base_gen_mixin.py} +52 -49
- edsl/questions/{ResponseValidatorABC.py → response_validator_abc.py} +6 -18
- edsl/questions/{ResponseValidatorFactory.py → response_validator_factory.py} +7 -1
- edsl/results/DatasetExportMixin.py +60 -119
- edsl/results/Result.py +109 -3
- edsl/results/Results.py +50 -39
- edsl/results/file_exports.py +252 -0
- edsl/scenarios/ScenarioList.py +35 -7
- edsl/surveys/Survey.py +71 -20
- edsl/test_h +1 -0
- edsl/utilities/gcp_bucket/example.py +50 -0
- {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/METADATA +2 -2
- {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/RECORD +85 -76
- edsl/language_models/registry.py +0 -180
- /edsl/agents/{QuestionOptionProcessor.py → question_option_processor.py} +0 -0
- /edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +0 -0
- /edsl/questions/{LoopProcessor.py → loop_processor.py} +0 -0
- /edsl/questions/{RegisterQuestionsMeta.py → register_questions_meta.py} +0 -0
- /edsl/results/{ResultsFetchMixin.py → results_fetch_mixin.py} +0 -0
- /edsl/results/{Selector.py → results_selector.py} +0 -0
- /edsl/results/{ResultsToolsMixin.py → results_tools_mixin.py} +0 -0
- /edsl/scenarios/{DirectoryScanner.py → directory_scanner.py} +0 -0
- /edsl/scenarios/{ScenarioJoin.py → scenario_join.py} +0 -0
- /edsl/scenarios/{ScenarioSelector.py → scenario_selector.py} +0 -0
- {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/LICENSE +0 -0
- {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/WHEEL +0 -0
edsl/data/Cache.py
CHANGED
@@ -6,11 +6,9 @@ from __future__ import annotations
|
|
6
6
|
import json
|
7
7
|
import os
|
8
8
|
import warnings
|
9
|
-
from typing import Optional, Union
|
9
|
+
from typing import Optional, Union, TYPE_CHECKING
|
10
10
|
from edsl.Base import Base
|
11
11
|
|
12
|
-
|
13
|
-
# from edsl.utilities.decorators import remove_edsl_version
|
14
12
|
from edsl.utilities.remove_edsl_version import remove_edsl_version
|
15
13
|
from edsl.exceptions.cache import CacheError
|
16
14
|
|
@@ -83,10 +81,6 @@ class Cache(Base):
|
|
83
81
|
|
84
82
|
self._perform_checks()
|
85
83
|
|
86
|
-
# def rich_print(sefl):
|
87
|
-
# pass
|
88
|
-
# # raise NotImplementedError("This method is not implemented yet.")
|
89
|
-
|
90
84
|
def code(sefl):
|
91
85
|
pass
|
92
86
|
# raise NotImplementedError("This method is not implemented yet.")
|
@@ -293,8 +287,8 @@ class Cache(Base):
|
|
293
287
|
|
294
288
|
CACHE_PATH = CONFIG.get("EDSL_DATABASE_PATH")
|
295
289
|
path = CACHE_PATH.replace("sqlite:///", "")
|
296
|
-
db_path = os.path.join(os.path.dirname(path), "data.db")
|
297
|
-
return cls.from_sqlite_db(
|
290
|
+
# db_path = os.path.join(os.path.dirname(path), "data.db")
|
291
|
+
return cls.from_sqlite_db(path)
|
298
292
|
|
299
293
|
@classmethod
|
300
294
|
def from_jsonl(cls, jsonlfile: str, db_path: Optional[str] = None) -> Cache:
|
@@ -368,12 +362,32 @@ class Cache(Base):
|
|
368
362
|
scenarios.append(s)
|
369
363
|
return ScenarioList(scenarios)
|
370
364
|
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
365
|
+
def __floordiv__(self, other: "Cache") -> "Cache":
|
366
|
+
"""
|
367
|
+
Return a new Cache containing entries that are in self but not in other.
|
368
|
+
Uses // operator as alternative to subtraction.
|
369
|
+
|
370
|
+
:param other: Another Cache object to compare against
|
371
|
+
:return: A new Cache object containing unique entries
|
372
|
+
|
373
|
+
>>> from edsl.data.CacheEntry import CacheEntry
|
374
|
+
>>> ce1 = CacheEntry.example(randomize = True)
|
375
|
+
>>> ce2 = CacheEntry.example(randomize = True)
|
376
|
+
>>> ce2 = CacheEntry.example(randomize = True)
|
377
|
+
>>> c1 = Cache(data={ce1.key: ce1, ce2.key: ce2})
|
378
|
+
>>> c2 = Cache(data={ce1.key: ce1})
|
379
|
+
>>> c3 = c1 // c2
|
380
|
+
>>> len(c3)
|
381
|
+
1
|
382
|
+
>>> c3.data[ce2.key] == ce2
|
383
|
+
True
|
384
|
+
"""
|
385
|
+
if not isinstance(other, Cache):
|
386
|
+
raise CacheError("Can only compare two caches")
|
387
|
+
|
388
|
+
diff_data = {k: v for k, v in self.data.items() if k not in other.data}
|
389
|
+
return Cache(data=diff_data, immediate_write=self.immediate_write)
|
390
|
+
|
377
391
|
@classmethod
|
378
392
|
def from_url(cls, db_path=None) -> Cache:
|
379
393
|
"""
|
@@ -399,9 +413,6 @@ class Cache(Base):
|
|
399
413
|
if self.filename:
|
400
414
|
self.write(self.filename)
|
401
415
|
|
402
|
-
####################
|
403
|
-
# DUNDER / USEFUL
|
404
|
-
####################
|
405
416
|
def __hash__(self):
|
406
417
|
"""Return the hash of the Cache."""
|
407
418
|
from edsl.utilities.utilities import dict_hash
|
edsl/data/CacheHandler.py
CHANGED
edsl/data/RemoteCacheSync.py
CHANGED
@@ -1,71 +1,166 @@
|
|
1
|
-
|
1
|
+
from typing import List, Dict, Any, Optional, TYPE_CHECKING, Callable
|
2
|
+
from dataclasses import dataclass
|
3
|
+
from contextlib import AbstractContextManager
|
4
|
+
from collections import UserList
|
5
|
+
|
6
|
+
if TYPE_CHECKING:
|
7
|
+
from .Cache import Cache
|
8
|
+
from edsl.coop.coop import Coop
|
9
|
+
from .CacheEntry import CacheEntry
|
10
|
+
|
11
|
+
from logging import Logger
|
12
|
+
|
13
|
+
|
14
|
+
class CacheKeyList(UserList):
|
15
|
+
def __init__(self, data: List[str]):
|
16
|
+
super().__init__(data)
|
17
|
+
self.data = data
|
18
|
+
|
19
|
+
def __repr__(self):
|
20
|
+
import reprlib
|
21
|
+
|
22
|
+
keys_repr = reprlib.repr(self.data)
|
23
|
+
return f"CacheKeyList({keys_repr})"
|
24
|
+
|
25
|
+
|
26
|
+
class CacheEntriesList(UserList):
|
27
|
+
def __init__(self, data: List["CacheEntry"]):
|
28
|
+
super().__init__(data)
|
29
|
+
self.data = data
|
30
|
+
|
31
|
+
def __repr__(self):
|
32
|
+
import reprlib
|
33
|
+
|
34
|
+
entries_repr = reprlib.repr(self.data)
|
35
|
+
return f"CacheEntries({entries_repr})"
|
36
|
+
|
37
|
+
def to_cache(self) -> "Cache":
|
38
|
+
from edsl.data.Cache import Cache
|
39
|
+
|
40
|
+
return Cache({entry.key: entry for entry in self.data})
|
41
|
+
|
42
|
+
|
43
|
+
@dataclass
|
44
|
+
class CacheDifference:
|
45
|
+
client_missing_entries: CacheEntriesList
|
46
|
+
server_missing_keys: List[str]
|
47
|
+
|
48
|
+
def __repr__(self):
|
49
|
+
"""Returns a string representation of the CacheDifference object."""
|
50
|
+
import reprlib
|
51
|
+
|
52
|
+
missing_entries_repr = reprlib.repr(self.client_missing_entries)
|
53
|
+
missing_keys_repr = reprlib.repr(self.server_missing_keys)
|
54
|
+
return f"CacheDifference(client_missing_entries={missing_entries_repr}, server_missing_keys={missing_keys_repr})"
|
55
|
+
|
56
|
+
|
57
|
+
class RemoteCacheSync(AbstractContextManager):
|
58
|
+
"""Synchronizes a local cache with a remote cache.
|
59
|
+
|
60
|
+
Handles bidirectional synchronization:
|
61
|
+
- Downloads missing entries from remote to local cache
|
62
|
+
- Uploads new local entries to remote cache
|
63
|
+
"""
|
64
|
+
|
2
65
|
def __init__(
|
3
|
-
self,
|
66
|
+
self,
|
67
|
+
coop: "Coop",
|
68
|
+
cache: "Cache",
|
69
|
+
output_func: Callable,
|
70
|
+
remote_cache: bool = True,
|
71
|
+
remote_cache_description: str = "",
|
4
72
|
):
|
73
|
+
"""
|
74
|
+
Initializes a RemoteCacheSync object.
|
75
|
+
|
76
|
+
:param coop: Coop object for interacting with the remote cache
|
77
|
+
:param cache: Cache object for local cache
|
78
|
+
:param output_func: Function for outputting messages
|
79
|
+
:param remote_cache: Whether to enable remote cache synchronization
|
80
|
+
:param remote_cache_description: Description for remote cache entries
|
81
|
+
|
82
|
+
"""
|
5
83
|
self.coop = coop
|
6
84
|
self.cache = cache
|
7
85
|
self._output = output_func
|
8
|
-
self.
|
9
|
-
self.old_entry_keys = []
|
10
|
-
self.new_cache_entries = []
|
86
|
+
self.remote_cache_enabled = remote_cache
|
11
87
|
self.remote_cache_description = remote_cache_description
|
88
|
+
self.initial_cache_keys = []
|
12
89
|
|
13
|
-
def __enter__(self):
|
14
|
-
if self.
|
90
|
+
def __enter__(self) -> "RemoteCacheSync":
|
91
|
+
if self.remote_cache_enabled:
|
15
92
|
self._sync_from_remote()
|
16
|
-
self.
|
93
|
+
self.initial_cache_keys = list(self.cache.keys())
|
17
94
|
return self
|
18
95
|
|
19
96
|
def __exit__(self, exc_type, exc_value, traceback):
|
20
|
-
if self.
|
97
|
+
if self.remote_cache_enabled:
|
21
98
|
self._sync_to_remote()
|
22
99
|
return False # Propagate exceptions
|
23
100
|
|
24
|
-
def
|
25
|
-
|
26
|
-
|
27
|
-
|
101
|
+
def _get_cache_difference(self) -> CacheDifference:
|
102
|
+
"""Retrieves differences between local and remote caches."""
|
103
|
+
diff = self.coop.remote_cache_get_diff(self.cache.keys())
|
104
|
+
return CacheDifference(
|
105
|
+
client_missing_entries=diff.get("client_missing_cacheentries", []),
|
106
|
+
server_missing_keys=diff.get("server_missing_cacheentry_keys", []),
|
28
107
|
)
|
29
|
-
missing_entry_count = len(client_missing_cacheentries)
|
30
108
|
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
{entry.key: entry for entry in client_missing_cacheentries}
|
38
|
-
)
|
39
|
-
self._output("Local cache updated!")
|
40
|
-
else:
|
109
|
+
def _sync_from_remote(self) -> None:
|
110
|
+
"""Downloads missing entries from remote cache to local cache."""
|
111
|
+
diff: CacheDifference = self._get_cache_difference()
|
112
|
+
missing_count = len(diff.client_missing_entries)
|
113
|
+
|
114
|
+
if missing_count == 0:
|
41
115
|
self._output("No new entries to add to local cache.")
|
116
|
+
return
|
42
117
|
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
"server_missing_cacheentry_keys", []
|
118
|
+
self._output(
|
119
|
+
f"Updating local cache with {missing_count:,} new "
|
120
|
+
f"{'entry' if missing_count == 1 else 'entries'} from remote..."
|
47
121
|
)
|
48
|
-
|
49
|
-
|
50
|
-
for
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
122
|
+
|
123
|
+
self.cache.add_from_dict(
|
124
|
+
{entry.key: entry for entry in diff.client_missing_entries}
|
125
|
+
)
|
126
|
+
self._output("Local cache updated!")
|
127
|
+
|
128
|
+
def _get_entries_to_upload(self, diff: CacheDifference) -> CacheEntriesList:
|
129
|
+
"""Determines which entries need to be uploaded to remote cache."""
|
130
|
+
# Get entries for keys missing from server
|
131
|
+
server_missing_entries = CacheEntriesList(
|
132
|
+
[
|
133
|
+
entry
|
134
|
+
for key in diff.server_missing_keys
|
135
|
+
if (entry := self.cache.data.get(key)) is not None
|
136
|
+
]
|
137
|
+
)
|
138
|
+
|
139
|
+
# Get newly added entries since sync started
|
140
|
+
new_entries = CacheEntriesList(
|
141
|
+
[
|
142
|
+
entry
|
143
|
+
for entry in self.cache.values()
|
144
|
+
if entry.key not in self.initial_cache_keys
|
145
|
+
]
|
146
|
+
)
|
147
|
+
|
148
|
+
return server_missing_entries + new_entries
|
149
|
+
|
150
|
+
def _sync_to_remote(self) -> None:
|
151
|
+
"""Uploads new local entries to remote cache."""
|
152
|
+
diff: CacheDifference = self._get_cache_difference()
|
153
|
+
entries_to_upload: CacheEntriesList = self._get_entries_to_upload(diff)
|
154
|
+
upload_count = len(entries_to_upload)
|
155
|
+
|
156
|
+
if upload_count > 0:
|
63
157
|
self._output(
|
64
|
-
f"Updating remote cache with {
|
65
|
-
f"{'entry' if
|
158
|
+
f"Updating remote cache with {upload_count:,} new "
|
159
|
+
f"{'entry' if upload_count == 1 else 'entries'}..."
|
66
160
|
)
|
161
|
+
|
67
162
|
self.coop.remote_cache_create_many(
|
68
|
-
|
163
|
+
entries_to_upload,
|
69
164
|
visibility="private",
|
70
165
|
description=self.remote_cache_description,
|
71
166
|
)
|
@@ -76,3 +171,16 @@ class RemoteCacheSync:
|
|
76
171
|
self._output(
|
77
172
|
f"There are {len(self.cache.keys()):,} entries in the local cache."
|
78
173
|
)
|
174
|
+
|
175
|
+
|
176
|
+
if __name__ == "__main__":
|
177
|
+
import doctest
|
178
|
+
|
179
|
+
doctest.testmod()
|
180
|
+
|
181
|
+
from edsl.coop.coop import Coop
|
182
|
+
from edsl.data.Cache import Cache
|
183
|
+
from edsl.data.CacheEntry import CacheEntry
|
184
|
+
|
185
|
+
r = RemoteCacheSync(Coop(), Cache(), print)
|
186
|
+
diff = r._get_cache_difference()
|
edsl/data/hack.py
ADDED
edsl/enums.py
CHANGED
@@ -86,6 +86,13 @@ InferenceServiceLiteral = Literal[
|
|
86
86
|
"perplexity",
|
87
87
|
]
|
88
88
|
|
89
|
+
available_models_urls = {
|
90
|
+
"anthropic": "https://docs.anthropic.com/en/docs/about-claude/models",
|
91
|
+
"openai": "https://platform.openai.com/docs/models/gp",
|
92
|
+
"groq": "https://console.groq.com/docs/models",
|
93
|
+
"google": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models",
|
94
|
+
}
|
95
|
+
|
89
96
|
|
90
97
|
service_to_api_keyname = {
|
91
98
|
InferenceServiceType.BEDROCK.value: "TBD",
|
@@ -11,21 +11,27 @@ class AnthropicService(InferenceServiceABC):
|
|
11
11
|
|
12
12
|
_inference_service_ = "anthropic"
|
13
13
|
_env_key_name_ = "ANTHROPIC_API_KEY"
|
14
|
-
key_sequence = ["content", 0, "text"]
|
14
|
+
key_sequence = ["content", 0, "text"]
|
15
15
|
usage_sequence = ["usage"]
|
16
16
|
input_token_name = "input_tokens"
|
17
17
|
output_token_name = "output_tokens"
|
18
18
|
model_exclude_list = []
|
19
19
|
|
20
|
+
@classmethod
|
21
|
+
def get_model_list(cls, api_key: str = None):
|
22
|
+
|
23
|
+
import requests
|
24
|
+
|
25
|
+
if api_key is None:
|
26
|
+
api_key = os.environ.get("ANTHROPIC_API_KEY")
|
27
|
+
headers = {"x-api-key": api_key, "anthropic-version": "2023-06-01"}
|
28
|
+
response = requests.get("https://api.anthropic.com/v1/models", headers=headers)
|
29
|
+
model_names = [m["id"] for m in response.json()["data"]]
|
30
|
+
return model_names
|
31
|
+
|
20
32
|
@classmethod
|
21
33
|
def available(cls):
|
22
|
-
|
23
|
-
return [
|
24
|
-
"claude-3-5-sonnet-20240620",
|
25
|
-
"claude-3-opus-20240229",
|
26
|
-
"claude-3-sonnet-20240229",
|
27
|
-
"claude-3-haiku-20240307",
|
28
|
-
]
|
34
|
+
return cls.get_model_list()
|
29
35
|
|
30
36
|
@classmethod
|
31
37
|
def create_model(
|
@@ -62,20 +68,36 @@ class AnthropicService(InferenceServiceABC):
|
|
62
68
|
system_prompt: str = "",
|
63
69
|
files_list: Optional[List["Files"]] = None,
|
64
70
|
) -> dict[str, Any]:
|
65
|
-
"""Calls the
|
71
|
+
"""Calls the Anthropic API and returns the API response."""
|
66
72
|
|
67
|
-
|
68
|
-
|
73
|
+
messages = [
|
74
|
+
{
|
75
|
+
"role": "user",
|
76
|
+
"content": [{"type": "text", "text": user_prompt}],
|
77
|
+
}
|
78
|
+
]
|
79
|
+
if files_list:
|
80
|
+
for file_entry in files_list:
|
81
|
+
encoded_image = file_entry.base64_string
|
82
|
+
messages[0]["content"].append(
|
83
|
+
{
|
84
|
+
"type": "image",
|
85
|
+
"source": {
|
86
|
+
"type": "base64",
|
87
|
+
"media_type": file_entry.mime_type,
|
88
|
+
"data": encoded_image,
|
89
|
+
},
|
90
|
+
}
|
91
|
+
)
|
92
|
+
# breakpoint()
|
93
|
+
client = AsyncAnthropic(api_key=self.api_token)
|
69
94
|
|
70
95
|
response = await client.messages.create(
|
71
96
|
model=model_name,
|
72
97
|
max_tokens=self.max_tokens,
|
73
98
|
temperature=self.temperature,
|
74
|
-
system=system_prompt,
|
75
|
-
messages=
|
76
|
-
# {"role": "system", "content": system_prompt},
|
77
|
-
{"role": "user", "content": user_prompt},
|
78
|
-
],
|
99
|
+
system=system_prompt, # note that the Anthropic API uses "system" parameter rather than put it in the message
|
100
|
+
messages=messages,
|
79
101
|
)
|
80
102
|
return response.model_dump()
|
81
103
|
|
@@ -133,6 +133,12 @@ class AvailableModelFetcher:
|
|
133
133
|
)
|
134
134
|
service_name = service._inference_service_
|
135
135
|
|
136
|
+
if not service_models:
|
137
|
+
import warnings
|
138
|
+
|
139
|
+
warnings.warn(f"No models found for service {service_name}")
|
140
|
+
return [], service_name
|
141
|
+
|
136
142
|
models_list = AvailableModels(
|
137
143
|
[
|
138
144
|
LanguageModelInfo(
|
@@ -177,7 +183,7 @@ class AvailableModelFetcher:
|
|
177
183
|
)
|
178
184
|
|
179
185
|
except Exception as exc:
|
180
|
-
print(f"Service query failed: {exc}")
|
186
|
+
print(f"Service query failed for service {service_name}: {exc}")
|
181
187
|
continue
|
182
188
|
|
183
189
|
return AvailableModels(all_models)
|
@@ -40,13 +40,17 @@ class GoogleService(InferenceServiceABC):
|
|
40
40
|
model_exclude_list = []
|
41
41
|
|
42
42
|
@classmethod
|
43
|
-
def
|
43
|
+
def get_model_list(cls):
|
44
44
|
model_list = []
|
45
45
|
for m in genai.list_models():
|
46
46
|
if "generateContent" in m.supported_generation_methods:
|
47
47
|
model_list.append(m.name.split("/")[-1])
|
48
48
|
return model_list
|
49
49
|
|
50
|
+
@classmethod
|
51
|
+
def available(cls) -> List[str]:
|
52
|
+
return cls.get_model_list()
|
53
|
+
|
50
54
|
@classmethod
|
51
55
|
def create_model(
|
52
56
|
cls, model_name: str = "gemini-pro", model_class_name=None
|
@@ -71,7 +71,12 @@ class ModelResolver:
|
|
71
71
|
self._models_to_services[model_name] = service
|
72
72
|
return service
|
73
73
|
|
74
|
-
raise InferenceServiceError(
|
74
|
+
raise InferenceServiceError(
|
75
|
+
f"""Model {model_name} not found in any services.
|
76
|
+
If you know the service that has this model, use the service_name parameter directly.
|
77
|
+
E.g., Model("gpt-4o", service_name="openai")
|
78
|
+
"""
|
79
|
+
)
|
75
80
|
|
76
81
|
|
77
82
|
class InferenceServicesCollection:
|
@@ -93,6 +98,9 @@ class InferenceServicesCollection:
|
|
93
98
|
if service_name not in cls.added_models:
|
94
99
|
cls.added_models[service_name].append(model_name)
|
95
100
|
|
101
|
+
def service_names_to_classes(self) -> Dict[str, InferenceServiceABC]:
|
102
|
+
return {service._inference_service_: service for service in self.services}
|
103
|
+
|
96
104
|
def available(
|
97
105
|
self,
|
98
106
|
service: Optional[str] = None,
|
@@ -112,7 +120,15 @@ class InferenceServicesCollection:
|
|
112
120
|
def create_model_factory(
|
113
121
|
self, model_name: str, service_name: Optional[InferenceServiceLiteral] = None
|
114
122
|
) -> "LanguageModel":
|
115
|
-
|
123
|
+
|
124
|
+
if service_name is None: # we try to find the right service
|
125
|
+
service = self.resolver.resolve_model(model_name, service_name)
|
126
|
+
else: # if they passed a service, we'll use that
|
127
|
+
service = self.service_names_to_classes().get(service_name)
|
128
|
+
|
129
|
+
if not service: # but if we can't find it, we'll raise an error
|
130
|
+
raise InferenceServiceError(f"Service {service_name} not found")
|
131
|
+
|
116
132
|
return service.create_model(model_name)
|
117
133
|
|
118
134
|
|
@@ -1,7 +1,8 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
from typing import Any, List, Optional
|
2
|
+
from typing import Any, List, Optional, Dict, NewType
|
3
3
|
import os
|
4
4
|
|
5
|
+
|
5
6
|
import openai
|
6
7
|
|
7
8
|
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
@@ -11,6 +12,8 @@ from edsl.utilities.utilities import fix_partial_correct_response
|
|
11
12
|
|
12
13
|
from edsl.config import CONFIG
|
13
14
|
|
15
|
+
APIToken = NewType("APIToken", str)
|
16
|
+
|
14
17
|
|
15
18
|
class OpenAIService(InferenceServiceABC):
|
16
19
|
"""OpenAI service class."""
|
@@ -22,35 +25,43 @@ class OpenAIService(InferenceServiceABC):
|
|
22
25
|
_sync_client_ = openai.OpenAI
|
23
26
|
_async_client_ = openai.AsyncOpenAI
|
24
27
|
|
25
|
-
|
26
|
-
|
28
|
+
_sync_client_instances: Dict[APIToken, openai.OpenAI] = {}
|
29
|
+
_async_client_instances: Dict[APIToken, openai.AsyncOpenAI] = {}
|
27
30
|
|
28
31
|
key_sequence = ["choices", 0, "message", "content"]
|
29
32
|
usage_sequence = ["usage"]
|
30
33
|
input_token_name = "prompt_tokens"
|
31
34
|
output_token_name = "completion_tokens"
|
32
35
|
|
36
|
+
available_models_url = "https://platform.openai.com/docs/models/gp"
|
37
|
+
|
33
38
|
def __init_subclass__(cls, **kwargs):
|
34
39
|
super().__init_subclass__(**kwargs)
|
35
|
-
# so subclasses have to create their own instances of the clients
|
36
|
-
cls.
|
37
|
-
cls.
|
40
|
+
# so subclasses that use the OpenAI api key have to create their own instances of the clients
|
41
|
+
cls._sync_client_instances = {}
|
42
|
+
cls._async_client_instances = {}
|
38
43
|
|
39
44
|
@classmethod
|
40
|
-
def sync_client(cls):
|
41
|
-
if cls.
|
42
|
-
|
43
|
-
api_key=
|
45
|
+
def sync_client(cls, api_key):
|
46
|
+
if api_key not in cls._sync_client_instances:
|
47
|
+
client = cls._sync_client_(
|
48
|
+
api_key=api_key,
|
49
|
+
base_url=cls._base_url_,
|
44
50
|
)
|
45
|
-
|
51
|
+
cls._sync_client_instances[api_key] = client
|
52
|
+
client = cls._sync_client_instances[api_key]
|
53
|
+
return client
|
46
54
|
|
47
55
|
@classmethod
|
48
|
-
def async_client(cls):
|
49
|
-
if cls.
|
50
|
-
|
51
|
-
api_key=
|
56
|
+
def async_client(cls, api_key):
|
57
|
+
if api_key not in cls._async_client_instances:
|
58
|
+
client = cls._async_client_(
|
59
|
+
api_key=api_key,
|
60
|
+
base_url=cls._base_url_,
|
52
61
|
)
|
53
|
-
|
62
|
+
cls._async_client_instances[api_key] = client
|
63
|
+
client = cls._async_client_instances[api_key]
|
64
|
+
return client
|
54
65
|
|
55
66
|
model_exclude_list = [
|
56
67
|
"whisper-1",
|
@@ -72,20 +83,24 @@ class OpenAIService(InferenceServiceABC):
|
|
72
83
|
_models_list_cache: List[str] = []
|
73
84
|
|
74
85
|
@classmethod
|
75
|
-
def get_model_list(cls):
|
76
|
-
|
86
|
+
def get_model_list(cls, api_key=None):
|
87
|
+
if api_key is None:
|
88
|
+
api_key = os.getenv(cls._env_key_name_)
|
89
|
+
raw_list = cls.sync_client(api_key).models.list()
|
77
90
|
if hasattr(raw_list, "data"):
|
78
91
|
return raw_list.data
|
79
92
|
else:
|
80
93
|
return raw_list
|
81
94
|
|
82
95
|
@classmethod
|
83
|
-
def available(cls) -> List[str]:
|
96
|
+
def available(cls, api_token=None) -> List[str]:
|
97
|
+
if api_token is None:
|
98
|
+
api_token = os.getenv(cls._env_key_name_)
|
84
99
|
if not cls._models_list_cache:
|
85
100
|
try:
|
86
101
|
cls._models_list_cache = [
|
87
102
|
m.id
|
88
|
-
for m in cls.get_model_list()
|
103
|
+
for m in cls.get_model_list(api_key=api_token)
|
89
104
|
if m.id not in cls.model_exclude_list
|
90
105
|
]
|
91
106
|
except Exception as e:
|
@@ -120,10 +135,10 @@ class OpenAIService(InferenceServiceABC):
|
|
120
135
|
}
|
121
136
|
|
122
137
|
def sync_client(self):
|
123
|
-
return cls.sync_client()
|
138
|
+
return cls.sync_client(api_key=self.api_token)
|
124
139
|
|
125
140
|
def async_client(self):
|
126
|
-
return cls.async_client()
|
141
|
+
return cls.async_client(api_key=self.api_token)
|
127
142
|
|
128
143
|
@classmethod
|
129
144
|
def available(cls) -> list[str]:
|
@@ -172,16 +187,16 @@ class OpenAIService(InferenceServiceABC):
|
|
172
187
|
) -> dict[str, Any]:
|
173
188
|
"""Calls the OpenAI API and returns the API response."""
|
174
189
|
if files_list:
|
175
|
-
encoded_image = files_list[0].base64_string
|
176
190
|
content = [{"type": "text", "text": user_prompt}]
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
"
|
182
|
-
|
183
|
-
|
184
|
-
|
191
|
+
for file_entry in files_list:
|
192
|
+
content.append(
|
193
|
+
{
|
194
|
+
"type": "image_url",
|
195
|
+
"image_url": {
|
196
|
+
"url": f"data:{file_entry.mime_type};base64,{file_entry.base64_string}"
|
197
|
+
},
|
198
|
+
}
|
199
|
+
)
|
185
200
|
else:
|
186
201
|
content = user_prompt
|
187
202
|
client = self.async_client()
|