edsl 0.1.39.dev1__py3-none-any.whl → 0.1.39.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (194) hide show
  1. edsl/Base.py +169 -116
  2. edsl/__init__.py +14 -6
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +358 -146
  5. edsl/agents/AgentList.py +211 -73
  6. edsl/agents/Invigilator.py +88 -36
  7. edsl/agents/InvigilatorBase.py +59 -70
  8. edsl/agents/PromptConstructor.py +117 -219
  9. edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
  10. edsl/agents/QuestionOptionProcessor.py +172 -0
  11. edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
  12. edsl/agents/__init__.py +0 -1
  13. edsl/agents/prompt_helpers.py +3 -3
  14. edsl/config.py +22 -2
  15. edsl/conversation/car_buying.py +2 -1
  16. edsl/coop/CoopFunctionsMixin.py +15 -0
  17. edsl/coop/ExpectedParrotKeyHandler.py +125 -0
  18. edsl/coop/PriceFetcher.py +1 -1
  19. edsl/coop/coop.py +104 -42
  20. edsl/coop/utils.py +14 -14
  21. edsl/data/Cache.py +21 -14
  22. edsl/data/CacheEntry.py +12 -15
  23. edsl/data/CacheHandler.py +33 -12
  24. edsl/data/__init__.py +4 -3
  25. edsl/data_transfer_models.py +2 -1
  26. edsl/enums.py +20 -0
  27. edsl/exceptions/__init__.py +50 -50
  28. edsl/exceptions/agents.py +12 -0
  29. edsl/exceptions/inference_services.py +5 -0
  30. edsl/exceptions/questions.py +24 -6
  31. edsl/exceptions/scenarios.py +7 -0
  32. edsl/inference_services/AnthropicService.py +0 -3
  33. edsl/inference_services/AvailableModelCacheHandler.py +184 -0
  34. edsl/inference_services/AvailableModelFetcher.py +209 -0
  35. edsl/inference_services/AwsBedrock.py +0 -2
  36. edsl/inference_services/AzureAI.py +0 -2
  37. edsl/inference_services/GoogleService.py +2 -11
  38. edsl/inference_services/InferenceServiceABC.py +18 -85
  39. edsl/inference_services/InferenceServicesCollection.py +105 -80
  40. edsl/inference_services/MistralAIService.py +0 -3
  41. edsl/inference_services/OpenAIService.py +1 -4
  42. edsl/inference_services/PerplexityService.py +0 -3
  43. edsl/inference_services/ServiceAvailability.py +135 -0
  44. edsl/inference_services/TestService.py +11 -8
  45. edsl/inference_services/data_structures.py +62 -0
  46. edsl/jobs/AnswerQuestionFunctionConstructor.py +188 -0
  47. edsl/jobs/Answers.py +1 -14
  48. edsl/jobs/FetchInvigilator.py +40 -0
  49. edsl/jobs/InterviewTaskManager.py +98 -0
  50. edsl/jobs/InterviewsConstructor.py +48 -0
  51. edsl/jobs/Jobs.py +102 -243
  52. edsl/jobs/JobsChecks.py +35 -10
  53. edsl/jobs/JobsComponentConstructor.py +189 -0
  54. edsl/jobs/JobsPrompts.py +5 -3
  55. edsl/jobs/JobsRemoteInferenceHandler.py +128 -80
  56. edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
  57. edsl/jobs/RequestTokenEstimator.py +30 -0
  58. edsl/jobs/buckets/BucketCollection.py +44 -3
  59. edsl/jobs/buckets/TokenBucket.py +53 -21
  60. edsl/jobs/buckets/TokenBucketAPI.py +211 -0
  61. edsl/jobs/buckets/TokenBucketClient.py +191 -0
  62. edsl/jobs/decorators.py +35 -0
  63. edsl/jobs/interviews/Interview.py +77 -380
  64. edsl/jobs/jobs_status_enums.py +9 -0
  65. edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
  66. edsl/jobs/runners/JobsRunnerAsyncio.py +4 -49
  67. edsl/jobs/tasks/QuestionTaskCreator.py +21 -19
  68. edsl/jobs/tasks/TaskHistory.py +14 -15
  69. edsl/jobs/tasks/task_status_enum.py +0 -2
  70. edsl/language_models/ComputeCost.py +63 -0
  71. edsl/language_models/LanguageModel.py +137 -234
  72. edsl/language_models/ModelList.py +11 -13
  73. edsl/language_models/PriceManager.py +127 -0
  74. edsl/language_models/RawResponseHandler.py +106 -0
  75. edsl/language_models/ServiceDataSources.py +0 -0
  76. edsl/language_models/__init__.py +0 -1
  77. edsl/language_models/key_management/KeyLookup.py +63 -0
  78. edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
  79. edsl/language_models/key_management/KeyLookupCollection.py +38 -0
  80. edsl/language_models/key_management/__init__.py +0 -0
  81. edsl/language_models/key_management/models.py +131 -0
  82. edsl/language_models/registry.py +49 -59
  83. edsl/language_models/repair.py +2 -2
  84. edsl/language_models/utilities.py +5 -4
  85. edsl/notebooks/Notebook.py +19 -14
  86. edsl/notebooks/NotebookToLaTeX.py +142 -0
  87. edsl/prompts/Prompt.py +29 -39
  88. edsl/questions/AnswerValidatorMixin.py +47 -2
  89. edsl/questions/ExceptionExplainer.py +77 -0
  90. edsl/questions/HTMLQuestion.py +103 -0
  91. edsl/questions/LoopProcessor.py +149 -0
  92. edsl/questions/QuestionBase.py +37 -192
  93. edsl/questions/QuestionBaseGenMixin.py +52 -48
  94. edsl/questions/QuestionBasePromptsMixin.py +7 -3
  95. edsl/questions/QuestionCheckBox.py +1 -1
  96. edsl/questions/QuestionExtract.py +1 -1
  97. edsl/questions/QuestionFreeText.py +1 -2
  98. edsl/questions/QuestionList.py +3 -5
  99. edsl/questions/QuestionMatrix.py +265 -0
  100. edsl/questions/QuestionMultipleChoice.py +66 -22
  101. edsl/questions/QuestionNumerical.py +1 -3
  102. edsl/questions/QuestionRank.py +6 -16
  103. edsl/questions/ResponseValidatorABC.py +37 -11
  104. edsl/questions/ResponseValidatorFactory.py +28 -0
  105. edsl/questions/SimpleAskMixin.py +4 -3
  106. edsl/questions/__init__.py +1 -0
  107. edsl/questions/derived/QuestionLinearScale.py +6 -3
  108. edsl/questions/derived/QuestionTopK.py +1 -1
  109. edsl/questions/descriptors.py +17 -3
  110. edsl/questions/question_registry.py +1 -1
  111. edsl/questions/templates/matrix/__init__.py +1 -0
  112. edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
  113. edsl/questions/templates/matrix/question_presentation.jinja +20 -0
  114. edsl/results/CSSParameterizer.py +1 -1
  115. edsl/results/Dataset.py +170 -7
  116. edsl/results/DatasetExportMixin.py +224 -302
  117. edsl/results/DatasetTree.py +28 -8
  118. edsl/results/MarkdownToDocx.py +122 -0
  119. edsl/results/MarkdownToPDF.py +111 -0
  120. edsl/results/Result.py +192 -206
  121. edsl/results/Results.py +120 -113
  122. edsl/results/ResultsExportMixin.py +2 -0
  123. edsl/results/Selector.py +23 -13
  124. edsl/results/TableDisplay.py +98 -171
  125. edsl/results/TextEditor.py +50 -0
  126. edsl/results/__init__.py +1 -1
  127. edsl/results/smart_objects.py +96 -0
  128. edsl/results/table_data_class.py +12 -0
  129. edsl/results/table_renderers.py +118 -0
  130. edsl/scenarios/ConstructDownloadLink.py +109 -0
  131. edsl/scenarios/DirectoryScanner.py +96 -0
  132. edsl/scenarios/DocumentChunker.py +102 -0
  133. edsl/scenarios/DocxScenario.py +16 -0
  134. edsl/scenarios/FileStore.py +118 -239
  135. edsl/scenarios/PdfExtractor.py +40 -0
  136. edsl/scenarios/Scenario.py +90 -193
  137. edsl/scenarios/ScenarioHtmlMixin.py +4 -3
  138. edsl/scenarios/ScenarioJoin.py +10 -6
  139. edsl/scenarios/ScenarioList.py +383 -240
  140. edsl/scenarios/ScenarioListExportMixin.py +0 -7
  141. edsl/scenarios/ScenarioListPdfMixin.py +15 -37
  142. edsl/scenarios/ScenarioSelector.py +156 -0
  143. edsl/scenarios/__init__.py +1 -2
  144. edsl/scenarios/file_methods.py +85 -0
  145. edsl/scenarios/handlers/__init__.py +13 -0
  146. edsl/scenarios/handlers/csv.py +38 -0
  147. edsl/scenarios/handlers/docx.py +76 -0
  148. edsl/scenarios/handlers/html.py +37 -0
  149. edsl/scenarios/handlers/json.py +111 -0
  150. edsl/scenarios/handlers/latex.py +5 -0
  151. edsl/scenarios/handlers/md.py +51 -0
  152. edsl/scenarios/handlers/pdf.py +68 -0
  153. edsl/scenarios/handlers/png.py +39 -0
  154. edsl/scenarios/handlers/pptx.py +105 -0
  155. edsl/scenarios/handlers/py.py +294 -0
  156. edsl/scenarios/handlers/sql.py +313 -0
  157. edsl/scenarios/handlers/sqlite.py +149 -0
  158. edsl/scenarios/handlers/txt.py +33 -0
  159. edsl/study/ObjectEntry.py +1 -1
  160. edsl/study/SnapShot.py +1 -1
  161. edsl/study/Study.py +5 -12
  162. edsl/surveys/ConstructDAG.py +92 -0
  163. edsl/surveys/EditSurvey.py +221 -0
  164. edsl/surveys/InstructionHandler.py +100 -0
  165. edsl/surveys/MemoryManagement.py +72 -0
  166. edsl/surveys/Rule.py +5 -4
  167. edsl/surveys/RuleCollection.py +25 -27
  168. edsl/surveys/RuleManager.py +172 -0
  169. edsl/surveys/Simulator.py +75 -0
  170. edsl/surveys/Survey.py +199 -771
  171. edsl/surveys/SurveyCSS.py +20 -8
  172. edsl/surveys/{SurveyFlowVisualizationMixin.py → SurveyFlowVisualization.py} +11 -9
  173. edsl/surveys/SurveyToApp.py +141 -0
  174. edsl/surveys/__init__.py +4 -2
  175. edsl/surveys/descriptors.py +6 -2
  176. edsl/surveys/instructions/ChangeInstruction.py +1 -2
  177. edsl/surveys/instructions/Instruction.py +4 -13
  178. edsl/surveys/instructions/InstructionCollection.py +11 -6
  179. edsl/templates/error_reporting/interview_details.html +1 -1
  180. edsl/templates/error_reporting/report.html +1 -1
  181. edsl/tools/plotting.py +1 -1
  182. edsl/utilities/PrettyList.py +56 -0
  183. edsl/utilities/is_notebook.py +18 -0
  184. edsl/utilities/is_valid_variable_name.py +11 -0
  185. edsl/utilities/remove_edsl_version.py +24 -0
  186. edsl/utilities/utilities.py +35 -23
  187. {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev2.dist-info}/METADATA +12 -10
  188. edsl-0.1.39.dev2.dist-info/RECORD +352 -0
  189. edsl/language_models/KeyLookup.py +0 -30
  190. edsl/language_models/unused/ReplicateBase.py +0 -83
  191. edsl/results/ResultsDBMixin.py +0 -238
  192. edsl-0.1.39.dev1.dist-info/RECORD +0 -277
  193. {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev2.dist-info}/LICENSE +0 -0
  194. {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev2.dist-info}/WHEEL +0 -0
