edsl 0.1.39__py3-none-any.whl → 0.1.39.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (212) hide show
  1. edsl/Base.py +116 -197
  2. edsl/__init__.py +7 -15
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +147 -351
  5. edsl/agents/AgentList.py +73 -211
  6. edsl/agents/Invigilator.py +50 -101
  7. edsl/agents/InvigilatorBase.py +70 -62
  8. edsl/agents/PromptConstructor.py +225 -143
  9. edsl/agents/__init__.py +1 -0
  10. edsl/agents/prompt_helpers.py +3 -3
  11. edsl/auto/AutoStudy.py +5 -18
  12. edsl/auto/StageBase.py +40 -53
  13. edsl/auto/StageQuestions.py +1 -2
  14. edsl/auto/utilities.py +6 -0
  15. edsl/config.py +2 -22
  16. edsl/conversation/car_buying.py +1 -2
  17. edsl/coop/PriceFetcher.py +1 -1
  18. edsl/coop/coop.py +47 -125
  19. edsl/coop/utils.py +14 -14
  20. edsl/data/Cache.py +27 -45
  21. edsl/data/CacheEntry.py +15 -12
  22. edsl/data/CacheHandler.py +12 -31
  23. edsl/data/RemoteCacheSync.py +46 -154
  24. edsl/data/__init__.py +3 -4
  25. edsl/data_transfer_models.py +1 -2
  26. edsl/enums.py +0 -27
  27. edsl/exceptions/__init__.py +50 -50
  28. edsl/exceptions/agents.py +0 -12
  29. edsl/exceptions/questions.py +6 -24
  30. edsl/exceptions/scenarios.py +0 -7
  31. edsl/inference_services/AnthropicService.py +19 -38
  32. edsl/inference_services/AwsBedrock.py +2 -0
  33. edsl/inference_services/AzureAI.py +2 -0
  34. edsl/inference_services/GoogleService.py +12 -7
  35. edsl/inference_services/InferenceServiceABC.py +85 -18
  36. edsl/inference_services/InferenceServicesCollection.py +79 -120
  37. edsl/inference_services/MistralAIService.py +3 -0
  38. edsl/inference_services/OpenAIService.py +35 -47
  39. edsl/inference_services/PerplexityService.py +3 -0
  40. edsl/inference_services/TestService.py +10 -11
  41. edsl/inference_services/TogetherAIService.py +3 -5
  42. edsl/jobs/Answers.py +14 -1
  43. edsl/jobs/Jobs.py +431 -356
  44. edsl/jobs/JobsChecks.py +10 -35
  45. edsl/jobs/JobsPrompts.py +4 -6
  46. edsl/jobs/JobsRemoteInferenceHandler.py +133 -205
  47. edsl/jobs/buckets/BucketCollection.py +3 -44
  48. edsl/jobs/buckets/TokenBucket.py +21 -53
  49. edsl/jobs/interviews/Interview.py +408 -143
  50. edsl/jobs/runners/JobsRunnerAsyncio.py +403 -88
  51. edsl/jobs/runners/JobsRunnerStatus.py +165 -133
  52. edsl/jobs/tasks/QuestionTaskCreator.py +19 -21
  53. edsl/jobs/tasks/TaskHistory.py +18 -38
  54. edsl/jobs/tasks/task_status_enum.py +2 -0
  55. edsl/language_models/KeyLookup.py +30 -0
  56. edsl/language_models/LanguageModel.py +236 -194
  57. edsl/language_models/ModelList.py +19 -28
  58. edsl/language_models/__init__.py +2 -1
  59. edsl/language_models/registry.py +190 -0
  60. edsl/language_models/repair.py +2 -2
  61. edsl/language_models/unused/ReplicateBase.py +83 -0
  62. edsl/language_models/utilities.py +4 -5
  63. edsl/notebooks/Notebook.py +14 -19
  64. edsl/prompts/Prompt.py +39 -29
  65. edsl/questions/{answer_validator_mixin.py → AnswerValidatorMixin.py} +2 -47
  66. edsl/questions/QuestionBase.py +214 -68
  67. edsl/questions/{question_base_gen_mixin.py → QuestionBaseGenMixin.py} +50 -57
  68. edsl/questions/QuestionBasePromptsMixin.py +3 -7
  69. edsl/questions/QuestionBudget.py +1 -1
  70. edsl/questions/QuestionCheckBox.py +3 -3
  71. edsl/questions/QuestionExtract.py +7 -5
  72. edsl/questions/QuestionFreeText.py +3 -2
  73. edsl/questions/QuestionList.py +18 -10
  74. edsl/questions/QuestionMultipleChoice.py +23 -67
  75. edsl/questions/QuestionNumerical.py +4 -2
  76. edsl/questions/QuestionRank.py +17 -7
  77. edsl/questions/{response_validator_abc.py → ResponseValidatorABC.py} +26 -40
  78. edsl/questions/SimpleAskMixin.py +3 -4
  79. edsl/questions/__init__.py +1 -2
  80. edsl/questions/derived/QuestionLinearScale.py +3 -6
  81. edsl/questions/derived/QuestionTopK.py +1 -1
  82. edsl/questions/descriptors.py +3 -17
  83. edsl/questions/question_registry.py +1 -1
  84. edsl/results/CSSParameterizer.py +1 -1
  85. edsl/results/Dataset.py +7 -170
  86. edsl/results/DatasetExportMixin.py +305 -168
  87. edsl/results/DatasetTree.py +8 -28
  88. edsl/results/Result.py +206 -298
  89. edsl/results/Results.py +131 -149
  90. edsl/results/ResultsDBMixin.py +238 -0
  91. edsl/results/ResultsExportMixin.py +0 -2
  92. edsl/results/{results_selector.py → Selector.py} +13 -23
  93. edsl/results/TableDisplay.py +171 -98
  94. edsl/results/__init__.py +1 -1
  95. edsl/scenarios/FileStore.py +239 -150
  96. edsl/scenarios/Scenario.py +193 -90
  97. edsl/scenarios/ScenarioHtmlMixin.py +3 -4
  98. edsl/scenarios/{scenario_join.py → ScenarioJoin.py} +6 -10
  99. edsl/scenarios/ScenarioList.py +244 -415
  100. edsl/scenarios/ScenarioListExportMixin.py +7 -0
  101. edsl/scenarios/ScenarioListPdfMixin.py +37 -15
  102. edsl/scenarios/__init__.py +2 -1
  103. edsl/study/ObjectEntry.py +1 -1
  104. edsl/study/SnapShot.py +1 -1
  105. edsl/study/Study.py +12 -5
  106. edsl/surveys/Rule.py +4 -5
  107. edsl/surveys/RuleCollection.py +27 -25
  108. edsl/surveys/Survey.py +791 -270
  109. edsl/surveys/SurveyCSS.py +8 -20
  110. edsl/surveys/{SurveyFlowVisualization.py → SurveyFlowVisualizationMixin.py} +9 -11
  111. edsl/surveys/__init__.py +2 -4
  112. edsl/surveys/descriptors.py +2 -6
  113. edsl/surveys/instructions/ChangeInstruction.py +2 -1
  114. edsl/surveys/instructions/Instruction.py +13 -4
  115. edsl/surveys/instructions/InstructionCollection.py +6 -11
  116. edsl/templates/error_reporting/interview_details.html +1 -1
  117. edsl/templates/error_reporting/report.html +1 -1
  118. edsl/tools/plotting.py +1 -1
  119. edsl/utilities/utilities.py +23 -35
  120. {edsl-0.1.39.dist-info → edsl-0.1.39.dev1.dist-info}/METADATA +10 -12
  121. edsl-0.1.39.dev1.dist-info/RECORD +277 -0
  122. {edsl-0.1.39.dist-info → edsl-0.1.39.dev1.dist-info}/WHEEL +1 -1
  123. edsl/agents/QuestionInstructionPromptBuilder.py +0 -128
  124. edsl/agents/QuestionTemplateReplacementsBuilder.py +0 -137
  125. edsl/agents/question_option_processor.py +0 -172
  126. edsl/coop/CoopFunctionsMixin.py +0 -15
  127. edsl/coop/ExpectedParrotKeyHandler.py +0 -125
  128. edsl/exceptions/inference_services.py +0 -5
  129. edsl/inference_services/AvailableModelCacheHandler.py +0 -184
  130. edsl/inference_services/AvailableModelFetcher.py +0 -215
  131. edsl/inference_services/ServiceAvailability.py +0 -135
  132. edsl/inference_services/data_structures.py +0 -134
  133. edsl/jobs/AnswerQuestionFunctionConstructor.py +0 -223
  134. edsl/jobs/FetchInvigilator.py +0 -47
  135. edsl/jobs/InterviewTaskManager.py +0 -98
  136. edsl/jobs/InterviewsConstructor.py +0 -50
  137. edsl/jobs/JobsComponentConstructor.py +0 -189
  138. edsl/jobs/JobsRemoteInferenceLogger.py +0 -239
  139. edsl/jobs/RequestTokenEstimator.py +0 -30
  140. edsl/jobs/async_interview_runner.py +0 -138
  141. edsl/jobs/buckets/TokenBucketAPI.py +0 -211
  142. edsl/jobs/buckets/TokenBucketClient.py +0 -191
  143. edsl/jobs/check_survey_scenario_compatibility.py +0 -85
  144. edsl/jobs/data_structures.py +0 -120
  145. edsl/jobs/decorators.py +0 -35
  146. edsl/jobs/jobs_status_enums.py +0 -9
  147. edsl/jobs/loggers/HTMLTableJobLogger.py +0 -304
  148. edsl/jobs/results_exceptions_handler.py +0 -98
  149. edsl/language_models/ComputeCost.py +0 -63
  150. edsl/language_models/PriceManager.py +0 -127
  151. edsl/language_models/RawResponseHandler.py +0 -106
  152. edsl/language_models/ServiceDataSources.py +0 -0
  153. edsl/language_models/key_management/KeyLookup.py +0 -63
  154. edsl/language_models/key_management/KeyLookupBuilder.py +0 -273
  155. edsl/language_models/key_management/KeyLookupCollection.py +0 -38
  156. edsl/language_models/key_management/__init__.py +0 -0
  157. edsl/language_models/key_management/models.py +0 -131
  158. edsl/language_models/model.py +0 -256
  159. edsl/notebooks/NotebookToLaTeX.py +0 -142
  160. edsl/questions/ExceptionExplainer.py +0 -77
  161. edsl/questions/HTMLQuestion.py +0 -103
  162. edsl/questions/QuestionMatrix.py +0 -265
  163. edsl/questions/data_structures.py +0 -20
  164. edsl/questions/loop_processor.py +0 -149
  165. edsl/questions/response_validator_factory.py +0 -34
  166. edsl/questions/templates/matrix/__init__.py +0 -1
  167. edsl/questions/templates/matrix/answering_instructions.jinja +0 -5
  168. edsl/questions/templates/matrix/question_presentation.jinja +0 -20
  169. edsl/results/MarkdownToDocx.py +0 -122
  170. edsl/results/MarkdownToPDF.py +0 -111
  171. edsl/results/TextEditor.py +0 -50
  172. edsl/results/file_exports.py +0 -252
  173. edsl/results/smart_objects.py +0 -96
  174. edsl/results/table_data_class.py +0 -12
  175. edsl/results/table_renderers.py +0 -118
  176. edsl/scenarios/ConstructDownloadLink.py +0 -109
  177. edsl/scenarios/DocumentChunker.py +0 -102
  178. edsl/scenarios/DocxScenario.py +0 -16
  179. edsl/scenarios/PdfExtractor.py +0 -40
  180. edsl/scenarios/directory_scanner.py +0 -96
  181. edsl/scenarios/file_methods.py +0 -85
  182. edsl/scenarios/handlers/__init__.py +0 -13
  183. edsl/scenarios/handlers/csv.py +0 -49
  184. edsl/scenarios/handlers/docx.py +0 -76
  185. edsl/scenarios/handlers/html.py +0 -37
  186. edsl/scenarios/handlers/json.py +0 -111
  187. edsl/scenarios/handlers/latex.py +0 -5
  188. edsl/scenarios/handlers/md.py +0 -51
  189. edsl/scenarios/handlers/pdf.py +0 -68
  190. edsl/scenarios/handlers/png.py +0 -39
  191. edsl/scenarios/handlers/pptx.py +0 -105
  192. edsl/scenarios/handlers/py.py +0 -294
  193. edsl/scenarios/handlers/sql.py +0 -313
  194. edsl/scenarios/handlers/sqlite.py +0 -149
  195. edsl/scenarios/handlers/txt.py +0 -33
  196. edsl/scenarios/scenario_selector.py +0 -156
  197. edsl/surveys/ConstructDAG.py +0 -92
  198. edsl/surveys/EditSurvey.py +0 -221
  199. edsl/surveys/InstructionHandler.py +0 -100
  200. edsl/surveys/MemoryManagement.py +0 -72
  201. edsl/surveys/RuleManager.py +0 -172
  202. edsl/surveys/Simulator.py +0 -75
  203. edsl/surveys/SurveyToApp.py +0 -141
  204. edsl/utilities/PrettyList.py +0 -56
  205. edsl/utilities/is_notebook.py +0 -18
  206. edsl/utilities/is_valid_variable_name.py +0 -11
  207. edsl/utilities/remove_edsl_version.py +0 -24
  208. edsl-0.1.39.dist-info/RECORD +0 -358
  209. /edsl/questions/{register_questions_meta.py → RegisterQuestionsMeta.py} +0 -0
  210. /edsl/results/{results_fetch_mixin.py → ResultsFetchMixin.py} +0 -0
  211. /edsl/results/{results_tools_mixin.py → ResultsToolsMixin.py} +0 -0
  212. {edsl-0.1.39.dist-info → edsl-0.1.39.dev1.dist-info}/LICENSE +0 -0
