edsl 0.1.39.dev2__py3-none-any.whl → 0.1.39.dev3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (334) hide show
  1. edsl/Base.py +332 -385
  2. edsl/BaseDiff.py +260 -260
  3. edsl/TemplateLoader.py +24 -24
  4. edsl/__init__.py +49 -57
  5. edsl/__version__.py +1 -1
  6. edsl/agents/Agent.py +867 -1079
  7. edsl/agents/AgentList.py +413 -551
  8. edsl/agents/Invigilator.py +233 -285
  9. edsl/agents/InvigilatorBase.py +270 -254
  10. edsl/agents/PromptConstructor.py +354 -252
  11. edsl/agents/__init__.py +3 -2
  12. edsl/agents/descriptors.py +99 -99
  13. edsl/agents/prompt_helpers.py +129 -129
  14. edsl/auto/AutoStudy.py +117 -117
  15. edsl/auto/StageBase.py +230 -230
  16. edsl/auto/StageGenerateSurvey.py +178 -178
  17. edsl/auto/StageLabelQuestions.py +125 -125
  18. edsl/auto/StagePersona.py +61 -61
  19. edsl/auto/StagePersonaDimensionValueRanges.py +88 -88
  20. edsl/auto/StagePersonaDimensionValues.py +74 -74
  21. edsl/auto/StagePersonaDimensions.py +69 -69
  22. edsl/auto/StageQuestions.py +73 -73
  23. edsl/auto/SurveyCreatorPipeline.py +21 -21
  24. edsl/auto/utilities.py +224 -224
  25. edsl/base/Base.py +279 -279
  26. edsl/config.py +157 -177
  27. edsl/conversation/Conversation.py +290 -290
  28. edsl/conversation/car_buying.py +58 -59
  29. edsl/conversation/chips.py +95 -95
  30. edsl/conversation/mug_negotiation.py +81 -81
  31. edsl/conversation/next_speaker_utilities.py +93 -93
  32. edsl/coop/PriceFetcher.py +54 -54
  33. edsl/coop/__init__.py +2 -2
  34. edsl/coop/coop.py +1028 -1090
  35. edsl/coop/utils.py +131 -131
  36. edsl/data/Cache.py +555 -562
  37. edsl/data/CacheEntry.py +233 -230
  38. edsl/data/CacheHandler.py +149 -170
  39. edsl/data/RemoteCacheSync.py +78 -78
  40. edsl/data/SQLiteDict.py +292 -292
  41. edsl/data/__init__.py +4 -5
  42. edsl/data/orm.py +10 -10
  43. edsl/data_transfer_models.py +73 -74
  44. edsl/enums.py +175 -195
  45. edsl/exceptions/BaseException.py +21 -21
  46. edsl/exceptions/__init__.py +54 -54
  47. edsl/exceptions/agents.py +42 -54
  48. edsl/exceptions/cache.py +5 -5
  49. edsl/exceptions/configuration.py +16 -16
  50. edsl/exceptions/coop.py +10 -10
  51. edsl/exceptions/data.py +14 -14
  52. edsl/exceptions/general.py +34 -34
  53. edsl/exceptions/jobs.py +33 -33
  54. edsl/exceptions/language_models.py +63 -63
  55. edsl/exceptions/prompts.py +15 -15
  56. edsl/exceptions/questions.py +91 -109
  57. edsl/exceptions/results.py +29 -29
  58. edsl/exceptions/scenarios.py +22 -29
  59. edsl/exceptions/surveys.py +37 -37
  60. edsl/inference_services/AnthropicService.py +87 -84
  61. edsl/inference_services/AwsBedrock.py +120 -118
  62. edsl/inference_services/AzureAI.py +217 -215
  63. edsl/inference_services/DeepInfraService.py +18 -18
  64. edsl/inference_services/GoogleService.py +148 -139
  65. edsl/inference_services/GroqService.py +20 -20
  66. edsl/inference_services/InferenceServiceABC.py +147 -80
  67. edsl/inference_services/InferenceServicesCollection.py +97 -122
  68. edsl/inference_services/MistralAIService.py +123 -120
  69. edsl/inference_services/OllamaService.py +18 -18
  70. edsl/inference_services/OpenAIService.py +224 -221
  71. edsl/inference_services/PerplexityService.py +163 -160
  72. edsl/inference_services/TestService.py +89 -92
  73. edsl/inference_services/TogetherAIService.py +170 -170
  74. edsl/inference_services/models_available_cache.py +118 -118
  75. edsl/inference_services/rate_limits_cache.py +25 -25
  76. edsl/inference_services/registry.py +41 -41
  77. edsl/inference_services/write_available.py +10 -10
  78. edsl/jobs/Answers.py +56 -43
  79. edsl/jobs/Jobs.py +898 -757
  80. edsl/jobs/JobsChecks.py +147 -172
  81. edsl/jobs/JobsPrompts.py +268 -270
  82. edsl/jobs/JobsRemoteInferenceHandler.py +239 -287
  83. edsl/jobs/__init__.py +1 -1
  84. edsl/jobs/buckets/BucketCollection.py +63 -104
  85. edsl/jobs/buckets/ModelBuckets.py +65 -65
  86. edsl/jobs/buckets/TokenBucket.py +251 -283
  87. edsl/jobs/interviews/Interview.py +661 -358
  88. edsl/jobs/interviews/InterviewExceptionCollection.py +99 -99
  89. edsl/jobs/interviews/InterviewExceptionEntry.py +186 -186
  90. edsl/jobs/interviews/InterviewStatistic.py +63 -63
  91. edsl/jobs/interviews/InterviewStatisticsCollection.py +25 -25
  92. edsl/jobs/interviews/InterviewStatusDictionary.py +78 -78
  93. edsl/jobs/interviews/InterviewStatusLog.py +92 -92
  94. edsl/jobs/interviews/ReportErrors.py +66 -66
  95. edsl/jobs/interviews/interview_status_enum.py +9 -9
  96. edsl/jobs/runners/JobsRunnerAsyncio.py +466 -421
  97. edsl/jobs/runners/JobsRunnerStatus.py +330 -330
  98. edsl/jobs/tasks/QuestionTaskCreator.py +242 -244
  99. edsl/jobs/tasks/TaskCreators.py +64 -64
  100. edsl/jobs/tasks/TaskHistory.py +450 -449
  101. edsl/jobs/tasks/TaskStatusLog.py +23 -23
  102. edsl/jobs/tasks/task_status_enum.py +163 -161
  103. edsl/jobs/tokens/InterviewTokenUsage.py +27 -27
  104. edsl/jobs/tokens/TokenUsage.py +34 -34
  105. edsl/language_models/KeyLookup.py +30 -0
  106. edsl/language_models/LanguageModel.py +668 -571
  107. edsl/language_models/ModelList.py +155 -153
  108. edsl/language_models/RegisterLanguageModelsMeta.py +184 -184
  109. edsl/language_models/__init__.py +3 -2
  110. edsl/language_models/fake_openai_call.py +15 -15
  111. edsl/language_models/fake_openai_service.py +61 -61
  112. edsl/language_models/registry.py +190 -180
  113. edsl/language_models/repair.py +156 -156
  114. edsl/language_models/unused/ReplicateBase.py +83 -0
  115. edsl/language_models/utilities.py +64 -65
  116. edsl/notebooks/Notebook.py +258 -263
  117. edsl/notebooks/__init__.py +1 -1
  118. edsl/prompts/Prompt.py +362 -352
  119. edsl/prompts/__init__.py +2 -2
  120. edsl/questions/AnswerValidatorMixin.py +289 -334
  121. edsl/questions/QuestionBase.py +664 -509
  122. edsl/questions/QuestionBaseGenMixin.py +161 -165
  123. edsl/questions/QuestionBasePromptsMixin.py +217 -221
  124. edsl/questions/QuestionBudget.py +227 -227
  125. edsl/questions/QuestionCheckBox.py +359 -359
  126. edsl/questions/QuestionExtract.py +182 -182
  127. edsl/questions/QuestionFreeText.py +114 -113
  128. edsl/questions/QuestionFunctional.py +166 -166
  129. edsl/questions/QuestionList.py +231 -229
  130. edsl/questions/QuestionMultipleChoice.py +286 -330
  131. edsl/questions/QuestionNumerical.py +153 -151
  132. edsl/questions/QuestionRank.py +324 -314
  133. edsl/questions/Quick.py +41 -41
  134. edsl/questions/RegisterQuestionsMeta.py +71 -71
  135. edsl/questions/ResponseValidatorABC.py +174 -200
  136. edsl/questions/SimpleAskMixin.py +73 -74
  137. edsl/questions/__init__.py +26 -27
  138. edsl/questions/compose_questions.py +98 -98
  139. edsl/questions/decorators.py +21 -21
  140. edsl/questions/derived/QuestionLikertFive.py +76 -76
  141. edsl/questions/derived/QuestionLinearScale.py +87 -90
  142. edsl/questions/derived/QuestionTopK.py +93 -93
  143. edsl/questions/derived/QuestionYesNo.py +82 -82
  144. edsl/questions/descriptors.py +413 -427
  145. edsl/questions/prompt_templates/question_budget.jinja +13 -13
  146. edsl/questions/prompt_templates/question_checkbox.jinja +32 -32
  147. edsl/questions/prompt_templates/question_extract.jinja +11 -11
  148. edsl/questions/prompt_templates/question_free_text.jinja +3 -3
  149. edsl/questions/prompt_templates/question_linear_scale.jinja +11 -11
  150. edsl/questions/prompt_templates/question_list.jinja +17 -17
  151. edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -33
  152. edsl/questions/prompt_templates/question_numerical.jinja +36 -36
  153. edsl/questions/question_registry.py +177 -177
  154. edsl/questions/settings.py +12 -12
  155. edsl/questions/templates/budget/answering_instructions.jinja +7 -7
  156. edsl/questions/templates/budget/question_presentation.jinja +7 -7
  157. edsl/questions/templates/checkbox/answering_instructions.jinja +10 -10
  158. edsl/questions/templates/checkbox/question_presentation.jinja +22 -22
  159. edsl/questions/templates/extract/answering_instructions.jinja +7 -7
  160. edsl/questions/templates/likert_five/answering_instructions.jinja +10 -10
  161. edsl/questions/templates/likert_five/question_presentation.jinja +11 -11
  162. edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -5
  163. edsl/questions/templates/linear_scale/question_presentation.jinja +5 -5
  164. edsl/questions/templates/list/answering_instructions.jinja +3 -3
  165. edsl/questions/templates/list/question_presentation.jinja +5 -5
  166. edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -9
  167. edsl/questions/templates/multiple_choice/question_presentation.jinja +11 -11
  168. edsl/questions/templates/numerical/answering_instructions.jinja +6 -6
  169. edsl/questions/templates/numerical/question_presentation.jinja +6 -6
  170. edsl/questions/templates/rank/answering_instructions.jinja +11 -11
  171. edsl/questions/templates/rank/question_presentation.jinja +15 -15
  172. edsl/questions/templates/top_k/answering_instructions.jinja +8 -8
  173. edsl/questions/templates/top_k/question_presentation.jinja +22 -22
  174. edsl/questions/templates/yes_no/answering_instructions.jinja +6 -6
  175. edsl/questions/templates/yes_no/question_presentation.jinja +11 -11
  176. edsl/results/CSSParameterizer.py +108 -108
  177. edsl/results/Dataset.py +424 -587
  178. edsl/results/DatasetExportMixin.py +731 -653
  179. edsl/results/DatasetTree.py +275 -295
  180. edsl/results/Result.py +465 -451
  181. edsl/results/Results.py +1165 -1172
  182. edsl/results/ResultsDBMixin.py +238 -0
  183. edsl/results/ResultsExportMixin.py +43 -45
  184. edsl/results/ResultsFetchMixin.py +33 -33
  185. edsl/results/ResultsGGMixin.py +121 -121
  186. edsl/results/ResultsToolsMixin.py +98 -98
  187. edsl/results/Selector.py +135 -145
  188. edsl/results/TableDisplay.py +198 -125
  189. edsl/results/__init__.py +2 -2
  190. edsl/results/table_display.css +77 -77
  191. edsl/results/tree_explore.py +115 -115
  192. edsl/scenarios/FileStore.py +632 -511
  193. edsl/scenarios/Scenario.py +601 -498
  194. edsl/scenarios/ScenarioHtmlMixin.py +64 -65
  195. edsl/scenarios/ScenarioJoin.py +127 -131
  196. edsl/scenarios/ScenarioList.py +1287 -1430
  197. edsl/scenarios/ScenarioListExportMixin.py +52 -45
  198. edsl/scenarios/ScenarioListPdfMixin.py +261 -239
  199. edsl/scenarios/__init__.py +4 -3
  200. edsl/shared.py +1 -1
  201. edsl/study/ObjectEntry.py +173 -173
  202. edsl/study/ProofOfWork.py +113 -113
  203. edsl/study/SnapShot.py +80 -80
  204. edsl/study/Study.py +528 -521
  205. edsl/study/__init__.py +4 -4
  206. edsl/surveys/DAG.py +148 -148
  207. edsl/surveys/Memory.py +31 -31
  208. edsl/surveys/MemoryPlan.py +244 -244
  209. edsl/surveys/Rule.py +326 -327
  210. edsl/surveys/RuleCollection.py +387 -385
  211. edsl/surveys/Survey.py +1801 -1229
  212. edsl/surveys/SurveyCSS.py +261 -273
  213. edsl/surveys/SurveyExportMixin.py +259 -259
  214. edsl/surveys/{SurveyFlowVisualization.py → SurveyFlowVisualizationMixin.py} +179 -181
  215. edsl/surveys/SurveyQualtricsImport.py +284 -284
  216. edsl/surveys/__init__.py +3 -5
  217. edsl/surveys/base.py +53 -53
  218. edsl/surveys/descriptors.py +56 -60
  219. edsl/surveys/instructions/ChangeInstruction.py +49 -48
  220. edsl/surveys/instructions/Instruction.py +65 -56
  221. edsl/surveys/instructions/InstructionCollection.py +77 -82
  222. edsl/templates/error_reporting/base.html +23 -23
  223. edsl/templates/error_reporting/exceptions_by_model.html +34 -34
  224. edsl/templates/error_reporting/exceptions_by_question_name.html +16 -16
  225. edsl/templates/error_reporting/exceptions_by_type.html +16 -16
  226. edsl/templates/error_reporting/interview_details.html +115 -115
  227. edsl/templates/error_reporting/interviews.html +19 -19
  228. edsl/templates/error_reporting/overview.html +4 -4
  229. edsl/templates/error_reporting/performance_plot.html +1 -1
  230. edsl/templates/error_reporting/report.css +73 -73
  231. edsl/templates/error_reporting/report.html +117 -117
  232. edsl/templates/error_reporting/report.js +25 -25
  233. edsl/tools/__init__.py +1 -1
  234. edsl/tools/clusters.py +192 -192
  235. edsl/tools/embeddings.py +27 -27
  236. edsl/tools/embeddings_plotting.py +118 -118
  237. edsl/tools/plotting.py +112 -112
  238. edsl/tools/summarize.py +18 -18
  239. edsl/utilities/SystemInfo.py +28 -28
  240. edsl/utilities/__init__.py +22 -22
  241. edsl/utilities/ast_utilities.py +25 -25
  242. edsl/utilities/data/Registry.py +6 -6
  243. edsl/utilities/data/__init__.py +1 -1
  244. edsl/utilities/data/scooter_results.json +1 -1
  245. edsl/utilities/decorators.py +77 -77
  246. edsl/utilities/gcp_bucket/cloud_storage.py +96 -96
  247. edsl/utilities/interface.py +627 -627
  248. edsl/utilities/naming_utilities.py +263 -263
  249. edsl/utilities/repair_functions.py +28 -28
  250. edsl/utilities/restricted_python.py +70 -70
  251. edsl/utilities/utilities.py +424 -436
  252. {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev3.dist-info}/LICENSE +21 -21
  253. {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev3.dist-info}/METADATA +10 -12
  254. edsl-0.1.39.dev3.dist-info/RECORD +277 -0
  255. edsl/agents/QuestionInstructionPromptBuilder.py +0 -128
  256. edsl/agents/QuestionOptionProcessor.py +0 -172
  257. edsl/agents/QuestionTemplateReplacementsBuilder.py +0 -137
  258. edsl/coop/CoopFunctionsMixin.py +0 -15
  259. edsl/coop/ExpectedParrotKeyHandler.py +0 -125
  260. edsl/exceptions/inference_services.py +0 -5
  261. edsl/inference_services/AvailableModelCacheHandler.py +0 -184
  262. edsl/inference_services/AvailableModelFetcher.py +0 -209
  263. edsl/inference_services/ServiceAvailability.py +0 -135
  264. edsl/inference_services/data_structures.py +0 -62
  265. edsl/jobs/AnswerQuestionFunctionConstructor.py +0 -188
  266. edsl/jobs/FetchInvigilator.py +0 -40
  267. edsl/jobs/InterviewTaskManager.py +0 -98
  268. edsl/jobs/InterviewsConstructor.py +0 -48
  269. edsl/jobs/JobsComponentConstructor.py +0 -189
  270. edsl/jobs/JobsRemoteInferenceLogger.py +0 -239
  271. edsl/jobs/RequestTokenEstimator.py +0 -30
  272. edsl/jobs/buckets/TokenBucketAPI.py +0 -211
  273. edsl/jobs/buckets/TokenBucketClient.py +0 -191
  274. edsl/jobs/decorators.py +0 -35
  275. edsl/jobs/jobs_status_enums.py +0 -9
  276. edsl/jobs/loggers/HTMLTableJobLogger.py +0 -304
  277. edsl/language_models/ComputeCost.py +0 -63
  278. edsl/language_models/PriceManager.py +0 -127
  279. edsl/language_models/RawResponseHandler.py +0 -106
  280. edsl/language_models/ServiceDataSources.py +0 -0
  281. edsl/language_models/key_management/KeyLookup.py +0 -63
  282. edsl/language_models/key_management/KeyLookupBuilder.py +0 -273
  283. edsl/language_models/key_management/KeyLookupCollection.py +0 -38
  284. edsl/language_models/key_management/__init__.py +0 -0
  285. edsl/language_models/key_management/models.py +0 -131
  286. edsl/notebooks/NotebookToLaTeX.py +0 -142
  287. edsl/questions/ExceptionExplainer.py +0 -77
  288. edsl/questions/HTMLQuestion.py +0 -103
  289. edsl/questions/LoopProcessor.py +0 -149
  290. edsl/questions/QuestionMatrix.py +0 -265
  291. edsl/questions/ResponseValidatorFactory.py +0 -28
  292. edsl/questions/templates/matrix/__init__.py +0 -1
  293. edsl/questions/templates/matrix/answering_instructions.jinja +0 -5
  294. edsl/questions/templates/matrix/question_presentation.jinja +0 -20
  295. edsl/results/MarkdownToDocx.py +0 -122
  296. edsl/results/MarkdownToPDF.py +0 -111
  297. edsl/results/TextEditor.py +0 -50
  298. edsl/results/smart_objects.py +0 -96
  299. edsl/results/table_data_class.py +0 -12
  300. edsl/results/table_renderers.py +0 -118
  301. edsl/scenarios/ConstructDownloadLink.py +0 -109
  302. edsl/scenarios/DirectoryScanner.py +0 -96
  303. edsl/scenarios/DocumentChunker.py +0 -102
  304. edsl/scenarios/DocxScenario.py +0 -16
  305. edsl/scenarios/PdfExtractor.py +0 -40
  306. edsl/scenarios/ScenarioSelector.py +0 -156
  307. edsl/scenarios/file_methods.py +0 -85
  308. edsl/scenarios/handlers/__init__.py +0 -13
  309. edsl/scenarios/handlers/csv.py +0 -38
  310. edsl/scenarios/handlers/docx.py +0 -76
  311. edsl/scenarios/handlers/html.py +0 -37
  312. edsl/scenarios/handlers/json.py +0 -111
  313. edsl/scenarios/handlers/latex.py +0 -5
  314. edsl/scenarios/handlers/md.py +0 -51
  315. edsl/scenarios/handlers/pdf.py +0 -68
  316. edsl/scenarios/handlers/png.py +0 -39
  317. edsl/scenarios/handlers/pptx.py +0 -105
  318. edsl/scenarios/handlers/py.py +0 -294
  319. edsl/scenarios/handlers/sql.py +0 -313
  320. edsl/scenarios/handlers/sqlite.py +0 -149
  321. edsl/scenarios/handlers/txt.py +0 -33
  322. edsl/surveys/ConstructDAG.py +0 -92
  323. edsl/surveys/EditSurvey.py +0 -221
  324. edsl/surveys/InstructionHandler.py +0 -100
  325. edsl/surveys/MemoryManagement.py +0 -72
  326. edsl/surveys/RuleManager.py +0 -172
  327. edsl/surveys/Simulator.py +0 -75
  328. edsl/surveys/SurveyToApp.py +0 -141
  329. edsl/utilities/PrettyList.py +0 -56
  330. edsl/utilities/is_notebook.py +0 -18
  331. edsl/utilities/is_valid_variable_name.py +0 -11
  332. edsl/utilities/remove_edsl_version.py +0 -24
  333. edsl-0.1.39.dev2.dist-info/RECORD +0 -352
  334. {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev3.dist-info}/WHEEL +0 -0
