edsl 0.1.33__py3-none-any.whl → 0.1.33.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. edsl/Base.py +3 -9
  2. edsl/__init__.py +3 -8
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +8 -40
  5. edsl/agents/AgentList.py +0 -43
  6. edsl/agents/Invigilator.py +219 -135
  7. edsl/agents/InvigilatorBase.py +59 -148
  8. edsl/agents/{PromptConstructor.py → PromptConstructionMixin.py} +89 -138
  9. edsl/agents/__init__.py +0 -1
  10. edsl/config.py +56 -47
  11. edsl/coop/coop.py +7 -50
  12. edsl/data/Cache.py +1 -35
  13. edsl/data_transfer_models.py +38 -73
  14. edsl/enums.py +0 -4
  15. edsl/exceptions/language_models.py +1 -25
  16. edsl/exceptions/questions.py +5 -62
  17. edsl/exceptions/results.py +0 -4
  18. edsl/inference_services/AnthropicService.py +11 -13
  19. edsl/inference_services/AwsBedrock.py +17 -19
  20. edsl/inference_services/AzureAI.py +20 -37
  21. edsl/inference_services/GoogleService.py +12 -16
  22. edsl/inference_services/GroqService.py +0 -2
  23. edsl/inference_services/InferenceServiceABC.py +3 -58
  24. edsl/inference_services/OpenAIService.py +54 -48
  25. edsl/inference_services/models_available_cache.py +6 -0
  26. edsl/inference_services/registry.py +0 -6
  27. edsl/jobs/Answers.py +12 -10
  28. edsl/jobs/Jobs.py +21 -36
  29. edsl/jobs/buckets/BucketCollection.py +15 -24
  30. edsl/jobs/buckets/TokenBucket.py +14 -93
  31. edsl/jobs/interviews/Interview.py +78 -366
  32. edsl/jobs/interviews/InterviewExceptionEntry.py +19 -85
  33. edsl/jobs/interviews/InterviewTaskBuildingMixin.py +286 -0
  34. edsl/jobs/interviews/{InterviewExceptionCollection.py → interview_exception_tracking.py} +68 -14
  35. edsl/jobs/interviews/retry_management.py +37 -0
  36. edsl/jobs/runners/JobsRunnerAsyncio.py +175 -146
  37. edsl/jobs/runners/JobsRunnerStatusMixin.py +333 -0
  38. edsl/jobs/tasks/QuestionTaskCreator.py +23 -30
  39. edsl/jobs/tasks/TaskHistory.py +213 -148
  40. edsl/language_models/LanguageModel.py +156 -261
  41. edsl/language_models/ModelList.py +2 -2
  42. edsl/language_models/RegisterLanguageModelsMeta.py +29 -14
  43. edsl/language_models/registry.py +6 -23
  44. edsl/language_models/repair.py +19 -0
  45. edsl/prompts/Prompt.py +2 -52
  46. edsl/questions/AnswerValidatorMixin.py +26 -23
  47. edsl/questions/QuestionBase.py +249 -329
  48. edsl/questions/QuestionBudget.py +41 -99
  49. edsl/questions/QuestionCheckBox.py +35 -227
  50. edsl/questions/QuestionExtract.py +27 -98
  51. edsl/questions/QuestionFreeText.py +29 -52
  52. edsl/questions/QuestionFunctional.py +0 -7
  53. edsl/questions/QuestionList.py +22 -141
  54. edsl/questions/QuestionMultipleChoice.py +65 -159
  55. edsl/questions/QuestionNumerical.py +46 -88
  56. edsl/questions/QuestionRank.py +24 -182
  57. edsl/questions/RegisterQuestionsMeta.py +12 -31
  58. edsl/questions/__init__.py +4 -3
  59. edsl/questions/derived/QuestionLikertFive.py +5 -10
  60. edsl/questions/derived/QuestionLinearScale.py +2 -15
  61. edsl/questions/derived/QuestionTopK.py +1 -10
  62. edsl/questions/derived/QuestionYesNo.py +3 -24
  63. edsl/questions/descriptors.py +7 -43
  64. edsl/questions/question_registry.py +2 -6
  65. edsl/results/Dataset.py +0 -20
  66. edsl/results/DatasetExportMixin.py +48 -46
  67. edsl/results/Result.py +5 -32
  68. edsl/results/Results.py +46 -135
  69. edsl/results/ResultsDBMixin.py +3 -3
  70. edsl/scenarios/FileStore.py +10 -71
  71. edsl/scenarios/Scenario.py +25 -96
  72. edsl/scenarios/ScenarioImageMixin.py +2 -2
  73. edsl/scenarios/ScenarioList.py +39 -361
  74. edsl/scenarios/ScenarioListExportMixin.py +0 -9
  75. edsl/scenarios/ScenarioListPdfMixin.py +4 -150
  76. edsl/study/SnapShot.py +1 -8
  77. edsl/study/Study.py +0 -32
  78. edsl/surveys/Rule.py +1 -10
  79. edsl/surveys/RuleCollection.py +5 -21
  80. edsl/surveys/Survey.py +310 -636
  81. edsl/surveys/SurveyExportMixin.py +9 -71
  82. edsl/surveys/SurveyFlowVisualizationMixin.py +1 -2
  83. edsl/surveys/SurveyQualtricsImport.py +4 -75
  84. edsl/utilities/gcp_bucket/simple_example.py +9 -0
  85. edsl/utilities/utilities.py +1 -9
  86. {edsl-0.1.33.dist-info → edsl-0.1.33.dev1.dist-info}/METADATA +2 -5
  87. edsl-0.1.33.dev1.dist-info/RECORD +209 -0
  88. edsl/TemplateLoader.py +0 -24
  89. edsl/auto/AutoStudy.py +0 -117
  90. edsl/auto/StageBase.py +0 -230
  91. edsl/auto/StageGenerateSurvey.py +0 -178
  92. edsl/auto/StageLabelQuestions.py +0 -125
  93. edsl/auto/StagePersona.py +0 -61
  94. edsl/auto/StagePersonaDimensionValueRanges.py +0 -88
  95. edsl/auto/StagePersonaDimensionValues.py +0 -74
  96. edsl/auto/StagePersonaDimensions.py +0 -69
  97. edsl/auto/StageQuestions.py +0 -73
  98. edsl/auto/SurveyCreatorPipeline.py +0 -21
  99. edsl/auto/utilities.py +0 -224
  100. edsl/coop/PriceFetcher.py +0 -58
  101. edsl/inference_services/MistralAIService.py +0 -120
  102. edsl/inference_services/TestService.py +0 -80
  103. edsl/inference_services/TogetherAIService.py +0 -170
  104. edsl/jobs/FailedQuestion.py +0 -78
  105. edsl/jobs/runners/JobsRunnerStatus.py +0 -331
  106. edsl/language_models/fake_openai_call.py +0 -15
  107. edsl/language_models/fake_openai_service.py +0 -61
  108. edsl/language_models/utilities.py +0 -61
  109. edsl/questions/QuestionBaseGenMixin.py +0 -133
  110. edsl/questions/QuestionBasePromptsMixin.py +0 -266
  111. edsl/questions/Quick.py +0 -41
  112. edsl/questions/ResponseValidatorABC.py +0 -170
  113. edsl/questions/decorators.py +0 -21
  114. edsl/questions/prompt_templates/question_budget.jinja +0 -13
  115. edsl/questions/prompt_templates/question_checkbox.jinja +0 -32
  116. edsl/questions/prompt_templates/question_extract.jinja +0 -11
  117. edsl/questions/prompt_templates/question_free_text.jinja +0 -3
  118. edsl/questions/prompt_templates/question_linear_scale.jinja +0 -11
  119. edsl/questions/prompt_templates/question_list.jinja +0 -17
  120. edsl/questions/prompt_templates/question_multiple_choice.jinja +0 -33
  121. edsl/questions/prompt_templates/question_numerical.jinja +0 -37
  122. edsl/questions/templates/__init__.py +0 -0
  123. edsl/questions/templates/budget/__init__.py +0 -0
  124. edsl/questions/templates/budget/answering_instructions.jinja +0 -7
  125. edsl/questions/templates/budget/question_presentation.jinja +0 -7
  126. edsl/questions/templates/checkbox/__init__.py +0 -0
  127. edsl/questions/templates/checkbox/answering_instructions.jinja +0 -10
  128. edsl/questions/templates/checkbox/question_presentation.jinja +0 -22
  129. edsl/questions/templates/extract/__init__.py +0 -0
  130. edsl/questions/templates/extract/answering_instructions.jinja +0 -7
  131. edsl/questions/templates/extract/question_presentation.jinja +0 -1
  132. edsl/questions/templates/free_text/__init__.py +0 -0
  133. edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
  134. edsl/questions/templates/free_text/question_presentation.jinja +0 -1
  135. edsl/questions/templates/likert_five/__init__.py +0 -0
  136. edsl/questions/templates/likert_five/answering_instructions.jinja +0 -10
  137. edsl/questions/templates/likert_five/question_presentation.jinja +0 -12
  138. edsl/questions/templates/linear_scale/__init__.py +0 -0
  139. edsl/questions/templates/linear_scale/answering_instructions.jinja +0 -5
  140. edsl/questions/templates/linear_scale/question_presentation.jinja +0 -5
  141. edsl/questions/templates/list/__init__.py +0 -0
  142. edsl/questions/templates/list/answering_instructions.jinja +0 -4
  143. edsl/questions/templates/list/question_presentation.jinja +0 -5
  144. edsl/questions/templates/multiple_choice/__init__.py +0 -0
  145. edsl/questions/templates/multiple_choice/answering_instructions.jinja +0 -9
  146. edsl/questions/templates/multiple_choice/html.jinja +0 -0
  147. edsl/questions/templates/multiple_choice/question_presentation.jinja +0 -12
  148. edsl/questions/templates/numerical/__init__.py +0 -0
  149. edsl/questions/templates/numerical/answering_instructions.jinja +0 -8
  150. edsl/questions/templates/numerical/question_presentation.jinja +0 -7
  151. edsl/questions/templates/rank/__init__.py +0 -0
  152. edsl/questions/templates/rank/answering_instructions.jinja +0 -11
  153. edsl/questions/templates/rank/question_presentation.jinja +0 -15
  154. edsl/questions/templates/top_k/__init__.py +0 -0
  155. edsl/questions/templates/top_k/answering_instructions.jinja +0 -8
  156. edsl/questions/templates/top_k/question_presentation.jinja +0 -22
  157. edsl/questions/templates/yes_no/__init__.py +0 -0
  158. edsl/questions/templates/yes_no/answering_instructions.jinja +0 -6
  159. edsl/questions/templates/yes_no/question_presentation.jinja +0 -12
  160. edsl/results/DatasetTree.py +0 -145
  161. edsl/results/Selector.py +0 -118
  162. edsl/results/tree_explore.py +0 -115
  163. edsl/surveys/instructions/ChangeInstruction.py +0 -47
  164. edsl/surveys/instructions/Instruction.py +0 -34
  165. edsl/surveys/instructions/InstructionCollection.py +0 -77
  166. edsl/surveys/instructions/__init__.py +0 -0
  167. edsl/templates/error_reporting/base.html +0 -24
  168. edsl/templates/error_reporting/exceptions_by_model.html +0 -35
  169. edsl/templates/error_reporting/exceptions_by_question_name.html +0 -17
  170. edsl/templates/error_reporting/exceptions_by_type.html +0 -17
  171. edsl/templates/error_reporting/interview_details.html +0 -116
  172. edsl/templates/error_reporting/interviews.html +0 -10
  173. edsl/templates/error_reporting/overview.html +0 -5
  174. edsl/templates/error_reporting/performance_plot.html +0 -2
  175. edsl/templates/error_reporting/report.css +0 -74
  176. edsl/templates/error_reporting/report.html +0 -118
  177. edsl/templates/error_reporting/report.js +0 -25
  178. edsl-0.1.33.dist-info/RECORD +0 -295
  179. {edsl-0.1.33.dist-info → edsl-0.1.33.dev1.dist-info}/LICENSE +0 -0
  180. {edsl-0.1.33.dist-info → edsl-0.1.33.dev1.dist-info}/WHEEL +0 -0