edsl/jobs/JobsChecks.py CHANGED
@@ -1,21 +1,16 @@
1
1
  import os
2
- from edsl.exceptions.general import MissingAPIKeyError
2
+ from edsl.exceptions import MissingAPIKeyError
3
3
 
4
4
 
5
5
  class JobsChecks:
6
6
  def __init__(self, jobs):
7
- """Checks a Jobs object for missing API keys and other requirements."""
7
+ """ """
8
8
  self.jobs = jobs
9
9
 
10
10
  def check_api_keys(self) -> None:
11
- from edsl.language_models.model import Model
11
+ from edsl import Model
12
12
 
13
- if len(self.jobs.models) == 0:
14
- models = [Model()]
15
- else:
16
- models = self.jobs.models
17
-
18
- for model in models: # + [Model()]:
13
+ for model in self.jobs.models + [Model()]:
19
14
  if not model.has_valid_api_key():
20
15
  raise MissingAPIKeyError(
21
16
  model_name=str(model.model),
@@ -28,7 +23,7 @@ class JobsChecks:
28
23
  """
29
24
  missing_api_keys = set()
30
25
 
31
- from edsl.language_models.model import Model
26
+ from edsl import Model
32
27
  from edsl.enums import service_to_api_keyname
33
28
 
34
29
  for model in self.jobs.models + [Model()]:
@@ -100,33 +95,16 @@ class JobsChecks:
100
95
  return True
101
96
 
102
97
  def needs_key_process(self):
103
- """
104
- A User needs the key process when:
105
- 1. They don't have all the model keys
106
- 2. They don't have the EP API
107
- 3. They need external LLMs to run the job
108
- """
109
98
  return (
110
99
  not self.user_has_all_model_keys()
111
100
  and not self.user_has_ep_api_key()
112
101
  and self.needs_external_llms()
113
102
  )
114
103
 
115
- def status(self) -> dict:
116
- """
117
- Returns a dictionary with the status of the job checks.
118
- """
119
- return {
120
- "user_has_ep_api_key": self.user_has_ep_api_key(),
121
- "user_has_all_model_keys": self.user_has_all_model_keys(),
122
- "needs_external_llms": self.needs_external_llms(),
123
- "needs_key_process": self.needs_key_process(),
124
- }
125
-
126
104
  def key_process(self):
127
105
  import secrets
128
106
  from dotenv import load_dotenv
129
- from edsl.config import CONFIG
107
+ from edsl import CONFIG
130
108
  from edsl.coop.coop import Coop
131
109
  from edsl.utilities.utilities import write_api_key_to_env
132
110
 
@@ -141,12 +119,10 @@ class JobsChecks:
141
119
  "\nYou can either add the missing keys to your .env file, or use remote inference."
142
120
  )
143
121
  print("Remote inference allows you to run jobs on our server.")
122
+ print("\n🚀 To use remote inference, sign up at the following link:")
144
123
 
145
124
  coop = Coop()
146
- coop._display_login_url(
147
- edsl_auth_token=edsl_auth_token,
148
- link_description="\n🚀 To use remote inference, sign up at the following link:",
149
- )
125
+ coop._display_login_url(edsl_auth_token=edsl_auth_token)
150
126
 
151
127
  print(
152
128
  "\nOnce you log in, we will automatically retrieve your Expected Parrot API key and continue your job remotely."
@@ -158,9 +134,8 @@ class JobsChecks:
158
134
  print("\nTimed out waiting for login. Please try again.")
159
135
  return
160
136
 
161
- path_to_env = write_api_key_to_env(api_key)
162
- print("\n✨ API key retrieved and written to .env file at the following path:")
163
- print(f" {path_to_env}")
137
+ write_api_key_to_env(api_key)
138
+ print("✨ API key retrieved and written to .env file.\n")
164
139
 
165
140
  # Retrieve API key so we can continue running the job
166
141
  load_dotenv()
edsl/jobs/JobsPrompts.py CHANGED
@@ -11,8 +11,6 @@ if TYPE_CHECKING:
11
11
  # from edsl.scenarios.ScenarioList import ScenarioList
12
12
  # from edsl.surveys.Survey import Survey
13
13
 
14
- from edsl.jobs.FetchInvigilator import FetchInvigilator
15
-
16
14
 
17
15
  class JobsPrompts:
18
16
  def __init__(self, jobs: "Jobs"):
@@ -25,7 +23,7 @@ class JobsPrompts:
25
23
  @property
26
24
  def price_lookup(self):
27
25
  if self._price_lookup is None:
28
- from edsl.coop.coop import Coop
26
+ from edsl import Coop
29
27
 
30
28
  c = Coop()
31
29
  self._price_lookup = c.fetch_prices()
@@ -50,8 +48,8 @@ class JobsPrompts:
50
48
 
51
49
  for interview_index, interview in enumerate(interviews):
52
50
  invigilators = [
53
- FetchInvigilator(interview)(question)
54
- for question in interview.survey.questions
51
+ interview._get_invigilator(question)
52
+ for question in self.survey.questions
55
53
  ]
56
54
  for _, invigilator in enumerate(invigilators):
57
55
  prompts = invigilator.get_prompts()
@@ -186,7 +184,7 @@ class JobsPrompts:
186
184
  data = []
187
185
  for interview in interviews:
188
186
  invigilators = [
189
- FetchInvigilator(interview)(question)
187
+ interview._get_invigilator(question)
190
188
  for question in self.survey.questions
191
189
  ]
192
190
  for invigilator in invigilators:
@@ -1,78 +1,47 @@
1
- from typing import Optional, Union, Literal, TYPE_CHECKING, NewType, Callable, Any
2
-
3
- from dataclasses import dataclass
4
-
5
-
6
- Seconds = NewType("Seconds", float)
7
- JobUUID = NewType("JobUUID", str)
8
-
1
+ from typing import Optional, Union, Literal
2
+ import requests
3
+ import sys
9
4
  from edsl.exceptions.coop import CoopServerResponseError
10
5
 
11
- if TYPE_CHECKING:
12
- from edsl.results.Results import Results
13
- from edsl.jobs.Jobs import Jobs
14
- from edsl.coop.coop import RemoteInferenceResponse, RemoteInferenceCreationInfo
15
- from edsl.jobs.JobsRemoteInferenceLogger import JobLogger
16
-
17
- from edsl.coop.coop import RemoteInferenceResponse, RemoteInferenceCreationInfo
18
-
19
- from edsl.jobs.jobs_status_enums import JobsStatus
20
- from edsl.coop.utils import VisibilityType
21
- from edsl.jobs.JobsRemoteInferenceLogger import JobLogger
22
-
23
-
24
- class RemoteJobConstants:
25
- """Constants for remote job handling."""
26
-
27
- REMOTE_JOB_POLL_INTERVAL = 1
28
- REMOTE_JOB_VERBOSE = False
29
- DISCORD_URL = "https://discord.com/invite/mxAYkjfy9m"
30
-
31
-
32
- @dataclass
33
- class RemoteJobInfo:
34
- creation_data: RemoteInferenceCreationInfo
35
- job_uuid: JobUUID
36
- logger: JobLogger
6
+ # from edsl.enums import VisibilityType
7
+ from edsl.results import Results
37
8
 
38
9
 
39
10
  class JobsRemoteInferenceHandler:
40
- def __init__(
41
- self,
42
- jobs: "Jobs",
43
- verbose: bool = RemoteJobConstants.REMOTE_JOB_VERBOSE,
44
- poll_interval: Seconds = RemoteJobConstants.REMOTE_JOB_POLL_INTERVAL,
45
- ):
46
- """Handles the creation and running of a remote inference job."""
11
+ def __init__(self, jobs, verbose=False, poll_interval=3):
12
+ """
13
+ >>> from edsl.jobs import Jobs
14
+ >>> jh = JobsRemoteInferenceHandler(Jobs.example(), verbose=True)
15
+ >>> jh.use_remote_inference(True)
16
+ False
17
+ >>> jh._poll_remote_inference_job({'uuid':1234}, testing_simulated_response={"status": "failed"}) # doctest: +NORMALIZE_WHITESPACE
18
+ Job failed.
19
+ ...
20
+ >>> jh._poll_remote_inference_job({'uuid':1234}, testing_simulated_response={"status": "completed"}) # doctest: +NORMALIZE_WHITESPACE
21
+ Job completed and Results stored on Coop: None.
22
+ Results(...)
23
+ """
47
24
  self.jobs = jobs
48
25
  self.verbose = verbose
49
26
  self.poll_interval = poll_interval
50
27
 
51
- from edsl.config import CONFIG
52
-
53
- self.expected_parrot_url = CONFIG.get("EXPECTED_PARROT_URL")
54
- self.remote_inference_url = f"{self.expected_parrot_url}/home/remote-inference"
28
+ self._remote_job_creation_data = None
29
+ self._job_uuid = None
55
30
 
56
- def _create_logger(self) -> JobLogger:
57
- from edsl.utilities.is_notebook import is_notebook
58
- from edsl.jobs.JobsRemoteInferenceLogger import (
59
- JupyterJobLogger,
60
- StdOutJobLogger,
61
- )
62
- from edsl.jobs.loggers.HTMLTableJobLogger import HTMLTableJobLogger
31
+ @property
32
+ def remote_job_creation_data(self):
33
+ return self._remote_job_creation_data
63
34
 
64
- if is_notebook():
65
- return HTMLTableJobLogger(verbose=self.verbose)
66
- return StdOutJobLogger(verbose=self.verbose)
35
+ @property
36
+ def job_uuid(self):
37
+ return self._job_uuid
67
38
 
68
39
  def use_remote_inference(self, disable_remote_inference: bool) -> bool:
69
- import requests
70
-
71
40
  if disable_remote_inference:
72
41
  return False
73
42
  if not disable_remote_inference:
74
43
  try:
75
- from edsl.coop.coop import Coop
44
+ from edsl import Coop
76
45
 
77
46
  user_edsl_settings = Coop().edsl_settings
78
47
  return user_edsl_settings.get("remote_inference", False)
@@ -87,19 +56,16 @@ class JobsRemoteInferenceHandler:
87
56
  self,
88
57
  iterations: int = 1,
89
58
  remote_inference_description: Optional[str] = None,
90
- remote_inference_results_visibility: Optional[VisibilityType] = "unlisted",
91
- ) -> RemoteJobInfo:
92
-
59
+ remote_inference_results_visibility: Optional["VisibilityType"] = "unlisted",
60
+ verbose=False,
61
+ ):
62
+ """ """
93
63
  from edsl.config import CONFIG
94
64
  from edsl.coop.coop import Coop
95
-
96
- logger = self._create_logger()
65
+ from rich import print as rich_print
97
66
 
98
67
  coop = Coop()
99
- logger.update(
100
- "Remote inference activated. Sending job to server...",
101
- status=JobsStatus.QUEUED,
102
- )
68
+ print("Remote inference activated. Sending job to server...")
103
69
  remote_job_creation_data = coop.remote_inference_create(
104
70
  self.jobs,
105
71
  description=remote_inference_description,
@@ -107,172 +73,136 @@ class JobsRemoteInferenceHandler:
107
73
  iterations=iterations,
108
74
  initial_results_visibility=remote_inference_results_visibility,
109
75
  )
110
- logger.update(
111
- "Your survey is running at the Expected Parrot server...",
112
- status=JobsStatus.RUNNING,
113
- )
114
76
  job_uuid = remote_job_creation_data.get("uuid")
115
- logger.update(
116
- message=f"Job sent to server. (Job uuid={job_uuid}).",
117
- status=JobsStatus.RUNNING,
118
- )
119
- logger.add_info("job_uuid", job_uuid)
77
+ print(f"Job sent to server. (Job uuid={job_uuid}).")
120
78
 
121
- logger.update(
122
- f"Job details are available at your Coop account {self.remote_inference_url}",
123
- status=JobsStatus.RUNNING,
124
- )
125
- progress_bar_url = (
126
- f"{self.expected_parrot_url}/home/remote-job-progress/{job_uuid}"
127
- )
128
- logger.add_info("progress_bar_url", progress_bar_url)
129
- logger.update(
130
- f"View job progress here: {progress_bar_url}", status=JobsStatus.RUNNING
131
- )
79
+ expected_parrot_url = CONFIG.get("EXPECTED_PARROT_URL")
80
+ progress_bar_url = f"{expected_parrot_url}/home/remote-job-progress/{job_uuid}"
132
81
 
133
- return RemoteJobInfo(
134
- creation_data=remote_job_creation_data,
135
- job_uuid=job_uuid,
136
- logger=logger,
82
+ rich_print(
83
+ f"View job progress here: [#38bdf8][link={progress_bar_url}]{progress_bar_url}[/link][/#38bdf8]"
137
84
  )
138
85
 
86
+ self._remote_job_creation_data = remote_job_creation_data
87
+ self._job_uuid = job_uuid
88
+ # return remote_job_creation_data
89
+
139
90
  @staticmethod
140
- def check_status(
141
- job_uuid: JobUUID,
142
- ) -> RemoteInferenceResponse:
91
+ def check_status(job_uuid):
143
92
  from edsl.coop.coop import Coop
144
93
 
145
94
  coop = Coop()
146
95
  return coop.remote_inference_get(job_uuid)
147
96
 
148
- def _construct_remote_job_fetcher(
149
- self, testing_simulated_response: Optional[Any] = None
150
- ) -> Callable:
151
- if testing_simulated_response is not None:
152
- return lambda job_uuid: testing_simulated_response
153
- else:
154
- from edsl.coop.coop import Coop
155
-
156
- coop = Coop()
157
- return coop.remote_inference_get
158
-
159
- def _construct_object_fetcher(
160
- self, testing_simulated_response: Optional[Any] = None
161
- ) -> Callable:
162
- "Constructs a function to fetch the results object from Coop."
163
- if testing_simulated_response is not None:
164
- return lambda results_uuid, expected_object_type: Results.example()
165
- else:
166
- from edsl.coop.coop import Coop
167
-
168
- coop = Coop()
169
- return coop.get
170
-
171
- def _handle_cancelled_job(self, job_info: RemoteJobInfo) -> None:
172
- "Handles a cancelled job by logging the cancellation and updating the job status."
173
-
174
- job_info.logger.update(
175
- message="Job cancelled by the user.", status=JobsStatus.CANCELLED
176
- )
177
- job_info.logger.update(
178
- f"See {self.expected_parrot_url}/home/remote-inference for more details.",
179
- status=JobsStatus.CANCELLED,
180
- )
181
-
182
- def _handle_failed_job(
183
- self, job_info: RemoteJobInfo, remote_job_data: RemoteInferenceResponse
184
- ) -> None:
185
- "Handles a failed job by logging the error and updating the job status."
186
- latest_error_report_url = remote_job_data.get("latest_error_report_url")
187
- if latest_error_report_url:
188
- job_info.logger.add_info("error_report_url", latest_error_report_url)
189
-
190
- job_info.logger.update("Job failed.", status=JobsStatus.FAILED)
191
- job_info.logger.update(
192
- f"See {self.expected_parrot_url}/home/remote-inference for more details.",
193
- status=JobsStatus.FAILED,
194
- )
195
- job_info.logger.update(
196
- f"Need support? Visit Discord: {RemoteJobConstants.DISCORD_URL}",
197
- status=JobsStatus.FAILED,
97
+ def poll_remote_inference_job(self):
98
+ return self._poll_remote_inference_job(
99
+ self.remote_job_creation_data, verbose=self.verbose
198
100
  )
199
101
 
200
- def _sleep_for_a_bit(self, job_info: RemoteJobInfo, status: str) -> None:
102
+ def _poll_remote_inference_job(
103
+ self,
104
+ remote_job_creation_data: dict,
105
+ verbose=False,
106
+ poll_interval: Optional[float] = None,
107
+ testing_simulated_response: Optional[dict] = None,
108
+ ) -> Union[Results, None]:
201
109
  import time
202
110
  from datetime import datetime
111
+ from edsl.config import CONFIG
112
+ from edsl.coop.coop import Coop
203
113
 
204
- time_checked = datetime.now().strftime("%Y-%m-%d %I:%M:%S %p")
205
- job_info.logger.update(
206
- f"Job status: {status} - last update: {time_checked}",
207
- status=JobsStatus.RUNNING,
208
- )
209
- time.sleep(self.poll_interval)
114
+ if poll_interval is None:
115
+ poll_interval = self.poll_interval
210
116
 
211
- def _fetch_results_and_log(
212
- self,
213
- job_info: RemoteJobInfo,
214
- results_uuid: str,
215
- remote_job_data: RemoteInferenceResponse,
216
- object_fetcher: Callable,
217
- ) -> "Results":
218
- "Fetches the results object and logs the results URL."
219
- job_info.logger.add_info("results_uuid", results_uuid)
220
- results = object_fetcher(results_uuid, expected_object_type="results")
221
- results_url = remote_job_data.get("results_url")
222
- job_info.logger.update(
223
- f"Job completed and Results stored on Coop: {results_url}",
224
- status=JobsStatus.COMPLETED,
225
- )
226
- results.job_uuid = job_info.job_uuid
227
- results.results_uuid = results_uuid
228
- return results
117
+ expected_parrot_url = CONFIG.get("EXPECTED_PARROT_URL")
229
118
 
230
- def poll_remote_inference_job(
231
- self,
232
- job_info: RemoteJobInfo,
233
- testing_simulated_response=None,
234
- ) -> Union[None, "Results"]:
235
- """Polls a remote inference job for completion and returns the results."""
119
+ job_uuid = remote_job_creation_data.get("uuid")
120
+ coop = Coop()
236
121
 
237
- remote_job_data_fetcher = self._construct_remote_job_fetcher(
238
- testing_simulated_response
239
- )
240
- object_fetcher = self._construct_object_fetcher(testing_simulated_response)
122
+ if testing_simulated_response is not None:
123
+ remote_job_data_fetcher = lambda job_uuid: testing_simulated_response
124
+ object_fetcher = (
125
+ lambda results_uuid, expected_object_type: Results.example()
126
+ )
127
+ else:
128
+ remote_job_data_fetcher = coop.remote_inference_get
129
+ object_fetcher = coop.get
241
130
 
242
131
  job_in_queue = True
243
132
  while job_in_queue:
244
- remote_job_data = remote_job_data_fetcher(job_info.job_uuid)
133
+ remote_job_data = remote_job_data_fetcher(job_uuid)
245
134
  status = remote_job_data.get("status")
246
-
247
135
  if status == "cancelled":
248
- self._handle_cancelled_job(job_info)
136
+ print("\r" + " " * 80 + "\r", end="")
137
+ print("Job cancelled by the user.")
138
+ print(
139
+ f"See {expected_parrot_url}/home/remote-inference for more details."
140
+ )
249
141
  return None
250
-
251
- elif status == "failed" or status == "completed":
252
- if status == "failed":
253
- self._handle_failed_job(job_info, remote_job_data)
254
-
255
- results_uuid = remote_job_data.get("results_uuid")
256
- if results_uuid:
257
- results = self._fetch_results_and_log(
258
- job_info=job_info,
259
- results_uuid=results_uuid,
260
- remote_job_data=remote_job_data,
261
- object_fetcher=object_fetcher,
142
+ elif status == "failed":
143
+ print("\r" + " " * 80 + "\r", end="")
144
+ # write to stderr
145
+ latest_error_report_url = remote_job_data.get("latest_error_report_url")
146
+ if latest_error_report_url:
147
+ print("Job failed.")
148
+ print(
149
+ f"Your job generated exceptions. Details on these exceptions can be found in the following report: {latest_error_report_url}"
150
+ )
151
+ print(
152
+ f"Need support? Post a message at the Expected Parrot Discord channel (https://discord.com/invite/mxAYkjfy9m) or send an email to info@expectedparrot.com."
262
153
  )
263
- return results
264
154
  else:
265
- return None
266
-
155
+ print("Job failed.")
156
+ print(
157
+ f"See {expected_parrot_url}/home/remote-inference for more details."
158
+ )
159
+ return None
160
+ elif status == "completed":
161
+ results_uuid = remote_job_data.get("results_uuid")
162
+ results_url = remote_job_data.get("results_url")
163
+ results = object_fetcher(results_uuid, expected_object_type="results")
164
+ print("\r" + " " * 80 + "\r", end="")
165
+ print(f"Job completed and Results stored on Coop: {results_url}.")
166
+ return results
267
167
  else:
268
- self._sleep_for_a_bit(job_info, status)
168
+ duration = poll_interval
169
+ time_checked = datetime.now().strftime("%Y-%m-%d %I:%M:%S %p")
170
+ frames = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
171
+ start_time = time.time()
172
+ i = 0
173
+ while time.time() - start_time < duration:
174
+ print(
175
+ f"\r{frames[i % len(frames)]} Job status: {status} - last update: {time_checked}",
176
+ end="",
177
+ flush=True,
178
+ )
179
+ time.sleep(0.1)
180
+ i += 1
181
+
182
+ def use_remote_inference(self, disable_remote_inference: bool) -> bool:
183
+ if disable_remote_inference:
184
+ return False
185
+ if not disable_remote_inference:
186
+ try:
187
+ from edsl import Coop
188
+
189
+ user_edsl_settings = Coop().edsl_settings
190
+ return user_edsl_settings.get("remote_inference", False)
191
+ except requests.ConnectionError:
192
+ pass
193
+ except CoopServerResponseError as e:
194
+ pass
195
+
196
+ return False
269
197
 
270
198
  async def create_and_poll_remote_job(
271
199
  self,
272
200
  iterations: int = 1,
273
201
  remote_inference_description: Optional[str] = None,
274
- remote_inference_results_visibility: Optional[VisibilityType] = "unlisted",
275
- ) -> Union["Results", None]:
202
+ remote_inference_results_visibility: Optional[
203
+ Literal["private", "public", "unlisted"]
204
+ ] = "unlisted",
205
+ ) -> Union[Results, None]:
276
206
  """
277
207
  Creates and polls a remote inference job asynchronously.
278
208
  Reuses existing synchronous methods but runs them in an async context.
@@ -287,7 +217,7 @@ class JobsRemoteInferenceHandler:
287
217
 
288
218
  # Create job using existing method
289
219
  loop = asyncio.get_event_loop()
290
- job_info = await loop.run_in_executor(
220
+ remote_job_creation_data = await loop.run_in_executor(
291
221
  None,
292
222
  partial(
293
223
  self.create_remote_inference_job,
@@ -296,12 +226,10 @@ class JobsRemoteInferenceHandler:
296
226
  remote_inference_results_visibility=remote_inference_results_visibility,
297
227
  ),
298
228
  )
299
- if job_info is None:
300
- raise ValueError("Remote job creation failed.")
301
229
 
230
+ # Poll using existing method but with async sleep
302
231
  return await loop.run_in_executor(
303
- None,
304
- partial(self.poll_remote_inference_job, job_info),
232
+ None, partial(self.poll_remote_inference_job, remote_job_creation_data)
305
233
  )
306
234
 
307
235
 
@@ -1,15 +1,8 @@
1
- from typing import Optional
2
1
  from collections import UserDict
3
2
  from edsl.jobs.buckets.TokenBucket import TokenBucket
4
3
  from edsl.jobs.buckets.ModelBuckets import ModelBuckets
5
4
 
6
- # from functools import wraps
7
- from threading import RLock
8
5
 
9
- from edsl.jobs.decorators import synchronized_class
10
-
11
-
12
- @synchronized_class
13
6
  class BucketCollection(UserDict):
14
7
  """A Jobs object will have a whole collection of model buckets, as multiple models could be used.
15
8
 
@@ -17,43 +10,11 @@ class BucketCollection(UserDict):
17
10
  Models themselves are hashable, so this works.
18
11
  """
19
12
 
20
- def __init__(self, infinity_buckets: bool = False):
21
- """Create a new BucketCollection.
22
- An infinity bucket is a bucket that never runs out of tokens or requests.
23
- """
13
+ def __init__(self, infinity_buckets=False):
24
14
  super().__init__()
25
15
  self.infinity_buckets = infinity_buckets
26
16
  self.models_to_services = {}
27
17
  self.services_to_buckets = {}
28
- self._lock = RLock()
29
-
30
- from edsl.config import CONFIG
31
- import os
32
-
33
- url = os.environ.get("EDSL_REMOTE_TOKEN_BUCKET_URL", None)
34
-
35
- if url == "None" or url is None:
36
- self.remote_url = None
37
- # print(f"Using remote token bucket URL: {url}")
38
- else:
39
- self.remote_url = url
40
-
41
- @classmethod
42
- def from_models(
43
- cls, models_list: list, infinity_buckets: bool = False
44
- ) -> "BucketCollection":
45
- """Create a BucketCollection from a list of models."""
46
- bucket_collection = cls(infinity_buckets=infinity_buckets)
47
- for model in models_list:
48
- bucket_collection.add_model(model)
49
- return bucket_collection
50
-
51
- def get_tokens(
52
- self, model: "LanguageModel", bucket_type: str, num_tokens: int
53
- ) -> int:
54
- """Get the number of tokens remaining in the bucket."""
55
- relevant_bucket = getattr(self[model], bucket_type)
56
- return relevant_bucket.get_tokens(num_tokens)
57
18
 
58
19
  def __repr__(self):
59
20
  return f"BucketCollection({self.data})"
@@ -65,8 +26,8 @@ class BucketCollection(UserDict):
65
26
 
66
27
  # compute the TPS and RPS from the model
67
28
  if not self.infinity_buckets:
68
- TPS = model.tpm / 60.0
69
- RPS = model.rpm / 60.0
29
+ TPS = model.TPM / 60.0
30
+ RPS = model.RPM / 60.0
70
31
  else:
71
32
  TPS = float("inf")
72
33
  RPS = float("inf")
@@ -79,14 +40,12 @@ class BucketCollection(UserDict):
79
40
  bucket_type="requests",
80
41
  capacity=RPS,
81
42
  refill_rate=RPS,
82
- remote_url=self.remote_url,
83
43
  )
84
44
  tokens_bucket = TokenBucket(
85
45
  bucket_name=service,
86
46
  bucket_type="tokens",
87
47
  capacity=TPS,
88
48
  refill_rate=TPS,
89
- remote_url=self.remote_url,
90
49
  )
91
50
  self.services_to_buckets[service] = ModelBuckets(
92
51
  requests_bucket, tokens_bucket