edsl 0.1.38.dev4__py3-none-any.whl → 0.1.39__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (212) hide show
  1. edsl/Base.py +197 -116
  2. edsl/__init__.py +15 -7
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +351 -147
  5. edsl/agents/AgentList.py +211 -73
  6. edsl/agents/Invigilator.py +101 -50
  7. edsl/agents/InvigilatorBase.py +62 -70
  8. edsl/agents/PromptConstructor.py +143 -225
  9. edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
  10. edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
  11. edsl/agents/__init__.py +0 -1
  12. edsl/agents/prompt_helpers.py +3 -3
  13. edsl/agents/question_option_processor.py +172 -0
  14. edsl/auto/AutoStudy.py +18 -5
  15. edsl/auto/StageBase.py +53 -40
  16. edsl/auto/StageQuestions.py +2 -1
  17. edsl/auto/utilities.py +0 -6
  18. edsl/config.py +22 -2
  19. edsl/conversation/car_buying.py +2 -1
  20. edsl/coop/CoopFunctionsMixin.py +15 -0
  21. edsl/coop/ExpectedParrotKeyHandler.py +125 -0
  22. edsl/coop/PriceFetcher.py +1 -1
  23. edsl/coop/coop.py +125 -47
  24. edsl/coop/utils.py +14 -14
  25. edsl/data/Cache.py +45 -27
  26. edsl/data/CacheEntry.py +12 -15
  27. edsl/data/CacheHandler.py +31 -12
  28. edsl/data/RemoteCacheSync.py +154 -46
  29. edsl/data/__init__.py +4 -3
  30. edsl/data_transfer_models.py +2 -1
  31. edsl/enums.py +27 -0
  32. edsl/exceptions/__init__.py +50 -50
  33. edsl/exceptions/agents.py +12 -0
  34. edsl/exceptions/inference_services.py +5 -0
  35. edsl/exceptions/questions.py +24 -6
  36. edsl/exceptions/scenarios.py +7 -0
  37. edsl/inference_services/AnthropicService.py +38 -19
  38. edsl/inference_services/AvailableModelCacheHandler.py +184 -0
  39. edsl/inference_services/AvailableModelFetcher.py +215 -0
  40. edsl/inference_services/AwsBedrock.py +0 -2
  41. edsl/inference_services/AzureAI.py +0 -2
  42. edsl/inference_services/GoogleService.py +7 -12
  43. edsl/inference_services/InferenceServiceABC.py +18 -85
  44. edsl/inference_services/InferenceServicesCollection.py +120 -79
  45. edsl/inference_services/MistralAIService.py +0 -3
  46. edsl/inference_services/OpenAIService.py +47 -35
  47. edsl/inference_services/PerplexityService.py +0 -3
  48. edsl/inference_services/ServiceAvailability.py +135 -0
  49. edsl/inference_services/TestService.py +11 -10
  50. edsl/inference_services/TogetherAIService.py +5 -3
  51. edsl/inference_services/data_structures.py +134 -0
  52. edsl/jobs/AnswerQuestionFunctionConstructor.py +223 -0
  53. edsl/jobs/Answers.py +1 -14
  54. edsl/jobs/FetchInvigilator.py +47 -0
  55. edsl/jobs/InterviewTaskManager.py +98 -0
  56. edsl/jobs/InterviewsConstructor.py +50 -0
  57. edsl/jobs/Jobs.py +356 -431
  58. edsl/jobs/JobsChecks.py +35 -10
  59. edsl/jobs/JobsComponentConstructor.py +189 -0
  60. edsl/jobs/JobsPrompts.py +6 -4
  61. edsl/jobs/JobsRemoteInferenceHandler.py +205 -133
  62. edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
  63. edsl/jobs/RequestTokenEstimator.py +30 -0
  64. edsl/jobs/async_interview_runner.py +138 -0
  65. edsl/jobs/buckets/BucketCollection.py +44 -3
  66. edsl/jobs/buckets/TokenBucket.py +53 -21
  67. edsl/jobs/buckets/TokenBucketAPI.py +211 -0
  68. edsl/jobs/buckets/TokenBucketClient.py +191 -0
  69. edsl/jobs/check_survey_scenario_compatibility.py +85 -0
  70. edsl/jobs/data_structures.py +120 -0
  71. edsl/jobs/decorators.py +35 -0
  72. edsl/jobs/interviews/Interview.py +143 -408
  73. edsl/jobs/jobs_status_enums.py +9 -0
  74. edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
  75. edsl/jobs/results_exceptions_handler.py +98 -0
  76. edsl/jobs/runners/JobsRunnerAsyncio.py +88 -403
  77. edsl/jobs/runners/JobsRunnerStatus.py +133 -165
  78. edsl/jobs/tasks/QuestionTaskCreator.py +21 -19
  79. edsl/jobs/tasks/TaskHistory.py +38 -18
  80. edsl/jobs/tasks/task_status_enum.py +0 -2
  81. edsl/language_models/ComputeCost.py +63 -0
  82. edsl/language_models/LanguageModel.py +194 -236
  83. edsl/language_models/ModelList.py +28 -19
  84. edsl/language_models/PriceManager.py +127 -0
  85. edsl/language_models/RawResponseHandler.py +106 -0
  86. edsl/language_models/ServiceDataSources.py +0 -0
  87. edsl/language_models/__init__.py +1 -2
  88. edsl/language_models/key_management/KeyLookup.py +63 -0
  89. edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
  90. edsl/language_models/key_management/KeyLookupCollection.py +38 -0
  91. edsl/language_models/key_management/__init__.py +0 -0
  92. edsl/language_models/key_management/models.py +131 -0
  93. edsl/language_models/model.py +256 -0
  94. edsl/language_models/repair.py +2 -2
  95. edsl/language_models/utilities.py +5 -4
  96. edsl/notebooks/Notebook.py +19 -14
  97. edsl/notebooks/NotebookToLaTeX.py +142 -0
  98. edsl/prompts/Prompt.py +29 -39
  99. edsl/questions/ExceptionExplainer.py +77 -0
  100. edsl/questions/HTMLQuestion.py +103 -0
  101. edsl/questions/QuestionBase.py +68 -214
  102. edsl/questions/QuestionBasePromptsMixin.py +7 -3
  103. edsl/questions/QuestionBudget.py +1 -1
  104. edsl/questions/QuestionCheckBox.py +3 -3
  105. edsl/questions/QuestionExtract.py +5 -7
  106. edsl/questions/QuestionFreeText.py +2 -3
  107. edsl/questions/QuestionList.py +10 -18
  108. edsl/questions/QuestionMatrix.py +265 -0
  109. edsl/questions/QuestionMultipleChoice.py +67 -23
  110. edsl/questions/QuestionNumerical.py +2 -4
  111. edsl/questions/QuestionRank.py +7 -17
  112. edsl/questions/SimpleAskMixin.py +4 -3
  113. edsl/questions/__init__.py +2 -1
  114. edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +47 -2
  115. edsl/questions/data_structures.py +20 -0
  116. edsl/questions/derived/QuestionLinearScale.py +6 -3
  117. edsl/questions/derived/QuestionTopK.py +1 -1
  118. edsl/questions/descriptors.py +17 -3
  119. edsl/questions/loop_processor.py +149 -0
  120. edsl/questions/{QuestionBaseGenMixin.py → question_base_gen_mixin.py} +57 -50
  121. edsl/questions/question_registry.py +1 -1
  122. edsl/questions/{ResponseValidatorABC.py → response_validator_abc.py} +40 -26
  123. edsl/questions/response_validator_factory.py +34 -0
  124. edsl/questions/templates/matrix/__init__.py +1 -0
  125. edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
  126. edsl/questions/templates/matrix/question_presentation.jinja +20 -0
  127. edsl/results/CSSParameterizer.py +1 -1
  128. edsl/results/Dataset.py +170 -7
  129. edsl/results/DatasetExportMixin.py +168 -305
  130. edsl/results/DatasetTree.py +28 -8
  131. edsl/results/MarkdownToDocx.py +122 -0
  132. edsl/results/MarkdownToPDF.py +111 -0
  133. edsl/results/Result.py +298 -206
  134. edsl/results/Results.py +149 -131
  135. edsl/results/ResultsExportMixin.py +2 -0
  136. edsl/results/TableDisplay.py +98 -171
  137. edsl/results/TextEditor.py +50 -0
  138. edsl/results/__init__.py +1 -1
  139. edsl/results/file_exports.py +252 -0
  140. edsl/results/{Selector.py → results_selector.py} +23 -13
  141. edsl/results/smart_objects.py +96 -0
  142. edsl/results/table_data_class.py +12 -0
  143. edsl/results/table_renderers.py +118 -0
  144. edsl/scenarios/ConstructDownloadLink.py +109 -0
  145. edsl/scenarios/DocumentChunker.py +102 -0
  146. edsl/scenarios/DocxScenario.py +16 -0
  147. edsl/scenarios/FileStore.py +150 -239
  148. edsl/scenarios/PdfExtractor.py +40 -0
  149. edsl/scenarios/Scenario.py +90 -193
  150. edsl/scenarios/ScenarioHtmlMixin.py +4 -3
  151. edsl/scenarios/ScenarioList.py +415 -244
  152. edsl/scenarios/ScenarioListExportMixin.py +0 -7
  153. edsl/scenarios/ScenarioListPdfMixin.py +15 -37
  154. edsl/scenarios/__init__.py +1 -2
  155. edsl/scenarios/directory_scanner.py +96 -0
  156. edsl/scenarios/file_methods.py +85 -0
  157. edsl/scenarios/handlers/__init__.py +13 -0
  158. edsl/scenarios/handlers/csv.py +49 -0
  159. edsl/scenarios/handlers/docx.py +76 -0
  160. edsl/scenarios/handlers/html.py +37 -0
  161. edsl/scenarios/handlers/json.py +111 -0
  162. edsl/scenarios/handlers/latex.py +5 -0
  163. edsl/scenarios/handlers/md.py +51 -0
  164. edsl/scenarios/handlers/pdf.py +68 -0
  165. edsl/scenarios/handlers/png.py +39 -0
  166. edsl/scenarios/handlers/pptx.py +105 -0
  167. edsl/scenarios/handlers/py.py +294 -0
  168. edsl/scenarios/handlers/sql.py +313 -0
  169. edsl/scenarios/handlers/sqlite.py +149 -0
  170. edsl/scenarios/handlers/txt.py +33 -0
  171. edsl/scenarios/{ScenarioJoin.py → scenario_join.py} +10 -6
  172. edsl/scenarios/scenario_selector.py +156 -0
  173. edsl/study/ObjectEntry.py +1 -1
  174. edsl/study/SnapShot.py +1 -1
  175. edsl/study/Study.py +5 -12
  176. edsl/surveys/ConstructDAG.py +92 -0
  177. edsl/surveys/EditSurvey.py +221 -0
  178. edsl/surveys/InstructionHandler.py +100 -0
  179. edsl/surveys/MemoryManagement.py +72 -0
  180. edsl/surveys/Rule.py +5 -4
  181. edsl/surveys/RuleCollection.py +25 -27
  182. edsl/surveys/RuleManager.py +172 -0
  183. edsl/surveys/Simulator.py +75 -0
  184. edsl/surveys/Survey.py +270 -791
  185. edsl/surveys/SurveyCSS.py +20 -8
  186. edsl/surveys/{SurveyFlowVisualizationMixin.py → SurveyFlowVisualization.py} +11 -9
  187. edsl/surveys/SurveyToApp.py +141 -0
  188. edsl/surveys/__init__.py +4 -2
  189. edsl/surveys/descriptors.py +6 -2
  190. edsl/surveys/instructions/ChangeInstruction.py +1 -2
  191. edsl/surveys/instructions/Instruction.py +4 -13
  192. edsl/surveys/instructions/InstructionCollection.py +11 -6
  193. edsl/templates/error_reporting/interview_details.html +1 -1
  194. edsl/templates/error_reporting/report.html +1 -1
  195. edsl/tools/plotting.py +1 -1
  196. edsl/utilities/PrettyList.py +56 -0
  197. edsl/utilities/is_notebook.py +18 -0
  198. edsl/utilities/is_valid_variable_name.py +11 -0
  199. edsl/utilities/remove_edsl_version.py +24 -0
  200. edsl/utilities/utilities.py +35 -23
  201. {edsl-0.1.38.dev4.dist-info → edsl-0.1.39.dist-info}/METADATA +12 -10
  202. edsl-0.1.39.dist-info/RECORD +358 -0
  203. {edsl-0.1.38.dev4.dist-info → edsl-0.1.39.dist-info}/WHEEL +1 -1
  204. edsl/language_models/KeyLookup.py +0 -30
  205. edsl/language_models/registry.py +0 -190
  206. edsl/language_models/unused/ReplicateBase.py +0 -83
  207. edsl/results/ResultsDBMixin.py +0 -238
  208. edsl-0.1.38.dev4.dist-info/RECORD +0 -277
  209. /edsl/questions/{RegisterQuestionsMeta.py → register_questions_meta.py} +0 -0
  210. /edsl/results/{ResultsFetchMixin.py → results_fetch_mixin.py} +0 -0
  211. /edsl/results/{ResultsToolsMixin.py → results_tools_mixin.py} +0 -0
  212. {edsl-0.1.38.dev4.dist-info → edsl-0.1.39.dist-info}/LICENSE +0 -0
