edsl 0.1.39__py3-none-any.whl → 0.1.39.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +0 -28
- edsl/__init__.py +1 -1
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +17 -9
- edsl/agents/Invigilator.py +14 -13
- edsl/agents/InvigilatorBase.py +1 -4
- edsl/agents/PromptConstructor.py +22 -42
- edsl/agents/QuestionInstructionPromptBuilder.py +1 -1
- edsl/auto/AutoStudy.py +5 -18
- edsl/auto/StageBase.py +40 -53
- edsl/auto/StageQuestions.py +1 -2
- edsl/auto/utilities.py +6 -0
- edsl/coop/coop.py +5 -21
- edsl/data/Cache.py +18 -29
- edsl/data/CacheHandler.py +2 -0
- edsl/data/RemoteCacheSync.py +46 -154
- edsl/enums.py +0 -7
- edsl/inference_services/AnthropicService.py +16 -38
- edsl/inference_services/AvailableModelFetcher.py +1 -7
- edsl/inference_services/GoogleService.py +1 -5
- edsl/inference_services/InferenceServicesCollection.py +2 -18
- edsl/inference_services/OpenAIService.py +31 -46
- edsl/inference_services/TestService.py +3 -1
- edsl/inference_services/TogetherAIService.py +3 -5
- edsl/inference_services/data_structures.py +2 -74
- edsl/jobs/AnswerQuestionFunctionConstructor.py +113 -148
- edsl/jobs/FetchInvigilator.py +3 -10
- edsl/jobs/InterviewsConstructor.py +4 -6
- edsl/jobs/Jobs.py +233 -299
- edsl/jobs/JobsChecks.py +2 -2
- edsl/jobs/JobsPrompts.py +1 -1
- edsl/jobs/JobsRemoteInferenceHandler.py +136 -160
- edsl/jobs/interviews/Interview.py +42 -80
- edsl/jobs/runners/JobsRunnerAsyncio.py +358 -88
- edsl/jobs/runners/JobsRunnerStatus.py +165 -133
- edsl/jobs/tasks/TaskHistory.py +3 -24
- edsl/language_models/LanguageModel.py +4 -59
- edsl/language_models/ModelList.py +8 -19
- edsl/language_models/__init__.py +1 -1
- edsl/language_models/registry.py +180 -0
- edsl/language_models/repair.py +1 -1
- edsl/questions/QuestionBase.py +26 -35
- edsl/questions/{question_base_gen_mixin.py → QuestionBaseGenMixin.py} +49 -52
- edsl/questions/QuestionBasePromptsMixin.py +1 -1
- edsl/questions/QuestionBudget.py +1 -1
- edsl/questions/QuestionCheckBox.py +2 -2
- edsl/questions/QuestionExtract.py +7 -5
- edsl/questions/QuestionFreeText.py +1 -1
- edsl/questions/QuestionList.py +15 -9
- edsl/questions/QuestionMatrix.py +1 -1
- edsl/questions/QuestionMultipleChoice.py +1 -1
- edsl/questions/QuestionNumerical.py +1 -1
- edsl/questions/QuestionRank.py +1 -1
- edsl/questions/{response_validator_abc.py → ResponseValidatorABC.py} +18 -6
- edsl/questions/{response_validator_factory.py → ResponseValidatorFactory.py} +1 -7
- edsl/questions/SimpleAskMixin.py +1 -1
- edsl/questions/__init__.py +1 -1
- edsl/results/DatasetExportMixin.py +119 -60
- edsl/results/Result.py +3 -109
- edsl/results/Results.py +39 -50
- edsl/scenarios/FileStore.py +0 -32
- edsl/scenarios/ScenarioList.py +7 -35
- edsl/scenarios/handlers/csv.py +0 -11
- edsl/surveys/Survey.py +20 -71
- {edsl-0.1.39.dist-info → edsl-0.1.39.dev2.dist-info}/METADATA +1 -1
- {edsl-0.1.39.dist-info → edsl-0.1.39.dev2.dist-info}/RECORD +78 -84
- {edsl-0.1.39.dist-info → edsl-0.1.39.dev2.dist-info}/WHEEL +1 -1
- edsl/jobs/async_interview_runner.py +0 -138
- edsl/jobs/check_survey_scenario_compatibility.py +0 -85
- edsl/jobs/data_structures.py +0 -120
- edsl/jobs/results_exceptions_handler.py +0 -98
- edsl/language_models/model.py +0 -256
- edsl/questions/data_structures.py +0 -20
- edsl/results/file_exports.py +0 -252
- /edsl/agents/{question_option_processor.py → QuestionOptionProcessor.py} +0 -0
- /edsl/questions/{answer_validator_mixin.py → AnswerValidatorMixin.py} +0 -0
- /edsl/questions/{loop_processor.py → LoopProcessor.py} +0 -0
- /edsl/questions/{register_questions_meta.py → RegisterQuestionsMeta.py} +0 -0
- /edsl/results/{results_fetch_mixin.py → ResultsFetchMixin.py} +0 -0
- /edsl/results/{results_tools_mixin.py → ResultsToolsMixin.py} +0 -0
- /edsl/results/{results_selector.py → Selector.py} +0 -0
- /edsl/scenarios/{directory_scanner.py → DirectoryScanner.py} +0 -0
- /edsl/scenarios/{scenario_join.py → ScenarioJoin.py} +0 -0
- /edsl/scenarios/{scenario_selector.py → ScenarioSelector.py} +0 -0
- {edsl-0.1.39.dist-info → edsl-0.1.39.dev2.dist-info}/LICENSE +0 -0
edsl/data/Cache.py
CHANGED
@@ -6,9 +6,11 @@ from __future__ import annotations
|
|
6
6
|
import json
|
7
7
|
import os
|
8
8
|
import warnings
|
9
|
-
from typing import Optional, Union
|
9
|
+
from typing import Optional, Union
|
10
10
|
from edsl.Base import Base
|
11
11
|
|
12
|
+
|
13
|
+
# from edsl.utilities.decorators import remove_edsl_version
|
12
14
|
from edsl.utilities.remove_edsl_version import remove_edsl_version
|
13
15
|
from edsl.exceptions.cache import CacheError
|
14
16
|
|
@@ -81,6 +83,10 @@ class Cache(Base):
|
|
81
83
|
|
82
84
|
self._perform_checks()
|
83
85
|
|
86
|
+
# def rich_print(sefl):
|
87
|
+
# pass
|
88
|
+
# # raise NotImplementedError("This method is not implemented yet.")
|
89
|
+
|
84
90
|
def code(sefl):
|
85
91
|
pass
|
86
92
|
# raise NotImplementedError("This method is not implemented yet.")
|
@@ -287,8 +293,8 @@ class Cache(Base):
|
|
287
293
|
|
288
294
|
CACHE_PATH = CONFIG.get("EDSL_DATABASE_PATH")
|
289
295
|
path = CACHE_PATH.replace("sqlite:///", "")
|
290
|
-
|
291
|
-
return cls.from_sqlite_db(
|
296
|
+
db_path = os.path.join(os.path.dirname(path), "data.db")
|
297
|
+
return cls.from_sqlite_db(db_path=db_path)
|
292
298
|
|
293
299
|
@classmethod
|
294
300
|
def from_jsonl(cls, jsonlfile: str, db_path: Optional[str] = None) -> Cache:
|
@@ -362,32 +368,12 @@ class Cache(Base):
|
|
362
368
|
scenarios.append(s)
|
363
369
|
return ScenarioList(scenarios)
|
364
370
|
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
:return: A new Cache object containing unique entries
|
372
|
-
|
373
|
-
>>> from edsl.data.CacheEntry import CacheEntry
|
374
|
-
>>> ce1 = CacheEntry.example(randomize = True)
|
375
|
-
>>> ce2 = CacheEntry.example(randomize = True)
|
376
|
-
>>> ce2 = CacheEntry.example(randomize = True)
|
377
|
-
>>> c1 = Cache(data={ce1.key: ce1, ce2.key: ce2})
|
378
|
-
>>> c2 = Cache(data={ce1.key: ce1})
|
379
|
-
>>> c3 = c1 // c2
|
380
|
-
>>> len(c3)
|
381
|
-
1
|
382
|
-
>>> c3.data[ce2.key] == ce2
|
383
|
-
True
|
384
|
-
"""
|
385
|
-
if not isinstance(other, Cache):
|
386
|
-
raise CacheError("Can only compare two caches")
|
387
|
-
|
388
|
-
diff_data = {k: v for k, v in self.data.items() if k not in other.data}
|
389
|
-
return Cache(data=diff_data, immediate_write=self.immediate_write)
|
390
|
-
|
371
|
+
####################
|
372
|
+
# REMOTE
|
373
|
+
####################
|
374
|
+
# TODO: Make this work
|
375
|
+
# - Need to decide whether the cache belongs to a user and what can be shared
|
376
|
+
# - I.e., some cache entries? all or nothing?
|
391
377
|
@classmethod
|
392
378
|
def from_url(cls, db_path=None) -> Cache:
|
393
379
|
"""
|
@@ -413,6 +399,9 @@ class Cache(Base):
|
|
413
399
|
if self.filename:
|
414
400
|
self.write(self.filename)
|
415
401
|
|
402
|
+
####################
|
403
|
+
# DUNDER / USEFUL
|
404
|
+
####################
|
416
405
|
def __hash__(self):
|
417
406
|
"""Return the hash of the Cache."""
|
418
407
|
from edsl.utilities.utilities import dict_hash
|
edsl/data/CacheHandler.py
CHANGED
edsl/data/RemoteCacheSync.py
CHANGED
@@ -1,166 +1,71 @@
|
|
1
|
-
|
2
|
-
from dataclasses import dataclass
|
3
|
-
from contextlib import AbstractContextManager
|
4
|
-
from collections import UserList
|
5
|
-
|
6
|
-
if TYPE_CHECKING:
|
7
|
-
from .Cache import Cache
|
8
|
-
from edsl.coop.coop import Coop
|
9
|
-
from .CacheEntry import CacheEntry
|
10
|
-
|
11
|
-
from logging import Logger
|
12
|
-
|
13
|
-
|
14
|
-
class CacheKeyList(UserList):
|
15
|
-
def __init__(self, data: List[str]):
|
16
|
-
super().__init__(data)
|
17
|
-
self.data = data
|
18
|
-
|
19
|
-
def __repr__(self):
|
20
|
-
import reprlib
|
21
|
-
|
22
|
-
keys_repr = reprlib.repr(self.data)
|
23
|
-
return f"CacheKeyList({keys_repr})"
|
24
|
-
|
25
|
-
|
26
|
-
class CacheEntriesList(UserList):
|
27
|
-
def __init__(self, data: List["CacheEntry"]):
|
28
|
-
super().__init__(data)
|
29
|
-
self.data = data
|
30
|
-
|
31
|
-
def __repr__(self):
|
32
|
-
import reprlib
|
33
|
-
|
34
|
-
entries_repr = reprlib.repr(self.data)
|
35
|
-
return f"CacheEntries({entries_repr})"
|
36
|
-
|
37
|
-
def to_cache(self) -> "Cache":
|
38
|
-
from edsl.data.Cache import Cache
|
39
|
-
|
40
|
-
return Cache({entry.key: entry for entry in self.data})
|
41
|
-
|
42
|
-
|
43
|
-
@dataclass
|
44
|
-
class CacheDifference:
|
45
|
-
client_missing_entries: CacheEntriesList
|
46
|
-
server_missing_keys: List[str]
|
47
|
-
|
48
|
-
def __repr__(self):
|
49
|
-
"""Returns a string representation of the CacheDifference object."""
|
50
|
-
import reprlib
|
51
|
-
|
52
|
-
missing_entries_repr = reprlib.repr(self.client_missing_entries)
|
53
|
-
missing_keys_repr = reprlib.repr(self.server_missing_keys)
|
54
|
-
return f"CacheDifference(client_missing_entries={missing_entries_repr}, server_missing_keys={missing_keys_repr})"
|
55
|
-
|
56
|
-
|
57
|
-
class RemoteCacheSync(AbstractContextManager):
|
58
|
-
"""Synchronizes a local cache with a remote cache.
|
59
|
-
|
60
|
-
Handles bidirectional synchronization:
|
61
|
-
- Downloads missing entries from remote to local cache
|
62
|
-
- Uploads new local entries to remote cache
|
63
|
-
"""
|
64
|
-
|
1
|
+
class RemoteCacheSync:
|
65
2
|
def __init__(
|
66
|
-
self,
|
67
|
-
coop: "Coop",
|
68
|
-
cache: "Cache",
|
69
|
-
output_func: Callable,
|
70
|
-
remote_cache: bool = True,
|
71
|
-
remote_cache_description: str = "",
|
3
|
+
self, coop, cache, output_func, remote_cache=True, remote_cache_description=""
|
72
4
|
):
|
73
|
-
"""
|
74
|
-
Initializes a RemoteCacheSync object.
|
75
|
-
|
76
|
-
:param coop: Coop object for interacting with the remote cache
|
77
|
-
:param cache: Cache object for local cache
|
78
|
-
:param output_func: Function for outputting messages
|
79
|
-
:param remote_cache: Whether to enable remote cache synchronization
|
80
|
-
:param remote_cache_description: Description for remote cache entries
|
81
|
-
|
82
|
-
"""
|
83
5
|
self.coop = coop
|
84
6
|
self.cache = cache
|
85
7
|
self._output = output_func
|
86
|
-
self.
|
8
|
+
self.remote_cache = remote_cache
|
9
|
+
self.old_entry_keys = []
|
10
|
+
self.new_cache_entries = []
|
87
11
|
self.remote_cache_description = remote_cache_description
|
88
|
-
self.initial_cache_keys = []
|
89
12
|
|
90
|
-
def __enter__(self)
|
91
|
-
if self.
|
13
|
+
def __enter__(self):
|
14
|
+
if self.remote_cache:
|
92
15
|
self._sync_from_remote()
|
93
|
-
self.
|
16
|
+
self.old_entry_keys = list(self.cache.keys())
|
94
17
|
return self
|
95
18
|
|
96
19
|
def __exit__(self, exc_type, exc_value, traceback):
|
97
|
-
if self.
|
20
|
+
if self.remote_cache:
|
98
21
|
self._sync_to_remote()
|
99
22
|
return False # Propagate exceptions
|
100
23
|
|
101
|
-
def
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
client_missing_entries=diff.get("client_missing_cacheentries", []),
|
106
|
-
server_missing_keys=diff.get("server_missing_cacheentry_keys", []),
|
24
|
+
def _sync_from_remote(self):
|
25
|
+
cache_difference = self.coop.remote_cache_get_diff(self.cache.keys())
|
26
|
+
client_missing_cacheentries = cache_difference.get(
|
27
|
+
"client_missing_cacheentries", []
|
107
28
|
)
|
29
|
+
missing_entry_count = len(client_missing_cacheentries)
|
108
30
|
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
31
|
+
if missing_entry_count > 0:
|
32
|
+
self._output(
|
33
|
+
f"Updating local cache with {missing_entry_count:,} new "
|
34
|
+
f"{'entry' if missing_entry_count == 1 else 'entries'} from remote..."
|
35
|
+
)
|
36
|
+
self.cache.add_from_dict(
|
37
|
+
{entry.key: entry for entry in client_missing_cacheentries}
|
38
|
+
)
|
39
|
+
self._output("Local cache updated!")
|
40
|
+
else:
|
115
41
|
self._output("No new entries to add to local cache.")
|
116
|
-
return
|
117
42
|
|
118
|
-
|
119
|
-
|
120
|
-
|
43
|
+
def _sync_to_remote(self):
|
44
|
+
cache_difference = self.coop.remote_cache_get_diff(self.cache.keys())
|
45
|
+
server_missing_cacheentry_keys = cache_difference.get(
|
46
|
+
"server_missing_cacheentry_keys", []
|
121
47
|
)
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
)
|
138
|
-
|
139
|
-
# Get newly added entries since sync started
|
140
|
-
new_entries = CacheEntriesList(
|
141
|
-
[
|
142
|
-
entry
|
143
|
-
for entry in self.cache.values()
|
144
|
-
if entry.key not in self.initial_cache_keys
|
145
|
-
]
|
146
|
-
)
|
147
|
-
|
148
|
-
return server_missing_entries + new_entries
|
149
|
-
|
150
|
-
def _sync_to_remote(self) -> None:
|
151
|
-
"""Uploads new local entries to remote cache."""
|
152
|
-
diff: CacheDifference = self._get_cache_difference()
|
153
|
-
entries_to_upload: CacheEntriesList = self._get_entries_to_upload(diff)
|
154
|
-
upload_count = len(entries_to_upload)
|
155
|
-
|
156
|
-
if upload_count > 0:
|
48
|
+
server_missing_cacheentries = [
|
49
|
+
entry
|
50
|
+
for key in server_missing_cacheentry_keys
|
51
|
+
if (entry := self.cache.data.get(key)) is not None
|
52
|
+
]
|
53
|
+
|
54
|
+
new_cache_entries = [
|
55
|
+
entry
|
56
|
+
for entry in self.cache.values()
|
57
|
+
if entry.key not in self.old_entry_keys
|
58
|
+
]
|
59
|
+
server_missing_cacheentries.extend(new_cache_entries)
|
60
|
+
new_entry_count = len(server_missing_cacheentries)
|
61
|
+
|
62
|
+
if new_entry_count > 0:
|
157
63
|
self._output(
|
158
|
-
f"Updating remote cache with {
|
159
|
-
f"{'entry' if
|
64
|
+
f"Updating remote cache with {new_entry_count:,} new "
|
65
|
+
f"{'entry' if new_entry_count == 1 else 'entries'}..."
|
160
66
|
)
|
161
|
-
|
162
67
|
self.coop.remote_cache_create_many(
|
163
|
-
|
68
|
+
server_missing_cacheentries,
|
164
69
|
visibility="private",
|
165
70
|
description=self.remote_cache_description,
|
166
71
|
)
|
@@ -171,16 +76,3 @@ class RemoteCacheSync(AbstractContextManager):
|
|
171
76
|
self._output(
|
172
77
|
f"There are {len(self.cache.keys()):,} entries in the local cache."
|
173
78
|
)
|
174
|
-
|
175
|
-
|
176
|
-
if __name__ == "__main__":
|
177
|
-
import doctest
|
178
|
-
|
179
|
-
doctest.testmod()
|
180
|
-
|
181
|
-
from edsl.coop.coop import Coop
|
182
|
-
from edsl.data.Cache import Cache
|
183
|
-
from edsl.data.CacheEntry import CacheEntry
|
184
|
-
|
185
|
-
r = RemoteCacheSync(Coop(), Cache(), print)
|
186
|
-
diff = r._get_cache_difference()
|
edsl/enums.py
CHANGED
@@ -86,13 +86,6 @@ InferenceServiceLiteral = Literal[
|
|
86
86
|
"perplexity",
|
87
87
|
]
|
88
88
|
|
89
|
-
available_models_urls = {
|
90
|
-
"anthropic": "https://docs.anthropic.com/en/docs/about-claude/models",
|
91
|
-
"openai": "https://platform.openai.com/docs/models/gp",
|
92
|
-
"groq": "https://console.groq.com/docs/models",
|
93
|
-
"google": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models",
|
94
|
-
}
|
95
|
-
|
96
89
|
|
97
90
|
service_to_api_keyname = {
|
98
91
|
InferenceServiceType.BEDROCK.value: "TBD",
|
@@ -11,27 +11,21 @@ class AnthropicService(InferenceServiceABC):
|
|
11
11
|
|
12
12
|
_inference_service_ = "anthropic"
|
13
13
|
_env_key_name_ = "ANTHROPIC_API_KEY"
|
14
|
-
key_sequence = ["content", 0, "text"]
|
14
|
+
key_sequence = ["content", 0, "text"] # ["content"][0]["text"]
|
15
15
|
usage_sequence = ["usage"]
|
16
16
|
input_token_name = "input_tokens"
|
17
17
|
output_token_name = "output_tokens"
|
18
18
|
model_exclude_list = []
|
19
19
|
|
20
|
-
@classmethod
|
21
|
-
def get_model_list(cls, api_key: str = None):
|
22
|
-
|
23
|
-
import requests
|
24
|
-
|
25
|
-
if api_key is None:
|
26
|
-
api_key = os.environ.get("ANTHROPIC_API_KEY")
|
27
|
-
headers = {"x-api-key": api_key, "anthropic-version": "2023-06-01"}
|
28
|
-
response = requests.get("https://api.anthropic.com/v1/models", headers=headers)
|
29
|
-
model_names = [m["id"] for m in response.json()["data"]]
|
30
|
-
return model_names
|
31
|
-
|
32
20
|
@classmethod
|
33
21
|
def available(cls):
|
34
|
-
|
22
|
+
# TODO - replace with an API call
|
23
|
+
return [
|
24
|
+
"claude-3-5-sonnet-20240620",
|
25
|
+
"claude-3-opus-20240229",
|
26
|
+
"claude-3-sonnet-20240229",
|
27
|
+
"claude-3-haiku-20240307",
|
28
|
+
]
|
35
29
|
|
36
30
|
@classmethod
|
37
31
|
def create_model(
|
@@ -68,36 +62,20 @@ class AnthropicService(InferenceServiceABC):
|
|
68
62
|
system_prompt: str = "",
|
69
63
|
files_list: Optional[List["Files"]] = None,
|
70
64
|
) -> dict[str, Any]:
|
71
|
-
"""Calls the
|
65
|
+
"""Calls the OpenAI API and returns the API response."""
|
72
66
|
|
73
|
-
|
74
|
-
|
75
|
-
"role": "user",
|
76
|
-
"content": [{"type": "text", "text": user_prompt}],
|
77
|
-
}
|
78
|
-
]
|
79
|
-
if files_list:
|
80
|
-
for file_entry in files_list:
|
81
|
-
encoded_image = file_entry.base64_string
|
82
|
-
messages[0]["content"].append(
|
83
|
-
{
|
84
|
-
"type": "image",
|
85
|
-
"source": {
|
86
|
-
"type": "base64",
|
87
|
-
"media_type": file_entry.mime_type,
|
88
|
-
"data": encoded_image,
|
89
|
-
},
|
90
|
-
}
|
91
|
-
)
|
92
|
-
# breakpoint()
|
93
|
-
client = AsyncAnthropic(api_key=self.api_token)
|
67
|
+
api_key = os.environ.get("ANTHROPIC_API_KEY")
|
68
|
+
client = AsyncAnthropic(api_key=api_key)
|
94
69
|
|
95
70
|
response = await client.messages.create(
|
96
71
|
model=model_name,
|
97
72
|
max_tokens=self.max_tokens,
|
98
73
|
temperature=self.temperature,
|
99
|
-
system=system_prompt,
|
100
|
-
messages=
|
74
|
+
system=system_prompt,
|
75
|
+
messages=[
|
76
|
+
# {"role": "system", "content": system_prompt},
|
77
|
+
{"role": "user", "content": user_prompt},
|
78
|
+
],
|
101
79
|
)
|
102
80
|
return response.model_dump()
|
103
81
|
|
@@ -133,12 +133,6 @@ class AvailableModelFetcher:
|
|
133
133
|
)
|
134
134
|
service_name = service._inference_service_
|
135
135
|
|
136
|
-
if not service_models:
|
137
|
-
import warnings
|
138
|
-
|
139
|
-
warnings.warn(f"No models found for service {service_name}")
|
140
|
-
return [], service_name
|
141
|
-
|
142
136
|
models_list = AvailableModels(
|
143
137
|
[
|
144
138
|
LanguageModelInfo(
|
@@ -183,7 +177,7 @@ class AvailableModelFetcher:
|
|
183
177
|
)
|
184
178
|
|
185
179
|
except Exception as exc:
|
186
|
-
print(f"Service query failed
|
180
|
+
print(f"Service query failed: {exc}")
|
187
181
|
continue
|
188
182
|
|
189
183
|
return AvailableModels(all_models)
|
@@ -40,17 +40,13 @@ class GoogleService(InferenceServiceABC):
|
|
40
40
|
model_exclude_list = []
|
41
41
|
|
42
42
|
@classmethod
|
43
|
-
def
|
43
|
+
def available(cls) -> List[str]:
|
44
44
|
model_list = []
|
45
45
|
for m in genai.list_models():
|
46
46
|
if "generateContent" in m.supported_generation_methods:
|
47
47
|
model_list.append(m.name.split("/")[-1])
|
48
48
|
return model_list
|
49
49
|
|
50
|
-
@classmethod
|
51
|
-
def available(cls) -> List[str]:
|
52
|
-
return cls.get_model_list()
|
53
|
-
|
54
50
|
@classmethod
|
55
51
|
def create_model(
|
56
52
|
cls, model_name: str = "gemini-pro", model_class_name=None
|
@@ -71,12 +71,7 @@ class ModelResolver:
|
|
71
71
|
self._models_to_services[model_name] = service
|
72
72
|
return service
|
73
73
|
|
74
|
-
raise InferenceServiceError(
|
75
|
-
f"""Model {model_name} not found in any services.
|
76
|
-
If you know the service that has this model, use the service_name parameter directly.
|
77
|
-
E.g., Model("gpt-4o", service_name="openai")
|
78
|
-
"""
|
79
|
-
)
|
74
|
+
raise InferenceServiceError(f"Model {model_name} not found in any services")
|
80
75
|
|
81
76
|
|
82
77
|
class InferenceServicesCollection:
|
@@ -98,9 +93,6 @@ class InferenceServicesCollection:
|
|
98
93
|
if service_name not in cls.added_models:
|
99
94
|
cls.added_models[service_name].append(model_name)
|
100
95
|
|
101
|
-
def service_names_to_classes(self) -> Dict[str, InferenceServiceABC]:
|
102
|
-
return {service._inference_service_: service for service in self.services}
|
103
|
-
|
104
96
|
def available(
|
105
97
|
self,
|
106
98
|
service: Optional[str] = None,
|
@@ -120,15 +112,7 @@ class InferenceServicesCollection:
|
|
120
112
|
def create_model_factory(
|
121
113
|
self, model_name: str, service_name: Optional[InferenceServiceLiteral] = None
|
122
114
|
) -> "LanguageModel":
|
123
|
-
|
124
|
-
if service_name is None: # we try to find the right service
|
125
|
-
service = self.resolver.resolve_model(model_name, service_name)
|
126
|
-
else: # if they passed a service, we'll use that
|
127
|
-
service = self.service_names_to_classes().get(service_name)
|
128
|
-
|
129
|
-
if not service: # but if we can't find it, we'll raise an error
|
130
|
-
raise InferenceServiceError(f"Service {service_name} not found")
|
131
|
-
|
115
|
+
service = self.resolver.resolve_model(model_name, service_name)
|
132
116
|
return service.create_model(model_name)
|
133
117
|
|
134
118
|
|
@@ -1,8 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
from typing import Any, List, Optional
|
2
|
+
from typing import Any, List, Optional
|
3
3
|
import os
|
4
4
|
|
5
|
-
|
6
5
|
import openai
|
7
6
|
|
8
7
|
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
@@ -12,8 +11,6 @@ from edsl.utilities.utilities import fix_partial_correct_response
|
|
12
11
|
|
13
12
|
from edsl.config import CONFIG
|
14
13
|
|
15
|
-
APIToken = NewType("APIToken", str)
|
16
|
-
|
17
14
|
|
18
15
|
class OpenAIService(InferenceServiceABC):
|
19
16
|
"""OpenAI service class."""
|
@@ -25,43 +22,35 @@ class OpenAIService(InferenceServiceABC):
|
|
25
22
|
_sync_client_ = openai.OpenAI
|
26
23
|
_async_client_ = openai.AsyncOpenAI
|
27
24
|
|
28
|
-
|
29
|
-
|
25
|
+
_sync_client_instance = None
|
26
|
+
_async_client_instance = None
|
30
27
|
|
31
28
|
key_sequence = ["choices", 0, "message", "content"]
|
32
29
|
usage_sequence = ["usage"]
|
33
30
|
input_token_name = "prompt_tokens"
|
34
31
|
output_token_name = "completion_tokens"
|
35
32
|
|
36
|
-
available_models_url = "https://platform.openai.com/docs/models/gp"
|
37
|
-
|
38
33
|
def __init_subclass__(cls, **kwargs):
|
39
34
|
super().__init_subclass__(**kwargs)
|
40
|
-
# so subclasses
|
41
|
-
cls.
|
42
|
-
cls.
|
35
|
+
# so subclasses have to create their own instances of the clients
|
36
|
+
cls._sync_client_instance = None
|
37
|
+
cls._async_client_instance = None
|
43
38
|
|
44
39
|
@classmethod
|
45
|
-
def sync_client(cls
|
46
|
-
if
|
47
|
-
|
48
|
-
api_key=
|
49
|
-
base_url=cls._base_url_,
|
40
|
+
def sync_client(cls):
|
41
|
+
if cls._sync_client_instance is None:
|
42
|
+
cls._sync_client_instance = cls._sync_client_(
|
43
|
+
api_key=os.getenv(cls._env_key_name_), base_url=cls._base_url_
|
50
44
|
)
|
51
|
-
|
52
|
-
client = cls._sync_client_instances[api_key]
|
53
|
-
return client
|
45
|
+
return cls._sync_client_instance
|
54
46
|
|
55
47
|
@classmethod
|
56
|
-
def async_client(cls
|
57
|
-
if
|
58
|
-
|
59
|
-
api_key=
|
60
|
-
base_url=cls._base_url_,
|
48
|
+
def async_client(cls):
|
49
|
+
if cls._async_client_instance is None:
|
50
|
+
cls._async_client_instance = cls._async_client_(
|
51
|
+
api_key=os.getenv(cls._env_key_name_), base_url=cls._base_url_
|
61
52
|
)
|
62
|
-
|
63
|
-
client = cls._async_client_instances[api_key]
|
64
|
-
return client
|
53
|
+
return cls._async_client_instance
|
65
54
|
|
66
55
|
model_exclude_list = [
|
67
56
|
"whisper-1",
|
@@ -83,24 +72,20 @@ class OpenAIService(InferenceServiceABC):
|
|
83
72
|
_models_list_cache: List[str] = []
|
84
73
|
|
85
74
|
@classmethod
|
86
|
-
def get_model_list(cls
|
87
|
-
|
88
|
-
api_key = os.getenv(cls._env_key_name_)
|
89
|
-
raw_list = cls.sync_client(api_key).models.list()
|
75
|
+
def get_model_list(cls):
|
76
|
+
raw_list = cls.sync_client().models.list()
|
90
77
|
if hasattr(raw_list, "data"):
|
91
78
|
return raw_list.data
|
92
79
|
else:
|
93
80
|
return raw_list
|
94
81
|
|
95
82
|
@classmethod
|
96
|
-
def available(cls
|
97
|
-
if api_token is None:
|
98
|
-
api_token = os.getenv(cls._env_key_name_)
|
83
|
+
def available(cls) -> List[str]:
|
99
84
|
if not cls._models_list_cache:
|
100
85
|
try:
|
101
86
|
cls._models_list_cache = [
|
102
87
|
m.id
|
103
|
-
for m in cls.get_model_list(
|
88
|
+
for m in cls.get_model_list()
|
104
89
|
if m.id not in cls.model_exclude_list
|
105
90
|
]
|
106
91
|
except Exception as e:
|
@@ -135,10 +120,10 @@ class OpenAIService(InferenceServiceABC):
|
|
135
120
|
}
|
136
121
|
|
137
122
|
def sync_client(self):
|
138
|
-
return cls.sync_client(
|
123
|
+
return cls.sync_client()
|
139
124
|
|
140
125
|
def async_client(self):
|
141
|
-
return cls.async_client(
|
126
|
+
return cls.async_client()
|
142
127
|
|
143
128
|
@classmethod
|
144
129
|
def available(cls) -> list[str]:
|
@@ -187,16 +172,16 @@ class OpenAIService(InferenceServiceABC):
|
|
187
172
|
) -> dict[str, Any]:
|
188
173
|
"""Calls the OpenAI API and returns the API response."""
|
189
174
|
if files_list:
|
175
|
+
encoded_image = files_list[0].base64_string
|
190
176
|
content = [{"type": "text", "text": user_prompt}]
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
"
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
)
|
177
|
+
content.append(
|
178
|
+
{
|
179
|
+
"type": "image_url",
|
180
|
+
"image_url": {
|
181
|
+
"url": f"data:image/jpeg;base64,{encoded_image}"
|
182
|
+
},
|
183
|
+
}
|
184
|
+
)
|
200
185
|
else:
|
201
186
|
content = user_prompt
|
202
187
|
client = self.async_client()
|
@@ -51,7 +51,6 @@ class TestService(InferenceServiceABC):
|
|
51
51
|
@property
|
52
52
|
def _canned_response(self):
|
53
53
|
if hasattr(self, "canned_response"):
|
54
|
-
|
55
54
|
return self.canned_response
|
56
55
|
else:
|
57
56
|
return "Hello, world"
|
@@ -64,6 +63,9 @@ class TestService(InferenceServiceABC):
|
|
64
63
|
files_list: Optional[List["File"]] = None,
|
65
64
|
) -> dict[str, Any]:
|
66
65
|
await asyncio.sleep(0.1)
|
66
|
+
# return {"message": """{"answer": "Hello, world"}"""}
|
67
|
+
|
68
|
+
# breakpoint()
|
67
69
|
|
68
70
|
if hasattr(self, "throw_exception") and self.throw_exception:
|
69
71
|
if hasattr(self, "exception_probability"):
|