edsl 0.1.39__py3-none-any.whl → 0.1.39.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. edsl/Base.py +0 -28
  2. edsl/__init__.py +1 -1
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +17 -9
  5. edsl/agents/Invigilator.py +14 -13
  6. edsl/agents/InvigilatorBase.py +1 -4
  7. edsl/agents/PromptConstructor.py +22 -42
  8. edsl/agents/QuestionInstructionPromptBuilder.py +1 -1
  9. edsl/auto/AutoStudy.py +5 -18
  10. edsl/auto/StageBase.py +40 -53
  11. edsl/auto/StageQuestions.py +1 -2
  12. edsl/auto/utilities.py +6 -0
  13. edsl/coop/coop.py +5 -21
  14. edsl/data/Cache.py +18 -29
  15. edsl/data/CacheHandler.py +2 -0
  16. edsl/data/RemoteCacheSync.py +46 -154
  17. edsl/enums.py +0 -7
  18. edsl/inference_services/AnthropicService.py +16 -38
  19. edsl/inference_services/AvailableModelFetcher.py +1 -7
  20. edsl/inference_services/GoogleService.py +1 -5
  21. edsl/inference_services/InferenceServicesCollection.py +2 -18
  22. edsl/inference_services/OpenAIService.py +31 -46
  23. edsl/inference_services/TestService.py +3 -1
  24. edsl/inference_services/TogetherAIService.py +3 -5
  25. edsl/inference_services/data_structures.py +2 -74
  26. edsl/jobs/AnswerQuestionFunctionConstructor.py +113 -148
  27. edsl/jobs/FetchInvigilator.py +3 -10
  28. edsl/jobs/InterviewsConstructor.py +4 -6
  29. edsl/jobs/Jobs.py +233 -299
  30. edsl/jobs/JobsChecks.py +2 -2
  31. edsl/jobs/JobsPrompts.py +1 -1
  32. edsl/jobs/JobsRemoteInferenceHandler.py +136 -160
  33. edsl/jobs/interviews/Interview.py +42 -80
  34. edsl/jobs/runners/JobsRunnerAsyncio.py +358 -88
  35. edsl/jobs/runners/JobsRunnerStatus.py +165 -133
  36. edsl/jobs/tasks/TaskHistory.py +3 -24
  37. edsl/language_models/LanguageModel.py +4 -59
  38. edsl/language_models/ModelList.py +8 -19
  39. edsl/language_models/__init__.py +1 -1
  40. edsl/language_models/registry.py +180 -0
  41. edsl/language_models/repair.py +1 -1
  42. edsl/questions/QuestionBase.py +26 -35
  43. edsl/questions/{question_base_gen_mixin.py → QuestionBaseGenMixin.py} +49 -52
  44. edsl/questions/QuestionBasePromptsMixin.py +1 -1
  45. edsl/questions/QuestionBudget.py +1 -1
  46. edsl/questions/QuestionCheckBox.py +2 -2
  47. edsl/questions/QuestionExtract.py +7 -5
  48. edsl/questions/QuestionFreeText.py +1 -1
  49. edsl/questions/QuestionList.py +15 -9
  50. edsl/questions/QuestionMatrix.py +1 -1
  51. edsl/questions/QuestionMultipleChoice.py +1 -1
  52. edsl/questions/QuestionNumerical.py +1 -1
  53. edsl/questions/QuestionRank.py +1 -1
  54. edsl/questions/{response_validator_abc.py → ResponseValidatorABC.py} +18 -6
  55. edsl/questions/{response_validator_factory.py → ResponseValidatorFactory.py} +1 -7
  56. edsl/questions/SimpleAskMixin.py +1 -1
  57. edsl/questions/__init__.py +1 -1
  58. edsl/results/DatasetExportMixin.py +119 -60
  59. edsl/results/Result.py +3 -109
  60. edsl/results/Results.py +39 -50
  61. edsl/scenarios/FileStore.py +0 -32
  62. edsl/scenarios/ScenarioList.py +7 -35
  63. edsl/scenarios/handlers/csv.py +0 -11
  64. edsl/surveys/Survey.py +20 -71
  65. {edsl-0.1.39.dist-info → edsl-0.1.39.dev2.dist-info}/METADATA +1 -1
  66. {edsl-0.1.39.dist-info → edsl-0.1.39.dev2.dist-info}/RECORD +78 -84
  67. {edsl-0.1.39.dist-info → edsl-0.1.39.dev2.dist-info}/WHEEL +1 -1
  68. edsl/jobs/async_interview_runner.py +0 -138
  69. edsl/jobs/check_survey_scenario_compatibility.py +0 -85
  70. edsl/jobs/data_structures.py +0 -120
  71. edsl/jobs/results_exceptions_handler.py +0 -98
  72. edsl/language_models/model.py +0 -256
  73. edsl/questions/data_structures.py +0 -20
  74. edsl/results/file_exports.py +0 -252
  75. /edsl/agents/{question_option_processor.py → QuestionOptionProcessor.py} +0 -0
  76. /edsl/questions/{answer_validator_mixin.py → AnswerValidatorMixin.py} +0 -0
  77. /edsl/questions/{loop_processor.py → LoopProcessor.py} +0 -0
  78. /edsl/questions/{register_questions_meta.py → RegisterQuestionsMeta.py} +0 -0
  79. /edsl/results/{results_fetch_mixin.py → ResultsFetchMixin.py} +0 -0
  80. /edsl/results/{results_tools_mixin.py → ResultsToolsMixin.py} +0 -0
  81. /edsl/results/{results_selector.py → Selector.py} +0 -0
  82. /edsl/scenarios/{directory_scanner.py → DirectoryScanner.py} +0 -0
  83. /edsl/scenarios/{scenario_join.py → ScenarioJoin.py} +0 -0
  84. /edsl/scenarios/{scenario_selector.py → ScenarioSelector.py} +0 -0
  85. {edsl-0.1.39.dist-info → edsl-0.1.39.dev2.dist-info}/LICENSE +0 -0
