edsl 0.1.39__py3-none-any.whl → 0.1.39.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (212) hide show
  1. edsl/Base.py +116 -197
  2. edsl/__init__.py +7 -15
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +147 -351
  5. edsl/agents/AgentList.py +73 -211
  6. edsl/agents/Invigilator.py +50 -101
  7. edsl/agents/InvigilatorBase.py +70 -62
  8. edsl/agents/PromptConstructor.py +225 -143
  9. edsl/agents/__init__.py +1 -0
  10. edsl/agents/prompt_helpers.py +3 -3
  11. edsl/auto/AutoStudy.py +5 -18
  12. edsl/auto/StageBase.py +40 -53
  13. edsl/auto/StageQuestions.py +1 -2
  14. edsl/auto/utilities.py +6 -0
  15. edsl/config.py +2 -22
  16. edsl/conversation/car_buying.py +1 -2
  17. edsl/coop/PriceFetcher.py +1 -1
  18. edsl/coop/coop.py +47 -125
  19. edsl/coop/utils.py +14 -14
  20. edsl/data/Cache.py +27 -45
  21. edsl/data/CacheEntry.py +15 -12
  22. edsl/data/CacheHandler.py +12 -31
  23. edsl/data/RemoteCacheSync.py +46 -154
  24. edsl/data/__init__.py +3 -4
  25. edsl/data_transfer_models.py +1 -2
  26. edsl/enums.py +0 -27
  27. edsl/exceptions/__init__.py +50 -50
  28. edsl/exceptions/agents.py +0 -12
  29. edsl/exceptions/questions.py +6 -24
  30. edsl/exceptions/scenarios.py +0 -7
  31. edsl/inference_services/AnthropicService.py +19 -38
  32. edsl/inference_services/AwsBedrock.py +2 -0
  33. edsl/inference_services/AzureAI.py +2 -0
  34. edsl/inference_services/GoogleService.py +12 -7
  35. edsl/inference_services/InferenceServiceABC.py +85 -18
  36. edsl/inference_services/InferenceServicesCollection.py +79 -120
  37. edsl/inference_services/MistralAIService.py +3 -0
  38. edsl/inference_services/OpenAIService.py +35 -47
  39. edsl/inference_services/PerplexityService.py +3 -0
  40. edsl/inference_services/TestService.py +10 -11
  41. edsl/inference_services/TogetherAIService.py +3 -5
  42. edsl/jobs/Answers.py +14 -1
  43. edsl/jobs/Jobs.py +431 -356
  44. edsl/jobs/JobsChecks.py +10 -35
  45. edsl/jobs/JobsPrompts.py +4 -6
  46. edsl/jobs/JobsRemoteInferenceHandler.py +133 -205
  47. edsl/jobs/buckets/BucketCollection.py +3 -44
  48. edsl/jobs/buckets/TokenBucket.py +21 -53
  49. edsl/jobs/interviews/Interview.py +408 -143
  50. edsl/jobs/runners/JobsRunnerAsyncio.py +403 -88
  51. edsl/jobs/runners/JobsRunnerStatus.py +165 -133
  52. edsl/jobs/tasks/QuestionTaskCreator.py +19 -21
  53. edsl/jobs/tasks/TaskHistory.py +18 -38
  54. edsl/jobs/tasks/task_status_enum.py +2 -0
  55. edsl/language_models/KeyLookup.py +30 -0
  56. edsl/language_models/LanguageModel.py +236 -194
  57. edsl/language_models/ModelList.py +19 -28
  58. edsl/language_models/__init__.py +2 -1
  59. edsl/language_models/registry.py +190 -0
  60. edsl/language_models/repair.py +2 -2
  61. edsl/language_models/unused/ReplicateBase.py +83 -0
  62. edsl/language_models/utilities.py +4 -5
  63. edsl/notebooks/Notebook.py +14 -19
  64. edsl/prompts/Prompt.py +39 -29
  65. edsl/questions/{answer_validator_mixin.py → AnswerValidatorMixin.py} +2 -47
  66. edsl/questions/QuestionBase.py +214 -68
  67. edsl/questions/{question_base_gen_mixin.py → QuestionBaseGenMixin.py} +50 -57
  68. edsl/questions/QuestionBasePromptsMixin.py +3 -7
  69. edsl/questions/QuestionBudget.py +1 -1
  70. edsl/questions/QuestionCheckBox.py +3 -3
  71. edsl/questions/QuestionExtract.py +7 -5
  72. edsl/questions/QuestionFreeText.py +3 -2
  73. edsl/questions/QuestionList.py +18 -10
  74. edsl/questions/QuestionMultipleChoice.py +23 -67
  75. edsl/questions/QuestionNumerical.py +4 -2
  76. edsl/questions/QuestionRank.py +17 -7
  77. edsl/questions/{response_validator_abc.py → ResponseValidatorABC.py} +26 -40
  78. edsl/questions/SimpleAskMixin.py +3 -4
  79. edsl/questions/__init__.py +1 -2
  80. edsl/questions/derived/QuestionLinearScale.py +3 -6
  81. edsl/questions/derived/QuestionTopK.py +1 -1
  82. edsl/questions/descriptors.py +3 -17
  83. edsl/questions/question_registry.py +1 -1
  84. edsl/results/CSSParameterizer.py +1 -1
  85. edsl/results/Dataset.py +7 -170
  86. edsl/results/DatasetExportMixin.py +305 -168
  87. edsl/results/DatasetTree.py +8 -28
  88. edsl/results/Result.py +206 -298
  89. edsl/results/Results.py +131 -149
  90. edsl/results/ResultsDBMixin.py +238 -0
  91. edsl/results/ResultsExportMixin.py +0 -2
  92. edsl/results/{results_selector.py → Selector.py} +13 -23
  93. edsl/results/TableDisplay.py +171 -98
  94. edsl/results/__init__.py +1 -1
  95. edsl/scenarios/FileStore.py +239 -150
  96. edsl/scenarios/Scenario.py +193 -90
  97. edsl/scenarios/ScenarioHtmlMixin.py +3 -4
  98. edsl/scenarios/{scenario_join.py → ScenarioJoin.py} +6 -10
  99. edsl/scenarios/ScenarioList.py +244 -415
  100. edsl/scenarios/ScenarioListExportMixin.py +7 -0
  101. edsl/scenarios/ScenarioListPdfMixin.py +37 -15
  102. edsl/scenarios/__init__.py +2 -1
  103. edsl/study/ObjectEntry.py +1 -1
  104. edsl/study/SnapShot.py +1 -1
  105. edsl/study/Study.py +12 -5
  106. edsl/surveys/Rule.py +4 -5
  107. edsl/surveys/RuleCollection.py +27 -25
  108. edsl/surveys/Survey.py +791 -270
  109. edsl/surveys/SurveyCSS.py +8 -20
  110. edsl/surveys/{SurveyFlowVisualization.py → SurveyFlowVisualizationMixin.py} +9 -11
  111. edsl/surveys/__init__.py +2 -4
  112. edsl/surveys/descriptors.py +2 -6
  113. edsl/surveys/instructions/ChangeInstruction.py +2 -1
  114. edsl/surveys/instructions/Instruction.py +13 -4
  115. edsl/surveys/instructions/InstructionCollection.py +6 -11
  116. edsl/templates/error_reporting/interview_details.html +1 -1
  117. edsl/templates/error_reporting/report.html +1 -1
  118. edsl/tools/plotting.py +1 -1
  119. edsl/utilities/utilities.py +23 -35
  120. {edsl-0.1.39.dist-info → edsl-0.1.39.dev1.dist-info}/METADATA +10 -12
  121. edsl-0.1.39.dev1.dist-info/RECORD +277 -0
  122. {edsl-0.1.39.dist-info → edsl-0.1.39.dev1.dist-info}/WHEEL +1 -1
  123. edsl/agents/QuestionInstructionPromptBuilder.py +0 -128
  124. edsl/agents/QuestionTemplateReplacementsBuilder.py +0 -137
  125. edsl/agents/question_option_processor.py +0 -172
  126. edsl/coop/CoopFunctionsMixin.py +0 -15
  127. edsl/coop/ExpectedParrotKeyHandler.py +0 -125
  128. edsl/exceptions/inference_services.py +0 -5
  129. edsl/inference_services/AvailableModelCacheHandler.py +0 -184
  130. edsl/inference_services/AvailableModelFetcher.py +0 -215
  131. edsl/inference_services/ServiceAvailability.py +0 -135
  132. edsl/inference_services/data_structures.py +0 -134
  133. edsl/jobs/AnswerQuestionFunctionConstructor.py +0 -223
  134. edsl/jobs/FetchInvigilator.py +0 -47
  135. edsl/jobs/InterviewTaskManager.py +0 -98
  136. edsl/jobs/InterviewsConstructor.py +0 -50
  137. edsl/jobs/JobsComponentConstructor.py +0 -189
  138. edsl/jobs/JobsRemoteInferenceLogger.py +0 -239
  139. edsl/jobs/RequestTokenEstimator.py +0 -30
  140. edsl/jobs/async_interview_runner.py +0 -138
  141. edsl/jobs/buckets/TokenBucketAPI.py +0 -211
  142. edsl/jobs/buckets/TokenBucketClient.py +0 -191
  143. edsl/jobs/check_survey_scenario_compatibility.py +0 -85
  144. edsl/jobs/data_structures.py +0 -120
  145. edsl/jobs/decorators.py +0 -35
  146. edsl/jobs/jobs_status_enums.py +0 -9
  147. edsl/jobs/loggers/HTMLTableJobLogger.py +0 -304
  148. edsl/jobs/results_exceptions_handler.py +0 -98
  149. edsl/language_models/ComputeCost.py +0 -63
  150. edsl/language_models/PriceManager.py +0 -127
  151. edsl/language_models/RawResponseHandler.py +0 -106
  152. edsl/language_models/ServiceDataSources.py +0 -0
  153. edsl/language_models/key_management/KeyLookup.py +0 -63
  154. edsl/language_models/key_management/KeyLookupBuilder.py +0 -273
  155. edsl/language_models/key_management/KeyLookupCollection.py +0 -38
  156. edsl/language_models/key_management/__init__.py +0 -0
  157. edsl/language_models/key_management/models.py +0 -131
  158. edsl/language_models/model.py +0 -256
  159. edsl/notebooks/NotebookToLaTeX.py +0 -142
  160. edsl/questions/ExceptionExplainer.py +0 -77
  161. edsl/questions/HTMLQuestion.py +0 -103
  162. edsl/questions/QuestionMatrix.py +0 -265
  163. edsl/questions/data_structures.py +0 -20
  164. edsl/questions/loop_processor.py +0 -149
  165. edsl/questions/response_validator_factory.py +0 -34
  166. edsl/questions/templates/matrix/__init__.py +0 -1
  167. edsl/questions/templates/matrix/answering_instructions.jinja +0 -5
  168. edsl/questions/templates/matrix/question_presentation.jinja +0 -20
  169. edsl/results/MarkdownToDocx.py +0 -122
  170. edsl/results/MarkdownToPDF.py +0 -111
  171. edsl/results/TextEditor.py +0 -50
  172. edsl/results/file_exports.py +0 -252
  173. edsl/results/smart_objects.py +0 -96
  174. edsl/results/table_data_class.py +0 -12
  175. edsl/results/table_renderers.py +0 -118
  176. edsl/scenarios/ConstructDownloadLink.py +0 -109
  177. edsl/scenarios/DocumentChunker.py +0 -102
  178. edsl/scenarios/DocxScenario.py +0 -16
  179. edsl/scenarios/PdfExtractor.py +0 -40
  180. edsl/scenarios/directory_scanner.py +0 -96
  181. edsl/scenarios/file_methods.py +0 -85
  182. edsl/scenarios/handlers/__init__.py +0 -13
  183. edsl/scenarios/handlers/csv.py +0 -49
  184. edsl/scenarios/handlers/docx.py +0 -76
  185. edsl/scenarios/handlers/html.py +0 -37
  186. edsl/scenarios/handlers/json.py +0 -111
  187. edsl/scenarios/handlers/latex.py +0 -5
  188. edsl/scenarios/handlers/md.py +0 -51
  189. edsl/scenarios/handlers/pdf.py +0 -68
  190. edsl/scenarios/handlers/png.py +0 -39
  191. edsl/scenarios/handlers/pptx.py +0 -105
  192. edsl/scenarios/handlers/py.py +0 -294
  193. edsl/scenarios/handlers/sql.py +0 -313
  194. edsl/scenarios/handlers/sqlite.py +0 -149
  195. edsl/scenarios/handlers/txt.py +0 -33
  196. edsl/scenarios/scenario_selector.py +0 -156
  197. edsl/surveys/ConstructDAG.py +0 -92
  198. edsl/surveys/EditSurvey.py +0 -221
  199. edsl/surveys/InstructionHandler.py +0 -100
  200. edsl/surveys/MemoryManagement.py +0 -72
  201. edsl/surveys/RuleManager.py +0 -172
  202. edsl/surveys/Simulator.py +0 -75
  203. edsl/surveys/SurveyToApp.py +0 -141
  204. edsl/utilities/PrettyList.py +0 -56
  205. edsl/utilities/is_notebook.py +0 -18
  206. edsl/utilities/is_valid_variable_name.py +0 -11
  207. edsl/utilities/remove_edsl_version.py +0 -24
  208. edsl-0.1.39.dist-info/RECORD +0 -358
  209. /edsl/questions/{register_questions_meta.py → RegisterQuestionsMeta.py} +0 -0
  210. /edsl/results/{results_fetch_mixin.py → ResultsFetchMixin.py} +0 -0
  211. /edsl/results/{results_tools_mixin.py → ResultsToolsMixin.py} +0 -0
  212. {edsl-0.1.39.dist-info → edsl-0.1.39.dev1.dist-info}/LICENSE +0 -0
