edsl 0.1.32__py3-none-any.whl → 0.1.33__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. edsl/Base.py +9 -3
  2. edsl/TemplateLoader.py +24 -0
  3. edsl/__init__.py +8 -3
  4. edsl/__version__.py +1 -1
  5. edsl/agents/Agent.py +40 -8
  6. edsl/agents/AgentList.py +43 -0
  7. edsl/agents/Invigilator.py +135 -219
  8. edsl/agents/InvigilatorBase.py +148 -59
  9. edsl/agents/{PromptConstructionMixin.py → PromptConstructor.py} +138 -89
  10. edsl/agents/__init__.py +1 -0
  11. edsl/auto/AutoStudy.py +117 -0
  12. edsl/auto/StageBase.py +230 -0
  13. edsl/auto/StageGenerateSurvey.py +178 -0
  14. edsl/auto/StageLabelQuestions.py +125 -0
  15. edsl/auto/StagePersona.py +61 -0
  16. edsl/auto/StagePersonaDimensionValueRanges.py +88 -0
  17. edsl/auto/StagePersonaDimensionValues.py +74 -0
  18. edsl/auto/StagePersonaDimensions.py +69 -0
  19. edsl/auto/StageQuestions.py +73 -0
  20. edsl/auto/SurveyCreatorPipeline.py +21 -0
  21. edsl/auto/utilities.py +224 -0
  22. edsl/config.py +47 -56
  23. edsl/coop/PriceFetcher.py +58 -0
  24. edsl/coop/coop.py +50 -7
  25. edsl/data/Cache.py +35 -1
  26. edsl/data_transfer_models.py +73 -38
  27. edsl/enums.py +4 -0
  28. edsl/exceptions/language_models.py +25 -1
  29. edsl/exceptions/questions.py +62 -5
  30. edsl/exceptions/results.py +4 -0
  31. edsl/inference_services/AnthropicService.py +13 -11
  32. edsl/inference_services/AwsBedrock.py +19 -17
  33. edsl/inference_services/AzureAI.py +37 -20
  34. edsl/inference_services/GoogleService.py +16 -12
  35. edsl/inference_services/GroqService.py +2 -0
  36. edsl/inference_services/InferenceServiceABC.py +58 -3
  37. edsl/inference_services/MistralAIService.py +120 -0
  38. edsl/inference_services/OpenAIService.py +48 -54
  39. edsl/inference_services/TestService.py +80 -0
  40. edsl/inference_services/TogetherAIService.py +170 -0
  41. edsl/inference_services/models_available_cache.py +0 -6
  42. edsl/inference_services/registry.py +6 -0
  43. edsl/jobs/Answers.py +10 -12
  44. edsl/jobs/FailedQuestion.py +78 -0
  45. edsl/jobs/Jobs.py +37 -22
  46. edsl/jobs/buckets/BucketCollection.py +24 -15
  47. edsl/jobs/buckets/TokenBucket.py +93 -14
  48. edsl/jobs/interviews/Interview.py +366 -78
  49. edsl/jobs/interviews/{interview_exception_tracking.py → InterviewExceptionCollection.py} +14 -68
  50. edsl/jobs/interviews/InterviewExceptionEntry.py +85 -19
  51. edsl/jobs/runners/JobsRunnerAsyncio.py +146 -175
  52. edsl/jobs/runners/JobsRunnerStatus.py +331 -0
  53. edsl/jobs/tasks/QuestionTaskCreator.py +30 -23
  54. edsl/jobs/tasks/TaskHistory.py +148 -213
  55. edsl/language_models/LanguageModel.py +261 -156
  56. edsl/language_models/ModelList.py +2 -2
  57. edsl/language_models/RegisterLanguageModelsMeta.py +14 -29
  58. edsl/language_models/fake_openai_call.py +15 -0
  59. edsl/language_models/fake_openai_service.py +61 -0
  60. edsl/language_models/registry.py +23 -6
  61. edsl/language_models/repair.py +0 -19
  62. edsl/language_models/utilities.py +61 -0
  63. edsl/notebooks/Notebook.py +20 -2
  64. edsl/prompts/Prompt.py +52 -2
  65. edsl/questions/AnswerValidatorMixin.py +23 -26
  66. edsl/questions/QuestionBase.py +330 -249
  67. edsl/questions/QuestionBaseGenMixin.py +133 -0
  68. edsl/questions/QuestionBasePromptsMixin.py +266 -0
  69. edsl/questions/QuestionBudget.py +99 -41
  70. edsl/questions/QuestionCheckBox.py +227 -35
  71. edsl/questions/QuestionExtract.py +98 -27
  72. edsl/questions/QuestionFreeText.py +52 -29
  73. edsl/questions/QuestionFunctional.py +7 -0
  74. edsl/questions/QuestionList.py +141 -22
  75. edsl/questions/QuestionMultipleChoice.py +159 -65
  76. edsl/questions/QuestionNumerical.py +88 -46
  77. edsl/questions/QuestionRank.py +182 -24
  78. edsl/questions/Quick.py +41 -0
  79. edsl/questions/RegisterQuestionsMeta.py +31 -12
  80. edsl/questions/ResponseValidatorABC.py +170 -0
  81. edsl/questions/__init__.py +3 -4
  82. edsl/questions/decorators.py +21 -0
  83. edsl/questions/derived/QuestionLikertFive.py +10 -5
  84. edsl/questions/derived/QuestionLinearScale.py +15 -2
  85. edsl/questions/derived/QuestionTopK.py +10 -1
  86. edsl/questions/derived/QuestionYesNo.py +24 -3
  87. edsl/questions/descriptors.py +43 -7
  88. edsl/questions/prompt_templates/question_budget.jinja +13 -0
  89. edsl/questions/prompt_templates/question_checkbox.jinja +32 -0
  90. edsl/questions/prompt_templates/question_extract.jinja +11 -0
  91. edsl/questions/prompt_templates/question_free_text.jinja +3 -0
  92. edsl/questions/prompt_templates/question_linear_scale.jinja +11 -0
  93. edsl/questions/prompt_templates/question_list.jinja +17 -0
  94. edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -0
  95. edsl/questions/prompt_templates/question_numerical.jinja +37 -0
  96. edsl/questions/question_registry.py +6 -2
  97. edsl/questions/templates/__init__.py +0 -0
  98. edsl/questions/templates/budget/__init__.py +0 -0
  99. edsl/questions/templates/budget/answering_instructions.jinja +7 -0
  100. edsl/questions/templates/budget/question_presentation.jinja +7 -0
  101. edsl/questions/templates/checkbox/__init__.py +0 -0
  102. edsl/questions/templates/checkbox/answering_instructions.jinja +10 -0
  103. edsl/questions/templates/checkbox/question_presentation.jinja +22 -0
  104. edsl/questions/templates/extract/__init__.py +0 -0
  105. edsl/questions/templates/extract/answering_instructions.jinja +7 -0
  106. edsl/questions/templates/extract/question_presentation.jinja +1 -0
  107. edsl/questions/templates/free_text/__init__.py +0 -0
  108. edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
  109. edsl/questions/templates/free_text/question_presentation.jinja +1 -0
  110. edsl/questions/templates/likert_five/__init__.py +0 -0
  111. edsl/questions/templates/likert_five/answering_instructions.jinja +10 -0
  112. edsl/questions/templates/likert_five/question_presentation.jinja +12 -0
  113. edsl/questions/templates/linear_scale/__init__.py +0 -0
  114. edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -0
  115. edsl/questions/templates/linear_scale/question_presentation.jinja +5 -0
  116. edsl/questions/templates/list/__init__.py +0 -0
  117. edsl/questions/templates/list/answering_instructions.jinja +4 -0
  118. edsl/questions/templates/list/question_presentation.jinja +5 -0
  119. edsl/questions/templates/multiple_choice/__init__.py +0 -0
  120. edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -0
  121. edsl/questions/templates/multiple_choice/html.jinja +0 -0
  122. edsl/questions/templates/multiple_choice/question_presentation.jinja +12 -0
  123. edsl/questions/templates/numerical/__init__.py +0 -0
  124. edsl/questions/templates/numerical/answering_instructions.jinja +8 -0
  125. edsl/questions/templates/numerical/question_presentation.jinja +7 -0
  126. edsl/questions/templates/rank/__init__.py +0 -0
  127. edsl/questions/templates/rank/answering_instructions.jinja +11 -0
  128. edsl/questions/templates/rank/question_presentation.jinja +15 -0
  129. edsl/questions/templates/top_k/__init__.py +0 -0
  130. edsl/questions/templates/top_k/answering_instructions.jinja +8 -0
  131. edsl/questions/templates/top_k/question_presentation.jinja +22 -0
  132. edsl/questions/templates/yes_no/__init__.py +0 -0
  133. edsl/questions/templates/yes_no/answering_instructions.jinja +6 -0
  134. edsl/questions/templates/yes_no/question_presentation.jinja +12 -0
  135. edsl/results/Dataset.py +20 -0
  136. edsl/results/DatasetExportMixin.py +46 -48
  137. edsl/results/DatasetTree.py +145 -0
  138. edsl/results/Result.py +32 -5
  139. edsl/results/Results.py +135 -46
  140. edsl/results/ResultsDBMixin.py +3 -3
  141. edsl/results/Selector.py +118 -0
  142. edsl/results/tree_explore.py +115 -0
  143. edsl/scenarios/FileStore.py +71 -10
  144. edsl/scenarios/Scenario.py +96 -25
  145. edsl/scenarios/ScenarioImageMixin.py +2 -2
  146. edsl/scenarios/ScenarioList.py +361 -39
  147. edsl/scenarios/ScenarioListExportMixin.py +9 -0
  148. edsl/scenarios/ScenarioListPdfMixin.py +150 -4
  149. edsl/study/SnapShot.py +8 -1
  150. edsl/study/Study.py +32 -0
  151. edsl/surveys/Rule.py +10 -1
  152. edsl/surveys/RuleCollection.py +21 -5
  153. edsl/surveys/Survey.py +637 -311
  154. edsl/surveys/SurveyExportMixin.py +71 -9
  155. edsl/surveys/SurveyFlowVisualizationMixin.py +2 -1
  156. edsl/surveys/SurveyQualtricsImport.py +75 -4
  157. edsl/surveys/instructions/ChangeInstruction.py +47 -0
  158. edsl/surveys/instructions/Instruction.py +34 -0
  159. edsl/surveys/instructions/InstructionCollection.py +77 -0
  160. edsl/surveys/instructions/__init__.py +0 -0
  161. edsl/templates/error_reporting/base.html +24 -0
  162. edsl/templates/error_reporting/exceptions_by_model.html +35 -0
  163. edsl/templates/error_reporting/exceptions_by_question_name.html +17 -0
  164. edsl/templates/error_reporting/exceptions_by_type.html +17 -0
  165. edsl/templates/error_reporting/interview_details.html +116 -0
  166. edsl/templates/error_reporting/interviews.html +10 -0
  167. edsl/templates/error_reporting/overview.html +5 -0
  168. edsl/templates/error_reporting/performance_plot.html +2 -0
  169. edsl/templates/error_reporting/report.css +74 -0
  170. edsl/templates/error_reporting/report.html +118 -0
  171. edsl/templates/error_reporting/report.js +25 -0
  172. edsl/utilities/utilities.py +9 -1
  173. {edsl-0.1.32.dist-info → edsl-0.1.33.dist-info}/METADATA +5 -2
  174. edsl-0.1.33.dist-info/RECORD +295 -0
  175. edsl/jobs/interviews/InterviewTaskBuildingMixin.py +0 -286
  176. edsl/jobs/interviews/retry_management.py +0 -37
  177. edsl/jobs/runners/JobsRunnerStatusMixin.py +0 -333
  178. edsl/utilities/gcp_bucket/simple_example.py +0 -9
  179. edsl-0.1.32.dist-info/RECORD +0 -209
  180. {edsl-0.1.32.dist-info → edsl-0.1.33.dist-info}/LICENSE +0 -0
  181. {edsl-0.1.32.dist-info → edsl-0.1.33.dist-info}/WHEEL +0 -0