edsl/config.py CHANGED
@@ -1,66 +1,83 @@
1
1
  """This module provides a Config class that loads environment variables from a .env file and sets them as class attributes."""
2
2
 
3
3
  import os
4
- from dotenv import load_dotenv, find_dotenv
5
4
  from edsl.exceptions import (
6
5
  InvalidEnvironmentVariableError,
7
6
  MissingEnvironmentVariableError,
8
7
  )
8
+ from dotenv import load_dotenv, find_dotenv
9
9
 
10
10
  # valid values for EDSL_RUN_MODE
11
- EDSL_RUN_MODES = [
12
- "development",
13
- "development-testrun",
14
- "production",
15
- ]
11
+ EDSL_RUN_MODES = ["development", "development-testrun", "production"]
16
12
 
17
13
  # `default` is used to impute values only in "production" mode
18
14
  # `info` gives a brief description of the env var
19
15
  CONFIG_MAP = {
20
16
  "EDSL_RUN_MODE": {
21
17
  "default": "production",
22
- "info": "This config var determines the run mode of the application.",
18
+ "info": "This env var determines the run mode of the application.",
19
+ },
20
+ "EDSL_DATABASE_PATH": {
21
+ "default": f"sqlite:///{os.path.join(os.getcwd(), '.edsl_cache/data.db')}",
22
+ "info": "This env var determines the path to the cache file.",
23
+ },
24
+ "EDSL_LOGGING_PATH": {
25
+ "default": f"{os.path.join(os.getcwd(), 'interview.log')}",
26
+ "info": "This env var determines the path to the log file.",
23
27
  },
24
28
  "EDSL_API_TIMEOUT": {
25
29
  "default": "60",
26
- "info": "This config var determines the maximum number of seconds to wait for an API call to return.",
30
+ "info": "This env var determines the maximum number of seconds to wait for an API call to return.",
27
31
  },
28
32
  "EDSL_BACKOFF_START_SEC": {
29
33
  "default": "1",
30
- "info": "This config var determines the number of seconds to wait before retrying a failed API call.",
34
+ "info": "This env var determines the number of seconds to wait before retrying a failed API call.",
31
35
  },
32
- "EDSL_BACKOFF_MAX_SEC": {
36
+ "EDSL_MAX_BACKOFF_SEC": {
33
37
  "default": "60",
34
- "info": "This config var determines the maximum number of seconds to wait before retrying a failed API call.",
35
- },
36
- "EDSL_DATABASE_PATH": {
37
- "default": f"sqlite:///{os.path.join(os.getcwd(), '.edsl_cache/data.db')}",
38
- "info": "This config var determines the path to the cache file.",
39
- },
40
- "EDSL_DEFAULT_MODEL": {
41
- "default": "gpt-4o",
42
- "info": "This config var holds the default model that will be used if a model is not explicitly passed.",
43
- },
44
- "EDSL_FETCH_TOKEN_PRICES": {
45
- "default": "True",
46
- "info": "This config var determines whether to fetch prices for tokens used in remote inference",
38
+ "info": "This env var determines the maximum number of seconds to wait before retrying a failed API call.",
47
39
  },
48
40
  "EDSL_MAX_ATTEMPTS": {
49
41
  "default": "5",
50
- "info": "This config var determines the maximum number of times to retry a failed API call.",
51
- },
52
- "EDSL_SERVICE_RPM_BASELINE": {
53
- "default": "100",
54
- "info": "This config var holds the maximum number of requests per minute. Model-specific values provided in env vars such as EDSL_SERVICE_RPM_OPENAI will override this. value for the corresponding model",
55
- },
56
- "EDSL_SERVICE_TPM_BASELINE": {
57
- "default": "2000000",
58
- "info": "This config var holds the maximum number of tokens per minute for all models. Model-specific values provided in env vars such as EDSL_SERVICE_TPM_OPENAI will override this value for the corresponding model.",
42
+ "info": "This env var determines the maximum number of times to retry a failed API call.",
59
43
  },
60
44
  "EXPECTED_PARROT_URL": {
61
45
  "default": "https://www.expectedparrot.com",
62
- "info": "This config var holds the URL of the Expected Parrot API.",
46
+ "info": "This env var holds the URL of the Expected Parrot API.",
63
47
  },
48
+ # "EXPECTED_PARROT_API_KEY": {
49
+ # "default": None,
50
+ # "info": "This env var holds your Expected Parrot API key (https://www.expectedparrot.com/).",
51
+ # },
52
+ # "OPENAI_API_KEY": {
53
+ # "default": None,
54
+ # "info": "This env var holds your OpenAI API key (https://platform.openai.com/api-keys).",
55
+ # },
56
+ # "DEEP_INFRA_API_KEY": {
57
+ # "default": None,
58
+ # "info": "This env var holds your DeepInfra API key (https://deepinfra.com/).",
59
+ # },
60
+ # "GOOGLE_API_KEY": {
61
+ # "default": None,
62
+ # "info": "This env var holds your Google API key (https://console.cloud.google.com/apis/credentials).",
63
+ # },
64
+ # "ANTHROPIC_API_KEY": {
65
+ # "default": None,
66
+ # "info": "This env var holds your Anthropic API key (https://www.anthropic.com/).",
67
+ # },
68
+ # "GROQ_API_KEY": {
69
+ # "default": None,
70
+ # "info": "This env var holds your GROQ API key (https://console.groq.com/login).",
71
+ # },
72
+ # "AWS_ACCESS_KEY_ID" :
73
+ # "default": None,
74
+ # "info": "This env var holds your AWS access key ID.",
75
+ # "AWS_SECRET_ACCESS_KEY:
76
+ # "default": None,
77
+ # "info": "This env var holds your AWS secret access key.",
78
+ # "AZURE_ENDPOINT_URL_AND_KEY":
79
+ # "default": None,
80
+ # "info": "This env var holds your Azure endpoint URL and key (URL:key). You can have several comma-separated URL-key pairs (URL1:key1,URL2:key2).",
64
81
  }
65
82
 
66
83
 
@@ -75,7 +92,7 @@ class Config:
75
92
 
76
93
  def _set_run_mode(self) -> None:
77
94
  """
78
- Sets EDSL_RUN_MODE as a class attribute.
95
+ Checks the validity and sets EDSL_RUN_MODE.
79
96
  """
80
97
  run_mode = os.getenv("EDSL_RUN_MODE")
81
98
  default = CONFIG_MAP.get("EDSL_RUN_MODE").get("default")
@@ -90,35 +107,27 @@ class Config:
90
107
  def _load_dotenv(self) -> None:
91
108
  """
92
109
  Loads the .env
93
- - The .env will override existing env vars **unless** EDSL_RUN_MODE=="development-testrun"
110
+ - Overrides existing env vars unless EDSL_RUN_MODE=="development-testrun"
94
111
  """
95
112
 
113
+ override = True
96
114
  if self.EDSL_RUN_MODE == "development-testrun":
97
115
  override = False
98
- else:
99
- override = True
100
116
  _ = load_dotenv(dotenv_path=find_dotenv(usecwd=True), override=override)
101
117
 
102
- def __contains__(self, env_var: str) -> bool:
103
- """
104
- Checks if an env var is set as a class attribute.
105
- """
106
- return env_var in self.__dict__
107
-
108
118
  def _set_env_vars(self) -> None:
109
119
  """
110
- Sets env vars as class attributes.
111
- - EDSL_RUN_MODE is not set my this method, but by _set_run_mode
120
+ Sets env vars as Config class attributes.
112
121
  - If an env var is not set and has a default value in the CONFIG_MAP, sets it to the default value.
113
122
  """
114
123
  # for each env var in the CONFIG_MAP
115
124
  for env_var, config in CONFIG_MAP.items():
116
- # EDSL_RUN_MODE is already set by _set_run_mode
125
+ # we've set it already in _set_run_mode
117
126
  if env_var == "EDSL_RUN_MODE":
118
127
  continue
119
128
  value = os.getenv(env_var)
120
129
  default_value = config.get("default")
121
- # if an env var exists, set it as a class attribute
130
+ # if the env var is set, set it as a CONFIG attribute
122
131
  if value:
123
132
  setattr(self, env_var, value)
124
133
  # otherwise, if EDSL_RUN_MODE == "production" set it to its default value
edsl/coop/coop.py CHANGED
@@ -53,39 +53,25 @@ class Coop:
53
53
  method: str,
54
54
  payload: Optional[dict[str, Any]] = None,
55
55
  params: Optional[dict[str, Any]] = None,
56
- timeout: Optional[float] = 5,
57
56
  ) -> requests.Response:
58
57
  """
59
58
  Send a request to the server and return the response.
60
59
  """
61
60
  url = f"{self.url}/{uri}"
62
- method = method.upper()
63
- if payload is None:
64
- timeout = 20
65
- elif (
66
- method.upper() == "POST"
67
- and "json_string" in payload
68
- and payload.get("json_string") is not None
69
- ):
70
- timeout = max(20, (len(payload.get("json_string", "")) // (1024 * 1024)))
71
61
  try:
62
+ method = method.upper()
72
63
  if method in ["GET", "DELETE"]:
73
64
  response = requests.request(
74
- method, url, params=params, headers=self.headers, timeout=timeout
65
+ method, url, params=params, headers=self.headers
75
66
  )
76
67
  elif method in ["POST", "PATCH"]:
77
68
  response = requests.request(
78
- method,
79
- url,
80
- params=params,
81
- json=payload,
82
- headers=self.headers,
83
- timeout=timeout,
69
+ method, url, params=params, json=payload, headers=self.headers
84
70
  )
85
71
  else:
86
72
  raise Exception(f"Invalid {method=}.")
87
73
  except requests.ConnectionError:
88
- raise requests.ConnectionError(f"Could not connect to the server at {url}.")
74
+ raise requests.ConnectionError("Could not connect to the server.")
89
75
 
90
76
  return response
91
77
 
@@ -95,7 +81,6 @@ class Coop:
95
81
  """
96
82
  if response.status_code >= 400:
97
83
  message = response.json().get("detail")
98
- # print(response.text)
99
84
  if "Authorization" in message:
100
85
  print(message)
101
86
  message = "Please provide an Expected Parrot API key."
@@ -125,18 +110,10 @@ class Coop:
125
110
  def edsl_settings(self) -> dict:
126
111
  """
127
112
  Retrieve and return the EDSL settings stored on Coop.
128
- If no response is received within 5 seconds, return an empty dict.
129
113
  """
130
- from requests.exceptions import Timeout
131
-
132
- try:
133
- response = self._send_server_request(
134
- uri="api/v0/edsl-settings", method="GET", timeout=5
135
- )
136
- self._resolve_server_response(response)
137
- return response.json()
138
- except Timeout:
139
- return {}
114
+ response = self._send_server_request(uri="api/v0/edsl-settings", method="GET")
115
+ self._resolve_server_response(response)
116
+ return response.json()
140
117
 
141
118
  ################
142
119
  # Objects
@@ -648,26 +625,6 @@ class Coop:
648
625
 
649
626
  return response_json
650
627
 
651
- def fetch_prices(self) -> dict:
652
- from edsl.coop.PriceFetcher import PriceFetcher
653
-
654
- from edsl.config import CONFIG
655
-
656
- if bool(CONFIG.get("EDSL_FETCH_TOKEN_PRICES")):
657
- price_fetcher = PriceFetcher()
658
- return price_fetcher.fetch_prices()
659
- else:
660
- return {}
661
-
662
-
663
- if __name__ == "__main__":
664
- sheet_data = fetch_sheet_data()
665
- if sheet_data:
666
- print(f"Successfully fetched {len(sheet_data)} rows of data.")
667
- print("First row:", sheet_data[0])
668
- else:
669
- print("Failed to fetch sheet data.")
670
-
671
628
 
672
629
  def main():
673
630
  """
edsl/data/Cache.py CHANGED
@@ -6,7 +6,6 @@ from __future__ import annotations
6
6
  import json
7
7
  import os
8
8
  import warnings
9
- import copy
10
9
  from typing import Optional, Union
11
10
  from edsl.Base import Base
12
11
  from edsl.data.CacheEntry import CacheEntry
@@ -89,24 +88,11 @@ class Cache(Base):
89
88
  # raise NotImplementedError("This method is not implemented yet.")
90
89
 
91
90
  def keys(self):
92
- """
93
- >>> from edsl import Cache
94
- >>> Cache.example().keys()
95
- ['5513286eb6967abc0511211f0402587d']
96
- """
97
91
  return list(self.data.keys())
98
92
 
99
93
  def values(self):
100
- """
101
- >>> from edsl import Cache
102
- >>> Cache.example().values()
103
- [CacheEntry(...)]
104
- """
105
94
  return list(self.data.values())
106
95
 
107
- def items(self):
108
- return zip(self.keys(), self.values())
109
-
110
96
  def new_entries_cache(self) -> Cache:
111
97
  """Return a new Cache object with the new entries."""
112
98
  return Cache(data={**self.new_entries, **self.fetched_data})
@@ -174,7 +160,7 @@ class Cache(Base):
174
160
  parameters: str,
175
161
  system_prompt: str,
176
162
  user_prompt: str,
177
- response: dict,
163
+ response: str,
178
164
  iteration: int,
179
165
  ) -> str:
180
166
  """
@@ -188,15 +174,6 @@ class Cache(Base):
188
174
  * The key-value pair is added to `self.new_entries`
189
175
  * If `immediate_write` is True , the key-value pair is added to `self.data`
190
176
  * If `immediate_write` is False, the key-value pair is added to `self.new_entries_to_write_later`
191
-
192
- >>> from edsl import Cache, Model, Question
193
- >>> m = Model("test")
194
- >>> c = Cache()
195
- >>> len(c)
196
- 0
197
- >>> results = Question.example("free_text").by(m).run(cache = c)
198
- >>> len(c)
199
- 1
200
177
  """
201
178
 
202
179
  entry = CacheEntry(
@@ -349,17 +326,6 @@ class Cache(Base):
349
326
  for key, value in self.data.items():
350
327
  f.write(json.dumps({key: value.to_dict()}) + "\n")
351
328
 
352
- def to_scenario_list(self):
353
- from edsl import ScenarioList, Scenario
354
-
355
- scenarios = []
356
- for key, value in self.data.items():
357
- new_d = value.to_dict()
358
- new_d["cache_key"] = key
359
- s = Scenario(new_d)
360
- scenarios.append(s)
361
- return ScenarioList(scenarios)
362
-
363
329
  ####################
364
330
  # REMOTE
365
331
  ####################
@@ -1,73 +1,38 @@
1
- from typing import NamedTuple, Dict, List, Optional, Any
2
- from dataclasses import dataclass, fields
3
- import reprlib
4
-
5
-
6
- class ModelInputs(NamedTuple):
7
- "This is what was send by the agent to the model"
8
- user_prompt: str
9
- system_prompt: str
10
- encoded_image: Optional[str] = None
11
-
12
-
13
- class EDSLOutput(NamedTuple):
14
- "This is the edsl dictionary that is returned by the model"
15
- answer: Any
16
- generated_tokens: str
17
- comment: Optional[str] = None
18
-
19
-
20
- class ModelResponse(NamedTuple):
21
- "This is the metadata that is returned by the model and includes info about the cache"
22
- response: dict
23
- cache_used: bool
24
- cache_key: str
25
- cached_response: Optional[Dict[str, Any]] = None
26
- cost: Optional[float] = None
27
-
28
-
29
- class AgentResponseDict(NamedTuple):
30
- edsl_dict: EDSLOutput
31
- model_inputs: ModelInputs
32
- model_outputs: ModelResponse
33
-
34
-
35
- class EDSLResultObjectInput(NamedTuple):
36
- generated_tokens: str
37
- question_name: str
38
- prompts: dict
39
- cached_response: str
40
- raw_model_response: str
41
- cache_used: bool
42
- cache_key: str
43
- answer: Any
44
- comment: str
45
- validated: bool = False
46
- exception_occurred: Exception = None
47
- cost: Optional[float] = None
48
-
49
-
50
- @dataclass
51
- class ImageInfo:
52
- file_path: str
53
- file_name: str
54
- image_format: str
55
- file_size: int
56
- encoded_image: str
57
-
58
- def __repr__(self):
59
- reprlib_instance = reprlib.Repr()
60
- reprlib_instance.maxstring = 30 # Limit the string length for the encoded image
61
-
62
- # Get all fields except encoded_image
63
- field_reprs = [
64
- f"{f.name}={getattr(self, f.name)!r}"
65
- for f in fields(self)
66
- if f.name != "encoded_image"
67
- ]
68
-
69
- # Add the reprlib-restricted encoded_image field
70
- field_reprs.append(f"encoded_image={reprlib_instance.repr(self.encoded_image)}")
71
-
72
- # Join everything to create the repr
73
- return f"{self.__class__.__name__}({', '.join(field_reprs)})"
1
+ """This module contains the data transfer models for the application."""
2
+
3
+ from collections import UserDict
4
+
5
+
6
+ class AgentResponseDict(UserDict):
7
+ """A dictionary to store the response of the agent to a question."""
8
+
9
+ def __init__(
10
+ self,
11
+ *,
12
+ question_name,
13
+ answer,
14
+ prompts,
15
+ usage=None,
16
+ comment=None,
17
+ cached_response=None,
18
+ raw_model_response=None,
19
+ simple_model_raw_response=None,
20
+ cache_used=None,
21
+ cache_key=None,
22
+ ):
23
+ """Initialize the AgentResponseDict object."""
24
+ usage = usage or {"prompt_tokens": 0, "completion_tokens": 0}
25
+ super().__init__(
26
+ {
27
+ "answer": answer,
28
+ "comment": comment,
29
+ "question_name": question_name,
30
+ "prompts": prompts,
31
+ "usage": usage,
32
+ "cached_response": cached_response,
33
+ "raw_model_response": raw_model_response,
34
+ "simple_model_raw_response": simple_model_raw_response,
35
+ "cache_used": cache_used,
36
+ "cache_key": cache_key,
37
+ }
38
+ )
edsl/enums.py CHANGED
@@ -62,8 +62,6 @@ class InferenceServiceType(EnumWithChecks):
62
62
  GROQ = "groq"
63
63
  AZURE = "azure"
64
64
  OLLAMA = "ollama"
65
- MISTRAL = "mistral"
66
- TOGETHER = "together"
67
65
 
68
66
 
69
67
  service_to_api_keyname = {
@@ -76,8 +74,6 @@ service_to_api_keyname = {
76
74
  InferenceServiceType.ANTHROPIC.value: "ANTHROPIC_API_KEY",
77
75
  InferenceServiceType.GROQ.value: "GROQ_API_KEY",
78
76
  InferenceServiceType.BEDROCK.value: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"],
79
- InferenceServiceType.MISTRAL.value: "MISTRAL_API_KEY",
80
- InferenceServiceType.TOGETHER.value: "TOGETHER_API_KEY",
81
77
  }
82
78
 
83
79
 
@@ -1,32 +1,8 @@
1
1
  from textwrap import dedent
2
- from typing import Optional
3
2
 
4
3
 
5
4
  class LanguageModelExceptions(Exception):
6
- explanation = (
7
- "This is the base class for all exceptions in the LanguageModel class."
8
- )
9
-
10
- def __init__(self, message):
11
- super().__init__(message)
12
- self.message = message
13
-
14
-
15
- class LanguageModelNoResponseError(LanguageModelExceptions):
16
- explanation = (
17
- """This happens when the LLM API cannot be reached and/or does not respond."""
18
- )
19
-
20
- def __init__(self, message):
21
- super().__init__(message)
22
-
23
-
24
- class LanguageModelBadResponseError(LanguageModelExceptions):
25
- explanation = """This happens when the LLM API can be reached and responds, does not return a usable answer."""
26
-
27
- def __init__(self, message, response_json: Optional[dict] = None):
28
- super().__init__(message)
29
- self.response_json = response_json
5
+ pass
30
6
 
31
7
 
32
8
  class LanguageModelNotFound(LanguageModelExceptions):
@@ -1,66 +1,5 @@
1
- from typing import Any, SupportsIndex
2
- from jinja2 import Template
3
- import json
4
-
5
-
6
1
  class QuestionErrors(Exception):
7
- """
8
- Base exception class for question-related errors.
9
- """
10
-
11
- def __init__(self, message="An error occurred with the question"):
12
- self.message = message
13
- super().__init__(self.message)
14
-
15
-
16
- class QuestionAnswerValidationError(QuestionErrors):
17
- documentation = "https://docs.expectedparrot.com/en/latest/exceptions.html"
18
-
19
- explanation = """This when the answer coming from the Language Model does not conform to the expectation for that question type.
20
- For example, if the question is a multiple choice question, the answer should be drawn from the list of options provided.
21
- """
22
-
23
- def __init__(self, message="Invalid answer.", data=None, model=None):
24
- self.message = message
25
- self.data = data
26
- self.model = model
27
- super().__init__(self.message)
28
-
29
- def __str__(self):
30
- return f"""{repr(self)}
31
- Data being validated: {self.data}
32
- Pydnantic Model: {self.model}.
33
- Reported error: {self.message}."""
34
-
35
- def to_html_dict(self):
36
- return {
37
- "error_type": ("Name of the exception", "p", "/p", self.__class__.__name__),
38
- "explaination": ("Explanation", "p", "/p", self.explanation),
39
- "edsl answer": (
40
- "What model returned",
41
- "pre",
42
- "/pre",
43
- json.dumps(self.data, indent=2),
44
- ),
45
- "validating_model": (
46
- "Pydantic model for answers",
47
- "pre",
48
- "/pre",
49
- json.dumps(self.model.model_json_schema(), indent=2),
50
- ),
51
- "error_message": (
52
- "Error message Pydantic returned",
53
- "p",
54
- "/p",
55
- self.message,
56
- ),
57
- "documentation_url": (
58
- "URL to EDSL docs",
59
- f"a href='{self.documentation}'",
60
- "/a",
61
- self.documentation,
62
- ),
63
- }
2
+ pass
64
3
 
65
4
 
66
5
  class QuestionCreationValidationError(QuestionErrors):
@@ -71,6 +10,10 @@ class QuestionResponseValidationError(QuestionErrors):
71
10
  pass
72
11
 
73
12
 
13
+ class QuestionAnswerValidationError(QuestionErrors):
14
+ pass
15
+
16
+
74
17
  class QuestionAttributeMissing(QuestionErrors):
75
18
  pass
76
19
 
@@ -2,10 +2,6 @@ class ResultsErrors(Exception):
2
2
  pass
3
3
 
4
4
 
5
- class ResultsDeserializationError(ResultsErrors):
6
- pass
7
-
8
-
9
5
  class ResultsBadMutationstringError(ResultsErrors):
10
6
  pass
11
7
 
@@ -11,11 +11,6 @@ class AnthropicService(InferenceServiceABC):
11
11
 
12
12
  _inference_service_ = "anthropic"
13
13
  _env_key_name_ = "ANTHROPIC_API_KEY"
14
- key_sequence = ["content", 0, "text"] # ["content"][0]["text"]
15
- usage_sequence = ["usage"]
16
- input_token_name = "input_tokens"
17
- output_token_name = "output_tokens"
18
- model_exclude_list = []
19
14
 
20
15
  @classmethod
21
16
  def available(cls):
@@ -39,11 +34,6 @@ class AnthropicService(InferenceServiceABC):
39
34
  Child class of LanguageModel for interacting with OpenAI models
40
35
  """
41
36
 
42
- key_sequence = cls.key_sequence
43
- usage_sequence = cls.usage_sequence
44
- input_token_name = cls.input_token_name
45
- output_token_name = cls.output_token_name
46
-
47
37
  _inference_service_ = cls._inference_service_
48
38
  _model_ = model_name
49
39
  _parameters_ = {
@@ -56,9 +46,6 @@ class AnthropicService(InferenceServiceABC):
56
46
  "top_logprobs": 3,
57
47
  }
58
48
 
59
- _tpm = cls.get_tpm(cls)
60
- _rpm = cls.get_rpm(cls)
61
-
62
49
  async def async_execute_model_call(
63
50
  self, user_prompt: str, system_prompt: str = ""
64
51
  ) -> dict[str, Any]:
@@ -79,6 +66,17 @@ class AnthropicService(InferenceServiceABC):
79
66
  )
80
67
  return response.model_dump()
81
68
 
69
+ @staticmethod
70
+ def parse_response(raw_response: dict[str, Any]) -> str:
71
+ """Parses the API response and returns the response text."""
72
+ response = raw_response["content"][0]["text"]
73
+ pattern = r"^```json(?:\\n|\n)(.+?)(?:\\n|\n)```$"
74
+ match = re.match(pattern, response, re.DOTALL)
75
+ if match:
76
+ return match.group(1)
77
+ else:
78
+ return response
79
+
82
80
  LLM.__name__ = model_class_name
83
81
 
84
82
  return LLM