@@ -1,54 +1,54 @@
1
- from .agents import (
2
- # AgentAttributeLookupCallbackError,
3
- AgentCombinationError,
4
- # AgentLacksLLMError,
5
- # AgentRespondedWithBadJSONError,
6
- )
7
- from .configuration import (
8
- InvalidEnvironmentVariableError,
9
- MissingEnvironmentVariableError,
10
- )
11
- from .data import (
12
- DatabaseConnectionError,
13
- DatabaseCRUDError,
14
- DatabaseIntegrityError,
15
- )
1
+ # from .agents import (
2
+ # # AgentAttributeLookupCallbackError,
3
+ # AgentCombinationError,
4
+ # # AgentLacksLLMError,
5
+ # # AgentRespondedWithBadJSONError,
6
+ # )
7
+ # from .configuration import (
8
+ # InvalidEnvironmentVariableError,
9
+ # MissingEnvironmentVariableError,
10
+ # )
11
+ # from .data import (
12
+ # DatabaseConnectionError,
13
+ # DatabaseCRUDError,
14
+ # DatabaseIntegrityError,
15
+ # )
16
16
 
17
- from .scenarios import (
18
- ScenarioError,
19
- )
17
+ # from .scenarios import (
18
+ # ScenarioError,
19
+ # )
20
20
 
