edsl 0.1.31.dev4__py3-none-any.whl → 0.1.33__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. edsl/Base.py +9 -3
  2. edsl/TemplateLoader.py +24 -0
  3. edsl/__init__.py +8 -3
  4. edsl/__version__.py +1 -1
  5. edsl/agents/Agent.py +40 -8
  6. edsl/agents/AgentList.py +43 -0
  7. edsl/agents/Invigilator.py +136 -221
  8. edsl/agents/InvigilatorBase.py +148 -59
  9. edsl/agents/{PromptConstructionMixin.py → PromptConstructor.py} +154 -85
  10. edsl/agents/__init__.py +1 -0
  11. edsl/auto/AutoStudy.py +117 -0
  12. edsl/auto/StageBase.py +230 -0
  13. edsl/auto/StageGenerateSurvey.py +178 -0
  14. edsl/auto/StageLabelQuestions.py +125 -0
  15. edsl/auto/StagePersona.py +61 -0
  16. edsl/auto/StagePersonaDimensionValueRanges.py +88 -0
  17. edsl/auto/StagePersonaDimensionValues.py +74 -0
  18. edsl/auto/StagePersonaDimensions.py +69 -0
  19. edsl/auto/StageQuestions.py +73 -0
  20. edsl/auto/SurveyCreatorPipeline.py +21 -0
  21. edsl/auto/utilities.py +224 -0
  22. edsl/config.py +48 -47
  23. edsl/conjure/Conjure.py +6 -0
  24. edsl/coop/PriceFetcher.py +58 -0
  25. edsl/coop/coop.py +50 -7
  26. edsl/data/Cache.py +35 -1
  27. edsl/data/CacheHandler.py +3 -4
  28. edsl/data_transfer_models.py +73 -38
  29. edsl/enums.py +8 -0
  30. edsl/exceptions/general.py +10 -8
  31. edsl/exceptions/language_models.py +25 -1
  32. edsl/exceptions/questions.py +62 -5
  33. edsl/exceptions/results.py +4 -0
  34. edsl/inference_services/AnthropicService.py +13 -11
  35. edsl/inference_services/AwsBedrock.py +112 -0
  36. edsl/inference_services/AzureAI.py +214 -0
  37. edsl/inference_services/DeepInfraService.py +4 -3
  38. edsl/inference_services/GoogleService.py +16 -12
  39. edsl/inference_services/GroqService.py +5 -4
  40. edsl/inference_services/InferenceServiceABC.py +58 -3
  41. edsl/inference_services/InferenceServicesCollection.py +13 -8
  42. edsl/inference_services/MistralAIService.py +120 -0
  43. edsl/inference_services/OllamaService.py +18 -0
  44. edsl/inference_services/OpenAIService.py +55 -56
  45. edsl/inference_services/TestService.py +80 -0
  46. edsl/inference_services/TogetherAIService.py +170 -0
  47. edsl/inference_services/models_available_cache.py +25 -0
  48. edsl/inference_services/registry.py +19 -1
  49. edsl/jobs/Answers.py +10 -12
  50. edsl/jobs/FailedQuestion.py +78 -0
  51. edsl/jobs/Jobs.py +137 -41
  52. edsl/jobs/buckets/BucketCollection.py +24 -15
  53. edsl/jobs/buckets/TokenBucket.py +105 -18
  54. edsl/jobs/interviews/Interview.py +393 -83
  55. edsl/jobs/interviews/{interview_exception_tracking.py → InterviewExceptionCollection.py} +22 -18
  56. edsl/jobs/interviews/InterviewExceptionEntry.py +167 -0
  57. edsl/jobs/runners/JobsRunnerAsyncio.py +152 -160
  58. edsl/jobs/runners/JobsRunnerStatus.py +331 -0
  59. edsl/jobs/tasks/QuestionTaskCreator.py +30 -23
  60. edsl/jobs/tasks/TaskCreators.py +1 -1
  61. edsl/jobs/tasks/TaskHistory.py +205 -126
  62. edsl/language_models/LanguageModel.py +297 -177
  63. edsl/language_models/ModelList.py +2 -2
  64. edsl/language_models/RegisterLanguageModelsMeta.py +14 -29
  65. edsl/language_models/fake_openai_call.py +15 -0
  66. edsl/language_models/fake_openai_service.py +61 -0
  67. edsl/language_models/registry.py +25 -8
  68. edsl/language_models/repair.py +0 -19
  69. edsl/language_models/utilities.py +61 -0
  70. edsl/notebooks/Notebook.py +20 -2
  71. edsl/prompts/Prompt.py +52 -2
  72. edsl/questions/AnswerValidatorMixin.py +23 -26
  73. edsl/questions/QuestionBase.py +330 -249
  74. edsl/questions/QuestionBaseGenMixin.py +133 -0
  75. edsl/questions/QuestionBasePromptsMixin.py +266 -0
  76. edsl/questions/QuestionBudget.py +99 -42
  77. edsl/questions/QuestionCheckBox.py +227 -36
  78. edsl/questions/QuestionExtract.py +98 -28
  79. edsl/questions/QuestionFreeText.py +47 -31
  80. edsl/questions/QuestionFunctional.py +7 -0
  81. edsl/questions/QuestionList.py +141 -23
  82. edsl/questions/QuestionMultipleChoice.py +159 -66
  83. edsl/questions/QuestionNumerical.py +88 -47
  84. edsl/questions/QuestionRank.py +182 -25
  85. edsl/questions/Quick.py +41 -0
  86. edsl/questions/RegisterQuestionsMeta.py +31 -12
  87. edsl/questions/ResponseValidatorABC.py +170 -0
  88. edsl/questions/__init__.py +3 -4
  89. edsl/questions/decorators.py +21 -0
  90. edsl/questions/derived/QuestionLikertFive.py +10 -5
  91. edsl/questions/derived/QuestionLinearScale.py +15 -2
  92. edsl/questions/derived/QuestionTopK.py +10 -1
  93. edsl/questions/derived/QuestionYesNo.py +24 -3
  94. edsl/questions/descriptors.py +43 -7
  95. edsl/questions/prompt_templates/question_budget.jinja +13 -0
  96. edsl/questions/prompt_templates/question_checkbox.jinja +32 -0
  97. edsl/questions/prompt_templates/question_extract.jinja +11 -0
  98. edsl/questions/prompt_templates/question_free_text.jinja +3 -0
  99. edsl/questions/prompt_templates/question_linear_scale.jinja +11 -0
  100. edsl/questions/prompt_templates/question_list.jinja +17 -0
  101. edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -0
  102. edsl/questions/prompt_templates/question_numerical.jinja +37 -0
  103. edsl/questions/question_registry.py +6 -2
  104. edsl/questions/templates/__init__.py +0 -0
  105. edsl/questions/templates/budget/__init__.py +0 -0
  106. edsl/questions/templates/budget/answering_instructions.jinja +7 -0
  107. edsl/questions/templates/budget/question_presentation.jinja +7 -0
  108. edsl/questions/templates/checkbox/__init__.py +0 -0
  109. edsl/questions/templates/checkbox/answering_instructions.jinja +10 -0
  110. edsl/questions/templates/checkbox/question_presentation.jinja +22 -0
  111. edsl/questions/templates/extract/__init__.py +0 -0
  112. edsl/questions/templates/extract/answering_instructions.jinja +7 -0
  113. edsl/questions/templates/extract/question_presentation.jinja +1 -0
  114. edsl/questions/templates/free_text/__init__.py +0 -0
  115. edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
  116. edsl/questions/templates/free_text/question_presentation.jinja +1 -0
  117. edsl/questions/templates/likert_five/__init__.py +0 -0
  118. edsl/questions/templates/likert_five/answering_instructions.jinja +10 -0
  119. edsl/questions/templates/likert_five/question_presentation.jinja +12 -0
  120. edsl/questions/templates/linear_scale/__init__.py +0 -0
  121. edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -0
  122. edsl/questions/templates/linear_scale/question_presentation.jinja +5 -0
  123. edsl/questions/templates/list/__init__.py +0 -0
  124. edsl/questions/templates/list/answering_instructions.jinja +4 -0
  125. edsl/questions/templates/list/question_presentation.jinja +5 -0
  126. edsl/questions/templates/multiple_choice/__init__.py +0 -0
  127. edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -0
  128. edsl/questions/templates/multiple_choice/html.jinja +0 -0
  129. edsl/questions/templates/multiple_choice/question_presentation.jinja +12 -0
  130. edsl/questions/templates/numerical/__init__.py +0 -0
  131. edsl/questions/templates/numerical/answering_instructions.jinja +8 -0
  132. edsl/questions/templates/numerical/question_presentation.jinja +7 -0
  133. edsl/questions/templates/rank/__init__.py +0 -0
  134. edsl/questions/templates/rank/answering_instructions.jinja +11 -0
  135. edsl/questions/templates/rank/question_presentation.jinja +15 -0
  136. edsl/questions/templates/top_k/__init__.py +0 -0
  137. edsl/questions/templates/top_k/answering_instructions.jinja +8 -0
  138. edsl/questions/templates/top_k/question_presentation.jinja +22 -0
  139. edsl/questions/templates/yes_no/__init__.py +0 -0
  140. edsl/questions/templates/yes_no/answering_instructions.jinja +6 -0
  141. edsl/questions/templates/yes_no/question_presentation.jinja +12 -0
  142. edsl/results/Dataset.py +20 -0
  143. edsl/results/DatasetExportMixin.py +58 -30
  144. edsl/results/DatasetTree.py +145 -0
  145. edsl/results/Result.py +32 -5
  146. edsl/results/Results.py +135 -46
  147. edsl/results/ResultsDBMixin.py +3 -3
  148. edsl/results/Selector.py +118 -0
  149. edsl/results/tree_explore.py +115 -0
  150. edsl/scenarios/FileStore.py +71 -10
  151. edsl/scenarios/Scenario.py +109 -24
  152. edsl/scenarios/ScenarioImageMixin.py +2 -2
  153. edsl/scenarios/ScenarioList.py +546 -21
  154. edsl/scenarios/ScenarioListExportMixin.py +24 -4
  155. edsl/scenarios/ScenarioListPdfMixin.py +153 -4
  156. edsl/study/SnapShot.py +8 -1
  157. edsl/study/Study.py +32 -0
  158. edsl/surveys/Rule.py +15 -3
  159. edsl/surveys/RuleCollection.py +21 -5
  160. edsl/surveys/Survey.py +707 -298
  161. edsl/surveys/SurveyExportMixin.py +71 -9
  162. edsl/surveys/SurveyFlowVisualizationMixin.py +2 -1
  163. edsl/surveys/SurveyQualtricsImport.py +284 -0
  164. edsl/surveys/instructions/ChangeInstruction.py +47 -0
  165. edsl/surveys/instructions/Instruction.py +34 -0
  166. edsl/surveys/instructions/InstructionCollection.py +77 -0
  167. edsl/surveys/instructions/__init__.py +0 -0
  168. edsl/templates/error_reporting/base.html +24 -0
  169. edsl/templates/error_reporting/exceptions_by_model.html +35 -0
  170. edsl/templates/error_reporting/exceptions_by_question_name.html +17 -0
  171. edsl/templates/error_reporting/exceptions_by_type.html +17 -0
  172. edsl/templates/error_reporting/interview_details.html +116 -0
  173. edsl/templates/error_reporting/interviews.html +10 -0
  174. edsl/templates/error_reporting/overview.html +5 -0
  175. edsl/templates/error_reporting/performance_plot.html +2 -0
  176. edsl/templates/error_reporting/report.css +74 -0
  177. edsl/templates/error_reporting/report.html +118 -0
  178. edsl/templates/error_reporting/report.js +25 -0
  179. edsl/utilities/utilities.py +40 -1
  180. {edsl-0.1.31.dev4.dist-info → edsl-0.1.33.dist-info}/METADATA +8 -2
  181. edsl-0.1.33.dist-info/RECORD +295 -0
  182. edsl/jobs/interviews/InterviewTaskBuildingMixin.py +0 -271
  183. edsl/jobs/interviews/retry_management.py +0 -37
  184. edsl/jobs/runners/JobsRunnerStatusMixin.py +0 -303
  185. edsl/utilities/gcp_bucket/simple_example.py +0 -9
  186. edsl-0.1.31.dev4.dist-info/RECORD +0 -204
  187. {edsl-0.1.31.dev4.dist-info → edsl-0.1.33.dist-info}/LICENSE +0 -0
  188. {edsl-0.1.31.dev4.dist-info → edsl-0.1.33.dist-info}/WHEEL +0 -0