@@ -11,21 +11,27 @@ class AnthropicService(InferenceServiceABC):
11
11
 
12
12
  _inference_service_ = "anthropic"
13
13
  _env_key_name_ = "ANTHROPIC_API_KEY"
14
- key_sequence = ["content", 0, "text"] # ["content"][0]["text"]
14
+ key_sequence = ["content", 0, "text"]
15
15
  usage_sequence = ["usage"]
16
16
  input_token_name = "input_tokens"
17
17
  output_token_name = "output_tokens"
18
18
  model_exclude_list = []
19
19
 
20
+ @classmethod
21
+ def get_model_list(cls, api_key: str = None):
22
+
23
+ import requests
24
+
25
+ if api_key is None:
26
+ api_key = os.environ.get("ANTHROPIC_API_KEY")
27
+ headers = {"x-api-key": api_key, "anthropic-version": "2023-06-01"}
28
+ response = requests.get("https://api.anthropic.com/v1/models", headers=headers)
29
+ model_names = [m["id"] for m in response.json()["data"]]
30
+ return model_names
31
+
20
32
  @classmethod
21
33
  def available(cls):
22
- # TODO - replace with an API call
23
- return [
24
- "claude-3-5-sonnet-20240620",
25
- "claude-3-opus-20240229",
26
- "claude-3-sonnet-20240229",
27
- "claude-3-haiku-20240307",
28
- ]
34
+ return cls.get_model_list()
29
35
 