edsl/config.py CHANGED
@@ -1,83 +1,66 @@
1
1
  """This module provides a Config class that loads environment variables from a .env file and sets them as class attributes."""
2
2
 
3
3
  import os
4
+ from dotenv import load_dotenv, find_dotenv
4
5
  from edsl.exceptions import (
5
6
  InvalidEnvironmentVariableError,
6
7
  MissingEnvironmentVariableError,
7
8
  )
8
- from dotenv import load_dotenv, find_dotenv
9
9
 
10
10
  # valid values for EDSL_RUN_MODE
11
- EDSL_RUN_MODES = ["development", "development-testrun", "production"]
11
+ EDSL_RUN_MODES = [
12
+ "development",
13
+ "development-testrun",
14
+ "production",
15
+ ]
12
16
 
13
17
  # `default` is used to impute values only in "production" mode
14
18
  # `info` gives a brief description of the env var
15
19
  CONFIG_MAP = {
16
20
  "EDSL_RUN_MODE": {
17
21
  "default": "production",
18
- "info": "This env var determines the run mode of the application.",
19
- },
20
- "EDSL_DATABASE_PATH": {
21
- "default": f"sqlite:///{os.path.join(os.getcwd(), '.edsl_cache/data.db')}",
22
- "info": "This env var determines the path to the cache file.",
23
- },
24
- "EDSL_LOGGING_PATH": {
25
- "default": f"{os.path.join(os.getcwd(), 'interview.log')}",
26
- "info": "This env var determines the path to the log file.",
22
+ "info": "This config var determines the run mode of the application.",
27
23
  },
28
24
  "EDSL_API_TIMEOUT": {
29
25
  "default": "60",
30
- "info": "This env var determines the maximum number of seconds to wait for an API call to return.",
26
+ "info": "This config var determines the maximum number of seconds to wait for an API call to return.",
31
27
  },
32
28
  "EDSL_BACKOFF_START_SEC": {
33
29
  "default": "1",
34
- "info": "This env var determines the number of seconds to wait before retrying a failed API call.",
30
+ "info": "This config var determines the number of seconds to wait before retrying a failed API call.",
35
31
  },
36
- "EDSL_MAX_BACKOFF_SEC": {
32
+ "EDSL_BACKOFF_MAX_SEC": {
37
33
  "default": "60",
38
- "info": "This env var determines the maximum number of seconds to wait before retrying a failed API call.",
34
+ "info": "This config var determines the maximum number of seconds to wait before retrying a failed API call.",
35
+ },
36
+ "EDSL_DATABASE_PATH": {
37
+ "default": f"sqlite:///{os.path.join(os.getcwd(), '.edsl_cache/data.db')}",
38
+ "info": "This config var determines the path to the cache file.",
39
+ },
40
+ "EDSL_DEFAULT_MODEL": {
41
+ "default": "gpt-4o",
42
+ "info": "This config var holds the default model that will be used if a model is not explicitly passed.",
43
+ },
44
+ "EDSL_FETCH_TOKEN_PRICES": {
45
+ "default": "True",
46
+ "info": "This config var determines whether to fetch prices for tokens used in remote inference",
39
47
  },
40
48
  "EDSL_MAX_ATTEMPTS": {
41
49
  "default": "5",
42
- "info": "This env var determines the maximum number of times to retry a failed API call.",
50
+ "info": "This config var determines the maximum number of times to retry a failed API call.",
51
+ },
52
+ "EDSL_SERVICE_RPM_BASELINE": {
53
+ "default": "100",
54
+ "info": "This config var holds the maximum number of requests per minute. Model-specific values provided in env vars such as EDSL_SERVICE_RPM_OPENAI will override this. value for the corresponding model",
55
+ },
56
+ "EDSL_SERVICE_TPM_BASELINE": {
57
+ "default": "2000000",
58
+ "info": "This config var holds the maximum number of tokens per minute for all models. Model-specific values provided in env vars such as EDSL_SERVICE_TPM_OPENAI will override this value for the corresponding model.",
43
59
  },
44
60
  "EXPECTED_PARROT_URL": {
45
61
  "default": "https://www.expectedparrot.com",
46
- "info": "This env var holds the URL of the Expected Parrot API.",
62
+ "info": "This config var holds the URL of the Expected Parrot API.",
47
63
  },
48
- # "EXPECTED_PARROT_API_KEY": {
49
- # "default": None,
50
- # "info": "This env var holds your Expected Parrot API key (https://www.expectedparrot.com/).",
51
- # },
52
- # "OPENAI_API_KEY": {
53
- # "default": None,
54
- # "info": "This env var holds your OpenAI API key (https://platform.openai.com/api-keys).",
55
- # },
56
- # "DEEP_INFRA_API_KEY": {
57
- # "default": None,
58
- # "info": "This env var holds your DeepInfra API key (https://deepinfra.com/).",
59
- # },
60
- # "GOOGLE_API_KEY": {
61
- # "default": None,
62
- # "info": "This env var holds your Google API key (https://console.cloud.google.com/apis/credentials).",
63
- # },
64
- # "ANTHROPIC_API_KEY": {
65
- # "default": None,
66
- # "info": "This env var holds your Anthropic API key (https://www.anthropic.com/).",
67
- # },
68
- # "GROQ_API_KEY": {
69
- # "default": None,
70
- # "info": "This env var holds your GROQ API key (https://console.groq.com/login).",
71
- # },
72
- # "AWS_ACCESS_KEY_ID" :
73
- # "default": None,
74
- # "info": "This env var holds your AWS access key ID.",
75
- # "AWS_SECRET_ACCESS_KEY:
76
- # "default": None,
77
- # "info": "This env var holds your AWS secret access key.",
78
- # "AZURE_ENDPOINT_URL_AND_KEY":
79
- # "default": None,
80
- # "info": "This env var holds your Azure endpoint URL and key (URL:key). You can have several comma-separated URL-key pairs (URL1:key1,URL2:key2).",
81
64
  }
82
65
 
83
66
 
@@ -92,7 +75,7 @@ class Config:
92
75
 
93
76
  def _set_run_mode(self) -> None:
94
77
  """
95
- Checks the validity and sets EDSL_RUN_MODE.
78
+ Sets EDSL_RUN_MODE as a class attribute.
96
79
  """
97
80
  run_mode = os.getenv("EDSL_RUN_MODE")
98
81
  default = CONFIG_MAP.get("EDSL_RUN_MODE").get("default")
@@ -107,27 +90,35 @@ class Config:
107
90
  def _load_dotenv(self) -> None:
108
91
  """
109
92
  Loads the .env
110
- - Overrides existing env vars unless EDSL_RUN_MODE=="development-testrun"
93
+ - The .env will override existing env vars **unless** EDSL_RUN_MODE=="development-testrun"
111
94
  """
112
95
 
113
- override = True
114
96
  if self.EDSL_RUN_MODE == "development-testrun":
115
97
  override = False
98
+ else:
99
+ override = True
116
100
  _ = load_dotenv(dotenv_path=find_dotenv(usecwd=True), override=override)
117
101
 
102
+ def __contains__(self, env_var: str) -> bool:
103
+ """
104
+ Checks if an env var is set as a class attribute.
105
+ """
106
+ return env_var in self.__dict__
107
+
118
108
  def _set_env_vars(self) -> None:
119
109
  """
120
- Sets env vars as Config class attributes.
110
+ Sets env vars as class attributes.
111
+ - EDSL_RUN_MODE is not set my this method, but by _set_run_mode
121
112
  - If an env var is not set and has a default value in the CONFIG_MAP, sets it to the default value.
122
113
  """
123
114
  # for each env var in the CONFIG_MAP
124
115
  for env_var, config in CONFIG_MAP.items():
125
- # we've set it already in _set_run_mode
116
+ # EDSL_RUN_MODE is already set by _set_run_mode
126
117
  if env_var == "EDSL_RUN_MODE":
127
118
  continue
128
119
  value = os.getenv(env_var)
129
120
  default_value = config.get("default")
130
- # if the env var is set, set it as a CONFIG attribute
121
+ # if an env var exists, set it as a class attribute
131
122
  if value:
132
123
  setattr(self, env_var, value)
133
124
  # otherwise, if EDSL_RUN_MODE == "production" set it to its default value
@@ -0,0 +1,58 @@
1
+ import requests
2
+ import csv
3
+ from io import StringIO
4
+
5
+
6
+ class PriceFetcher:
7
+ _instance = None
8
+
9
+ def __new__(cls):
10
+ if cls._instance is None:
11
+ cls._instance = super(PriceFetcher, cls).__new__(cls)
12
+ cls._instance._cached_prices = None
13
+ return cls._instance
14
+
15
+ def fetch_prices(self):
16
+ if self._cached_prices is not None:
17
+ return self._cached_prices
18
+
19
+ import requests
20
+ import csv
21
+ from io import StringIO
22
+
23
+ sheet_id = "1SAO3Bhntefl0XQHJv27rMxpvu6uzKDWNXFHRa7jrUDs"
24
+
25
+ # Construct the URL to fetch the CSV
26
+ url = f"https://docs.google.com/spreadsheets/d/{sheet_id}/export?format=csv"
27
+
28
+ try:
29
+ # Fetch the CSV data
30
+ response = requests.get(url)
31
+ response.raise_for_status() # Raise an exception for bad responses
32
+
33
+ # Parse the CSV data
34
+ csv_data = StringIO(response.text)
35
+ reader = csv.reader(csv_data)
36
+
37
+ # Convert to list of dictionaries
38
+ headers = next(reader)
39
+ data = [dict(zip(headers, row)) for row in reader]
40
+
41
+ # self._cached_prices = data
42
+ # return data
43
+ price_lookup = {}
44
+ for entry in data:
45
+ service = entry.get("service", None)
46
+ model = entry.get("model", None)
47
+ if service and model:
48
+ token_type = entry.get("token_type", None)
49
+ if (service, model) in price_lookup:
50
+ price_lookup[(service, model)].update({token_type: entry})
51
+ else:
52
+ price_lookup[(service, model)] = {token_type: entry}
53
+ self._cached_prices = price_lookup
54
+ return self._cached_prices
55
+
56
+ except requests.RequestException as e:
57
+ # print(f"An error occurred: {e}")
58
+ return {}
edsl/coop/coop.py CHANGED
@@ -53,25 +53,39 @@ class Coop:
53
53
  method: str,
54
54
  payload: Optional[dict[str, Any]] = None,
55
55
  params: Optional[dict[str, Any]] = None,
56
+ timeout: Optional[float] = 5,
56
57
  ) -> requests.Response:
57
58
  """
58
59
  Send a request to the server and return the response.
59
60
  """
60
61
  url = f"{self.url}/{uri}"
62
+ method = method.upper()
63
+ if payload is None:
64
+ timeout = 20
65
+ elif (
66
+ method.upper() == "POST"
67
+ and "json_string" in payload
68
+ and payload.get("json_string") is not None
69
+ ):
70
+ timeout = max(20, (len(payload.get("json_string", "")) // (1024 * 1024)))
61
71
  try:
62
- method = method.upper()
63
72
  if method in ["GET", "DELETE"]:
64
73
  response = requests.request(
65
- method, url, params=params, headers=self.headers
74
+ method, url, params=params, headers=self.headers, timeout=timeout
66
75
  )
67
76
  elif method in ["POST", "PATCH"]:
68
77
  response = requests.request(
69
- method, url, params=params, json=payload, headers=self.headers
78
+ method,
79
+ url,
80
+ params=params,
81
+ json=payload,
82
+ headers=self.headers,
83
+ timeout=timeout,
70
84
  )
71
85
  else:
72
86
  raise Exception(f"Invalid {method=}.")
73
87
  except requests.ConnectionError:
74
- raise requests.ConnectionError("Could not connect to the server.")
88
+ raise requests.ConnectionError(f"Could not connect to the server at {url}.")
75
89
 
76
90
  return response
77
91
 
@@ -81,6 +95,7 @@ class Coop:
81
95
  """
82
96
  if response.status_code >= 400:
83
97
  message = response.json().get("detail")
98
+ # print(response.text)
84
99
  if "Authorization" in message:
85
100
  print(message)
86
101
  message = "Please provide an Expected Parrot API key."
@@ -110,10 +125,18 @@ class Coop:
110
125
  def edsl_settings(self) -> dict:
111
126
  """
112
127
  Retrieve and return the EDSL settings stored on Coop.
128
+ If no response is received within 5 seconds, return an empty dict.
113
129
  """
114
- response = self._send_server_request(uri="api/v0/edsl-settings", method="GET")
115
- self._resolve_server_response(response)
116
- return response.json()
130
+ from requests.exceptions import Timeout
131
+
132
+ try:
133
+ response = self._send_server_request(
134
+ uri="api/v0/edsl-settings", method="GET", timeout=5
135
+ )
136
+ self._resolve_server_response(response)
137
+ return response.json()
138
+ except Timeout:
139
+ return {}
117
140
 
118
141
  ################
119
142
  # Objects
@@ -625,6 +648,26 @@ class Coop:
625
648
 
626
649
  return response_json
627
650
 
651
+ def fetch_prices(self) -> dict:
652
+ from edsl.coop.PriceFetcher import PriceFetcher
653
+
654
+ from edsl.config import CONFIG
655
+
656
+ if bool(CONFIG.get("EDSL_FETCH_TOKEN_PRICES")):
657
+ price_fetcher = PriceFetcher()
658
+ return price_fetcher.fetch_prices()
659
+ else:
660
+ return {}
661
+
662
+
663
+ if __name__ == "__main__":
664
+ sheet_data = fetch_sheet_data()
665
+ if sheet_data:
666
+ print(f"Successfully fetched {len(sheet_data)} rows of data.")
667
+ print("First row:", sheet_data[0])
668
+ else:
669
+ print("Failed to fetch sheet data.")
670
+
628
671
 
629
672
  def main():
630
673
  """
edsl/data/Cache.py CHANGED
@@ -6,6 +6,7 @@ from __future__ import annotations
6
6
  import json
7
7
  import os
8
8
  import warnings
9
+ import copy
9
10
  from typing import Optional, Union
10
11
  from edsl.Base import Base
11
12
  from edsl.data.CacheEntry import CacheEntry
@@ -88,11 +89,24 @@ class Cache(Base):
88
89
  # raise NotImplementedError("This method is not implemented yet.")
89
90
 
90
91
  def keys(self):
92
+ """
93
+ >>> from edsl import Cache
94
+ >>> Cache.example().keys()
95
+ ['5513286eb6967abc0511211f0402587d']
96
+ """
91
97
  return list(self.data.keys())
92
98
 
93
99
  def values(self):
100
+ """
101
+ >>> from edsl import Cache
102
+ >>> Cache.example().values()
103
+ [CacheEntry(...)]
104
+ """
94
105
  return list(self.data.values())
95
106
 
107
+ def items(self):
108
+ return zip(self.keys(), self.values())
109
+
96
110
  def new_entries_cache(self) -> Cache:
97
111
  """Return a new Cache object with the new entries."""
98
112
  return Cache(data={**self.new_entries, **self.fetched_data})
@@ -160,7 +174,7 @@ class Cache(Base):
160
174
  parameters: str,
161
175
  system_prompt: str,
162
176
  user_prompt: str,
163
- response: str,
177
+ response: dict,
164
178
  iteration: int,
165
179
  ) -> str:
166
180
  """
@@ -174,6 +188,15 @@ class Cache(Base):
174
188
  * The key-value pair is added to `self.new_entries`
175
189
  * If `immediate_write` is True , the key-value pair is added to `self.data`
176
190
  * If `immediate_write` is False, the key-value pair is added to `self.new_entries_to_write_later`
191
+
192
+ >>> from edsl import Cache, Model, Question
193
+ >>> m = Model("test")
194
+ >>> c = Cache()
195
+ >>> len(c)
196
+ 0
197
+ >>> results = Question.example("free_text").by(m).run(cache = c)
198
+ >>> len(c)
199
+ 1
177
200
  """
178
201
 
179
202
  entry = CacheEntry(
@@ -326,6 +349,17 @@ class Cache(Base):
326
349
  for key, value in self.data.items():
327
350
  f.write(json.dumps({key: value.to_dict()}) + "\n")
328
351
 
352
+ def to_scenario_list(self):
353
+ from edsl import ScenarioList, Scenario
354
+
355
+ scenarios = []
356
+ for key, value in self.data.items():
357
+ new_d = value.to_dict()
358
+ new_d["cache_key"] = key
359
+ s = Scenario(new_d)
360
+ scenarios.append(s)
361
+ return ScenarioList(scenarios)
362
+
329
363
  ####################
330
364
  # REMOTE
331
365
  ####################
@@ -1,38 +1,73 @@
1
- """This module contains the data transfer models for the application."""
2
-
3
- from collections import UserDict
4
-
5
-
6
- class AgentResponseDict(UserDict):
7
- """A dictionary to store the response of the agent to a question."""
8
-
9
- def __init__(
10
- self,
11
- *,
12
- question_name,
13
- answer,
14
- prompts,
15
- usage=None,
16
- comment=None,
17
- cached_response=None,
18
- raw_model_response=None,
19
- simple_model_raw_response=None,
20
- cache_used=None,
21
- cache_key=None,
22
- ):
23
- """Initialize the AgentResponseDict object."""
24
- usage = usage or {"prompt_tokens": 0, "completion_tokens": 0}
25
- super().__init__(
26
- {
27
- "answer": answer,
28
- "comment": comment,
29
- "question_name": question_name,
30
- "prompts": prompts,
31
- "usage": usage,
32
- "cached_response": cached_response,
33
- "raw_model_response": raw_model_response,
34
- "simple_model_raw_response": simple_model_raw_response,
35
- "cache_used": cache_used,
36
- "cache_key": cache_key,
37
- }
38
- )
1
+ from typing import NamedTuple, Dict, List, Optional, Any
2
+ from dataclasses import dataclass, fields
3
+ import reprlib
4
+
5
+
6
+ class ModelInputs(NamedTuple):
7
+ "This is what was send by the agent to the model"
8
+ user_prompt: str
9
+ system_prompt: str
10
+ encoded_image: Optional[str] = None
11
+
12
+
13
+ class EDSLOutput(NamedTuple):
14
+ "This is the edsl dictionary that is returned by the model"
15
+ answer: Any
16
+ generated_tokens: str
17
+ comment: Optional[str] = None
18
+
19
+
20
+ class ModelResponse(NamedTuple):
21
+ "This is the metadata that is returned by the model and includes info about the cache"
22
+ response: dict
23
+ cache_used: bool
24
+ cache_key: str
25
+ cached_response: Optional[Dict[str, Any]] = None
26
+ cost: Optional[float] = None
27
+
28
+
29
+ class AgentResponseDict(NamedTuple):
30
+ edsl_dict: EDSLOutput
31
+ model_inputs: ModelInputs
32
+ model_outputs: ModelResponse
33
+
34
+
35
+ class EDSLResultObjectInput(NamedTuple):
36
+ generated_tokens: str
37
+ question_name: str
38
+ prompts: dict
39
+ cached_response: str
40
+ raw_model_response: str
41
+ cache_used: bool
42
+ cache_key: str
43
+ answer: Any
44
+ comment: str
45
+ validated: bool = False
46
+ exception_occurred: Exception = None
47
+ cost: Optional[float] = None
48
+
49
+
50
+ @dataclass
51
+ class ImageInfo:
52
+ file_path: str
53
+ file_name: str
54
+ image_format: str
55
+ file_size: int
56
+ encoded_image: str
57
+
58
+ def __repr__(self):
59
+ reprlib_instance = reprlib.Repr()
60
+ reprlib_instance.maxstring = 30 # Limit the string length for the encoded image
61
+
62
+ # Get all fields except encoded_image
63
+ field_reprs = [
64
+ f"{f.name}={getattr(self, f.name)!r}"
65
+ for f in fields(self)
66
+ if f.name != "encoded_image"
67
+ ]
68
+
69
+ # Add the reprlib-restricted encoded_image field
70
+ field_reprs.append(f"encoded_image={reprlib_instance.repr(self.encoded_image)}")
71
+
72
+ # Join everything to create the repr
73
+ return f"{self.__class__.__name__}({', '.join(field_reprs)})"
edsl/enums.py CHANGED
@@ -62,6 +62,8 @@ class InferenceServiceType(EnumWithChecks):
62
62
  GROQ = "groq"
63
63
  AZURE = "azure"
64
64
  OLLAMA = "ollama"
65
+ MISTRAL = "mistral"
66
+ TOGETHER = "together"
65
67
 
66
68
 
67
69
  service_to_api_keyname = {
@@ -74,6 +76,8 @@ service_to_api_keyname = {
74
76
  InferenceServiceType.ANTHROPIC.value: "ANTHROPIC_API_KEY",
75
77
  InferenceServiceType.GROQ.value: "GROQ_API_KEY",
76
78
  InferenceServiceType.BEDROCK.value: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"],
79
+ InferenceServiceType.MISTRAL.value: "MISTRAL_API_KEY",
80
+ InferenceServiceType.TOGETHER.value: "TOGETHER_API_KEY",
77
81
  }
78
82
 
79
83
 
@@ -1,8 +1,32 @@
1
1
  from textwrap import dedent
2
+ from typing import Optional
2
3
 
3
4
 
4
5
  class LanguageModelExceptions(Exception):
5
- pass
6
+ explanation = (
7
+ "This is the base class for all exceptions in the LanguageModel class."
8
+ )
9
+
10
+ def __init__(self, message):
11
+ super().__init__(message)
12
+ self.message = message
13
+
14
+
15
+ class LanguageModelNoResponseError(LanguageModelExceptions):
16
+ explanation = (
17
+ """This happens when the LLM API cannot be reached and/or does not respond."""
18
+ )
19
+
20
+ def __init__(self, message):
21
+ super().__init__(message)
22
+
23
+
24
+ class LanguageModelBadResponseError(LanguageModelExceptions):
25
+ explanation = """This happens when the LLM API can be reached and responds, does not return a usable answer."""
26
+
27
+ def __init__(self, message, response_json: Optional[dict] = None):
28
+ super().__init__(message)
29
+ self.response_json = response_json
6
30
 
7
31
 
8
32
  class LanguageModelNotFound(LanguageModelExceptions):
@@ -1,16 +1,73 @@
1
+ from typing import Any, SupportsIndex
2
+ from jinja2 import Template
3
+ import json
4
+
5
+
1
6
  class QuestionErrors(Exception):
2
- pass
7
+ """
8
+ Base exception class for question-related errors.
9
+ """
3
10
 
11
+ def __init__(self, message="An error occurred with the question"):
12
+ self.message = message
13
+ super().__init__(self.message)
4
14
 
5
- class QuestionCreationValidationError(QuestionErrors):
6
- pass
7
15
 
16
+ class QuestionAnswerValidationError(QuestionErrors):
17
+ documentation = "https://docs.expectedparrot.com/en/latest/exceptions.html"
18
+
19
+ explanation = """This when the answer coming from the Language Model does not conform to the expectation for that question type.
20
+ For example, if the question is a multiple choice question, the answer should be drawn from the list of options provided.
21
+ """
22
+
23
+ def __init__(self, message="Invalid answer.", data=None, model=None):
24
+ self.message = message
25
+ self.data = data
26
+ self.model = model
27
+ super().__init__(self.message)
28
+
29
+ def __str__(self):
30
+ return f"""{repr(self)}
31
+ Data being validated: {self.data}
32
+ Pydnantic Model: {self.model}.
33
+ Reported error: {self.message}."""
34
+
35
+ def to_html_dict(self):
36
+ return {
37
+ "error_type": ("Name of the exception", "p", "/p", self.__class__.__name__),
38
+ "explaination": ("Explanation", "p", "/p", self.explanation),
39
+ "edsl answer": (
40
+ "What model returned",
41
+ "pre",
42
+ "/pre",
43
+ json.dumps(self.data, indent=2),
44
+ ),
45
+ "validating_model": (
46
+ "Pydantic model for answers",
47
+ "pre",
48
+ "/pre",
49
+ json.dumps(self.model.model_json_schema(), indent=2),
50
+ ),
51
+ "error_message": (
52
+ "Error message Pydantic returned",
53
+ "p",
54
+ "/p",
55
+ self.message,
56
+ ),
57
+ "documentation_url": (
58
+ "URL to EDSL docs",
59
+ f"a href='{self.documentation}'",
60
+ "/a",
61
+ self.documentation,
62
+ ),
63
+ }
8
64
 
9
- class QuestionResponseValidationError(QuestionErrors):
65
+
66
+ class QuestionCreationValidationError(QuestionErrors):
10
67
  pass
11
68
 
12
69
 
13
- class QuestionAnswerValidationError(QuestionErrors):
70
+ class QuestionResponseValidationError(QuestionErrors):
14
71
  pass
15
72
 
16
73
 
@@ -2,6 +2,10 @@ class ResultsErrors(Exception):
2
2
  pass
3
3
 
4
4
 
5
+ class ResultsDeserializationError(ResultsErrors):
6
+ pass
7
+
8
+
5
9
  class ResultsBadMutationstringError(ResultsErrors):
6
10
  pass
7
11