21
- from .general import MissingAPIKeyError
21
+ # from .general import MissingAPIKeyError
22
22
 
23
- from .jobs import JobsRunError, InterviewErrorPriorTaskCanceled, InterviewTimeoutError
23
+ # from .jobs import JobsRunError, InterviewErrorPriorTaskCanceled, InterviewTimeoutError
24
24
 
25
- from .language_models import (
26
- LanguageModelResponseNotJSONError,
27
- LanguageModelMissingAttributeError,
28
- LanguageModelAttributeTypeError,
29
- LanguageModelDoNotAddError,
30
- )
31
- from .questions import (
32
- QuestionAnswerValidationError,
33
- QuestionAttributeMissing,
34
- QuestionCreationValidationError,
35
- QuestionResponseValidationError,
36
- QuestionSerializationError,
37
- QuestionScenarioRenderError,
38
- )
39
- from .results import (
40
- ResultsBadMutationstringError,
41
- ResultsColumnNotFoundError,
42
- ResultsInvalidNameError,
43
- ResultsMutateError,
44
- )
45
- from .surveys import (
46
- SurveyCreationError,
47
- SurveyHasNoRulesError,
48
- SurveyRuleCannotEvaluateError,
49
- SurveyRuleCollectionHasNoRulesAtNodeError,
50
- SurveyRuleReferenceInRuleToUnknownQuestionError,
51
- SurveyRuleRefersToFutureStateError,
52
- SurveyRuleSendsYouBackwardsError,
53
- SurveyRuleSkipLogicSyntaxError,
54
- )
25
+ # from .language_models import (
26
+ # LanguageModelResponseNotJSONError,
27
+ # LanguageModelMissingAttributeError,
28
+ # LanguageModelAttributeTypeError,
29
+ # LanguageModelDoNotAddError,
30
+ # )
31
+ # from .questions import (
32
+ # QuestionAnswerValidationError,
33
+ # QuestionAttributeMissing,
34
+ # QuestionCreationValidationError,
35
+ # QuestionResponseValidationError,
36
+ # QuestionSerializationError,
37
+ # QuestionScenarioRenderError,
38
+ # )
39
+ # from .results import (
40
+ # ResultsBadMutationstringError,
41
+ # ResultsColumnNotFoundError,
42
+ # ResultsInvalidNameError,
43
+ # ResultsMutateError,
44
+ # )
45
+ # from .surveys import (
46
+ # SurveyCreationError,
47
+ # SurveyHasNoRulesError,
48
+ # SurveyRuleCannotEvaluateError,
49
+ # SurveyRuleCollectionHasNoRulesAtNodeError,
50
+ # SurveyRuleReferenceInRuleToUnknownQuestionError,
51
+ # SurveyRuleRefersToFutureStateError,
52
+ # SurveyRuleSendsYouBackwardsError,
53
+ # SurveyRuleSkipLogicSyntaxError,
54
+ # )
edsl/exceptions/agents.py CHANGED
@@ -1,6 +1,18 @@
1
1
  from edsl.exceptions.BaseException import BaseException