30
36
  @classmethod
31
37
  def create_model(
@@ -56,29 +62,42 @@ class AnthropicService(InferenceServiceABC):
56
62
  "top_logprobs": 3,
57
63
  }
58
64
 
59
- _tpm = cls.get_tpm(cls)
60
- _rpm = cls.get_rpm(cls)
61
-
62
65
  async def async_execute_model_call(
63
66
  self,
64
67
  user_prompt: str,
65
68
  system_prompt: str = "",
66
69
  files_list: Optional[List["Files"]] = None,
67
70
  ) -> dict[str, Any]:
68
- """Calls the OpenAI API and returns the API response."""
71
+ """Calls the Anthropic API and returns the API response."""
69
72
 
70
- api_key = os.environ.get("ANTHROPIC_API_KEY")
71
- client = AsyncAnthropic(api_key=api_key)
73
+ messages = [
74
+ {
75
+ "role": "user",
76
+ "content": [{"type": "text", "text": user_prompt}],
77
+ }
78
+ ]
79
+ if files_list:
80
+ for file_entry in files_list:
81
+ encoded_image = file_entry.base64_string
82
+ messages[0]["content"].append(
83
+ {
84
+ "type": "image",
85
+ "source": {
86
+ "type": "base64",
87
+ "media_type": file_entry.mime_type,
88
+ "data": encoded_image,
89
+ },
90
+ }
91
+ )
92
+ # breakpoint()
93
+ client = AsyncAnthropic(api_key=self.api_token)
72
94
 