edsl/data/Cache.py CHANGED
@@ -6,9 +6,11 @@ from __future__ import annotations
6
6
  import json
7
7
  import os
8
8
  import warnings
9
- from typing import Optional, Union, TYPE_CHECKING
9
+ from typing import Optional, Union
10
10
  from edsl.Base import Base
11
11
 
12
+
13
+ # from edsl.utilities.decorators import remove_edsl_version
12
14
  from edsl.utilities.remove_edsl_version import remove_edsl_version
13
15
  from edsl.exceptions.cache import CacheError
14
16
 
@@ -81,6 +83,10 @@ class Cache(Base):
81
83
 
82
84
  self._perform_checks()
83
85
 
86
+ # def rich_print(sefl):
87
+ # pass
88
+ # # raise NotImplementedError("This method is not implemented yet.")
89
+
84
90
  def code(sefl):
85
91
  pass
86
92
  # raise NotImplementedError("This method is not implemented yet.")
@@ -287,8 +293,8 @@ class Cache(Base):
287
293
 
288
294
  CACHE_PATH = CONFIG.get("EDSL_DATABASE_PATH")
289
295
  path = CACHE_PATH.replace("sqlite:///", "")
290
- # db_path = os.path.join(os.path.dirname(path), "data.db")
291
- return cls.from_sqlite_db(path)
296
+ db_path = os.path.join(os.path.dirname(path), "data.db")
297
+ return cls.from_sqlite_db(db_path=db_path)
292
298
 
293
299
  @classmethod
294
300
  def from_jsonl(cls, jsonlfile: str, db_path: Optional[str] = None) -> Cache:
@@ -362,32 +368,12 @@ class Cache(Base):
362
368
  scenarios.append(s)
363
369
  return ScenarioList(scenarios)
364
370
 