2
2
 
3
3
 
4
+ # from edsl.utilities.utilities import is_notebook
5
+
6
+ # from IPython.core.error import UsageError
7
+
8
+ # class AgentListErrorAlternative(UsageError):
9
+ # def __init__(self, message):
10
+ # super().__init__(message)
11
+
12
+ import sys
13
+ from edsl.utilities.is_notebook import is_notebook
14
+
15
+
4
16
  class AgentListError(BaseException):
5
17
  relevant_doc = "https://docs.expectedparrot.com/en/latest/agents.html#agent-lists"
6
18
 
@@ -0,0 +1,5 @@
1
+ from edsl.exceptions.BaseException import BaseException
2
+
3
+
4
+ class InferenceServiceError(BaseException):
5
+ relevant_doc = "https://docs.expectedparrot.com/"
@@ -1,6 +1,6 @@
1
1
  from typing import Any, SupportsIndex
2
- from jinja2 import Template
3
2
  import json
3
+ from pydantic import ValidationError
4
4
 
5
5
 
6
6
  class QuestionErrors(Exception):
@@ -20,17 +20,35 @@ class QuestionAnswerValidationError(QuestionErrors):
20
20
  For example, if the question is a multiple choice question, the answer should be drawn from the list of options provided.
21
21
  """
22
22
 
23
- def __init__(self, message="Invalid answer.", data=None, model=None):
23
+ def __init__(
24
+ self,
25
+ message="Invalid answer.",
26
+ pydantic_error: ValidationError = None,
27
+ data: dict = None,
28
+ model=None,
29
+ ):
24
30
  self.message = message
31
+ self.pydantic_error = pydantic_error
25
32
  self.data = data
26
33
  self.model = model
27
34
  super().__init__(self.message)
28
35
 
29
36
  def __str__(self):
30
- return f"""{repr(self)}
31
- Data being validated: {self.data}
32
- Pydnantic Model: {self.model}.
33
- Reported error: {self.message}."""
37
+ if isinstance(self.message, ValidationError):
38
+ # If it's a ValidationError, just return the core error message
39
+ return str(self.message)
40
+ elif hasattr(self.message, "errors"):
41
+ # Handle the case where it's already been converted to a string but has errors
42
+ error_list = self.message.errors()
43
+ if error_list:
44
+ return str(error_list[0].get("msg", "Unknown error"))
45
+ return str(self.message)
46
+
47
+ # def __str__(self):
48
+ # return f"""{repr(self)}
49
+ # Data being validated: {self.data}
50
+ # Pydnantic Model: {self.model}.
51
+ # Reported error: {self.message}."""
34
52
 