@@ -1,138 +0,0 @@
1
- from collections.abc import AsyncGenerator
2
- from typing import List, TypeVar, Generator, Tuple, TYPE_CHECKING
3
- from dataclasses import dataclass
4
- import asyncio
5
- from contextlib import asynccontextmanager
6
- from edsl.data_transfer_models import EDSLResultObjectInput
7
-
8
- from edsl.results.Result import Result
9
- from edsl.jobs.interviews.Interview import Interview
10
-
11
- if TYPE_CHECKING:
12
- from edsl.jobs.Jobs import Jobs
13
-
14
-
15
- @dataclass
16
- class InterviewResult:
17
- result: Result
18
- interview: Interview
19
- order: int
20
-
21
-
22
- from edsl.jobs.data_structures import RunConfig
23
-
24
-
25
- class AsyncInterviewRunner:
26
- MAX_CONCURRENT = 5
27
-
28
- def __init__(self, jobs: "Jobs", run_config: RunConfig):
29
- self.jobs = jobs
30
- self.run_config = run_config
31
- self._initialized = asyncio.Event()
32
-
33
- def _expand_interviews(self) -> Generator["Interview", None, None]:
34
- """Populates self.total_interviews with n copies of each interview.
35
-
36
- It also has to set the cache for each interview.
37
-
38
- :param n: how many times to run each interview.
39
- """
40
- for interview in self.jobs.generate_interviews():
41
- for iteration in range(self.run_config.parameters.n):
42
- if iteration > 0:
43
- yield interview.duplicate(
44
- iteration=iteration, cache=self.run_config.environment.cache
45
- )
46
- else:
47
- interview.cache = self.run_config.environment.cache
48
- yield interview
49
-
50
- async def _conduct_interview(
51
- self, interview: "Interview"
52
- ) -> Tuple["Result", "Interview"]:
53
- """Conducts an interview and returns the result object, along with the associated interview.
54
-
55
- We return the interview because it is not populated with exceptions, if any.
56
-
57
- :param interview: the interview to conduct
58
- :return: the result of the interview
59
-
60
- 'extracted_answers' is a dictionary of the answers to the questions in the interview.
61
- This is not the same as the generated_tokens---it can include substantial cleaning and processing / validation.
62
- """
63
- # the model buckets are used to track usage rates
64
- # model_buckets = self.bucket_collection[interview.model]
65
- # model_buckets = self.run_config.environment.bucket_collection[interview.model]
66
-
67
- # get the results of the interview e.g., {'how_are_you':"Good" 'how_are_you_generated_tokens': "Good"}
68
- extracted_answers: dict[str, str]
69
- model_response_objects: List[EDSLResultObjectInput]
70
-
71
- extracted_answers, model_response_objects = (
72
- await interview.async_conduct_interview(self.run_config)
73
- )
74
- result = Result.from_interview(
75
- interview=interview,
76
- extracted_answers=extracted_answers,
77
- model_response_objects=model_response_objects,
78
- )
79
- return result, interview
80
-
81
- async def run(
82
- self,
83
- ) -> AsyncGenerator[tuple[Result, Interview], None]:
84
- """Creates and processes tasks asynchronously, yielding results as they complete.
85
-
86
- Uses TaskGroup for structured concurrency and automated cleanup.
87
- Results are yielded as they become available while maintaining controlled concurrency.
88
- """
89
- interviews = list(self._expand_interviews())
90
- self._initialized.set()
91
-
92
- async def _process_single_interview(
93
- interview: Interview, idx: int
94
- ) -> InterviewResult:
95
- try:
96
- result, interview = await self._conduct_interview(interview)
97
- self.run_config.environment.jobs_runner_status.add_completed_interview(
98
- result
99
- )
100
- result.order = idx
101
- return InterviewResult(result, interview, idx)
102
- except Exception as e:
103
- # breakpoint()
104
- if self.run_config.parameters.stop_on_exception:
105
- raise
106
- # logger.error(f"Task failed with error: {e}")
107
- return None
108
-
109
- # Process interviews in chunks
110
- for i in range(0, len(interviews), self.MAX_CONCURRENT):
111
- chunk = interviews[i : i + self.MAX_CONCURRENT]
112
- tasks = [
113
- asyncio.create_task(_process_single_interview(interview, idx))
114
- for idx, interview in enumerate(chunk, start=i)
115
- ]
116
-
117
- try:
118
- # Wait for all tasks in the chunk to complete
119
- results = await asyncio.gather(
120
- *tasks,
121
- return_exceptions=not self.run_config.parameters.stop_on_exception
122
- )
123
-
124
- # Process successful results
125
- for result in (r for r in results if r is not None):
126
- yield result.result, result.interview
127
-
128
- except Exception as e:
129
- if self.run_config.parameters.stop_on_exception:
130
- raise
131
- # logger.error(f"Chunk processing failed with error: {e}")
132
- continue
133
-
134
- finally:
135
- # Clean up any remaining tasks
136
- for task in tasks:
137
- if not task.done():
138
- task.cancel()
@@ -1,211 +0,0 @@
1
- from fastapi import FastAPI, HTTPException
2
- from pydantic import BaseModel
3
- from typing import Union, Dict
4
- from typing import Union, List, Any, Optional
5
- from threading import RLock
6
- from edsl.jobs.buckets.TokenBucket import TokenBucket # Original implementation
7
-
8
-
9
- def safe_float_for_json(value: float) -> Union[float, str]:
10
- """Convert float('inf') to 'infinity' for JSON serialization.
11
-
12
- Args:
13
- value: The float value to convert
14
-
15
- Returns:
16
- Either the original float or the string 'infinity' if the value is infinite
17
- """
18
- if value == float("inf"):
19
- return "infinity"
20
- return value
21
-
22
-
23
- app = FastAPI()
24
-
25
- # In-memory storage for TokenBucket instances
26
- buckets: Dict[str, TokenBucket] = {}
27
-
28
-
29
- class TokenBucketCreate(BaseModel):
30
- bucket_name: str
31
- bucket_type: str
32
- capacity: Union[int, float]
33
- refill_rate: Union[int, float]
34
-
35
-
36
- @app.get("/buckets")
37
- async def list_buckets(
38
- bucket_type: Optional[str] = None,
39
- bucket_name: Optional[str] = None,
40
- include_logs: bool = False,
41
- ):
42
- """List all buckets and their current status.
43
-
44
- Args:
45
- bucket_type: Optional filter by bucket type
46
- bucket_name: Optional filter by bucket name
47
- include_logs: Whether to include the full logs in the response
48
- """
49
- result = {}
50
-
51
- for bucket_id, bucket in buckets.items():
52
- # Apply filters if specified
53
- if bucket_type and bucket.bucket_type != bucket_type:
54
- continue
55
- if bucket_name and bucket.bucket_name != bucket_name:
56
- continue
57
-
58
- # Get basic bucket info
59
- bucket_info = {
60
- "bucket_name": bucket.bucket_name,
61
- "bucket_type": bucket.bucket_type,
62
- "tokens": bucket.tokens,
63
- "capacity": bucket.capacity,
64
- "refill_rate": bucket.refill_rate,
65
- "turbo_mode": bucket.turbo_mode,
66
- "num_requests": bucket.num_requests,
67
- "num_released": bucket.num_released,
68
- "tokens_returned": bucket.tokens_returned,
69
- }
70
- for k, v in bucket_info.items():
71
- if isinstance(v, float):
72
- bucket_info[k] = safe_float_for_json(v)
73
-
74
- # Only include logs if requested
75
- if include_logs:
76
- bucket_info["log"] = bucket.log
77
-
78
- result[bucket_id] = bucket_info
79
-
80
- return result
81
-
82
-
83
- @app.post("/bucket/{bucket_id}/add_tokens")
84
- async def add_tokens(bucket_id: str, amount: float):
85
- """Add tokens to an existing bucket."""
86
- if bucket_id not in buckets:
87
- raise HTTPException(status_code=404, detail="Bucket not found")
88
-
89
- if not isinstance(amount, (int, float)) or amount != amount: # Check for NaN
90
- raise HTTPException(status_code=400, detail="Invalid amount specified")
91
-
92
- if amount == float("inf") or amount == float("-inf"):
93
- raise HTTPException(status_code=400, detail="Amount cannot be infinite")
94
-
95
- bucket = buckets[bucket_id]
96
- bucket.add_tokens(amount)
97
-
98
- # Ensure we return a JSON-serializable float
99
- current_tokens = float(bucket.tokens)
100
- if not -1e308 <= current_tokens <= 1e308: # Check if within JSON float bounds
101
- current_tokens = 0.0 # or some other reasonable default
102
-
103
- return {"status": "success", "current_tokens": safe_float_for_json(current_tokens)}
104
-
105
-
106
- # @app.post("/bucket")
107
- # async def create_bucket(bucket: TokenBucketCreate):
108
- # bucket_id = f"{bucket.bucket_name}_{bucket.bucket_type}"
109
- # if bucket_id in buckets:
110
- # raise HTTPException(status_code=400, detail="Bucket already exists")
111
-
112
- # # Create an actual TokenBucket instance
113
- # buckets[bucket_id] = TokenBucket(
114
- # bucket_name=bucket.bucket_name,
115
- # bucket_type=bucket.bucket_type,
116
- # capacity=bucket.capacity,
117
- # refill_rate=bucket.refill_rate,
118
- # )
119
- # return {"status": "created"}
120
-
121
-
122
- @app.post("/bucket")
123
- async def create_bucket(bucket: TokenBucketCreate):
124
- if (
125
- not isinstance(bucket.capacity, (int, float))
126
- or bucket.capacity != bucket.capacity
127
- ): # Check for NaN
128
- raise HTTPException(status_code=400, detail="Invalid capacity value")
129
- if (
130
- not isinstance(bucket.refill_rate, (int, float))
131
- or bucket.refill_rate != bucket.refill_rate
132
- ): # Check for NaN
133
- raise HTTPException(status_code=400, detail="Invalid refill rate value")
134
- if bucket.capacity == float("inf") or bucket.refill_rate == float("inf"):
135
- raise HTTPException(status_code=400, detail="Values cannot be infinite")
136
- bucket_id = f"{bucket.bucket_name}_{bucket.bucket_type}"
137
- if bucket_id in buckets:
138
- # Instead of error, return success with "existing" status
139
- return {
140
- "status": "existing",
141
- "bucket": {
142
- "capacity": safe_float_for_json(buckets[bucket_id].capacity),
143
- "refill_rate": safe_float_for_json(buckets[bucket_id].refill_rate),
144
- },
145
- }
146
-
147
- # Create a new bucket
148
- buckets[bucket_id] = TokenBucket(
149
- bucket_name=bucket.bucket_name,
150
- bucket_type=bucket.bucket_type,
151
- capacity=bucket.capacity,
152
- refill_rate=bucket.refill_rate,
153
- )
154
- return {"status": "created"}
155
-
156
-
157
- @app.post("/bucket/{bucket_id}/get_tokens")
158
- async def get_tokens(bucket_id: str, amount: float, cheat_bucket_capacity: bool = True):
159
- if bucket_id not in buckets:
160
- raise HTTPException(status_code=404, detail="Bucket not found")
161
-
162
- bucket = buckets[bucket_id]
163
- await bucket.get_tokens(amount, cheat_bucket_capacity)
164
- return {"status": "success"}
165
-
166
-
167
- @app.post("/bucket/{bucket_id}/turbo_mode/{state}")
168
- async def set_turbo_mode(bucket_id: str, state: bool):
169
- if bucket_id not in buckets:
170
- raise HTTPException(status_code=404, detail="Bucket not found")
171
-
172
- bucket = buckets[bucket_id]
173
- if state:
174
- bucket.turbo_mode_on()
175
- else:
176
- bucket.turbo_mode_off()
177
- return {"status": "success"}
178
-
179
-
180
- @app.get("/bucket/{bucket_id}/status")
181
- async def get_bucket_status(bucket_id: str):
182
- if bucket_id not in buckets:
183
- raise HTTPException(status_code=404, detail="Bucket not found")
184
-
185
- bucket = buckets[bucket_id]
186
- status = {
187
- "tokens": bucket.tokens,
188
- "capacity": bucket.capacity,
189
- "refill_rate": bucket.refill_rate,
190
- "turbo_mode": bucket.turbo_mode,
191
- "num_requests": bucket.num_requests,
192
- "num_released": bucket.num_released,
193
- "tokens_returned": bucket.tokens_returned,
194
- "log": bucket.log,
195
- }
196
- for k, v in status.items():
197
- if isinstance(v, float):
198
- status[k] = safe_float_for_json(v)
199
-
200
- for index, entry in enumerate(status["log"]):
201
- ts, value = entry
202
- status["log"][index] = (ts, safe_float_for_json(value))
203
-
204
- # print(status)
205
- return status
206
-
207
-
208
- if __name__ == "__main__":
209
- import uvicorn
210
-
211
- uvicorn.run(app, host="0.0.0.0", port=8001)
@@ -1,191 +0,0 @@
1
- from typing import Union, Optional
2
- import asyncio
3
- import time
4
- import aiohttp
5
-
6
-
7
- class TokenBucketClient:
8
- """REST API client version of TokenBucket that maintains the same interface
9
- by delegating to a server running the original TokenBucket implementation."""
10
-
11
- def __init__(
12
- self,
13
- *,
14
- bucket_name: str,
15
- bucket_type: str,
16
- capacity: Union[int, float],
17
- refill_rate: Union[int, float],
18
- api_base_url: str = "http://localhost:8000",
19
- ):
20
- self.bucket_name = bucket_name
21
- self.bucket_type = bucket_type
22
- self.capacity = capacity
23
- self.refill_rate = refill_rate
24
- self.api_base_url = api_base_url
25
- self.bucket_id = f"{bucket_name}_{bucket_type}"
26
-
27
- # Initialize the bucket on the server
28
- asyncio.run(self._create_bucket())
29
-
30
- # Cache some values locally
31
- self.creation_time = time.monotonic()
32
- self.turbo_mode = False
33
-
34
- async def _create_bucket(self):
35
- async with aiohttp.ClientSession() as session:
36
- payload = {
37
- "bucket_name": self.bucket_name,
38
- "bucket_type": self.bucket_type,
39
- "capacity": self.capacity,
40
- "refill_rate": self.refill_rate,
41
- }
42
- async with session.post(
43
- f"{self.api_base_url}/bucket",
44
- json=payload,
45
- ) as response:
46
- if response.status != 200:
47
- raise ValueError(f"Unexpected error: {await response.text()}")
48
-
49
- result = await response.json()
50
- if result["status"] == "existing":
51
- # Update our local values to match the existing bucket
52
- self.capacity = float(result["bucket"]["capacity"])
53
- self.refill_rate = float(result["bucket"]["refill_rate"])
54
-
55
- def turbo_mode_on(self):
56
- """Set the refill rate to infinity."""
57
- asyncio.run(self._set_turbo_mode(True))
58
- self.turbo_mode = True
59
-
60
- def turbo_mode_off(self):
61
- """Restore the refill rate to its original value."""
62
- asyncio.run(self._set_turbo_mode(False))
63
- self.turbo_mode = False
64
-
65
- async def add_tokens(self, amount: Union[int, float]):
66
- """Add tokens to the bucket."""
67
- async with aiohttp.ClientSession() as session:
68
- async with session.post(
69
- f"{self.api_base_url}/bucket/{self.bucket_id}/add_tokens",
70
- params={"amount": amount},
71
- ) as response:
72
- if response.status != 200:
73
- raise ValueError(f"Failed to add tokens: {await response.text()}")
74
-
75
- async def _set_turbo_mode(self, state: bool):
76
- async with aiohttp.ClientSession() as session:
77
- async with session.post(
78
- f"{self.api_base_url}/bucket/{self.bucket_id}/turbo_mode/{str(state).lower()}"
79
- ) as response:
80
- if response.status != 200:
81
- raise ValueError(
82
- f"Failed to set turbo mode: {await response.text()}"
83
- )
84
-
85
- async def get_tokens(
86
- self, amount: Union[int, float] = 1, cheat_bucket_capacity=True
87
- ) -> None:
88
- async with aiohttp.ClientSession() as session:
89
- async with session.post(
90
- f"{self.api_base_url}/bucket/{self.bucket_id}/get_tokens",
91
- params={
92
- "amount": amount,
93
- "cheat_bucket_capacity": int(cheat_bucket_capacity),
94
- },
95
- ) as response:
96
- if response.status != 200:
97
- raise ValueError(f"Failed to get tokens: {await response.text()}")
98
-
99
- def get_throughput(self, time_window: Optional[float] = None) -> float:
100
- status = asyncio.run(self._get_status())
101
- now = time.monotonic()
102
-
103
- if time_window is None:
104
- start_time = self.creation_time
105
- else:
106
- start_time = now - time_window
107
-
108
- if start_time < self.creation_time:
109
- start_time = self.creation_time
110
-
111
- elapsed_time = now - start_time
112
-
113
- if elapsed_time == 0:
114
- return status["num_released"] / 0.001
115
-
116
- return (status["num_released"] / elapsed_time) * 60
117
-
118
- async def _get_status(self) -> dict:
119
- async with aiohttp.ClientSession() as session:
120
- async with session.get(
121
- f"{self.api_base_url}/bucket/{self.bucket_id}/status"
122
- ) as response:
123
- if response.status != 200:
124
- raise ValueError(
125
- f"Failed to get bucket status: {await response.text()}"
126
- )
127
- return await response.json()
128
-
129
- def __add__(self, other) -> "TokenBucketClient":
130
- """Combine two token buckets."""
131
- return TokenBucketClient(
132
- bucket_name=self.bucket_name,
133
- bucket_type=self.bucket_type,
134
- capacity=min(self.capacity, other.capacity),
135
- refill_rate=min(self.refill_rate, other.refill_rate),
136
- api_base_url=self.api_base_url,
137
- )
138
-
139
- @property
140
- def tokens(self) -> float:
141
- """Get the number of tokens remaining in the bucket."""
142
- status = asyncio.run(self._get_status())
143
- return float(status["tokens"])
144
-
145
- def wait_time(self, requested_tokens: Union[float, int]) -> float:
146
- """Calculate the time to wait for the requested number of tokens."""
147
- # self.refill() # Update the current token count
148
- if self.tokens >= float(requested_tokens):
149
- return 0.0
150
- try:
151
- return (requested_tokens - self.tokens) / self.refill_rate
152
- except Exception as e:
153
- raise ValueError(f"Error calculating wait time: {e}")
154
-
155
- # def wait_time(self, num_tokens: Union[int, float]) -> float:
156
- # return 0 # TODO - Need to implement this on the server side
157
-
158
- def visualize(self):
159
- """Visualize the token bucket over time."""
160
- status = asyncio.run(self._get_status())
161
- times, tokens = zip(*status["log"])
162
- start_time = times[0]
163
- times = [t - start_time for t in times]
164
-
165
- from matplotlib import pyplot as plt
166
-
167
- plt.figure(figsize=(10, 6))
168
- plt.plot(times, tokens, label="Tokens Available")
169
- plt.xlabel("Time (seconds)", fontsize=12)
170
- plt.ylabel("Number of Tokens", fontsize=12)
171
- details = f"{self.bucket_name} ({self.bucket_type}) Bucket Usage Over Time\nCapacity: {self.capacity:.1f}, Refill Rate: {self.refill_rate:.1f}/second"
172
- plt.title(details, fontsize=14)
173
- plt.legend()
174
- plt.grid(True)
175
- plt.tight_layout()
176
- plt.show()
177
-
178
-
179
- if __name__ == "__main__":
180
- import doctest
181
-
182
- doctest.testmod()
183
- # bucket = TokenBucketClient(
184
- # bucket_name="test", bucket_type="test", capacity=100, refill_rate=10
185
- # )
186
- # asyncio.run(bucket.get_tokens(50))
187
- # time.sleep(1) # Wait for 1 second
188
- # asyncio.run(bucket.get_tokens(30))
189
- # throughput = bucket.get_throughput(1)
190
- # print(throughput)
191
- # bucket.visualize()
@@ -1,85 +0,0 @@
1
- import warnings
2
- from typing import TYPE_CHECKING
3
-
4
- if TYPE_CHECKING:
5
- from edsl.surveys.Survey import Survey
6
- from edsl.scenarios.ScenarioList import ScenarioList
7
-
8
-
9
- class CheckSurveyScenarioCompatibility:
10
-
11
- def __init__(self, survey: "Survey", scenarios: "ScenarioList"):
12
- self.survey = survey
13
- self.scenarios = scenarios
14
-
15
- def check(self, strict: bool = False, warn: bool = False) -> None:
16
- """Check if the parameters in the survey and scenarios are consistent.
17
-
18
- >>> from edsl.jobs.Jobs import Jobs
19
- >>> from edsl.questions.QuestionFreeText import QuestionFreeText
20
- >>> from edsl.surveys.Survey import Survey
21
- >>> from edsl.scenarios.Scenario import Scenario
22
- >>> q = QuestionFreeText(question_text = "{{poo}}", question_name = "ugly_question")
23
- >>> j = Jobs(survey = Survey(questions=[q]))
24
- >>> cs = CheckSurveyScenarioCompatibility(j.survey, j.scenarios)
25
- >>> with warnings.catch_warnings(record=True) as w:
26
- ... cs.check(warn = True)
27
- ... assert len(w) == 1
28
- ... assert issubclass(w[-1].category, UserWarning)
29
- ... assert "The following parameters are in the survey but not in the scenarios" in str(w[-1].message)
30
-
31
- >>> q = QuestionFreeText(question_text = "{{poo}}", question_name = "ugly_question")
32
- >>> s = Scenario({'plop': "A", 'poo': "B"})
33
- >>> j = Jobs(survey = Survey(questions=[q])).by(s)
34
- >>> cs = CheckSurveyScenarioCompatibility(j.survey, j.scenarios)
35
- >>> cs.check(strict = True)
36
- Traceback (most recent call last):
37
- ...
38
- ValueError: The following parameters are in the scenarios but not in the survey: {'plop'}
39
-
40
- >>> q = QuestionFreeText(question_text = "Hello", question_name = "ugly_question")
41
- >>> s = Scenario({'ugly_question': "B"})
42
- >>> from edsl.scenarios.ScenarioList import ScenarioList
43
- >>> cs = CheckSurveyScenarioCompatibility(Survey(questions=[q]), ScenarioList([s]))
44
- >>> cs.check()
45
- Traceback (most recent call last):
46
- ...
47
- ValueError: The following names are in both the survey question_names and the scenario keys: {'ugly_question'}. This will create issues.
48
- """
49
- survey_parameters: set = self.survey.parameters
50
- scenario_parameters: set = self.scenarios.parameters
51
-
52
- msg0, msg1, msg2 = None, None, None
53
-
54
- # look for key issues
55
- if intersection := set(self.scenarios.parameters) & set(
56
- self.survey.question_names
57
- ):
58
- msg0 = f"The following names are in both the survey question_names and the scenario keys: {intersection}. This will create issues."
59
-
60
- raise ValueError(msg0)
61
-
62
- if in_survey_but_not_in_scenarios := survey_parameters - scenario_parameters:
63
- msg1 = f"The following parameters are in the survey but not in the scenarios: {in_survey_but_not_in_scenarios}"
64
- if in_scenarios_but_not_in_survey := scenario_parameters - survey_parameters:
65
- msg2 = f"The following parameters are in the scenarios but not in the survey: {in_scenarios_but_not_in_survey}"
66
-
67
- if msg1 or msg2:
68
- message = "\n".join(filter(None, [msg1, msg2]))
69
- if strict:
70
- raise ValueError(message)
71
- else:
72
- if warn:
73
- warnings.warn(message)
74
-
75
- if self.scenarios.has_jinja_braces:
76
- warnings.warn(
77
- "The scenarios have Jinja braces ({{ and }}). Converting to '<<' and '>>'. If you want a different conversion, use the convert_jinja_braces method first to modify the scenario."
78
- )
79
- self.scenarios = self.scenarios._convert_jinja_braces()
80
-
81
-
82
- if __name__ == "__main__":
83
- import doctest
84
-
85
- doctest.testmod()