73
95
  response = await client.messages.create(
74
96
  model=model_name,
75
97
  max_tokens=self.max_tokens,
76
98
  temperature=self.temperature,
77
- system=system_prompt,
78
- messages=[
79
- # {"role": "system", "content": system_prompt},
80
- {"role": "user", "content": user_prompt},
81
- ],
99
+ system=system_prompt, # note that the Anthropic API uses "system" parameter rather than put it in the message
100
+ messages=messages,
82
101
  )
83
102
  return response.model_dump()
84
103
 
@@ -0,0 +1,184 @@
1
+ from typing import List, Optional, get_args, Union
2
+ from pathlib import Path
3
+ import sqlite3
4
+ from datetime import datetime
5
+ import tempfile
6
+ from platformdirs import user_cache_dir
7
+ from dataclasses import dataclass
8
+ import os
9
+
10
+ from edsl.inference_services.data_structures import LanguageModelInfo, AvailableModels
11
+ from edsl.enums import InferenceServiceLiteral
12
+
13
+
14
+ class AvailableModelCacheHandler:
15
+ MAX_ROWS = 1000
16
+ CACHE_VALIDITY_HOURS = 48
17
+
18
+ def __init__(
19
+ self,
20
+ cache_validity_hours: int = 48,
21
+ verbose: bool = False,
22
+ testing_db_name: str = None,
23
+ ):
24
+ self.cache_validity_hours = cache_validity_hours
25
+ self.verbose = verbose
26
+
27
+ if testing_db_name:
28
+ self.cache_dir = Path(tempfile.mkdtemp())
29
+ self.db_path = self.cache_dir / testing_db_name
30
+ else:
31
+ self.cache_dir = Path(user_cache_dir("edsl", "model_availability"))
32
+ self.db_path = self.cache_dir / "available_models.db"
33
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
34
+
35
+ if os.path.exists(self.db_path):
36
+ if self.verbose:
37
+ print(f"Using existing cache DB: {self.db_path}")
38
+ else:
39
+ self._initialize_db()
40
+
41
+ @property
42
+ def path_to_db(self):
43
+ return self.db_path
44
+
45
+ def _initialize_db(self):
46
+ """Initialize the SQLite database with the required schema."""
47
+ with sqlite3.connect(self.db_path) as conn:
48
+ cursor = conn.cursor()
49
+ # Drop the old table if it exists (for migration)
50
+ cursor.execute("DROP TABLE IF EXISTS model_cache")
51
+ cursor.execute(
52
+ """
53
+ CREATE TABLE IF NOT EXISTS model_cache (
54
+ timestamp DATETIME NOT NULL,
55
+ model_name TEXT NOT NULL,
56
+ service_name TEXT NOT NULL,
57
+ UNIQUE(model_name, service_name)
58
+ )
59
+ """
60
+ )
61
+ conn.commit()
62
+
63
+ def _prune_old_entries(self, conn: sqlite3.Connection):
64
+ """Delete oldest entries when MAX_ROWS is exceeded."""
65
+ cursor = conn.cursor()
66
+ cursor.execute("SELECT COUNT(*) FROM model_cache")
67
+ count = cursor.fetchone()[0]
68
+
69
+ if count > self.MAX_ROWS:
70
+ cursor.execute(
71
+ """
72
+ DELETE FROM model_cache
73
+ WHERE rowid IN (
74
+ SELECT rowid
75
+ FROM model_cache
76
+ ORDER BY timestamp ASC
77
+ LIMIT ?
78
+ )
79
+ """,
80
+ (count - self.MAX_ROWS,),
81
+ )
82
+ conn.commit()
83
+
84
+ @classmethod
85
+ def example_models(cls) -> List[LanguageModelInfo]:
86
+ return [
87
+ LanguageModelInfo(
88
+ "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", "deep_infra"
89
+ ),
90
+ LanguageModelInfo("openai/gpt-4", "openai"),
91
+ ]
92
+
93
+ def add_models_to_cache(self, models_data: List[LanguageModelInfo]):
94
+ """Add new models to the cache, updating timestamps for existing entries."""
95
+ current_time = datetime.now()
96
+
97
+ with sqlite3.connect(self.db_path) as conn:
98
+ cursor = conn.cursor()
99
+ for model in models_data:
100
+ cursor.execute(
101
+ """
102
+ INSERT INTO model_cache (timestamp, model_name, service_name)
103
+ VALUES (?, ?, ?)
104
+ ON CONFLICT(model_name, service_name)
105
+ DO UPDATE SET timestamp = excluded.timestamp
106
+ """,
107
+ (current_time, model.model_name, model.service_name),
108
+ )
109
+
110
+ # self._prune_old_entries(conn)
111
+ conn.commit()
112
+
113
+ def reset_cache(self):
114
+ """Clear all entries from the cache."""
115
+ with sqlite3.connect(self.db_path) as conn:
116
+ cursor = conn.cursor()
117
+ cursor.execute("DELETE FROM model_cache")
118
+ conn.commit()
119
+
120
+ @property
121
+ def num_cache_entries(self):
122
+ """Return the number of entries in the cache."""
123
+ with sqlite3.connect(self.db_path) as conn:
124
+ cursor = conn.cursor()
125
+ cursor.execute("SELECT COUNT(*) FROM model_cache")
126
+ count = cursor.fetchone()[0]
127
+ return count
128
+
129
+ def models(
130
+ self,
131
+ service: Optional[InferenceServiceLiteral],
132
+ ) -> Union[None, AvailableModels]:
133
+ """Return the available models within the cache validity period."""
134
+ # if service is not None:
135
+ # assert service in get_args(InferenceServiceLiteral)
136
+
137
+ with sqlite3.connect(self.db_path) as conn:
138
+ cursor = conn.cursor()
139
+ valid_time = datetime.now().timestamp() - (self.cache_validity_hours * 3600)
140
+
141
+ if self.verbose:
142
+ print(f"Fetching all with timestamp greater than {valid_time}")
143
+
144
+ cursor.execute(
145
+ """
146
+ SELECT DISTINCT model_name, service_name
147
+ FROM model_cache
148
+ WHERE timestamp > ?
149
+ ORDER BY timestamp DESC
150
+ """,
151
+ (valid_time,),
152
+ )
153
+
154
+ results = cursor.fetchall()
155
+ if not results:
156
+ if self.verbose:
157
+ print("No results found in cache DB.")
158
+ return None
159
+
160
+ matching_models = [
161
+ LanguageModelInfo(model_name=row[0], service_name=row[1])
162
+ for row in results
163
+ ]
164
+
165
+ if self.verbose:
166
+ print(f"Found {len(matching_models)} models in cache DB.")
167
+ if service:
168
+ matching_models = [
169
+ model for model in matching_models if model.service_name == service
170
+ ]
171
+
172
+ return AvailableModels(matching_models)
173
+
174
+
175
+ if __name__ == "__main__":
176
+ import doctest
177
+
178
+ doctest.testmod()
179
+ # cache_handler = AvailableModelCacheHandler(verbose=True)
180
+ # models_data = cache_handler.example_models()
181
+ # cache_handler.add_models_to_cache(models_data)
182
+ # print(cache_handler.models())
183
+ # cache_handler.clear_cache()
184
+ # print(cache_handler.models())
@@ -0,0 +1,215 @@
1
+ from typing import Any, List, Tuple, Optional, Dict, TYPE_CHECKING, Union, Generator
2
+ from concurrent.futures import ThreadPoolExecutor, as_completed
3
+ from collections import UserList
4
+
5
+ from edsl.inference_services.ServiceAvailability import ServiceAvailability
6
+ from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
7
+ from edsl.inference_services.data_structures import ModelNamesList
8
+ from edsl.enums import InferenceServiceLiteral
9
+
10
+ from edsl.inference_services.data_structures import LanguageModelInfo
11
+ from edsl.inference_services.AvailableModelCacheHandler import (
12
+ AvailableModelCacheHandler,
13
+ )
14
+
15
+
16
+ from edsl.inference_services.data_structures import AvailableModels
17
+
18
+
19
+ class AvailableModelFetcher:
20
+ """Fetches available models from the various services with JSON caching."""
21
+
22
+ service_availability = ServiceAvailability()
23
+ CACHE_VALIDITY_HOURS = 48 # Cache validity period in hours
24
+
25
+ def __init__(
26
+ self,
27
+ services: List["InferenceServiceABC"],
28
+ added_models: Dict[str, List[str]],
29
+ verbose: bool = False,
30
+ use_cache: bool = True,
31
+ ):
32
+ self.services = services
33
+ self.added_models = added_models
34
+ self._service_map = {
35
+ service._inference_service_: service for service in services
36
+ }
37
+ self.verbose = verbose
38
+ if use_cache:
39
+ self.cache_handler = AvailableModelCacheHandler()
40
+ else:
41
+ self.cache_handler = None
42
+
43
+ @property
44
+ def num_cache_entries(self):
45
+ return self.cache_handler.num_cache_entries
46
+
47
+ @property
48
+ def path_to_db(self):
49
+ return self.cache_handler.path_to_db
50
+
51
+ def reset_cache(self):
52
+ if self.cache_handler:
53
+ self.cache_handler.reset_cache()
54
+
55
+ def available(
56
+ self,
57
+ service: Optional[InferenceServiceABC] = None,
58
+ force_refresh: bool = False,
59
+ ) -> List[LanguageModelInfo]:
60
+ """
61
+ Get available models from all services, using cached data when available.
62
+
63
+ :param service: Optional[InferenceServiceABC] - If specified, only fetch models for this service.
64
+
65
+ >>> from edsl.inference_services.OpenAIService import OpenAIService
66
+ >>> af = AvailableModelFetcher([OpenAIService()], {})
67
+ >>> af.available(service="openai")
68
+ [LanguageModelInfo(model_name='...', service_name='openai'), ...]
69
+
70
+ Returns a list of [model, service_name, index] entries.
71
+ """
72
+
73
+ if service: # they passed a specific service
74
+ matching_models, _ = self.get_available_models_by_service(
75
+ service=service, force_refresh=force_refresh
76
+ )
77
+ return matching_models
78
+
79
+ # Nope, we need to fetch them all
80
+ all_models = self._get_all_models()
81
+
82
+ # if self.cache_handler:
83
+ # self.cache_handler.add_models_to_cache(all_models)
84
+
85
+ return all_models
86
+
87
+ def get_available_models_by_service(
88
+ self,
89
+ service: Union["InferenceServiceABC", InferenceServiceLiteral],
90
+ force_refresh: bool = False,
91
+ ) -> Tuple[AvailableModels, InferenceServiceLiteral]:
92
+ """Get models for a single service.
93
+
94
+ :param service: InferenceServiceABC - e.g., OpenAIService or "openai"
95
+ :return: Tuple[List[LanguageModelInfo], InferenceServiceLiteral]
96
+ """
97
+ if isinstance(service, str):
98
+ service = self._fetch_service_by_service_name(service)
99
+
100
+ if not force_refresh:
101
+ models_from_cache = self.cache_handler.models(
102
+ service=service._inference_service_
103
+ )
104
+ if self.verbose:
105
+ print(
106
+ "Searching cache for models with service name:",
107
+ service._inference_service_,
108
+ )
109
+ print("Got models from cache:", models_from_cache)
110
+ else:
111
+ models_from_cache = None
112
+
113
+ if models_from_cache:
114
+ # print(f"Models from cache for {service}: {models_from_cache}")
115
+ # print(hasattr(models_from_cache[0], "service_name"))
116
+ return models_from_cache, service._inference_service_
117
+ else:
118
+ return self.get_available_models_by_service_fresh(service)
119
+
120
+ def get_available_models_by_service_fresh(
121
+ self, service: Union["InferenceServiceABC", InferenceServiceLiteral]
122
+ ) -> Tuple[AvailableModels, InferenceServiceLiteral]:
123
+ """Get models for a single service. This method always fetches fresh data.
124
+
125
+ :param service: InferenceServiceABC - e.g., OpenAIService or "openai"
126
+ :return: Tuple[List[LanguageModelInfo], InferenceServiceLiteral]
127
+ """
128
+ if isinstance(service, str):
129
+ service = self._fetch_service_by_service_name(service)
130
+
131
+ service_models: ModelNamesList = (
132
+ self.service_availability.get_service_available(service, warn=False)
133
+ )
134
+ service_name = service._inference_service_
135
+
136
+ if not service_models:
137
+ import warnings
138
+
139
+ warnings.warn(f"No models found for service {service_name}")
140
+ return [], service_name
141
+
142
+ models_list = AvailableModels(
143
+ [
144
+ LanguageModelInfo(
145
+ model_name=model_name,
146
+ service_name=service_name,
147
+ )
148
+ for model_name in service_models
149
+ ]
150
+ )
151
+ self.cache_handler.add_models_to_cache(models_list) # update the cache
152
+ return models_list, service_name
153
+
154
+ def _fetch_service_by_service_name(
155
+ self, service_name: InferenceServiceLiteral
156
+ ) -> "InferenceServiceABC":
157
+ """The service name is the _inference_service_ attribute of the service."""
158
+ if service_name in self._service_map:
159
+ return self._service_map[service_name]
160
+ raise ValueError(f"Service {service_name} not found")
161
+
162
+ def _get_all_models(self, force_refresh=False) -> List[LanguageModelInfo]:
163
+ all_models = []
164
+ with ThreadPoolExecutor(max_workers=min(len(self.services), 10)) as executor:
165
+ future_to_service = {
166
+ executor.submit(
167
+ self.get_available_models_by_service, service, force_refresh
168
+ ): service
169
+ for service in self.services
170
+ }
171
+
172
+ for future in as_completed(future_to_service):
173
+ try:
174
+ models, service_name = future.result()
175
+ all_models.extend(models)
176
+
177
+ # Add any additional models for this service
178
+ for model in self.added_models.get(service_name, []):
179
+ all_models.append(
180
+ LanguageModelInfo(
181
+ model_name=model, service_name=service_name
182
+ )
183
+ )
184
+
185
+ except Exception as exc:
186
+ print(f"Service query failed for service {service_name}: {exc}")
187
+ continue
188
+
189
+ return AvailableModels(all_models)
190
+
191
+
192
+ def main():
193
+ from edsl.inference_services.OpenAIService import OpenAIService
194
+
195
+ af = AvailableModelFetcher([OpenAIService()], {}, verbose=True)
196
+ # print(af.available(service="openai"))
197
+ all_models = AvailableModelFetcher([OpenAIService()], {})._get_all_models(
198
+ force_refresh=True
199
+ )
200
+ print(all_models)
201
+
202
+
203
+ if __name__ == "__main__":
204
+ import doctest
205
+
206
+ doctest.testmod(optionflags=doctest.ELLIPSIS)
207
+ # main()
208
+
209
+ # from edsl.inference_services.OpenAIService import OpenAIService
210
+
211
+ # af = AvailableModelFetcher([OpenAIService()], {}, verbose=True)
212
+ # # print(af.available(service="openai"))
213
+
214
+ # all_models = AvailableModelFetcher([OpenAIService()], {})._get_all_models()
215
+ # print(all_models)
@@ -69,8 +69,6 @@ class AwsBedrockService(InferenceServiceABC):
69
69
  }