35
53
  def to_html_dict(self):
36
54
  return {
@@ -1,6 +1,13 @@
1
1
  import re
2
2
  import textwrap
3
3
 
4
+ # from IPython.core.error import UsageError
5
+
6
+
7
+ class AgentListError(Exception):
8
+ def __init__(self, message):
9
+ super().__init__(message)
10
+
4
11
 
5
12
  class ScenarioError(Exception):
6
13
  documentation = "https://docs.expectedparrot.com/en/latest/scenarios.html#module-edsl.scenarios.Scenario"
@@ -56,9 +56,6 @@ class AnthropicService(InferenceServiceABC):
56
56
  "top_logprobs": 3,
57
57
  }
58
58
 
59
- _tpm = cls.get_tpm(cls)
60
- _rpm = cls.get_rpm(cls)
61
-
62
59
  async def async_execute_model_call(
63
60
  self,
64
61
  user_prompt: str,
@@ -0,0 +1,184 @@
1
+ from typing import List, Optional, get_args, Union
2
+ from pathlib import Path
3
+ import sqlite3
4
+ from datetime import datetime
5
+ import tempfile
6
+ from platformdirs import user_cache_dir
7
+ from dataclasses import dataclass
8
+ import os
9
+
10
+ from edsl.inference_services.data_structures import LanguageModelInfo, AvailableModels
11
+ from edsl.enums import InferenceServiceLiteral
12
+
13
+
14
+ class AvailableModelCacheHandler:
15
+ MAX_ROWS = 1000
16
+ CACHE_VALIDITY_HOURS = 48
17
+
18
+ def __init__(
19
+ self,
20
+ cache_validity_hours: int = 48,
21
+ verbose: bool = False,
22
+ testing_db_name: str = None,
23
+ ):
24
+ self.cache_validity_hours = cache_validity_hours
25
+ self.verbose = verbose
26
+
27
+ if testing_db_name:
28
+ self.cache_dir = Path(tempfile.mkdtemp())
29
+ self.db_path = self.cache_dir / testing_db_name
30
+ else:
31
+ self.cache_dir = Path(user_cache_dir("edsl", "model_availability"))
32
+ self.db_path = self.cache_dir / "available_models.db"
33
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
34
+
35
+ if os.path.exists(self.db_path):
36
+ if self.verbose:
37
+ print(f"Using existing cache DB: {self.db_path}")
38
+ else:
39
+ self._initialize_db()
40
+
41
+ @property
42
+ def path_to_db(self):
43
+ return self.db_path
44
+
45
+ def _initialize_db(self):
46
+ """Initialize the SQLite database with the required schema."""
47
+ with sqlite3.connect(self.db_path) as conn:
48
+ cursor = conn.cursor()
49
+ # Drop the old table if it exists (for migration)
50
+ cursor.execute("DROP TABLE IF EXISTS model_cache")
51
+ cursor.execute(
52
+ """
53
+ CREATE TABLE IF NOT EXISTS model_cache (
54
+ timestamp DATETIME NOT NULL,
55
+ model_name TEXT NOT NULL,
56
+ service_name TEXT NOT NULL,
57
+ UNIQUE(model_name, service_name)
58
+ )
59
+ """
60
+ )
61
+ conn.commit()
62
+
63
+ def _prune_old_entries(self, conn: sqlite3.Connection):
64
+ """Delete oldest entries when MAX_ROWS is exceeded."""
65
+ cursor = conn.cursor()
66
+ cursor.execute("SELECT COUNT(*) FROM model_cache")
67
+ count = cursor.fetchone()[0]
68
+
69
+ if count > self.MAX_ROWS:
70
+ cursor.execute(
71
+ """
72
+ DELETE FROM model_cache
73
+ WHERE rowid IN (
74
+ SELECT rowid
75
+ FROM model_cache
76
+ ORDER BY timestamp ASC
77
+ LIMIT ?
78
+ )
79
+ """,
80
+ (count - self.MAX_ROWS,),
81
+ )
82
+ conn.commit()
83
+
84
+ @classmethod
85
+ def example_models(cls) -> List[LanguageModelInfo]:
86
+ return [
87
+ LanguageModelInfo(
88
+ "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", "deep_infra"
89
+ ),
90
+ LanguageModelInfo("openai/gpt-4", "openai"),
91
+ ]
92
+
93
+ def add_models_to_cache(self, models_data: List[LanguageModelInfo]):
94
+ """Add new models to the cache, updating timestamps for existing entries."""
95
+ current_time = datetime.now()
96
+
97
+ with sqlite3.connect(self.db_path) as conn:
98
+ cursor = conn.cursor()
99
+ for model in models_data:
100
+ cursor.execute(
101
+ """
102
+ INSERT INTO model_cache (timestamp, model_name, service_name)
103
+ VALUES (?, ?, ?)
104
+ ON CONFLICT(model_name, service_name)
105
+ DO UPDATE SET timestamp = excluded.timestamp
106
+ """,
107
+ (current_time, model.model_name, model.service_name),
108
+ )
109
+
110
+ # self._prune_old_entries(conn)
111
+ conn.commit()
112
+
113
+ def reset_cache(self):
114
+ """Clear all entries from the cache."""
115
+ with sqlite3.connect(self.db_path) as conn:
116
+ cursor = conn.cursor()
117
+ cursor.execute("DELETE FROM model_cache")
118
+ conn.commit()
119
+
120
+ @property
121
+ def num_cache_entries(self):
122
+ """Return the number of entries in the cache."""
123
+ with sqlite3.connect(self.db_path) as conn:
124
+ cursor = conn.cursor()
125
+ cursor.execute("SELECT COUNT(*) FROM model_cache")
126
+ count = cursor.fetchone()[0]
127
+ return count
128
+
129
+ def models(
130
+ self,
131
+ service: Optional[InferenceServiceLiteral],
132
+ ) -> Union[None, AvailableModels]:
133
+ """Return the available models within the cache validity period."""
134
+ # if service is not None:
135
+ # assert service in get_args(InferenceServiceLiteral)
136
+
137
+ with sqlite3.connect(self.db_path) as conn:
138
+ cursor = conn.cursor()
139
+ valid_time = datetime.now().timestamp() - (self.cache_validity_hours * 3600)
140
+
141
+ if self.verbose:
142
+ print(f"Fetching all with timestamp greater than {valid_time}")
143
+
144
+ cursor.execute(
145
+ """
146
+ SELECT DISTINCT model_name, service_name
147
+ FROM model_cache
148
+ WHERE timestamp > ?
149
+ ORDER BY timestamp DESC
150
+ """,
151
+ (valid_time,),
152
+ )
153
+
154
+ results = cursor.fetchall()
155
+ if not results:
156
+ if self.verbose:
157
+ print("No results found in cache DB.")
158
+ return None
159
+
160
+ matching_models = [
161
+ LanguageModelInfo(model_name=row[0], service_name=row[1])
162
+ for row in results
163
+ ]
164
+
165
+ if self.verbose:
166
+ print(f"Found {len(matching_models)} models in cache DB.")
167
+ if service:
168
+ matching_models = [
169
+ model for model in matching_models if model.service_name == service
170
+ ]
171
+
172
+ return AvailableModels(matching_models)
173
+
174
+
175
+ if __name__ == "__main__":
176
+ import doctest
177
+
178
+ doctest.testmod()
179
+ # cache_handler = AvailableModelCacheHandler(verbose=True)
180
+ # models_data = cache_handler.example_models()
181
+ # cache_handler.add_models_to_cache(models_data)
182
+ # print(cache_handler.models())
183
+ # cache_handler.clear_cache()
184
+ # print(cache_handler.models())
@@ -0,0 +1,209 @@
1
+ from typing import Any, List, Tuple, Optional, Dict, TYPE_CHECKING, Union, Generator
2
+ from concurrent.futures import ThreadPoolExecutor, as_completed
3
+ from collections import UserList
4
+
5
+ from edsl.inference_services.ServiceAvailability import ServiceAvailability
6
+ from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
7
+ from edsl.inference_services.data_structures import ModelNamesList
8
+ from edsl.enums import InferenceServiceLiteral
9
+
10
+ from edsl.inference_services.data_structures import LanguageModelInfo
11
+ from edsl.inference_services.AvailableModelCacheHandler import (
12
+ AvailableModelCacheHandler,
13
+ )
14
+
15
+
16
+ from edsl.inference_services.data_structures import AvailableModels
17
+
18
+
19
+ class AvailableModelFetcher:
20
+ """Fetches available models from the various services with JSON caching."""
21
+
22
+ service_availability = ServiceAvailability()
23
+ CACHE_VALIDITY_HOURS = 48 # Cache validity period in hours
24
+
25
+ def __init__(
26
+ self,
27
+ services: List["InferenceServiceABC"],
28
+ added_models: Dict[str, List[str]],
29
+ verbose: bool = False,
30
+ use_cache: bool = True,
31
+ ):
32
+ self.services = services
33
+ self.added_models = added_models
34
+ self._service_map = {
35
+ service._inference_service_: service for service in services
36
+ }
37
+ self.verbose = verbose
38
+ if use_cache:
39
+ self.cache_handler = AvailableModelCacheHandler()
40
+ else:
41
+ self.cache_handler = None
42
+
43
+ @property
44
+ def num_cache_entries(self):
45
+ return self.cache_handler.num_cache_entries
46
+
47
+ @property
48
+ def path_to_db(self):
49
+ return self.cache_handler.path_to_db
50
+
51
+ def reset_cache(self):
52
+ if self.cache_handler:
53
+ self.cache_handler.reset_cache()
54
+
55
+ def available(
56
+ self,
57
+ service: Optional[InferenceServiceABC] = None,
58
+ force_refresh: bool = False,
59
+ ) -> List[LanguageModelInfo]:
60
+ """
61
+ Get available models from all services, using cached data when available.
62
+
63
+ :param service: Optional[InferenceServiceABC] - If specified, only fetch models for this service.
64
+
65
+ >>> from edsl.inference_services.OpenAIService import OpenAIService
66
+ >>> af = AvailableModelFetcher([OpenAIService()], {})
67
+ >>> af.available(service="openai")
68
+ [LanguageModelInfo(model_name='...', service_name='openai'), ...]
69
+
70
+ Returns a list of [model, service_name, index] entries.
71
+ """
72
+
73
+ if service: # they passed a specific service
74
+ matching_models, _ = self.get_available_models_by_service(
75
+ service=service, force_refresh=force_refresh
76
+ )
77
+ return matching_models
78
+
79
+ # Nope, we need to fetch them all
80
+ all_models = self._get_all_models()
81
+
82
+ # if self.cache_handler:
83
+ # self.cache_handler.add_models_to_cache(all_models)
84
+
85
+ return all_models
86
+
87
+ def get_available_models_by_service(
88
+ self,
89
+ service: Union["InferenceServiceABC", InferenceServiceLiteral],
90
+ force_refresh: bool = False,
91
+ ) -> Tuple[AvailableModels, InferenceServiceLiteral]:
92
+ """Get models for a single service.
93
+
94
+ :param service: InferenceServiceABC - e.g., OpenAIService or "openai"
95
+ :return: Tuple[List[LanguageModelInfo], InferenceServiceLiteral]
96
+ """
97
+ if isinstance(service, str):
98
+ service = self._fetch_service_by_service_name(service)
99
+
100
+ if not force_refresh:
101
+ models_from_cache = self.cache_handler.models(
102
+ service=service._inference_service_
103
+ )
104
+ if self.verbose:
105
+ print(
106
+ "Searching cache for models with service name:",
107
+ service._inference_service_,
108
+ )
109
+ print("Got models from cache:", models_from_cache)
110
+ else:
111
+ models_from_cache = None
112
+
113
+ if models_from_cache:
114
+ # print(f"Models from cache for {service}: {models_from_cache}")
115
+ # print(hasattr(models_from_cache[0], "service_name"))
116
+ return models_from_cache, service._inference_service_
117
+ else:
118
+ return self.get_available_models_by_service_fresh(service)
119
+
120
+ def get_available_models_by_service_fresh(
121
+ self, service: Union["InferenceServiceABC", InferenceServiceLiteral]
122
+ ) -> Tuple[AvailableModels, InferenceServiceLiteral]:
123
+ """Get models for a single service. This method always fetches fresh data.
124
+
125
+ :param service: InferenceServiceABC - e.g., OpenAIService or "openai"
126
+ :return: Tuple[List[LanguageModelInfo], InferenceServiceLiteral]
127
+ """
128
+ if isinstance(service, str):
129
+ service = self._fetch_service_by_service_name(service)
130
+
131
+ service_models: ModelNamesList = (
132
+ self.service_availability.get_service_available(service, warn=False)
133
+ )
134
+ service_name = service._inference_service_
135
+
136
+ models_list = AvailableModels(
137
+ [
138
+ LanguageModelInfo(
139
+ model_name=model_name,
140
+ service_name=service_name,
141
+ )
142
+ for model_name in service_models
143
+ ]
144
+ )
145
+ self.cache_handler.add_models_to_cache(models_list) # update the cache
146
+ return models_list, service_name
147
+
148
+ def _fetch_service_by_service_name(
149
+ self, service_name: InferenceServiceLiteral
150
+ ) -> "InferenceServiceABC":
151
+ """The service name is the _inference_service_ attribute of the service."""
152
+ if service_name in self._service_map:
153
+ return self._service_map[service_name]
154
+ raise ValueError(f"Service {service_name} not found")
155
+
156
+ def _get_all_models(self, force_refresh=False) -> List[LanguageModelInfo]:
157
+ all_models = []
158
+ with ThreadPoolExecutor(max_workers=min(len(self.services), 10)) as executor:
159
+ future_to_service = {
160
+ executor.submit(
161
+ self.get_available_models_by_service, service, force_refresh
162
+ ): service
163
+ for service in self.services
164
+ }
165
+
166
+ for future in as_completed(future_to_service):
167
+ try:
168
+ models, service_name = future.result()
169
+ all_models.extend(models)
170
+
171
+ # Add any additional models for this service
172
+ for model in self.added_models.get(service_name, []):
173
+ all_models.append(
174
+ LanguageModelInfo(
175
+ model_name=model, service_name=service_name
176
+ )
177
+ )
178
+
179
+ except Exception as exc:
180
+ print(f"Service query failed: {exc}")
181
+ continue
182
+
183
+ return AvailableModels(all_models)
184
+
185
+
186
+ def main():
187
+ from edsl.inference_services.OpenAIService import OpenAIService
188
+
189
+ af = AvailableModelFetcher([OpenAIService()], {}, verbose=True)
190
+ # print(af.available(service="openai"))
191
+ all_models = AvailableModelFetcher([OpenAIService()], {})._get_all_models(
192
+ force_refresh=True
193
+ )
194
+ print(all_models)
195
+
196
+
197
+ if __name__ == "__main__":
198
+ import doctest
199
+
200
+ doctest.testmod(optionflags=doctest.ELLIPSIS)
201
+ # main()
202
+
203
+ # from edsl.inference_services.OpenAIService import OpenAIService
204
+
205
+ # af = AvailableModelFetcher([OpenAIService()], {}, verbose=True)
206
+ # # print(af.available(service="openai"))
207
+
208
+ # all_models = AvailableModelFetcher([OpenAIService()], {})._get_all_models()
209
+ # print(all_models)
@@ -69,8 +69,6 @@ class AwsBedrockService(InferenceServiceABC):
69
69
  }