365
- def __floordiv__(self, other: "Cache") -> "Cache":
366
- """
367
- Return a new Cache containing entries that are in self but not in other.
368
- Uses // operator as alternative to subtraction.
369
-
370
- :param other: Another Cache object to compare against
371
- :return: A new Cache object containing unique entries
372
-
373
- >>> from edsl.data.CacheEntry import CacheEntry
374
- >>> ce1 = CacheEntry.example(randomize = True)
375
- >>> ce2 = CacheEntry.example(randomize = True)
376
- >>> ce2 = CacheEntry.example(randomize = True)
377
- >>> c1 = Cache(data={ce1.key: ce1, ce2.key: ce2})
378
- >>> c2 = Cache(data={ce1.key: ce1})
379
- >>> c3 = c1 // c2
380
- >>> len(c3)
381
- 1
382
- >>> c3.data[ce2.key] == ce2
383
- True
384
- """
385
- if not isinstance(other, Cache):
386
- raise CacheError("Can only compare two caches")
387
-
388
- diff_data = {k: v for k, v in self.data.items() if k not in other.data}
389
- return Cache(data=diff_data, immediate_write=self.immediate_write)
390
-
371
+ ####################
372
+ # REMOTE
373
+ ####################
374
+ # TODO: Make this work
375
+ # - Need to decide whether the cache belongs to a user and what can be shared
376
+ # - I.e., some cache entries? all or nothing?
391
377
  @classmethod
392
378
  def from_url(cls, db_path=None) -> Cache:
