edsl 0.1.39.dev2__py3-none-any.whl → 0.1.39.dev4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. edsl/Base.py +28 -0
  2. edsl/__init__.py +1 -1
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +8 -16
  5. edsl/agents/Invigilator.py +13 -14
  6. edsl/agents/InvigilatorBase.py +4 -1
  7. edsl/agents/PromptConstructor.py +42 -22
  8. edsl/agents/QuestionInstructionPromptBuilder.py +1 -1
  9. edsl/auto/AutoStudy.py +18 -5
  10. edsl/auto/StageBase.py +53 -40
  11. edsl/auto/StageQuestions.py +2 -1
  12. edsl/auto/utilities.py +0 -6
  13. edsl/coop/coop.py +21 -5
  14. edsl/data/Cache.py +29 -18
  15. edsl/data/CacheHandler.py +0 -2
  16. edsl/data/RemoteCacheSync.py +154 -46
  17. edsl/data/hack.py +10 -0
  18. edsl/enums.py +7 -0
  19. edsl/inference_services/AnthropicService.py +38 -16
  20. edsl/inference_services/AvailableModelFetcher.py +7 -1
  21. edsl/inference_services/GoogleService.py +5 -1
  22. edsl/inference_services/InferenceServicesCollection.py +18 -2
  23. edsl/inference_services/OpenAIService.py +46 -31
  24. edsl/inference_services/TestService.py +1 -3
  25. edsl/inference_services/TogetherAIService.py +5 -3
  26. edsl/inference_services/data_structures.py +74 -2
  27. edsl/jobs/AnswerQuestionFunctionConstructor.py +148 -113
  28. edsl/jobs/FetchInvigilator.py +10 -3
  29. edsl/jobs/InterviewsConstructor.py +6 -4
  30. edsl/jobs/Jobs.py +299 -233
  31. edsl/jobs/JobsChecks.py +2 -2
  32. edsl/jobs/JobsPrompts.py +1 -1
  33. edsl/jobs/JobsRemoteInferenceHandler.py +160 -136
  34. edsl/jobs/async_interview_runner.py +138 -0
  35. edsl/jobs/check_survey_scenario_compatibility.py +85 -0
  36. edsl/jobs/data_structures.py +120 -0
  37. edsl/jobs/interviews/Interview.py +80 -42
  38. edsl/jobs/results_exceptions_handler.py +98 -0
  39. edsl/jobs/runners/JobsRunnerAsyncio.py +87 -357
  40. edsl/jobs/runners/JobsRunnerStatus.py +131 -164
  41. edsl/jobs/tasks/TaskHistory.py +24 -3
  42. edsl/language_models/LanguageModel.py +59 -4
  43. edsl/language_models/ModelList.py +19 -8
  44. edsl/language_models/__init__.py +1 -1
  45. edsl/language_models/model.py +256 -0
  46. edsl/language_models/repair.py +1 -1
  47. edsl/questions/QuestionBase.py +35 -26
  48. edsl/questions/QuestionBasePromptsMixin.py +1 -1
  49. edsl/questions/QuestionBudget.py +1 -1
  50. edsl/questions/QuestionCheckBox.py +2 -2
  51. edsl/questions/QuestionExtract.py +5 -7
  52. edsl/questions/QuestionFreeText.py +1 -1
  53. edsl/questions/QuestionList.py +9 -15
  54. edsl/questions/QuestionMatrix.py +1 -1
  55. edsl/questions/QuestionMultipleChoice.py +1 -1
  56. edsl/questions/QuestionNumerical.py +1 -1
  57. edsl/questions/QuestionRank.py +1 -1
  58. edsl/questions/SimpleAskMixin.py +1 -1
  59. edsl/questions/__init__.py +1 -1
  60. edsl/questions/data_structures.py +20 -0
  61. edsl/questions/{QuestionBaseGenMixin.py → question_base_gen_mixin.py} +52 -49
  62. edsl/questions/{ResponseValidatorABC.py → response_validator_abc.py} +6 -18
  63. edsl/questions/{ResponseValidatorFactory.py → response_validator_factory.py} +7 -1
  64. edsl/results/DatasetExportMixin.py +60 -119
  65. edsl/results/Result.py +109 -3
  66. edsl/results/Results.py +50 -39
  67. edsl/results/file_exports.py +252 -0
  68. edsl/scenarios/ScenarioList.py +35 -7
  69. edsl/surveys/Survey.py +71 -20
  70. edsl/test_h +1 -0
  71. edsl/utilities/gcp_bucket/example.py +50 -0
  72. {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/METADATA +2 -2
  73. {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/RECORD +85 -76
  74. edsl/language_models/registry.py +0 -180
  75. /edsl/agents/{QuestionOptionProcessor.py → question_option_processor.py} +0 -0
  76. /edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +0 -0
  77. /edsl/questions/{LoopProcessor.py → loop_processor.py} +0 -0
  78. /edsl/questions/{RegisterQuestionsMeta.py → register_questions_meta.py} +0 -0
  79. /edsl/results/{ResultsFetchMixin.py → results_fetch_mixin.py} +0 -0
  80. /edsl/results/{Selector.py → results_selector.py} +0 -0
  81. /edsl/results/{ResultsToolsMixin.py → results_tools_mixin.py} +0 -0
  82. /edsl/scenarios/{DirectoryScanner.py → directory_scanner.py} +0 -0
  83. /edsl/scenarios/{ScenarioJoin.py → scenario_join.py} +0 -0
  84. /edsl/scenarios/{ScenarioSelector.py → scenario_selector.py} +0 -0
  85. {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/LICENSE +0 -0
  86. {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/WHEEL +0 -0
edsl/data/Cache.py CHANGED
@@ -6,11 +6,9 @@ from __future__ import annotations
6
6
  import json
7
7
  import os
8
8
  import warnings
9
- from typing import Optional, Union
9
+ from typing import Optional, Union, TYPE_CHECKING
10
10
  from edsl.Base import Base
11
11
 
12
-
13
- # from edsl.utilities.decorators import remove_edsl_version
14
12
  from edsl.utilities.remove_edsl_version import remove_edsl_version
15
13
  from edsl.exceptions.cache import CacheError
16
14
 
@@ -83,10 +81,6 @@ class Cache(Base):
83
81
 
84
82
  self._perform_checks()
85
83
 
86
- # def rich_print(sefl):
87
- # pass
88
- # # raise NotImplementedError("This method is not implemented yet.")
89
-
90
84
  def code(sefl):
91
85
  pass
92
86
  # raise NotImplementedError("This method is not implemented yet.")
@@ -293,8 +287,8 @@ class Cache(Base):
293
287
 
294
288
  CACHE_PATH = CONFIG.get("EDSL_DATABASE_PATH")
295
289
  path = CACHE_PATH.replace("sqlite:///", "")
296
- db_path = os.path.join(os.path.dirname(path), "data.db")
297
- return cls.from_sqlite_db(db_path=db_path)
290
+ # db_path = os.path.join(os.path.dirname(path), "data.db")
291
+ return cls.from_sqlite_db(path)
298
292
 
299
293
  @classmethod
300
294
  def from_jsonl(cls, jsonlfile: str, db_path: Optional[str] = None) -> Cache:
@@ -368,12 +362,32 @@ class Cache(Base):
368
362
  scenarios.append(s)
369
363
  return ScenarioList(scenarios)
370
364
 
371
- ####################
372
- # REMOTE
373
- ####################
374
- # TODO: Make this work
375
- # - Need to decide whether the cache belongs to a user and what can be shared
376
- # - I.e., some cache entries? all or nothing?
365
+ def __floordiv__(self, other: "Cache") -> "Cache":
366
+ """
367
+ Return a new Cache containing entries that are in self but not in other.
368
+ Uses // operator as alternative to subtraction.
369
+
370
+ :param other: Another Cache object to compare against
371
+ :return: A new Cache object containing unique entries
372
+
373
+ >>> from edsl.data.CacheEntry import CacheEntry
374
+ >>> ce1 = CacheEntry.example(randomize = True)
375
+ >>> ce2 = CacheEntry.example(randomize = True)
376
+ >>> ce2 = CacheEntry.example(randomize = True)
377
+ >>> c1 = Cache(data={ce1.key: ce1, ce2.key: ce2})
378
+ >>> c2 = Cache(data={ce1.key: ce1})
379
+ >>> c3 = c1 // c2
380
+ >>> len(c3)
381
+ 1
382
+ >>> c3.data[ce2.key] == ce2
383
+ True
384
+ """
385
+ if not isinstance(other, Cache):
386
+ raise CacheError("Can only compare two caches")
387
+
388
+ diff_data = {k: v for k, v in self.data.items() if k not in other.data}
389
+ return Cache(data=diff_data, immediate_write=self.immediate_write)
390
+
377
391
  @classmethod
378
392
  def from_url(cls, db_path=None) -> Cache:
379
393
  """
@@ -399,9 +413,6 @@ class Cache(Base):
399
413
  if self.filename:
400
414
  self.write(self.filename)
401
415
 
402
- ####################
403
- # DUNDER / USEFUL
404
- ####################
405
416
  def __hash__(self):
406
417
  """Return the hash of the Cache."""
407
418
  from edsl.utilities.utilities import dict_hash
edsl/data/CacheHandler.py CHANGED
@@ -67,8 +67,6 @@ class CacheHandler:
67
67
  if self.test:
68
68
  return Cache(data={})
69
69
 
70
- # if self.CACHE_PATH is not None:
71
- # return self.CACHE_PATH
72
70
  from edsl.config import CONFIG
73
71
 
74
72
  if hasattr(CONFIG, "EDSL_SESSION_CACHE"):
@@ -1,71 +1,166 @@
1
- class RemoteCacheSync:
1
+ from typing import List, Dict, Any, Optional, TYPE_CHECKING, Callable
2
+ from dataclasses import dataclass
3
+ from contextlib import AbstractContextManager
4
+ from collections import UserList
5
+
6
+ if TYPE_CHECKING:
7
+ from .Cache import Cache
8
+ from edsl.coop.coop import Coop
9
+ from .CacheEntry import CacheEntry
10
+
11
+ from logging import Logger
12
+
13
+
14
+ class CacheKeyList(UserList):
15
+ def __init__(self, data: List[str]):
16
+ super().__init__(data)
17
+ self.data = data
18
+
19
+ def __repr__(self):
20
+ import reprlib
21
+
22
+ keys_repr = reprlib.repr(self.data)
23
+ return f"CacheKeyList({keys_repr})"
24
+
25
+
26
+ class CacheEntriesList(UserList):
27
+ def __init__(self, data: List["CacheEntry"]):
28
+ super().__init__(data)
29
+ self.data = data
30
+
31
+ def __repr__(self):
32
+ import reprlib
33
+
34
+ entries_repr = reprlib.repr(self.data)
35
+ return f"CacheEntries({entries_repr})"
36
+
37
+ def to_cache(self) -> "Cache":
38
+ from edsl.data.Cache import Cache
39
+
40
+ return Cache({entry.key: entry for entry in self.data})
41
+
42
+
43
+ @dataclass
44
+ class CacheDifference:
45
+ client_missing_entries: CacheEntriesList
46
+ server_missing_keys: List[str]
47
+
48
+ def __repr__(self):
49
+ """Returns a string representation of the CacheDifference object."""
50
+ import reprlib
51
+
52
+ missing_entries_repr = reprlib.repr(self.client_missing_entries)
53
+ missing_keys_repr = reprlib.repr(self.server_missing_keys)
54
+ return f"CacheDifference(client_missing_entries={missing_entries_repr}, server_missing_keys={missing_keys_repr})"
55
+
56
+
57
+ class RemoteCacheSync(AbstractContextManager):
58
+ """Synchronizes a local cache with a remote cache.
59
+
60
+ Handles bidirectional synchronization:
61
+ - Downloads missing entries from remote to local cache
62
+ - Uploads new local entries to remote cache
63
+ """
64
+
2
65
  def __init__(
3
- self, coop, cache, output_func, remote_cache=True, remote_cache_description=""
66
+ self,
67
+ coop: "Coop",
68
+ cache: "Cache",
69
+ output_func: Callable,
70
+ remote_cache: bool = True,
71
+ remote_cache_description: str = "",
4
72
  ):
73
+ """
74
+ Initializes a RemoteCacheSync object.
75
+
76
+ :param coop: Coop object for interacting with the remote cache
77
+ :param cache: Cache object for local cache
78
+ :param output_func: Function for outputting messages
79
+ :param remote_cache: Whether to enable remote cache synchronization
80
+ :param remote_cache_description: Description for remote cache entries
81
+
82
+ """
5
83
  self.coop = coop
6
84
  self.cache = cache
7
85
  self._output = output_func
8
- self.remote_cache = remote_cache
9
- self.old_entry_keys = []
10
- self.new_cache_entries = []
86
+ self.remote_cache_enabled = remote_cache
11
87
  self.remote_cache_description = remote_cache_description
88
+ self.initial_cache_keys = []
12
89
 
13
- def __enter__(self):
14
- if self.remote_cache:
90
+ def __enter__(self) -> "RemoteCacheSync":
91
+ if self.remote_cache_enabled:
15
92
  self._sync_from_remote()
16
- self.old_entry_keys = list(self.cache.keys())
93
+ self.initial_cache_keys = list(self.cache.keys())
17
94
  return self
18
95
 
19
96
  def __exit__(self, exc_type, exc_value, traceback):
20
- if self.remote_cache:
97
+ if self.remote_cache_enabled:
21
98
  self._sync_to_remote()
22
99
  return False # Propagate exceptions
23
100
 
24
- def _sync_from_remote(self):
25
- cache_difference = self.coop.remote_cache_get_diff(self.cache.keys())
26
- client_missing_cacheentries = cache_difference.get(
27
- "client_missing_cacheentries", []
101
+ def _get_cache_difference(self) -> CacheDifference:
102
+ """Retrieves differences between local and remote caches."""
103
+ diff = self.coop.remote_cache_get_diff(self.cache.keys())
104
+ return CacheDifference(
105
+ client_missing_entries=diff.get("client_missing_cacheentries", []),
106
+ server_missing_keys=diff.get("server_missing_cacheentry_keys", []),
28
107
  )
29
- missing_entry_count = len(client_missing_cacheentries)
30
108
 
31
- if missing_entry_count > 0:
32
- self._output(
33
- f"Updating local cache with {missing_entry_count:,} new "
34
- f"{'entry' if missing_entry_count == 1 else 'entries'} from remote..."
35
- )
36
- self.cache.add_from_dict(
37
- {entry.key: entry for entry in client_missing_cacheentries}
38
- )
39
- self._output("Local cache updated!")
40
- else:
109
+ def _sync_from_remote(self) -> None:
110
+ """Downloads missing entries from remote cache to local cache."""
111
+ diff: CacheDifference = self._get_cache_difference()
112
+ missing_count = len(diff.client_missing_entries)
113
+
114
+ if missing_count == 0:
41
115
  self._output("No new entries to add to local cache.")
116
+ return
42
117
 
43
- def _sync_to_remote(self):
44
- cache_difference = self.coop.remote_cache_get_diff(self.cache.keys())
45
- server_missing_cacheentry_keys = cache_difference.get(
46
- "server_missing_cacheentry_keys", []
118
+ self._output(
119
+ f"Updating local cache with {missing_count:,} new "
120
+ f"{'entry' if missing_count == 1 else 'entries'} from remote..."
47
121
  )
48
- server_missing_cacheentries = [
49
- entry
50
- for key in server_missing_cacheentry_keys
51
- if (entry := self.cache.data.get(key)) is not None
52
- ]
53
-
54
- new_cache_entries = [
55
- entry
56
- for entry in self.cache.values()
57
- if entry.key not in self.old_entry_keys
58
- ]
59
- server_missing_cacheentries.extend(new_cache_entries)
60
- new_entry_count = len(server_missing_cacheentries)
61
-
62
- if new_entry_count > 0:
122
+
123
+ self.cache.add_from_dict(
124
+ {entry.key: entry for entry in diff.client_missing_entries}
125
+ )
126
+ self._output("Local cache updated!")
127
+
128
+ def _get_entries_to_upload(self, diff: CacheDifference) -> CacheEntriesList:
129
+ """Determines which entries need to be uploaded to remote cache."""
130
+ # Get entries for keys missing from server
131
+ server_missing_entries = CacheEntriesList(
132
+ [
133
+ entry
134
+ for key in diff.server_missing_keys
135
+ if (entry := self.cache.data.get(key)) is not None
136
+ ]
137
+ )
138
+
139
+ # Get newly added entries since sync started
140
+ new_entries = CacheEntriesList(
141
+ [
142
+ entry
143
+ for entry in self.cache.values()
144
+ if entry.key not in self.initial_cache_keys
145
+ ]
146
+ )
147
+
148
+ return server_missing_entries + new_entries
149
+
150
+ def _sync_to_remote(self) -> None:
151
+ """Uploads new local entries to remote cache."""
152
+ diff: CacheDifference = self._get_cache_difference()
153
+ entries_to_upload: CacheEntriesList = self._get_entries_to_upload(diff)
154
+ upload_count = len(entries_to_upload)
155
+
156
+ if upload_count > 0:
63
157
  self._output(
64
- f"Updating remote cache with {new_entry_count:,} new "
65
- f"{'entry' if new_entry_count == 1 else 'entries'}..."
158
+ f"Updating remote cache with {upload_count:,} new "
159
+ f"{'entry' if upload_count == 1 else 'entries'}..."
66
160
  )
161
+
67
162
  self.coop.remote_cache_create_many(
68
- server_missing_cacheentries,
163
+ entries_to_upload,
69
164
  visibility="private",
70
165
  description=self.remote_cache_description,
71
166
  )
@@ -76,3 +171,16 @@ class RemoteCacheSync:
76
171
  self._output(
77
172
  f"There are {len(self.cache.keys()):,} entries in the local cache."
78
173
  )
174
+
175
+
176
+ if __name__ == "__main__":
177
+ import doctest
178
+
179
+ doctest.testmod()
180
+
181
+ from edsl.coop.coop import Coop
182
+ from edsl.data.Cache import Cache
183
+ from edsl.data.CacheEntry import CacheEntry
184
+
185
+ r = RemoteCacheSync(Coop(), Cache(), print)
186
+ diff = r._get_cache_difference()
edsl/data/hack.py ADDED
@@ -0,0 +1,10 @@
1
+ from edsl.data.CacheEntry import CacheEntry
2
+
3
+ first = 0
4
+ for i in range(0,1000000):
5
+ if i == 0:
6
+ first = CacheEntry.example().key
7
+ if first != "55ce2e13d38aa7fb6ec848053285edb4":
8
+ print(first)
9
+ print(CacheEntry.example().__dict__)
10
+ break
edsl/enums.py CHANGED
@@ -86,6 +86,13 @@ InferenceServiceLiteral = Literal[
86
86
  "perplexity",
87
87
  ]
88
88
 
89
+ available_models_urls = {
90
+ "anthropic": "https://docs.anthropic.com/en/docs/about-claude/models",
91
+ "openai": "https://platform.openai.com/docs/models/gp",
92
+ "groq": "https://console.groq.com/docs/models",
93
+ "google": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models",
94
+ }
95
+
89
96
 
90
97
  service_to_api_keyname = {
91
98
  InferenceServiceType.BEDROCK.value: "TBD",
@@ -11,21 +11,27 @@ class AnthropicService(InferenceServiceABC):
11
11
 
12
12
  _inference_service_ = "anthropic"
13
13
  _env_key_name_ = "ANTHROPIC_API_KEY"
14
- key_sequence = ["content", 0, "text"] # ["content"][0]["text"]
14
+ key_sequence = ["content", 0, "text"]
15
15
  usage_sequence = ["usage"]
16
16
  input_token_name = "input_tokens"
17
17
  output_token_name = "output_tokens"
18
18
  model_exclude_list = []
19
19
 
20
+ @classmethod
21
+ def get_model_list(cls, api_key: str = None):
22
+
23
+ import requests
24
+
25
+ if api_key is None:
26
+ api_key = os.environ.get("ANTHROPIC_API_KEY")
27
+ headers = {"x-api-key": api_key, "anthropic-version": "2023-06-01"}
28
+ response = requests.get("https://api.anthropic.com/v1/models", headers=headers)
29
+ model_names = [m["id"] for m in response.json()["data"]]
30
+ return model_names
31
+
20
32
  @classmethod
21
33
  def available(cls):
22
- # TODO - replace with an API call
23
- return [
24
- "claude-3-5-sonnet-20240620",
25
- "claude-3-opus-20240229",
26
- "claude-3-sonnet-20240229",
27
- "claude-3-haiku-20240307",
28
- ]
34
+ return cls.get_model_list()
29
35
 
30
36
  @classmethod
31
37
  def create_model(
@@ -62,20 +68,36 @@ class AnthropicService(InferenceServiceABC):
62
68
  system_prompt: str = "",
63
69
  files_list: Optional[List["Files"]] = None,
64
70
  ) -> dict[str, Any]:
65
- """Calls the OpenAI API and returns the API response."""
71
+ """Calls the Anthropic API and returns the API response."""
66
72
 
67
- api_key = os.environ.get("ANTHROPIC_API_KEY")
68
- client = AsyncAnthropic(api_key=api_key)
73
+ messages = [
74
+ {
75
+ "role": "user",
76
+ "content": [{"type": "text", "text": user_prompt}],
77
+ }
78
+ ]
79
+ if files_list:
80
+ for file_entry in files_list:
81
+ encoded_image = file_entry.base64_string
82
+ messages[0]["content"].append(
83
+ {
84
+ "type": "image",
85
+ "source": {
86
+ "type": "base64",
87
+ "media_type": file_entry.mime_type,
88
+ "data": encoded_image,
89
+ },
90
+ }
91
+ )
92
+ # breakpoint()
93
+ client = AsyncAnthropic(api_key=self.api_token)
69
94
 
70
95
  response = await client.messages.create(
71
96
  model=model_name,
72
97
  max_tokens=self.max_tokens,
73
98
  temperature=self.temperature,
74
- system=system_prompt,
75
- messages=[
76
- # {"role": "system", "content": system_prompt},
77
- {"role": "user", "content": user_prompt},
78
- ],
99
+ system=system_prompt, # note that the Anthropic API uses "system" parameter rather than put it in the message
100
+ messages=messages,
79
101
  )
80
102
  return response.model_dump()
81
103
 
@@ -133,6 +133,12 @@ class AvailableModelFetcher:
133
133
  )
134
134
  service_name = service._inference_service_
135
135
 
136
+ if not service_models:
137
+ import warnings
138
+
139
+ warnings.warn(f"No models found for service {service_name}")
140
+ return [], service_name
141
+
136
142
  models_list = AvailableModels(
137
143
  [
138
144
  LanguageModelInfo(
@@ -177,7 +183,7 @@ class AvailableModelFetcher:
177
183
  )
178
184
 
179
185
  except Exception as exc:
180
- print(f"Service query failed: {exc}")
186
+ print(f"Service query failed for service {service_name}: {exc}")
181
187
  continue
182
188
 
183
189
  return AvailableModels(all_models)
@@ -40,13 +40,17 @@ class GoogleService(InferenceServiceABC):
40
40
  model_exclude_list = []
41
41
 
42
42
  @classmethod
43
- def available(cls) -> List[str]:
43
+ def get_model_list(cls):
44
44
  model_list = []
45
45
  for m in genai.list_models():
46
46
  if "generateContent" in m.supported_generation_methods:
47
47
  model_list.append(m.name.split("/")[-1])
48
48
  return model_list
49
49
 
50
+ @classmethod
51
+ def available(cls) -> List[str]:
52
+ return cls.get_model_list()
53
+
50
54
  @classmethod
51
55
  def create_model(
52
56
  cls, model_name: str = "gemini-pro", model_class_name=None
@@ -71,7 +71,12 @@ class ModelResolver:
71
71
  self._models_to_services[model_name] = service
72
72
  return service
73
73
 
74
- raise InferenceServiceError(f"Model {model_name} not found in any services")
74
+ raise InferenceServiceError(
75
+ f"""Model {model_name} not found in any services.
76
+ If you know the service that has this model, use the service_name parameter directly.
77
+ E.g., Model("gpt-4o", service_name="openai")
78
+ """
79
+ )
75
80
 
76
81
 
77
82
  class InferenceServicesCollection:
@@ -93,6 +98,9 @@ class InferenceServicesCollection:
93
98
  if service_name not in cls.added_models:
94
99
  cls.added_models[service_name].append(model_name)
95
100
 
101
+ def service_names_to_classes(self) -> Dict[str, InferenceServiceABC]:
102
+ return {service._inference_service_: service for service in self.services}
103
+
96
104
  def available(
97
105
  self,
98
106
  service: Optional[str] = None,
@@ -112,7 +120,15 @@ class InferenceServicesCollection:
112
120
  def create_model_factory(
113
121
  self, model_name: str, service_name: Optional[InferenceServiceLiteral] = None
114
122
  ) -> "LanguageModel":
115
- service = self.resolver.resolve_model(model_name, service_name)
123
+
124
+ if service_name is None: # we try to find the right service
125
+ service = self.resolver.resolve_model(model_name, service_name)
126
+ else: # if they passed a service, we'll use that
127
+ service = self.service_names_to_classes().get(service_name)
128
+
129
+ if not service: # but if we can't find it, we'll raise an error
130
+ raise InferenceServiceError(f"Service {service_name} not found")
131
+
116
132
  return service.create_model(model_name)
117
133
 
118
134
 
@@ -1,7 +1,8 @@
1
1
  from __future__ import annotations
2
- from typing import Any, List, Optional
2
+ from typing import Any, List, Optional, Dict, NewType
3
3
  import os
4
4
 
5
+
5
6
  import openai
6
7
 
7
8
  from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
@@ -11,6 +12,8 @@ from edsl.utilities.utilities import fix_partial_correct_response
11
12
 
12
13
  from edsl.config import CONFIG
13
14
 
15
+ APIToken = NewType("APIToken", str)
16
+
14
17
 
15
18
  class OpenAIService(InferenceServiceABC):
16
19
  """OpenAI service class."""
@@ -22,35 +25,43 @@ class OpenAIService(InferenceServiceABC):
22
25
  _sync_client_ = openai.OpenAI
23
26
  _async_client_ = openai.AsyncOpenAI
24
27
 
25
- _sync_client_instance = None
26
- _async_client_instance = None
28
+ _sync_client_instances: Dict[APIToken, openai.OpenAI] = {}
29
+ _async_client_instances: Dict[APIToken, openai.AsyncOpenAI] = {}
27
30
 
28
31
  key_sequence = ["choices", 0, "message", "content"]
29
32
  usage_sequence = ["usage"]
30
33
  input_token_name = "prompt_tokens"
31
34
  output_token_name = "completion_tokens"
32
35
 
36
+ available_models_url = "https://platform.openai.com/docs/models/gp"
37
+
33
38
  def __init_subclass__(cls, **kwargs):
34
39
  super().__init_subclass__(**kwargs)
35
- # so subclasses have to create their own instances of the clients
36
- cls._sync_client_instance = None
37
- cls._async_client_instance = None
40
+ # so subclasses that use the OpenAI api key have to create their own instances of the clients
41
+ cls._sync_client_instances = {}
42
+ cls._async_client_instances = {}
38
43
 
39
44
  @classmethod
40
- def sync_client(cls):
41
- if cls._sync_client_instance is None:
42
- cls._sync_client_instance = cls._sync_client_(
43
- api_key=os.getenv(cls._env_key_name_), base_url=cls._base_url_
45
+ def sync_client(cls, api_key):
46
+ if api_key not in cls._sync_client_instances:
47
+ client = cls._sync_client_(
48
+ api_key=api_key,
49
+ base_url=cls._base_url_,
44
50
  )
45
- return cls._sync_client_instance
51
+ cls._sync_client_instances[api_key] = client
52
+ client = cls._sync_client_instances[api_key]
53
+ return client
46
54
 
47
55
  @classmethod
48
- def async_client(cls):
49
- if cls._async_client_instance is None:
50
- cls._async_client_instance = cls._async_client_(
51
- api_key=os.getenv(cls._env_key_name_), base_url=cls._base_url_
56
+ def async_client(cls, api_key):
57
+ if api_key not in cls._async_client_instances:
58
+ client = cls._async_client_(
59
+ api_key=api_key,
60
+ base_url=cls._base_url_,
52
61
  )
53
- return cls._async_client_instance
62
+ cls._async_client_instances[api_key] = client
63
+ client = cls._async_client_instances[api_key]
64
+ return client
54
65
 
55
66
  model_exclude_list = [
56
67
  "whisper-1",
@@ -72,20 +83,24 @@ class OpenAIService(InferenceServiceABC):
72
83
  _models_list_cache: List[str] = []
73
84
 
74
85
  @classmethod
75
- def get_model_list(cls):
76
- raw_list = cls.sync_client().models.list()
86
+ def get_model_list(cls, api_key=None):
87
+ if api_key is None:
88
+ api_key = os.getenv(cls._env_key_name_)
89
+ raw_list = cls.sync_client(api_key).models.list()
77
90
  if hasattr(raw_list, "data"):
78
91
  return raw_list.data
79
92
  else:
80
93
  return raw_list
81
94
 
82
95
  @classmethod
83
- def available(cls) -> List[str]:
96
+ def available(cls, api_token=None) -> List[str]:
97
+ if api_token is None:
98
+ api_token = os.getenv(cls._env_key_name_)
84
99
  if not cls._models_list_cache:
85
100
  try:
86
101
  cls._models_list_cache = [
87
102
  m.id
88
- for m in cls.get_model_list()
103
+ for m in cls.get_model_list(api_key=api_token)
89
104
  if m.id not in cls.model_exclude_list
90
105
  ]
91
106
  except Exception as e:
@@ -120,10 +135,10 @@ class OpenAIService(InferenceServiceABC):
120
135
  }
121
136
 
122
137
  def sync_client(self):
123
- return cls.sync_client()
138
+ return cls.sync_client(api_key=self.api_token)
124
139
 
125
140
  def async_client(self):
126
- return cls.async_client()
141
+ return cls.async_client(api_key=self.api_token)
127
142
 
128
143
  @classmethod
129
144
  def available(cls) -> list[str]:
@@ -172,16 +187,16 @@ class OpenAIService(InferenceServiceABC):
172
187
  ) -> dict[str, Any]:
173
188
  """Calls the OpenAI API and returns the API response."""
174
189
  if files_list:
175
- encoded_image = files_list[0].base64_string
176
190
  content = [{"type": "text", "text": user_prompt}]
177
- content.append(
178
- {
179
- "type": "image_url",
180
- "image_url": {
181
- "url": f"data:image/jpeg;base64,{encoded_image}"
182
- },
183
- }
184
- )
191
+ for file_entry in files_list:
192
+ content.append(
193
+ {
194
+ "type": "image_url",
195
+ "image_url": {
196
+ "url": f"data:{file_entry.mime_type};base64,{file_entry.base64_string}"
197
+ },
198
+ }
199
+ )
185
200
  else:
186
201
  content = user_prompt
187
202
  client = self.async_client()