edsl/data/CacheHandler.py CHANGED
@@ -1,170 +1,149 @@
1
- from __future__ import annotations
2
- import ast
3
- import json
4
- import os
5
- import shutil
6
- from typing import TYPE_CHECKING
7
-
8
- if TYPE_CHECKING:
9
- from edsl.data.Cache import Cache
10
- from edsl.data.CacheEntry import CacheEntry
11
-
12
-
13
- def set_session_cache(cache: "Cache") -> None:
14
- """
15
- Set the session cache.
16
- """
17
- from edsl.config import CONFIG
18
-
19
- CONFIG.EDSL_SESSION_CACHE = cache
20
-
21
-
22
- def unset_session_cache() -> None:
23
- """
24
- Unset the session cache.
25
- """
26
- from edsl.config import CONFIG
27
-
28
- if hasattr(CONFIG, "EDSL_SESSION_CACHE"):
29
- del CONFIG.EDSL_SESSION_CACHE
30
-
31
-
32
- class CacheHandler:
33
- """
34
- This CacheHandler figures out what caches are available and does migrations, as needed.
35
- """
36
-
37
- @property
38
- def CACHE_PATH(self):
39
- from edsl.config import CONFIG
40
-
41
- return CONFIG.get("EDSL_DATABASE_PATH")
42
-
43
- def __init__(self, test: bool = False):
44
- self.test = test
45
- self.create_cache_directory()
46
- self.cache = self.gen_cache()
47
- old_data = self.from_old_sqlite_cache()
48
- self.cache.add_from_dict(old_data)
49
-
50
- def create_cache_directory(self, notify=False) -> None:
51
- """
52
- Create the cache directory if one is required and it does not exist.
53
- """
54
- path = self.CACHE_PATH.replace("sqlite:///", "")
55
- dir_path = os.path.dirname(path)
56
- if dir_path and not os.path.exists(dir_path):
57
- os.makedirs(dir_path)
58
- if notify:
59
- print(f"Created cache directory: {dir_path}")
60
-
61
- def gen_cache(self) -> "Cache":
62
- """
63
- Generate a Cache object.
64
- """
65
- from edsl.data.Cache import Cache
66
-
67
- if self.test:
68
- return Cache(data={})
69
-
70
- # if self.CACHE_PATH is not None:
71
- # return self.CACHE_PATH
72
- from edsl.config import CONFIG
73
-
74
- if hasattr(CONFIG, "EDSL_SESSION_CACHE"):
75
- return CONFIG.EDSL_SESSION_CACHE
76
-
77
- from edsl.data.SQLiteDict import SQLiteDict
78
-
79
- cache = Cache(data=SQLiteDict(self.CACHE_PATH))
80
- return cache
81
-
82
- def from_old_sqlite_cache(
83
- self, path: str = "edsl_cache.db"
84
- ) -> dict[str, CacheEntry]:
85
- """
86
- Convert an old-style cache to the new format.
87
- - NB: Not worth converting to sqlalchemy - this is a one-time operation.
88
- """
89
- old_data = {}
90
- if not os.path.exists(os.path.join(os.getcwd(), path)):
91
- return old_data
92
- try:
93
- import sqlite3
94
-
95
- conn = sqlite3.connect(path)
96
- with conn:
97
- cur = conn.cursor()
98
- table_name = "responses"
99
- cur.execute(f"PRAGMA table_info({table_name})")
100
- columns = cur.fetchall()
101
- schema = {column[1]: column[2] for column in columns}
102
- data = cur.execute(f"SELECT * FROM {table_name}").fetchall()
103
- for row in data:
104
- entry = self._parse_old_cache_entry(row, schema)
105
- old_data[entry.key] = entry
106
- print(
107
- f"Found old cache at {path} with {len(old_data)} entries.\n"
108
- f"We will convert this to the new cache format.\n"
109
- f"The old cache is backed up to {path}.bak"
110
- )
111
- shutil.copy(path, f"{path}.bak")
112
- os.remove(path)
113
- except sqlite3.OperationalError:
114
- print("Found an old Cache but could not convert it to new format.")
115
-
116
- return old_data
117
-
118
- def _parse_old_cache_entry(self, row: tuple, schema) -> CacheEntry:
119
- """
120
- Parse an old cache entry.
121
- """
122
- entry_dict = {k: row[i] for i, k in enumerate(schema.keys())}
123
- _ = entry_dict.pop("id")
124
- entry_dict["user_prompt"] = entry_dict.pop("prompt")
125
- parameters = entry_dict["parameters"]
126
- entry_dict["parameters"] = ast.literal_eval(parameters)
127
- from edsl.data.CacheEntry import CacheEntry
128
-
129
- entry = CacheEntry(**entry_dict)
130
- return entry
131
-
132
- def get_cache(self) -> Cache:
133
- return self.cache
134
-
135
- ###############
136
- # NOT IN USE
137
- ###############
138
- def from_sqlite(uri="new_edsl_cache.db") -> dict[str, "CacheEntry"]:
139
- """
140
- Read in a new-style sqlite cache and return a dictionary of dictionaries.
141
- """
142
- conn = sqlite3.connect(uri)
143
- with conn:
144
- cur = conn.cursor()
145
- data = cur.execute("SELECT key, value FROM data").fetchall()
146
- newdata = {}
147
- for _, value in data:
148
- entry = CacheEntry.from_dict(json.loads(value))
149
- newdata[entry.key] = entry
150
- return newdata
151
-
152
- def from_jsonl(filename="edsl_cache.jsonl") -> dict[str, "CacheEntry"]:
153
- """Read in a jsonl file and return a dictionary of CacheEntry objects."""
154
- with open(filename, "a+") as f:
155
- f.seek(0)
156
- lines = f.readlines()
157
- newdata = {}
158
- for line in lines:
159
- d = json.loads(line)
160
- key = list(d.keys())[0]
161
- value = list(d.values())[0]
162
- newdata[key] = CacheEntry.from_dict(value)
163
- return newdata
164
-
165
-
166
- if __name__ == "__main__":
167
- # ch = CacheHandler()
168
- import doctest
169
-
170
- doctest.testmod()
1
+ from __future__ import annotations
2
+ import ast
3
+ import json
4
+ import os
5
+ import shutil
6
+ import sqlite3
7
+ from edsl.config import CONFIG
8
+ from edsl.data.Cache import Cache
9
+ from edsl.data.CacheEntry import CacheEntry
10
+ from edsl.data.SQLiteDict import SQLiteDict
11
+
12
+ from edsl.config import CONFIG
13
+
14
+
15
+ def set_session_cache(cache: Cache) -> None:
16
+ """
17
+ Set the session cache.
18
+ """
19
+ CONFIG.EDSL_SESSION_CACHE = cache
20
+
21
+
22
+ def unset_session_cache() -> None:
23
+ """
24
+ Unset the session cache.
25
+ """
26
+ if hasattr(CONFIG, "EDSL_SESSION_CACHE"):
27
+ del CONFIG.EDSL_SESSION_CACHE
28
+
29
+
30
+ class CacheHandler:
31
+ """
32
+ This CacheHandler figures out what caches are available and does migrations, as needed.
33
+ """
34
+
35
+ CACHE_PATH = CONFIG.get("EDSL_DATABASE_PATH")
36
+
37
+ def __init__(self, test: bool = False):
38
+ self.test = test
39
+ self.create_cache_directory()
40
+ self.cache = self.gen_cache()
41
+ old_data = self.from_old_sqlite_cache()
42
+ self.cache.add_from_dict(old_data)
43
+
44
+ def create_cache_directory(self, notify=False) -> None:
45
+ """
46
+ Create the cache directory if one is required and it does not exist.
47
+ """
48
+ path = self.CACHE_PATH.replace("sqlite:///", "")
49
+ dir_path = os.path.dirname(path)
50
+ if dir_path and not os.path.exists(dir_path):
51
+ os.makedirs(dir_path)
52
+ if notify:
53
+ print(f"Created cache directory: {dir_path}")
54
+
55
+ def gen_cache(self) -> Cache:
56
+ """
57
+ Generate a Cache object.
58
+ """
59
+ if self.test:
60
+ return Cache(data={})
61
+
62
+ if hasattr(CONFIG, "EDSL_SESSION_CACHE"):
63
+ return CONFIG.EDSL_SESSION_CACHE
64
+
65
+ cache = Cache(data=SQLiteDict(self.CACHE_PATH))
66
+ return cache
67
+
68
+ def from_old_sqlite_cache(
69
+ self, path: str = "edsl_cache.db"
70
+ ) -> dict[str, CacheEntry]:
71
+ """
72
+ Convert an old-style cache to the new format.
73
+ - NB: Not worth converting to sqlalchemy - this is a one-time operation.
74
+ """
75
+ old_data = {}
76
+ if not os.path.exists(os.path.join(os.getcwd(), path)):
77
+ return old_data
78
+ try:
79
+ conn = sqlite3.connect(path)
80
+ with conn:
81
+ cur = conn.cursor()
82
+ table_name = "responses"
83
+ cur.execute(f"PRAGMA table_info({table_name})")
84
+ columns = cur.fetchall()
85
+ schema = {column[1]: column[2] for column in columns}
86
+ data = cur.execute(f"SELECT * FROM {table_name}").fetchall()
87
+ for row in data:
88
+ entry = self._parse_old_cache_entry(row, schema)
89
+ old_data[entry.key] = entry
90
+ print(
91
+ f"Found old cache at {path} with {len(old_data)} entries.\n"
92
+ f"We will convert this to the new cache format.\n"
93
+ f"The old cache is backed up to {path}.bak"
94
+ )
95
+ shutil.copy(path, f"{path}.bak")
96
+ os.remove(path)
97
+ except sqlite3.OperationalError:
98
+ print("Found an old Cache but could not convert it to new format.")
99
+
100
+ return old_data
101
+
102
+ def _parse_old_cache_entry(self, row: tuple, schema) -> CacheEntry:
103
+ """
104
+ Parse an old cache entry.
105
+ """
106
+ entry_dict = {k: row[i] for i, k in enumerate(schema.keys())}
107
+ _ = entry_dict.pop("id")
108
+ entry_dict["user_prompt"] = entry_dict.pop("prompt")
109
+ parameters = entry_dict["parameters"]
110
+ entry_dict["parameters"] = ast.literal_eval(parameters)
111
+ entry = CacheEntry(**entry_dict)
112
+ return entry
113
+
114
+ def get_cache(self) -> Cache:
115
+ return self.cache
116
+
117
+ ###############
118
+ # NOT IN USE
119
+ ###############
120
+ def from_sqlite(uri="new_edsl_cache.db") -> dict[str, CacheEntry]:
121
+ """
122
+ Read in a new-style sqlite cache and return a dictionary of dictionaries.
123
+ """
124
+ conn = sqlite3.connect(uri)
125
+ with conn:
126
+ cur = conn.cursor()
127
+ data = cur.execute("SELECT key, value FROM data").fetchall()
128
+ newdata = {}
129
+ for _, value in data:
130
+ entry = CacheEntry.from_dict(json.loads(value))
131
+ newdata[entry.key] = entry
132
+ return newdata
133
+
134
+ def from_jsonl(filename="edsl_cache.jsonl") -> dict[str, CacheEntry]:
135
+ """Read in a jsonl file and return a dictionary of CacheEntry objects."""
136
+ with open(filename, "a+") as f:
137
+ f.seek(0)
138
+ lines = f.readlines()
139
+ newdata = {}
140
+ for line in lines:
141
+ d = json.loads(line)
142
+ key = list(d.keys())[0]
143
+ value = list(d.values())[0]
144
+ newdata[key] = CacheEntry.from_dict(value)
145
+ return newdata
146
+
147
+
148
+ if __name__ == "__main__":
149
+ ch = CacheHandler()
@@ -1,78 +1,78 @@
1
- class RemoteCacheSync:
2
- def __init__(
3
- self, coop, cache, output_func, remote_cache=True, remote_cache_description=""
4
- ):
5
- self.coop = coop
6
- self.cache = cache
7
- self._output = output_func
8
- self.remote_cache = remote_cache
9
- self.old_entry_keys = []
10
- self.new_cache_entries = []
11
- self.remote_cache_description = remote_cache_description
12
-
13
- def __enter__(self):
14
- if self.remote_cache:
15
- self._sync_from_remote()
16
- self.old_entry_keys = list(self.cache.keys())
17
- return self
18
-
19
- def __exit__(self, exc_type, exc_value, traceback):
20
- if self.remote_cache:
21
- self._sync_to_remote()
22
- return False # Propagate exceptions
23
-
24
- def _sync_from_remote(self):
25
- cache_difference = self.coop.remote_cache_get_diff(self.cache.keys())
26
- client_missing_cacheentries = cache_difference.get(
27
- "client_missing_cacheentries", []
28
- )
29
- missing_entry_count = len(client_missing_cacheentries)
30
-
31
- if missing_entry_count > 0:
32
- self._output(
33
- f"Updating local cache with {missing_entry_count:,} new "
34
- f"{'entry' if missing_entry_count == 1 else 'entries'} from remote..."
35
- )
36
- self.cache.add_from_dict(
37
- {entry.key: entry for entry in client_missing_cacheentries}
38
- )
39
- self._output("Local cache updated!")
40
- else:
41
- self._output("No new entries to add to local cache.")
42
-
43
- def _sync_to_remote(self):
44
- cache_difference = self.coop.remote_cache_get_diff(self.cache.keys())
45
- server_missing_cacheentry_keys = cache_difference.get(
46
- "server_missing_cacheentry_keys", []
47
- )
48
- server_missing_cacheentries = [
49
- entry
50
- for key in server_missing_cacheentry_keys
51
- if (entry := self.cache.data.get(key)) is not None
52
- ]
53
-
54
- new_cache_entries = [
55
- entry
56
- for entry in self.cache.values()
57
- if entry.key not in self.old_entry_keys
58
- ]
59
- server_missing_cacheentries.extend(new_cache_entries)
60
- new_entry_count = len(server_missing_cacheentries)
61
-
62
- if new_entry_count > 0:
63
- self._output(
64
- f"Updating remote cache with {new_entry_count:,} new "
65
- f"{'entry' if new_entry_count == 1 else 'entries'}..."
66
- )
67
- self.coop.remote_cache_create_many(
68
- server_missing_cacheentries,
69
- visibility="private",
70
- description=self.remote_cache_description,
71
- )
72
- self._output("Remote cache updated!")
73
- else:
74
- self._output("No new entries to add to remote cache.")
75
-
76
- self._output(
77
- f"There are {len(self.cache.keys()):,} entries in the local cache."
78
- )
1
+ class RemoteCacheSync:
2
+ def __init__(
3
+ self, coop, cache, output_func, remote_cache=True, remote_cache_description=""
4
+ ):
5
+ self.coop = coop
6
+ self.cache = cache
7
+ self._output = output_func
8
+ self.remote_cache = remote_cache
9
+ self.old_entry_keys = []
10
+ self.new_cache_entries = []
11
+ self.remote_cache_description = remote_cache_description
12
+
13
+ def __enter__(self):
14
+ if self.remote_cache:
15
+ self._sync_from_remote()
16
+ self.old_entry_keys = list(self.cache.keys())
17
+ return self
18
+
19
+ def __exit__(self, exc_type, exc_value, traceback):
20
+ if self.remote_cache:
21
+ self._sync_to_remote()
22
+ return False # Propagate exceptions
23
+
24
+ def _sync_from_remote(self):
25
+ cache_difference = self.coop.remote_cache_get_diff(self.cache.keys())
26
+ client_missing_cacheentries = cache_difference.get(
27
+ "client_missing_cacheentries", []
28
+ )
29
+ missing_entry_count = len(client_missing_cacheentries)
30
+
31
+ if missing_entry_count > 0:
32
+ self._output(
33
+ f"Updating local cache with {missing_entry_count:,} new "
34
+ f"{'entry' if missing_entry_count == 1 else 'entries'} from remote..."
35
+ )
36
+ self.cache.add_from_dict(
37
+ {entry.key: entry for entry in client_missing_cacheentries}
38
+ )
39
+ self._output("Local cache updated!")
40
+ else:
41
+ self._output("No new entries to add to local cache.")
42
+
43
+ def _sync_to_remote(self):
44
+ cache_difference = self.coop.remote_cache_get_diff(self.cache.keys())
45
+ server_missing_cacheentry_keys = cache_difference.get(
46
+ "server_missing_cacheentry_keys", []
47
+ )
48
+ server_missing_cacheentries = [
49
+ entry
50
+ for key in server_missing_cacheentry_keys
51
+ if (entry := self.cache.data.get(key)) is not None
52
+ ]
53
+
54
+ new_cache_entries = [
55
+ entry
56
+ for entry in self.cache.values()
57
+ if entry.key not in self.old_entry_keys
58
+ ]
59
+ server_missing_cacheentries.extend(new_cache_entries)
60
+ new_entry_count = len(server_missing_cacheentries)
61
+
62
+ if new_entry_count > 0:
63
+ self._output(
64
+ f"Updating remote cache with {new_entry_count:,} new "
65
+ f"{'entry' if new_entry_count == 1 else 'entries'}..."
66
+ )
67
+ self.coop.remote_cache_create_many(
68
+ server_missing_cacheentries,
69
+ visibility="private",
70
+ description=self.remote_cache_description,
71
+ )
72
+ self._output("Remote cache updated!")
73
+ else:
74
+ self._output("No new entries to add to remote cache.")
75
+
76
+ self._output(
77
+ f"There are {len(self.cache.keys()):,} entries in the local cache."
78
+ )