393
379
  """
@@ -413,6 +399,9 @@ class Cache(Base):
413
399
  if self.filename:
414
400
  self.write(self.filename)
415
401
 
402
+ ####################
403
+ # DUNDER / USEFUL
404
+ ####################
416
405
  def __hash__(self):
417
406
  """Return the hash of the Cache."""
418
407
  from edsl.utilities.utilities import dict_hash
edsl/data/CacheHandler.py CHANGED
@@ -67,6 +67,8 @@ class CacheHandler:
67
67
  if self.test:
68
68
  return Cache(data={})
69
69
 
70
+ # if self.CACHE_PATH is not None:
71
+ # return self.CACHE_PATH
70
72
  from edsl.config import CONFIG
71
73
 
72
74
  if hasattr(CONFIG, "EDSL_SESSION_CACHE"):
@@ -1,166 +1,71 @@
1
- from typing import List, Dict, Any, Optional, TYPE_CHECKING, Callable
2
- from dataclasses import dataclass
3
- from contextlib import AbstractContextManager
4
- from collections import UserList
5
-
6
- if TYPE_CHECKING:
7
- from .Cache import Cache
8
- from edsl.coop.coop import Coop
9
- from .CacheEntry import CacheEntry
10
-
11
- from logging import Logger
12
-
13
-
14
- class CacheKeyList(UserList):
15
- def __init__(self, data: List[str]):
16
- super().__init__(data)
17
- self.data = data
18
-
19
- def __repr__(self):
20
- import reprlib
21
-
22
- keys_repr = reprlib.repr(self.data)
23
- return f"CacheKeyList({keys_repr})"
24
-
25
-
26
- class CacheEntriesList(UserList):
27
- def __init__(self, data: List["CacheEntry"]):
28
- super().__init__(data)
29
- self.data = data
30
-
31
- def __repr__(self):
32
- import reprlib
33
-
34
- entries_repr = reprlib.repr(self.data)
35
- return f"CacheEntries({entries_repr})"
36
-
37
- def to_cache(self) -> "Cache":
38
- from edsl.data.Cache import Cache
39
-
40
- return Cache({entry.key: entry for entry in self.data})
41
-
42
-
43
- @dataclass
44
- class CacheDifference:
45
- client_missing_entries: CacheEntriesList
46
- server_missing_keys: List[str]
47
-
48
- def __repr__(self):
49
- """Returns a string representation of the CacheDifference object."""
50
- import reprlib
51
-
52
- missing_entries_repr = reprlib.repr(self.client_missing_entries)
53
- missing_keys_repr = reprlib.repr(self.server_missing_keys)
54
- return f"CacheDifference(client_missing_entries={missing_entries_repr}, server_missing_keys={missing_keys_repr})"
55
-
56
-
57
- class RemoteCacheSync(AbstractContextManager):
58
- """Synchronizes a local cache with a remote cache.
59
-
60
- Handles bidirectional synchronization:
61
- - Downloads missing entries from remote to local cache
62
- - Uploads new local entries to remote cache
63
- """
64
-
1
+ class RemoteCacheSync:
65
2
  def __init__(
66
- self,
67
- coop: "Coop",
68
- cache: "Cache",
69
- output_func: Callable,
70
- remote_cache: bool = True,
71
- remote_cache_description: str = "",
3
+ self, coop, cache, output_func, remote_cache=True, remote_cache_description=""
72
4
  ):
73
- """
74
- Initializes a RemoteCacheSync object.
75
-
76
- :param coop: Coop object for interacting with the remote cache
77
- :param cache: Cache object for local cache
78
- :param output_func: Function for outputting messages
79
- :param remote_cache: Whether to enable remote cache synchronization
80
- :param remote_cache_description: Description for remote cache entries
81
-
82
- """
83
5
  self.coop = coop
84
6
  self.cache = cache
85
7
  self._output = output_func
86
- self.remote_cache_enabled = remote_cache
8
+ self.remote_cache = remote_cache
9
+ self.old_entry_keys = []
10
+ self.new_cache_entries = []
87
11
  self.remote_cache_description = remote_cache_description
88
- self.initial_cache_keys = []
89
12
 
90
- def __enter__(self) -> "RemoteCacheSync":
91
- if self.remote_cache_enabled:
13
+ def __enter__(self):
14
+ if self.remote_cache:
92
15
  self._sync_from_remote()
93
- self.initial_cache_keys = list(self.cache.keys())
16
+ self.old_entry_keys = list(self.cache.keys())
94
17
  return self
95
18
 
96
19
  def __exit__(self, exc_type, exc_value, traceback):
97
- if self.remote_cache_enabled:
20
+ if self.remote_cache:
98
21
  self._sync_to_remote()
99
22
  return False # Propagate exceptions
100
23
 
101
- def _get_cache_difference(self) -> CacheDifference:
102
- """Retrieves differences between local and remote caches."""
103
- diff = self.coop.remote_cache_get_diff(self.cache.keys())
104
- return CacheDifference(
105
- client_missing_entries=diff.get("client_missing_cacheentries", []),
106
- server_missing_keys=diff.get("server_missing_cacheentry_keys", []),
24
+ def _sync_from_remote(self):
25
+ cache_difference = self.coop.remote_cache_get_diff(self.cache.keys())
26
+ client_missing_cacheentries = cache_difference.get(
27
+ "client_missing_cacheentries", []
107
28
  )
29
+ missing_entry_count = len(client_missing_cacheentries)
108
30
 
109
- def _sync_from_remote(self) -> None:
110
- """Downloads missing entries from remote cache to local cache."""
111
- diff: CacheDifference = self._get_cache_difference()
112
- missing_count = len(diff.client_missing_entries)
113
-
114
- if missing_count == 0:
31
+ if missing_entry_count > 0:
32
+ self._output(
33
+ f"Updating local cache with {missing_entry_count:,} new "
34
+ f"{'entry' if missing_entry_count == 1 else 'entries'} from remote..."
35
+ )
36
+ self.cache.add_from_dict(
37
+ {entry.key: entry for entry in client_missing_cacheentries}
38
+ )
39
+ self._output("Local cache updated!")
40
+ else:
115
41
  self._output("No new entries to add to local cache.")
116
- return
117
42
 
118
- self._output(
119
- f"Updating local cache with {missing_count:,} new "
120
- f"{'entry' if missing_count == 1 else 'entries'} from remote..."
43
+ def _sync_to_remote(self):
44
+ cache_difference = self.coop.remote_cache_get_diff(self.cache.keys())
45
+ server_missing_cacheentry_keys = cache_difference.get(
46
+ "server_missing_cacheentry_keys", []
121
47
  )
122
-
123
- self.cache.add_from_dict(
124
- {entry.key: entry for entry in diff.client_missing_entries}
125
- )
126
- self._output("Local cache updated!")
127
-
128
- def _get_entries_to_upload(self, diff: CacheDifference) -> CacheEntriesList:
129
- """Determines which entries need to be uploaded to remote cache."""
130
- # Get entries for keys missing from server
131
- server_missing_entries = CacheEntriesList(
132
- [
133
- entry
134
- for key in diff.server_missing_keys
135
- if (entry := self.cache.data.get(key)) is not None
136
- ]
137
- )
138
-
139
- # Get newly added entries since sync started
140
- new_entries = CacheEntriesList(
141
- [
142
- entry
143
- for entry in self.cache.values()
144
- if entry.key not in self.initial_cache_keys
145
- ]
146
- )
147
-
148
- return server_missing_entries + new_entries
149
-
150
- def _sync_to_remote(self) -> None:
151
- """Uploads new local entries to remote cache."""
152
- diff: CacheDifference = self._get_cache_difference()
153
- entries_to_upload: CacheEntriesList = self._get_entries_to_upload(diff)
154
- upload_count = len(entries_to_upload)
155
-
156
- if upload_count > 0:
48
+ server_missing_cacheentries = [
49
+ entry
50
+ for key in server_missing_cacheentry_keys
51
+ if (entry := self.cache.data.get(key)) is not None
52
+ ]
53
+
54
+ new_cache_entries = [
55
+ entry
56
+ for entry in self.cache.values()
57
+ if entry.key not in self.old_entry_keys
58
+ ]
59
+ server_missing_cacheentries.extend(new_cache_entries)
60
+ new_entry_count = len(server_missing_cacheentries)
61
+
62
+ if new_entry_count > 0:
157
63
  self._output(
158
- f"Updating remote cache with {upload_count:,} new "
159
- f"{'entry' if upload_count == 1 else 'entries'}..."
64
+ f"Updating remote cache with {new_entry_count:,} new "
65
+ f"{'entry' if new_entry_count == 1 else 'entries'}..."
160
66
  )
161
-
162
67
  self.coop.remote_cache_create_many(
163
- entries_to_upload,
68
+ server_missing_cacheentries,
164
69
  visibility="private",
165
70
  description=self.remote_cache_description,
166
71
  )
@@ -171,16 +76,3 @@ class RemoteCacheSync(AbstractContextManager):
171
76
  self._output(
172
77
  f"There are {len(self.cache.keys()):,} entries in the local cache."
173
78
  )
174
-
175
-
176
- if __name__ == "__main__":
177
- import doctest
178
-
179
- doctest.testmod()
180
-
181
- from edsl.coop.coop import Coop
182
- from edsl.data.Cache import Cache
183
- from edsl.data.CacheEntry import CacheEntry
184
-
185
- r = RemoteCacheSync(Coop(), Cache(), print)
186
- diff = r._get_cache_difference()
edsl/enums.py CHANGED
@@ -86,13 +86,6 @@ InferenceServiceLiteral = Literal[
86
86
  "perplexity",
87
87
  ]
88
88
 
89
- available_models_urls = {
90
- "anthropic": "https://docs.anthropic.com/en/docs/about-claude/models",
91
- "openai": "https://platform.openai.com/docs/models/gp",
92
- "groq": "https://console.groq.com/docs/models",
93
- "google": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models",
94
- }
95
-
96
89
 
97
90
  service_to_api_keyname = {
98
91
  InferenceServiceType.BEDROCK.value: "TBD",
@@ -11,27 +11,21 @@ class AnthropicService(InferenceServiceABC):
11
11
 
12
12
  _inference_service_ = "anthropic"
13
13
  _env_key_name_ = "ANTHROPIC_API_KEY"
14
- key_sequence = ["content", 0, "text"]
14
+ key_sequence = ["content", 0, "text"] # ["content"][0]["text"]
15
15
  usage_sequence = ["usage"]
16
16
  input_token_name = "input_tokens"
17
17
  output_token_name = "output_tokens"
18
18
  model_exclude_list = []
19
19
 
20
- @classmethod
21
- def get_model_list(cls, api_key: str = None):
22
-
23
- import requests
24
-
25
- if api_key is None:
26
- api_key = os.environ.get("ANTHROPIC_API_KEY")
27
- headers = {"x-api-key": api_key, "anthropic-version": "2023-06-01"}
28
- response = requests.get("https://api.anthropic.com/v1/models", headers=headers)
29
- model_names = [m["id"] for m in response.json()["data"]]
30
- return model_names
31
-
32
20
  @classmethod
33
21
  def available(cls):
34
- return cls.get_model_list()
22
+ # TODO - replace with an API call
23
+ return [
24
+ "claude-3-5-sonnet-20240620",
25
+ "claude-3-opus-20240229",
26
+ "claude-3-sonnet-20240229",
27
+ "claude-3-haiku-20240307",
28
+ ]
35
29
 
36
30
  @classmethod
37
31
  def create_model(
@@ -68,36 +62,20 @@ class AnthropicService(InferenceServiceABC):
68
62
  system_prompt: str = "",
69
63
  files_list: Optional[List["Files"]] = None,
70
64
  ) -> dict[str, Any]:
71
- """Calls the Anthropic API and returns the API response."""
65
+ """Calls the OpenAI API and returns the API response."""
72
66
 
73
- messages = [
74
- {
75
- "role": "user",
76
- "content": [{"type": "text", "text": user_prompt}],
77
- }
78
- ]
79
- if files_list:
80
- for file_entry in files_list:
81
- encoded_image = file_entry.base64_string
82
- messages[0]["content"].append(
83
- {
84
- "type": "image",
85
- "source": {
86
- "type": "base64",
87
- "media_type": file_entry.mime_type,
88
- "data": encoded_image,
89
- },
90
- }
91
- )
92
- # breakpoint()
93
- client = AsyncAnthropic(api_key=self.api_token)
67
+ api_key = os.environ.get("ANTHROPIC_API_KEY")
68
+ client = AsyncAnthropic(api_key=api_key)
94
69
 
95
70
  response = await client.messages.create(
96
71
  model=model_name,
97
72
  max_tokens=self.max_tokens,
98
73
  temperature=self.temperature,
99
- system=system_prompt, # note that the Anthropic API uses "system" parameter rather than put it in the message
100
- messages=messages,
74
+ system=system_prompt,
75
+ messages=[
76
+ # {"role": "system", "content": system_prompt},
77
+ {"role": "user", "content": user_prompt},
78
+ ],
101
79
  )
102
80
  return response.model_dump()
103
81
 
@@ -133,12 +133,6 @@ class AvailableModelFetcher:
133
133
  )
134
134
  service_name = service._inference_service_
135
135
 
136
- if not service_models:
137
- import warnings
138
-
139
- warnings.warn(f"No models found for service {service_name}")
140
- return [], service_name
141
-
142
136
  models_list = AvailableModels(
143
137
  [
144
138
  LanguageModelInfo(
@@ -183,7 +177,7 @@ class AvailableModelFetcher:
183
177
  )
184
178
 
185
179
  except Exception as exc:
186
- print(f"Service query failed for service {service_name}: {exc}")
180
+ print(f"Service query failed: {exc}")
187
181
  continue
188
182
 
189
183
  return AvailableModels(all_models)
@@ -40,17 +40,13 @@ class GoogleService(InferenceServiceABC):
40
40
  model_exclude_list = []
41
41
 
42
42
  @classmethod
43
- def get_model_list(cls):
43
+ def available(cls) -> List[str]:
44
44
  model_list = []
45
45
  for m in genai.list_models():
46
46
  if "generateContent" in m.supported_generation_methods:
47
47
  model_list.append(m.name.split("/")[-1])
48
48
  return model_list
49
49
 
50
- @classmethod
51
- def available(cls) -> List[str]:
52
- return cls.get_model_list()
53
-
54
50
  @classmethod
55
51
  def create_model(
56
52
  cls, model_name: str = "gemini-pro", model_class_name=None
@@ -71,12 +71,7 @@ class ModelResolver:
71
71
  self._models_to_services[model_name] = service
72
72
  return service
73
73
 
74
- raise InferenceServiceError(
75
- f"""Model {model_name} not found in any services.
76
- If you know the service that has this model, use the service_name parameter directly.
77
- E.g., Model("gpt-4o", service_name="openai")
78
- """
79
- )
74
+ raise InferenceServiceError(f"Model {model_name} not found in any services")
80
75
 
81
76
 
82
77
  class InferenceServicesCollection:
@@ -98,9 +93,6 @@ class InferenceServicesCollection:
98
93
  if service_name not in cls.added_models:
99
94
  cls.added_models[service_name].append(model_name)
100
95
 
101
- def service_names_to_classes(self) -> Dict[str, InferenceServiceABC]:
102
- return {service._inference_service_: service for service in self.services}
103
-
104
96
  def available(
105
97
  self,
106
98
  service: Optional[str] = None,
@@ -120,15 +112,7 @@ class InferenceServicesCollection:
120
112
  def create_model_factory(
121
113
  self, model_name: str, service_name: Optional[InferenceServiceLiteral] = None
122
114
  ) -> "LanguageModel":
123
-
124
- if service_name is None: # we try to find the right service
125
- service = self.resolver.resolve_model(model_name, service_name)
126
- else: # if they passed a service, we'll use that
127
- service = self.service_names_to_classes().get(service_name)
128
-
129
- if not service: # but if we can't find it, we'll raise an error
130
- raise InferenceServiceError(f"Service {service_name} not found")
131
-
115
+ service = self.resolver.resolve_model(model_name, service_name)
132
116
  return service.create_model(model_name)
133
117
 
134
118
 
@@ -1,8 +1,7 @@
1
1
  from __future__ import annotations
2
- from typing import Any, List, Optional, Dict, NewType
2
+ from typing import Any, List, Optional
3
3
  import os
4
4
 
5
-
6
5
  import openai
7
6
 
8
7
  from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
@@ -12,8 +11,6 @@ from edsl.utilities.utilities import fix_partial_correct_response
12
11
 
13
12
  from edsl.config import CONFIG
14
13
 
15
- APIToken = NewType("APIToken", str)
16
-
17
14
 
18
15
  class OpenAIService(InferenceServiceABC):
19
16
  """OpenAI service class."""
@@ -25,43 +22,35 @@ class OpenAIService(InferenceServiceABC):
25
22
  _sync_client_ = openai.OpenAI
26
23
  _async_client_ = openai.AsyncOpenAI
27
24
 
28
- _sync_client_instances: Dict[APIToken, openai.OpenAI] = {}
29
- _async_client_instances: Dict[APIToken, openai.AsyncOpenAI] = {}
25
+ _sync_client_instance = None
26
+ _async_client_instance = None
30
27
 
31
28
  key_sequence = ["choices", 0, "message", "content"]
32
29
  usage_sequence = ["usage"]
33
30
  input_token_name = "prompt_tokens"
34
31
  output_token_name = "completion_tokens"
35
32
 
36
- available_models_url = "https://platform.openai.com/docs/models/gp"
37
-
38
33
  def __init_subclass__(cls, **kwargs):
39
34
  super().__init_subclass__(**kwargs)
40
- # so subclasses that use the OpenAI api key have to create their own instances of the clients
41
- cls._sync_client_instances = {}
42
- cls._async_client_instances = {}
35
+ # so subclasses have to create their own instances of the clients
36
+ cls._sync_client_instance = None
37
+ cls._async_client_instance = None
43
38
 
44
39
  @classmethod
45
- def sync_client(cls, api_key):
46
- if api_key not in cls._sync_client_instances:
47
- client = cls._sync_client_(
48
- api_key=api_key,
49
- base_url=cls._base_url_,
40
+ def sync_client(cls):
41
+ if cls._sync_client_instance is None:
42
+ cls._sync_client_instance = cls._sync_client_(
43
+ api_key=os.getenv(cls._env_key_name_), base_url=cls._base_url_
50
44
  )
51
- cls._sync_client_instances[api_key] = client
52
- client = cls._sync_client_instances[api_key]
53
- return client
45
+ return cls._sync_client_instance
54
46
 
55
47
  @classmethod
56
- def async_client(cls, api_key):
57
- if api_key not in cls._async_client_instances:
58
- client = cls._async_client_(
59
- api_key=api_key,
60
- base_url=cls._base_url_,
48
+ def async_client(cls):
49
+ if cls._async_client_instance is None:
50
+ cls._async_client_instance = cls._async_client_(
51
+ api_key=os.getenv(cls._env_key_name_), base_url=cls._base_url_
61
52
  )
62
- cls._async_client_instances[api_key] = client
63
- client = cls._async_client_instances[api_key]
64
- return client
53
+ return cls._async_client_instance
65
54
 
66
55
  model_exclude_list = [
67
56
  "whisper-1",
@@ -83,24 +72,20 @@ class OpenAIService(InferenceServiceABC):
83
72
  _models_list_cache: List[str] = []
84
73
 
85
74
  @classmethod
86
- def get_model_list(cls, api_key=None):
87
- if api_key is None:
88
- api_key = os.getenv(cls._env_key_name_)
89
- raw_list = cls.sync_client(api_key).models.list()
75
+ def get_model_list(cls):
76
+ raw_list = cls.sync_client().models.list()
90
77
  if hasattr(raw_list, "data"):
91
78
  return raw_list.data
92
79
  else:
93
80
  return raw_list
94
81
 
95
82
  @classmethod
96
- def available(cls, api_token=None) -> List[str]:
97
- if api_token is None:
98
- api_token = os.getenv(cls._env_key_name_)
83
+ def available(cls) -> List[str]:
99
84
  if not cls._models_list_cache:
100
85
  try:
101
86
  cls._models_list_cache = [
102
87
  m.id
103
- for m in cls.get_model_list(api_key=api_token)
88
+ for m in cls.get_model_list()
104
89
  if m.id not in cls.model_exclude_list
105
90
  ]
106
91
  except Exception as e:
@@ -135,10 +120,10 @@ class OpenAIService(InferenceServiceABC):
135
120
  }
136
121
 
137
122
  def sync_client(self):
138
- return cls.sync_client(api_key=self.api_token)
123
+ return cls.sync_client()
139
124
 
140
125
  def async_client(self):
141
- return cls.async_client(api_key=self.api_token)
126
+ return cls.async_client()
142
127
 
143
128
  @classmethod
144
129
  def available(cls) -> list[str]:
@@ -187,16 +172,16 @@ class OpenAIService(InferenceServiceABC):
187
172
  ) -> dict[str, Any]:
188
173
  """Calls the OpenAI API and returns the API response."""
189
174
  if files_list:
175
+ encoded_image = files_list[0].base64_string
190
176
  content = [{"type": "text", "text": user_prompt}]
191
- for file_entry in files_list:
192
- content.append(
193
- {
194
- "type": "image_url",
195
- "image_url": {
196
- "url": f"data:{file_entry.mime_type};base64,{file_entry.base64_string}"
197
- },
198
- }
199
- )
177
+ content.append(
178
+ {
179
+ "type": "image_url",
180
+ "image_url": {
181
+ "url": f"data:image/jpeg;base64,{encoded_image}"
182
+ },
183
+ }
184
+ )
200
185
  else:
201
186
  content = user_prompt
202
187
  client = self.async_client()
@@ -51,7 +51,6 @@ class TestService(InferenceServiceABC):
51
51
  @property
52
52
  def _canned_response(self):
53
53
  if hasattr(self, "canned_response"):
54
-
55
54
  return self.canned_response
56
55
  else:
57
56
  return "Hello, world"
@@ -64,6 +63,9 @@ class TestService(InferenceServiceABC):
64
63
  files_list: Optional[List["File"]] = None,
65
64
  ) -> dict[str, Any]:
66
65
  await asyncio.sleep(0.1)
66
+ # return {"message": """{"answer": "Hello, world"}"""}
67
+
68
+ # breakpoint()
67
69
 
68
70
  if hasattr(self, "throw_exception") and self.throw_exception:
69
71
  if hasattr(self, "exception_probability"):