70
70
  input_token_name = cls.input_token_name
71
71
  output_token_name = cls.output_token_name
72
- _rpm = cls.get_rpm(cls)
73
- _tpm = cls.get_tpm(cls)
74
72
 
75
73
  async def async_execute_model_call(
76
74
  self,
@@ -118,8 +118,6 @@ class AzureAIService(InferenceServiceABC):
118
118
  "max_tokens": 512,
119
119
  "top_p": 0.9,
120
120
  }
121
- _rpm = cls.get_rpm(cls)
122
- _tpm = cls.get_tpm(cls)
123
121
 
124
122
  async def async_execute_model_call(
125
123
  self,
@@ -1,11 +1,11 @@
1
- import os
1
+ # import os
2
2
  from typing import Any, Dict, List, Optional
3
3
  import google
4
4
  import google.generativeai as genai
5
5
  from google.generativeai.types import GenerationConfig
6
6
  from google.api_core.exceptions import InvalidArgument
7
7
 
8
- from edsl.exceptions import MissingAPIKeyError
8
+ # from edsl.exceptions.general import MissingAPIKeyError
9
9
  from edsl.language_models.LanguageModel import LanguageModel
10
10
  from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
11
11
  from edsl.coop import Coop
@@ -39,10 +39,6 @@ class GoogleService(InferenceServiceABC):
39
39
 
40
40
  model_exclude_list = []
41
41
 
42
- # @classmethod
43
- # def available(cls) -> List[str]:
44
- # return ["gemini-pro", "gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.0-pro"]
45
-
46
42
  @classmethod
47
43
  def available(cls) -> List[str]:
48
44
  model_list = []
@@ -66,9 +62,6 @@ class GoogleService(InferenceServiceABC):
66
62
  output_token_name = cls.output_token_name
67
63
  _inference_service_ = cls._inference_service_
68
64
 
69
- _tpm = cls.get_tpm(cls)
70
- _rpm = cls.get_rpm(cls)
71
-
72
65
  _parameters_ = {
73
66
  "temperature": 0.5,
74
67
  "topP": 1,
@@ -77,7 +70,6 @@ class GoogleService(InferenceServiceABC):
77
70
  "stopSequences": [],
78
71
  }
79
72
 
80
- api_token = None
81
73
  model = None
82
74
 
83
75
  def __init__(self, *args, **kwargs):
@@ -102,7 +94,6 @@ class GoogleService(InferenceServiceABC):
102
94
 
103
95
  if files_list is None:
104
96
  files_list = []
105
-
106
97
  genai.configure(api_key=self.api_token)
107
98
  if (
108
99
  system_prompt is not None