70
70
  input_token_name = cls.input_token_name
71
71
  output_token_name = cls.output_token_name
72
- _rpm = cls.get_rpm(cls)
73
- _tpm = cls.get_tpm(cls)
74
72
 
75
73
  async def async_execute_model_call(
76
74
  self,
@@ -118,8 +118,6 @@ class AzureAIService(InferenceServiceABC):
118
118
  "max_tokens": 512,
119
119
  "top_p": 0.9,
120
120
  }
121
- _rpm = cls.get_rpm(cls)
122
- _tpm = cls.get_tpm(cls)
123
121
 
124
122
  async def async_execute_model_call(
125
123
  self,
@@ -1,11 +1,11 @@
1
- import os
1
+ # import os
2
2
  from typing import Any, Dict, List, Optional
3
3
  import google
4
4
  import google.generativeai as genai
5
5
  from google.generativeai.types import GenerationConfig
6
6
  from google.api_core.exceptions import InvalidArgument
7
7
 
8
- from edsl.exceptions import MissingAPIKeyError
8
+ # from edsl.exceptions.general import MissingAPIKeyError
9
9
  from edsl.language_models.LanguageModel import LanguageModel
10
10
  from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
11
11
  from edsl.coop import Coop
@@ -39,18 +39,18 @@ class GoogleService(InferenceServiceABC):
39
39
 
40
40
  model_exclude_list = []
41
41
 
42
- # @classmethod
43
- # def available(cls) -> List[str]:
44
- # return ["gemini-pro", "gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.0-pro"]
45
-
46
42
  @classmethod
47
- def available(cls) -> List[str]:
43
+ def get_model_list(cls):
48
44
  model_list = []
49
45
  for m in genai.list_models():
50
46
  if "generateContent" in m.supported_generation_methods:
51
47
  model_list.append(m.name.split("/")[-1])
52
48
  return model_list
53
49
 
50
+ @classmethod
51
+ def available(cls) -> List[str]:
52
+ return cls.get_model_list()
53
+
54
54
  @classmethod
55
55
  def create_model(
56
56
  cls, model_name: str = "gemini-pro", model_class_name=None
@@ -66,9 +66,6 @@ class GoogleService(InferenceServiceABC):
66
66
  output_token_name = cls.output_token_name
67
67
  _inference_service_ = cls._inference_service_
68
68
 
69
- _tpm = cls.get_tpm(cls)
70
- _rpm = cls.get_rpm(cls)
71
-
72
69
  _parameters_ = {
73
70
  "temperature": 0.5,
74
71
  "topP": 1,
@@ -77,7 +74,6 @@ class GoogleService(InferenceServiceABC):
77
74
  "stopSequences": [],
78
75
  }
79
76
 
80
- api_token = None
81
77
  model = None
82
78
 
83
79
  def __init__(self, *args, **kwargs):
@@ -102,7 +98,6 @@ class GoogleService(InferenceServiceABC):
102
98
 
103
99
  if files_list is None:
104
100
  files_list = []
105
-
106
101
  genai.configure(api_key=self.api_token)
107
102
  if (
108
103
  system_prompt is not None