edsl/config.py CHANGED
@@ -1,74 +1,66 @@
1
1
  """This module provides a Config class that loads environment variables from a .env file and sets them as class attributes."""
2
2
 
3
3
  import os
4
+ from dotenv import load_dotenv, find_dotenv
4
5
  from edsl.exceptions import (
5
6
  InvalidEnvironmentVariableError,
6
7
  MissingEnvironmentVariableError,
7
8
  )
8
- from dotenv import load_dotenv, find_dotenv
9
9
 
10
10
  # valid values for EDSL_RUN_MODE
11
- EDSL_RUN_MODES = ["development", "development-testrun", "production"]
11
+ EDSL_RUN_MODES = [
12
+ "development",
13
+ "development-testrun",
14
+ "production",
15
+ ]
12
16
 
13
17
  # `default` is used to impute values only in "production" mode
14
18
  # `info` gives a brief description of the env var
15
19
  CONFIG_MAP = {
16
20
  "EDSL_RUN_MODE": {
17
21
  "default": "production",
18
- "info": "This env var determines the run mode of the application.",
19
- },
20
- "EDSL_DATABASE_PATH": {
21
- "default": f"sqlite:///{os.path.join(os.getcwd(), '.edsl_cache/data.db')}",
22
- "info": "This env var determines the path to the cache file.",
23
- },
24
- "EDSL_LOGGING_PATH": {
25
- "default": f"{os.path.join(os.getcwd(), 'interview.log')}",
26
- "info": "This env var determines the path to the log file.",
22
+ "info": "This config var determines the run mode of the application.",
27
23
  },
28
24
  "EDSL_API_TIMEOUT": {
29
25
  "default": "60",
30
- "info": "This env var determines the maximum number of seconds to wait for an API call to return.",
26
+ "info": "This config var determines the maximum number of seconds to wait for an API call to return.",
31
27
  },
32
28
  "EDSL_BACKOFF_START_SEC": {
33
29
  "default": "1",
34
- "info": "This env var determines the number of seconds to wait before retrying a failed API call.",
30
+ "info": "This config var determines the number of seconds to wait before retrying a failed API call.",
35
31
  },
36
- "EDSL_MAX_BACKOFF_SEC": {
32
+ "EDSL_BACKOFF_MAX_SEC": {
37
33
  "default": "60",
38
- "info": "This env var determines the maximum number of seconds to wait before retrying a failed API call.",
34
+ "info": "This config var determines the maximum number of seconds to wait before retrying a failed API call.",
35
+ },
36
+ "EDSL_DATABASE_PATH": {
37
+ "default": f"sqlite:///{os.path.join(os.getcwd(), '.edsl_cache/data.db')}",
38
+ "info": "This config var determines the path to the cache file.",
39
+ },
40
+ "EDSL_DEFAULT_MODEL": {
41
+ "default": "gpt-4o",
42
+ "info": "This config var holds the default model that will be used if a model is not explicitly passed.",
43
+ },
44
+ "EDSL_FETCH_TOKEN_PRICES": {
45
+ "default": "True",
46
+ "info": "This config var determines whether to fetch prices for tokens used in remote inference",
39
47
  },
40
48
  "EDSL_MAX_ATTEMPTS": {
41
49
  "default": "5",
42
- "info": "This env var determines the maximum number of times to retry a failed API call.",
50
+ "info": "This config var determines the maximum number of times to retry a failed API call.",
51
+ },
52
+ "EDSL_SERVICE_RPM_BASELINE": {
53
+ "default": "100",
54
+ "info": "This config var holds the maximum number of requests per minute. Model-specific values provided in env vars such as EDSL_SERVICE_RPM_OPENAI will override this. value for the corresponding model",
55
+ },
56
+ "EDSL_SERVICE_TPM_BASELINE": {
57
+ "default": "2000000",
58
+ "info": "This config var holds the maximum number of tokens per minute for all models. Model-specific values provided in env vars such as EDSL_SERVICE_TPM_OPENAI will override this value for the corresponding model.",
43
59
  },
44
60
  "EXPECTED_PARROT_URL": {
45
61
  "default": "https://www.expectedparrot.com",
46
- "info": "This env var holds the URL of the Expected Parrot API.",
62
+ "info": "This config var holds the URL of the Expected Parrot API.",
47
63
  },
48
- # "EXPECTED_PARROT_API_KEY": {
49
- # "default": None,
50
- # "info": "This env var holds your Expected Parrot API key (https://www.expectedparrot.com/).",
51
- # },
52
- # "OPENAI_API_KEY": {
53
- # "default": None,
54
- # "info": "This env var holds your OpenAI API key (https://platform.openai.com/api-keys).",
55
- # },
56
- # "DEEP_INFRA_API_KEY": {
57
- # "default": None,
58
- # "info": "This env var holds your DeepInfra API key (https://deepinfra.com/).",
59
- # },
60
- # "GOOGLE_API_KEY": {
61
- # "default": None,
62
- # "info": "This env var holds your Google API key (https://console.cloud.google.com/apis/credentials).",
63
- # },
64
- # "ANTHROPIC_API_KEY": {
65
- # "default": None,
66
- # "info": "This env var holds your Anthropic API key (https://www.anthropic.com/).",
67
- # },
68
- # "GROQ_API_KEY": {
69
- # "default": None,
70
- # "info": "This env var holds your GROQ API key (https://console.groq.com/login).",
71
- # },
72
64
  }
73
65
 
74
66
 
@@ -83,7 +75,7 @@ class Config:
83
75
 
84
76
  def _set_run_mode(self) -> None:
85
77
  """
86
- Checks the validity and sets EDSL_RUN_MODE.
78
+ Sets EDSL_RUN_MODE as a class attribute.
87
79
  """
88
80
  run_mode = os.getenv("EDSL_RUN_MODE")
89
81
  default = CONFIG_MAP.get("EDSL_RUN_MODE").get("default")
@@ -98,26 +90,35 @@ class Config:
98
90
  def _load_dotenv(self) -> None:
99
91
  """
100
92
  Loads the .env
101
- - Overrides existing env vars unless EDSL_RUN_MODE=="development-testrun"
93
+ - The .env will override existing env vars **unless** EDSL_RUN_MODE=="development-testrun"
102
94
  """
103
95
 
104
- override = True
105
96
  if self.EDSL_RUN_MODE == "development-testrun":
106
97
  override = False
98
+ else:
99
+ override = True
107
100
  _ = load_dotenv(dotenv_path=find_dotenv(usecwd=True), override=override)
108
101
 
102
+ def __contains__(self, env_var: str) -> bool:
103
+ """
104
+ Checks if an env var is set as a class attribute.
105
+ """
106
+ return env_var in self.__dict__
107
+
109
108
  def _set_env_vars(self) -> None:
110
109
  """
111
- Sets env vars as Config class attributes.
110
+ Sets env vars as class attributes.
111
+ - EDSL_RUN_MODE is not set my this method, but by _set_run_mode
112
112
  - If an env var is not set and has a default value in the CONFIG_MAP, sets it to the default value.
113
113
  """
114
114
  # for each env var in the CONFIG_MAP
115
115
  for env_var, config in CONFIG_MAP.items():
116
+ # EDSL_RUN_MODE is already set by _set_run_mode
116
117
  if env_var == "EDSL_RUN_MODE":
117
- continue # we've set it already in _set_run_mode
118
+ continue
118
119
  value = os.getenv(env_var)
119
120
  default_value = config.get("default")
120
- # if the env var is set, set it as a CONFIG attribute
121
+ # if an env var exists, set it as a class attribute
121
122
  if value:
122
123
  setattr(self, env_var, value)
123
124
  # otherwise, if EDSL_RUN_MODE == "production" set it to its default value
edsl/conjure/Conjure.py CHANGED
@@ -35,6 +35,12 @@ class Conjure:
35
35
  # The __init__ method in Conjure won't be called because __new__ returns a different class instance.
36
36
  pass
37
37
 
38
+ @classmethod
39
+ def example(cls):
40
+ from edsl.conjure.InputData import InputDataABC
41
+
42
+ return InputDataABC.example()
43
+
38
44
 
39
45
  if __name__ == "__main__":
40
46
  pass
@@ -0,0 +1,58 @@
1
+ import requests
2
+ import csv
3
+ from io import StringIO
4
+
5
+
6
+ class PriceFetcher:
7
+ _instance = None
8
+
9
+ def __new__(cls):
10
+ if cls._instance is None:
11
+ cls._instance = super(PriceFetcher, cls).__new__(cls)
12
+ cls._instance._cached_prices = None
13
+ return cls._instance
14
+
15
+ def fetch_prices(self):
16
+ if self._cached_prices is not None:
17
+ return self._cached_prices
18
+
19
+ import requests
20
+ import csv
21
+ from io import StringIO
22
+
23
+ sheet_id = "1SAO3Bhntefl0XQHJv27rMxpvu6uzKDWNXFHRa7jrUDs"
24
+
25
+ # Construct the URL to fetch the CSV
26
+ url = f"https://docs.google.com/spreadsheets/d/{sheet_id}/export?format=csv"
27
+
28
+ try:
29
+ # Fetch the CSV data
30
+ response = requests.get(url)
31
+ response.raise_for_status() # Raise an exception for bad responses
32
+
33
+ # Parse the CSV data
34
+ csv_data = StringIO(response.text)
35
+ reader = csv.reader(csv_data)
36
+
37
+ # Convert to list of dictionaries
38
+ headers = next(reader)
39
+ data = [dict(zip(headers, row)) for row in reader]
40
+
41
+ # self._cached_prices = data
42
+ # return data
43
+ price_lookup = {}
44
+ for entry in data:
45
+ service = entry.get("service", None)
46
+ model = entry.get("model", None)
47
+ if service and model:
48
+ token_type = entry.get("token_type", None)
49
+ if (service, model) in price_lookup:
50
+ price_lookup[(service, model)].update({token_type: entry})
51
+ else:
52
+ price_lookup[(service, model)] = {token_type: entry}
53
+ self._cached_prices = price_lookup
54
+ return self._cached_prices
55
+
56
+ except requests.RequestException as e:
57
+ # print(f"An error occurred: {e}")
58
+ return {}
edsl/coop/coop.py CHANGED
@@ -53,25 +53,39 @@ class Coop:
53
53
  method: str,
54
54
  payload: Optional[dict[str, Any]] = None,
55
55
  params: Optional[dict[str, Any]] = None,
56
+ timeout: Optional[float] = 5,
56
57
  ) -> requests.Response:
57
58
  """
58
59
  Send a request to the server and return the response.
59
60
  """
60
61
  url = f"{self.url}/{uri}"
62
+ method = method.upper()
63
+ if payload is None:
64
+ timeout = 20
65
+ elif (
66
+ method.upper() == "POST"
67
+ and "json_string" in payload
68
+ and payload.get("json_string") is not None
69
+ ):
70
+ timeout = max(20, (len(payload.get("json_string", "")) // (1024 * 1024)))
61
71
  try:
62
- method = method.upper()
63
72
  if method in ["GET", "DELETE"]:
64
73
  response = requests.request(
65
- method, url, params=params, headers=self.headers
74
+ method, url, params=params, headers=self.headers, timeout=timeout
66
75
  )
67
76
  elif method in ["POST", "PATCH"]:
68
77
  response = requests.request(
69
- method, url, params=params, json=payload, headers=self.headers
78
+ method,
79
+ url,
80
+ params=params,
81
+ json=payload,
82
+ headers=self.headers,
83
+ timeout=timeout,
70
84
  )
71
85
  else:
72
86
  raise Exception(f"Invalid {method=}.")
73
87
  except requests.ConnectionError:
74
- raise requests.ConnectionError("Could not connect to the server.")
88
+ raise requests.ConnectionError(f"Could not connect to the server at {url}.")
75
89
 
76
90
  return response
77
91
 
@@ -81,6 +95,7 @@ class Coop:
81
95
  """
82
96
  if response.status_code >= 400:
83
97
  message = response.json().get("detail")
98
+ # print(response.text)
84
99
  if "Authorization" in message:
85
100
  print(message)
86
101
  message = "Please provide an Expected Parrot API key."
@@ -110,10 +125,18 @@ class Coop:
110
125
  def edsl_settings(self) -> dict:
111
126
  """
112
127
  Retrieve and return the EDSL settings stored on Coop.
128
+ If no response is received within 5 seconds, return an empty dict.
113
129
  """
114
- response = self._send_server_request(uri="api/v0/edsl-settings", method="GET")
115
- self._resolve_server_response(response)
116
- return response.json()
130
+ from requests.exceptions import Timeout
131
+
132
+ try:
133
+ response = self._send_server_request(
134
+ uri="api/v0/edsl-settings", method="GET", timeout=5
135
+ )
136
+ self._resolve_server_response(response)
137
+ return response.json()
138
+ except Timeout:
139
+ return {}
117
140
 
118
141
  ################
119
142
  # Objects
@@ -625,6 +648,26 @@ class Coop:
625
648
 
626
649
  return response_json
627
650
 
651
+ def fetch_prices(self) -> dict:
652
+ from edsl.coop.PriceFetcher import PriceFetcher
653
+
654
+ from edsl.config import CONFIG
655
+
656
+ if bool(CONFIG.get("EDSL_FETCH_TOKEN_PRICES")):
657
+ price_fetcher = PriceFetcher()
658
+ return price_fetcher.fetch_prices()
659
+ else:
660
+ return {}
661
+
662
+
663
+ if __name__ == "__main__":
664
+ sheet_data = fetch_sheet_data()
665
+ if sheet_data:
666
+ print(f"Successfully fetched {len(sheet_data)} rows of data.")
667
+ print("First row:", sheet_data[0])
668
+ else:
669
+ print("Failed to fetch sheet data.")
670
+
628
671
 
629
672
  def main():
630
673
  """
edsl/data/Cache.py CHANGED
@@ -6,6 +6,7 @@ from __future__ import annotations
6
6
  import json
7
7
  import os
8
8
  import warnings
9
+ import copy
9
10
  from typing import Optional, Union
10
11
  from edsl.Base import Base
11
12
  from edsl.data.CacheEntry import CacheEntry
@@ -88,11 +89,24 @@ class Cache(Base):
88
89
  # raise NotImplementedError("This method is not implemented yet.")
89
90
 
90
91
  def keys(self):
92
+ """
93
+ >>> from edsl import Cache
94
+ >>> Cache.example().keys()
95
+ ['5513286eb6967abc0511211f0402587d']
96
+ """
91
97
  return list(self.data.keys())
92
98
 
93
99
  def values(self):
100
+ """
101
+ >>> from edsl import Cache
102
+ >>> Cache.example().values()
103
+ [CacheEntry(...)]
104
+ """
94
105
  return list(self.data.values())
95
106
 
107
+ def items(self):
108
+ return zip(self.keys(), self.values())
109
+
96
110
  def new_entries_cache(self) -> Cache:
97
111
  """Return a new Cache object with the new entries."""
98
112
  return Cache(data={**self.new_entries, **self.fetched_data})
@@ -160,7 +174,7 @@ class Cache(Base):
160
174
  parameters: str,
161
175
  system_prompt: str,
162
176
  user_prompt: str,
163
- response: str,
177
+ response: dict,
164
178
  iteration: int,
165
179
  ) -> str:
166
180
  """
@@ -174,6 +188,15 @@ class Cache(Base):
174
188
  * The key-value pair is added to `self.new_entries`
175
189
  * If `immediate_write` is True , the key-value pair is added to `self.data`
176
190
  * If `immediate_write` is False, the key-value pair is added to `self.new_entries_to_write_later`
191
+
192
+ >>> from edsl import Cache, Model, Question
193
+ >>> m = Model("test")
194
+ >>> c = Cache()
195
+ >>> len(c)
196
+ 0
197
+ >>> results = Question.example("free_text").by(m).run(cache = c)
198
+ >>> len(c)
199
+ 1
177
200
  """
178
201
 
179
202
  entry = CacheEntry(
@@ -326,6 +349,17 @@ class Cache(Base):
326
349
  for key, value in self.data.items():
327
350
  f.write(json.dumps({key: value.to_dict()}) + "\n")
328
351
 
352
+ def to_scenario_list(self):
353
+ from edsl import ScenarioList, Scenario
354
+
355
+ scenarios = []
356
+ for key, value in self.data.items():
357
+ new_d = value.to_dict()
358
+ new_d["cache_key"] = key
359
+ s = Scenario(new_d)
360
+ scenarios.append(s)
361
+ return ScenarioList(scenarios)
362
+
329
363
  ####################
330
364
  # REMOTE
331
365
  ####################
edsl/data/CacheHandler.py CHANGED
@@ -41,7 +41,7 @@ class CacheHandler:
41
41
  old_data = self.from_old_sqlite_cache()
42
42
  self.cache.add_from_dict(old_data)
43
43
 
44
- def create_cache_directory(self) -> None:
44
+ def create_cache_directory(self, notify=False) -> None:
45
45
  """
46
46
  Create the cache directory if one is required and it does not exist.
47
47
  """
@@ -49,9 +49,8 @@ class CacheHandler:
49
49
  dir_path = os.path.dirname(path)
50
50
  if dir_path and not os.path.exists(dir_path):
51
51
  os.makedirs(dir_path)
52
- import warnings
53
-
54
- warnings.warn(f"Created cache directory: {dir_path}")
52
+ if notify:
53
+ print(f"Created cache directory: {dir_path}")
55
54
 
56
55
  def gen_cache(self) -> Cache:
57
56
  """
@@ -1,38 +1,73 @@
1
- """This module contains the data transfer models for the application."""
2
-
3
- from collections import UserDict
4
-
5
-
6
- class AgentResponseDict(UserDict):
7
- """A dictionary to store the response of the agent to a question."""
8
-
9
- def __init__(
10
- self,
11
- *,
12
- question_name,
13
- answer,
14
- prompts,
15
- usage=None,
16
- comment=None,
17
- cached_response=None,
18
- raw_model_response=None,
19
- simple_model_raw_response=None,
20
- cache_used=None,
21
- cache_key=None,
22
- ):
23
- """Initialize the AgentResponseDict object."""
24
- usage = usage or {"prompt_tokens": 0, "completion_tokens": 0}
25
- super().__init__(
26
- {
27
- "answer": answer,
28
- "comment": comment,
29
- "question_name": question_name,
30
- "prompts": prompts,
31
- "usage": usage,
32
- "cached_response": cached_response,
33
- "raw_model_response": raw_model_response,
34
- "simple_model_raw_response": simple_model_raw_response,
35
- "cache_used": cache_used,
36
- "cache_key": cache_key,
37
- }
38
- )
1
+ from typing import NamedTuple, Dict, List, Optional, Any
2
+ from dataclasses import dataclass, fields
3
+ import reprlib
4
+
5
+
6
+ class ModelInputs(NamedTuple):
7
+ "This is what was send by the agent to the model"
8
+ user_prompt: str
9
+ system_prompt: str
10
+ encoded_image: Optional[str] = None
11
+
12
+
13
+ class EDSLOutput(NamedTuple):
14
+ "This is the edsl dictionary that is returned by the model"
15
+ answer: Any
16
+ generated_tokens: str
17
+ comment: Optional[str] = None
18
+
19
+
20
+ class ModelResponse(NamedTuple):
21
+ "This is the metadata that is returned by the model and includes info about the cache"
22
+ response: dict
23
+ cache_used: bool
24
+ cache_key: str
25
+ cached_response: Optional[Dict[str, Any]] = None
26
+ cost: Optional[float] = None
27
+
28
+
29
+ class AgentResponseDict(NamedTuple):
30
+ edsl_dict: EDSLOutput
31
+ model_inputs: ModelInputs
32
+ model_outputs: ModelResponse
33
+
34
+
35
+ class EDSLResultObjectInput(NamedTuple):
36
+ generated_tokens: str
37
+ question_name: str
38
+ prompts: dict
39
+ cached_response: str
40
+ raw_model_response: str
41
+ cache_used: bool
42
+ cache_key: str
43
+ answer: Any
44
+ comment: str
45
+ validated: bool = False
46
+ exception_occurred: Exception = None
47
+ cost: Optional[float] = None
48
+
49
+
50
+ @dataclass
51
+ class ImageInfo:
52
+ file_path: str
53
+ file_name: str
54
+ image_format: str
55
+ file_size: int
56
+ encoded_image: str
57
+
58
+ def __repr__(self):
59
+ reprlib_instance = reprlib.Repr()
60
+ reprlib_instance.maxstring = 30 # Limit the string length for the encoded image
61
+
62
+ # Get all fields except encoded_image
63
+ field_reprs = [
64
+ f"{f.name}={getattr(self, f.name)!r}"
65
+ for f in fields(self)
66
+ if f.name != "encoded_image"
67
+ ]
68
+
69
+ # Add the reprlib-restricted encoded_image field
70
+ field_reprs.append(f"encoded_image={reprlib_instance.repr(self.encoded_image)}")
71
+
72
+ # Join everything to create the repr
73
+ return f"{self.__class__.__name__}({', '.join(field_reprs)})"
edsl/enums.py CHANGED
@@ -60,6 +60,11 @@ class InferenceServiceType(EnumWithChecks):
60
60
  TEST = "test"
61
61
  ANTHROPIC = "anthropic"
62
62
  GROQ = "groq"
63
+ AZURE = "azure"
64
+ OLLAMA = "ollama"
65
+ MISTRAL = "mistral"
66
+ TOGETHER = "together"
67
+
63
68
 
64
69
  service_to_api_keyname = {
65
70
  InferenceServiceType.BEDROCK.value: "TBD",
@@ -70,6 +75,9 @@ service_to_api_keyname = {
70
75
  InferenceServiceType.TEST.value: "TBD",
71
76
  InferenceServiceType.ANTHROPIC.value: "ANTHROPIC_API_KEY",
72
77
  InferenceServiceType.GROQ.value: "GROQ_API_KEY",
78
+ InferenceServiceType.BEDROCK.value: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"],
79
+ InferenceServiceType.MISTRAL.value: "MISTRAL_API_KEY",
80
+ InferenceServiceType.TOGETHER.value: "TOGETHER_API_KEY",
73
81
  }
74
82
 
75
83
 
@@ -21,12 +21,14 @@ class GeneralErrors(Exception):
21
21
 
22
22
 
23
23
  class MissingAPIKeyError(GeneralErrors):
24
- def __init__(self, model_name, inference_service):
25
- full_message = dedent(
26
- f"""
27
- An API Key for model `{model_name}` is missing from the .env file.
28
- This key is associated with the inference service `{inference_service}`.
29
- Please see https://docs.expectedparrot.com/en/latest/api_keys.html for more information.
30
- """
31
- )
24
+ def __init__(self, full_message=None, model_name=None, inference_service=None):
25
+ if model_name and inference_service:
26
+ full_message = dedent(
27
+ f"""
28
+ An API Key for model `{model_name}` is missing from the .env file.
29
+ This key is associated with the inference service `{inference_service}`.
30
+ Please see https://docs.expectedparrot.com/en/latest/api_keys.html for more information.
31
+ """
32
+ )
33
+
32
34
  super().__init__(full_message)
@@ -1,8 +1,32 @@
1
1
  from textwrap import dedent
2
+ from typing import Optional
2
3
 
3
4
 
4
5
  class LanguageModelExceptions(Exception):
5
- pass
6
+ explanation = (
7
+ "This is the base class for all exceptions in the LanguageModel class."
8
+ )
9
+
10
+ def __init__(self, message):
11
+ super().__init__(message)
12
+ self.message = message
13
+
14
+
15
+ class LanguageModelNoResponseError(LanguageModelExceptions):
16
+ explanation = (
17
+ """This happens when the LLM API cannot be reached and/or does not respond."""
18
+ )
19
+
20
+ def __init__(self, message):
21
+ super().__init__(message)
22
+
23
+
24
+ class LanguageModelBadResponseError(LanguageModelExceptions):
25
+ explanation = """This happens when the LLM API can be reached and responds, does not return a usable answer."""
26
+
27
+ def __init__(self, message, response_json: Optional[dict] = None):
28
+ super().__init__(message)
29
+ self.response_json = response_json
6
30
 
7
31
 
8
32
  class LanguageModelNotFound(LanguageModelExceptions):