edsl 0.1.39.dev3__py3-none-any.whl → 0.1.39.dev4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (344) hide show
  1. edsl/Base.py +413 -332
  2. edsl/BaseDiff.py +260 -260
  3. edsl/TemplateLoader.py +24 -24
  4. edsl/__init__.py +57 -49
  5. edsl/__version__.py +1 -1
  6. edsl/agents/Agent.py +1071 -867
  7. edsl/agents/AgentList.py +551 -413
  8. edsl/agents/Invigilator.py +284 -233
  9. edsl/agents/InvigilatorBase.py +257 -270
  10. edsl/agents/PromptConstructor.py +272 -354
  11. edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
  12. edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
  13. edsl/agents/__init__.py +2 -3
  14. edsl/agents/descriptors.py +99 -99
  15. edsl/agents/prompt_helpers.py +129 -129
  16. edsl/agents/question_option_processor.py +172 -0
  17. edsl/auto/AutoStudy.py +130 -117
  18. edsl/auto/StageBase.py +243 -230
  19. edsl/auto/StageGenerateSurvey.py +178 -178
  20. edsl/auto/StageLabelQuestions.py +125 -125
  21. edsl/auto/StagePersona.py +61 -61
  22. edsl/auto/StagePersonaDimensionValueRanges.py +88 -88
  23. edsl/auto/StagePersonaDimensionValues.py +74 -74
  24. edsl/auto/StagePersonaDimensions.py +69 -69
  25. edsl/auto/StageQuestions.py +74 -73
  26. edsl/auto/SurveyCreatorPipeline.py +21 -21
  27. edsl/auto/utilities.py +218 -224
  28. edsl/base/Base.py +279 -279
  29. edsl/config.py +177 -157
  30. edsl/conversation/Conversation.py +290 -290
  31. edsl/conversation/car_buying.py +59 -58
  32. edsl/conversation/chips.py +95 -95
  33. edsl/conversation/mug_negotiation.py +81 -81
  34. edsl/conversation/next_speaker_utilities.py +93 -93
  35. edsl/coop/CoopFunctionsMixin.py +15 -0
  36. edsl/coop/ExpectedParrotKeyHandler.py +125 -0
  37. edsl/coop/PriceFetcher.py +54 -54
  38. edsl/coop/__init__.py +2 -2
  39. edsl/coop/coop.py +1106 -1028
  40. edsl/coop/utils.py +131 -131
  41. edsl/data/Cache.py +573 -555
  42. edsl/data/CacheEntry.py +230 -233
  43. edsl/data/CacheHandler.py +168 -149
  44. edsl/data/RemoteCacheSync.py +186 -78
  45. edsl/data/SQLiteDict.py +292 -292
  46. edsl/data/__init__.py +5 -4
  47. edsl/data/hack.py +10 -0
  48. edsl/data/orm.py +10 -10
  49. edsl/data_transfer_models.py +74 -73
  50. edsl/enums.py +202 -175
  51. edsl/exceptions/BaseException.py +21 -21
  52. edsl/exceptions/__init__.py +54 -54
  53. edsl/exceptions/agents.py +54 -42
  54. edsl/exceptions/cache.py +5 -5
  55. edsl/exceptions/configuration.py +16 -16
  56. edsl/exceptions/coop.py +10 -10
  57. edsl/exceptions/data.py +14 -14
  58. edsl/exceptions/general.py +34 -34
  59. edsl/exceptions/inference_services.py +5 -0
  60. edsl/exceptions/jobs.py +33 -33
  61. edsl/exceptions/language_models.py +63 -63
  62. edsl/exceptions/prompts.py +15 -15
  63. edsl/exceptions/questions.py +109 -91
  64. edsl/exceptions/results.py +29 -29
  65. edsl/exceptions/scenarios.py +29 -22
  66. edsl/exceptions/surveys.py +37 -37
  67. edsl/inference_services/AnthropicService.py +106 -87
  68. edsl/inference_services/AvailableModelCacheHandler.py +184 -0
  69. edsl/inference_services/AvailableModelFetcher.py +215 -0
  70. edsl/inference_services/AwsBedrock.py +118 -120
  71. edsl/inference_services/AzureAI.py +215 -217
  72. edsl/inference_services/DeepInfraService.py +18 -18
  73. edsl/inference_services/GoogleService.py +143 -148
  74. edsl/inference_services/GroqService.py +20 -20
  75. edsl/inference_services/InferenceServiceABC.py +80 -147
  76. edsl/inference_services/InferenceServicesCollection.py +138 -97
  77. edsl/inference_services/MistralAIService.py +120 -123
  78. edsl/inference_services/OllamaService.py +18 -18
  79. edsl/inference_services/OpenAIService.py +236 -224
  80. edsl/inference_services/PerplexityService.py +160 -163
  81. edsl/inference_services/ServiceAvailability.py +135 -0
  82. edsl/inference_services/TestService.py +90 -89
  83. edsl/inference_services/TogetherAIService.py +172 -170
  84. edsl/inference_services/data_structures.py +134 -0
  85. edsl/inference_services/models_available_cache.py +118 -118
  86. edsl/inference_services/rate_limits_cache.py +25 -25
  87. edsl/inference_services/registry.py +41 -41
  88. edsl/inference_services/write_available.py +10 -10
  89. edsl/jobs/AnswerQuestionFunctionConstructor.py +223 -0
  90. edsl/jobs/Answers.py +43 -56
  91. edsl/jobs/FetchInvigilator.py +47 -0
  92. edsl/jobs/InterviewTaskManager.py +98 -0
  93. edsl/jobs/InterviewsConstructor.py +50 -0
  94. edsl/jobs/Jobs.py +823 -898
  95. edsl/jobs/JobsChecks.py +172 -147
  96. edsl/jobs/JobsComponentConstructor.py +189 -0
  97. edsl/jobs/JobsPrompts.py +270 -268
  98. edsl/jobs/JobsRemoteInferenceHandler.py +311 -239
  99. edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
  100. edsl/jobs/RequestTokenEstimator.py +30 -0
  101. edsl/jobs/__init__.py +1 -1
  102. edsl/jobs/async_interview_runner.py +138 -0
  103. edsl/jobs/buckets/BucketCollection.py +104 -63
  104. edsl/jobs/buckets/ModelBuckets.py +65 -65
  105. edsl/jobs/buckets/TokenBucket.py +283 -251
  106. edsl/jobs/buckets/TokenBucketAPI.py +211 -0
  107. edsl/jobs/buckets/TokenBucketClient.py +191 -0
  108. edsl/jobs/check_survey_scenario_compatibility.py +85 -0
  109. edsl/jobs/data_structures.py +120 -0
  110. edsl/jobs/decorators.py +35 -0
  111. edsl/jobs/interviews/Interview.py +396 -661
  112. edsl/jobs/interviews/InterviewExceptionCollection.py +99 -99
  113. edsl/jobs/interviews/InterviewExceptionEntry.py +186 -186
  114. edsl/jobs/interviews/InterviewStatistic.py +63 -63
  115. edsl/jobs/interviews/InterviewStatisticsCollection.py +25 -25
  116. edsl/jobs/interviews/InterviewStatusDictionary.py +78 -78
  117. edsl/jobs/interviews/InterviewStatusLog.py +92 -92
  118. edsl/jobs/interviews/ReportErrors.py +66 -66
  119. edsl/jobs/interviews/interview_status_enum.py +9 -9
  120. edsl/jobs/jobs_status_enums.py +9 -0
  121. edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
  122. edsl/jobs/results_exceptions_handler.py +98 -0
  123. edsl/jobs/runners/JobsRunnerAsyncio.py +151 -466
  124. edsl/jobs/runners/JobsRunnerStatus.py +297 -330
  125. edsl/jobs/tasks/QuestionTaskCreator.py +244 -242
  126. edsl/jobs/tasks/TaskCreators.py +64 -64
  127. edsl/jobs/tasks/TaskHistory.py +470 -450
  128. edsl/jobs/tasks/TaskStatusLog.py +23 -23
  129. edsl/jobs/tasks/task_status_enum.py +161 -163
  130. edsl/jobs/tokens/InterviewTokenUsage.py +27 -27
  131. edsl/jobs/tokens/TokenUsage.py +34 -34
  132. edsl/language_models/ComputeCost.py +63 -0
  133. edsl/language_models/LanguageModel.py +626 -668
  134. edsl/language_models/ModelList.py +164 -155
  135. edsl/language_models/PriceManager.py +127 -0
  136. edsl/language_models/RawResponseHandler.py +106 -0
  137. edsl/language_models/RegisterLanguageModelsMeta.py +184 -184
  138. edsl/language_models/ServiceDataSources.py +0 -0
  139. edsl/language_models/__init__.py +2 -3
  140. edsl/language_models/fake_openai_call.py +15 -15
  141. edsl/language_models/fake_openai_service.py +61 -61
  142. edsl/language_models/key_management/KeyLookup.py +63 -0
  143. edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
  144. edsl/language_models/key_management/KeyLookupCollection.py +38 -0
  145. edsl/language_models/key_management/__init__.py +0 -0
  146. edsl/language_models/key_management/models.py +131 -0
  147. edsl/language_models/model.py +256 -0
  148. edsl/language_models/repair.py +156 -156
  149. edsl/language_models/utilities.py +65 -64
  150. edsl/notebooks/Notebook.py +263 -258
  151. edsl/notebooks/NotebookToLaTeX.py +142 -0
  152. edsl/notebooks/__init__.py +1 -1
  153. edsl/prompts/Prompt.py +352 -362
  154. edsl/prompts/__init__.py +2 -2
  155. edsl/questions/ExceptionExplainer.py +77 -0
  156. edsl/questions/HTMLQuestion.py +103 -0
  157. edsl/questions/QuestionBase.py +518 -664
  158. edsl/questions/QuestionBasePromptsMixin.py +221 -217
  159. edsl/questions/QuestionBudget.py +227 -227
  160. edsl/questions/QuestionCheckBox.py +359 -359
  161. edsl/questions/QuestionExtract.py +180 -182
  162. edsl/questions/QuestionFreeText.py +113 -114
  163. edsl/questions/QuestionFunctional.py +166 -166
  164. edsl/questions/QuestionList.py +223 -231
  165. edsl/questions/QuestionMatrix.py +265 -0
  166. edsl/questions/QuestionMultipleChoice.py +330 -286
  167. edsl/questions/QuestionNumerical.py +151 -153
  168. edsl/questions/QuestionRank.py +314 -324
  169. edsl/questions/Quick.py +41 -41
  170. edsl/questions/SimpleAskMixin.py +74 -73
  171. edsl/questions/__init__.py +27 -26
  172. edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +334 -289
  173. edsl/questions/compose_questions.py +98 -98
  174. edsl/questions/data_structures.py +20 -0
  175. edsl/questions/decorators.py +21 -21
  176. edsl/questions/derived/QuestionLikertFive.py +76 -76
  177. edsl/questions/derived/QuestionLinearScale.py +90 -87
  178. edsl/questions/derived/QuestionTopK.py +93 -93
  179. edsl/questions/derived/QuestionYesNo.py +82 -82
  180. edsl/questions/descriptors.py +427 -413
  181. edsl/questions/loop_processor.py +149 -0
  182. edsl/questions/prompt_templates/question_budget.jinja +13 -13
  183. edsl/questions/prompt_templates/question_checkbox.jinja +32 -32
  184. edsl/questions/prompt_templates/question_extract.jinja +11 -11
  185. edsl/questions/prompt_templates/question_free_text.jinja +3 -3
  186. edsl/questions/prompt_templates/question_linear_scale.jinja +11 -11
  187. edsl/questions/prompt_templates/question_list.jinja +17 -17
  188. edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -33
  189. edsl/questions/prompt_templates/question_numerical.jinja +36 -36
  190. edsl/questions/{QuestionBaseGenMixin.py → question_base_gen_mixin.py} +168 -161
  191. edsl/questions/question_registry.py +177 -177
  192. edsl/questions/{RegisterQuestionsMeta.py → register_questions_meta.py} +71 -71
  193. edsl/questions/{ResponseValidatorABC.py → response_validator_abc.py} +188 -174
  194. edsl/questions/response_validator_factory.py +34 -0
  195. edsl/questions/settings.py +12 -12
  196. edsl/questions/templates/budget/answering_instructions.jinja +7 -7
  197. edsl/questions/templates/budget/question_presentation.jinja +7 -7
  198. edsl/questions/templates/checkbox/answering_instructions.jinja +10 -10
  199. edsl/questions/templates/checkbox/question_presentation.jinja +22 -22
  200. edsl/questions/templates/extract/answering_instructions.jinja +7 -7
  201. edsl/questions/templates/likert_five/answering_instructions.jinja +10 -10
  202. edsl/questions/templates/likert_five/question_presentation.jinja +11 -11
  203. edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -5
  204. edsl/questions/templates/linear_scale/question_presentation.jinja +5 -5
  205. edsl/questions/templates/list/answering_instructions.jinja +3 -3
  206. edsl/questions/templates/list/question_presentation.jinja +5 -5
  207. edsl/questions/templates/matrix/__init__.py +1 -0
  208. edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
  209. edsl/questions/templates/matrix/question_presentation.jinja +20 -0
  210. edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -9
  211. edsl/questions/templates/multiple_choice/question_presentation.jinja +11 -11
  212. edsl/questions/templates/numerical/answering_instructions.jinja +6 -6
  213. edsl/questions/templates/numerical/question_presentation.jinja +6 -6
  214. edsl/questions/templates/rank/answering_instructions.jinja +11 -11
  215. edsl/questions/templates/rank/question_presentation.jinja +15 -15
  216. edsl/questions/templates/top_k/answering_instructions.jinja +8 -8
  217. edsl/questions/templates/top_k/question_presentation.jinja +22 -22
  218. edsl/questions/templates/yes_no/answering_instructions.jinja +6 -6
  219. edsl/questions/templates/yes_no/question_presentation.jinja +11 -11
  220. edsl/results/CSSParameterizer.py +108 -108
  221. edsl/results/Dataset.py +587 -424
  222. edsl/results/DatasetExportMixin.py +594 -731
  223. edsl/results/DatasetTree.py +295 -275
  224. edsl/results/MarkdownToDocx.py +122 -0
  225. edsl/results/MarkdownToPDF.py +111 -0
  226. edsl/results/Result.py +557 -465
  227. edsl/results/Results.py +1183 -1165
  228. edsl/results/ResultsExportMixin.py +45 -43
  229. edsl/results/ResultsGGMixin.py +121 -121
  230. edsl/results/TableDisplay.py +125 -198
  231. edsl/results/TextEditor.py +50 -0
  232. edsl/results/__init__.py +2 -2
  233. edsl/results/file_exports.py +252 -0
  234. edsl/results/{ResultsFetchMixin.py → results_fetch_mixin.py} +33 -33
  235. edsl/results/{Selector.py → results_selector.py} +145 -135
  236. edsl/results/{ResultsToolsMixin.py → results_tools_mixin.py} +98 -98
  237. edsl/results/smart_objects.py +96 -0
  238. edsl/results/table_data_class.py +12 -0
  239. edsl/results/table_display.css +77 -77
  240. edsl/results/table_renderers.py +118 -0
  241. edsl/results/tree_explore.py +115 -115
  242. edsl/scenarios/ConstructDownloadLink.py +109 -0
  243. edsl/scenarios/DocumentChunker.py +102 -0
  244. edsl/scenarios/DocxScenario.py +16 -0
  245. edsl/scenarios/FileStore.py +511 -632
  246. edsl/scenarios/PdfExtractor.py +40 -0
  247. edsl/scenarios/Scenario.py +498 -601
  248. edsl/scenarios/ScenarioHtmlMixin.py +65 -64
  249. edsl/scenarios/ScenarioList.py +1458 -1287
  250. edsl/scenarios/ScenarioListExportMixin.py +45 -52
  251. edsl/scenarios/ScenarioListPdfMixin.py +239 -261
  252. edsl/scenarios/__init__.py +3 -4
  253. edsl/scenarios/directory_scanner.py +96 -0
  254. edsl/scenarios/file_methods.py +85 -0
  255. edsl/scenarios/handlers/__init__.py +13 -0
  256. edsl/scenarios/handlers/csv.py +38 -0
  257. edsl/scenarios/handlers/docx.py +76 -0
  258. edsl/scenarios/handlers/html.py +37 -0
  259. edsl/scenarios/handlers/json.py +111 -0
  260. edsl/scenarios/handlers/latex.py +5 -0
  261. edsl/scenarios/handlers/md.py +51 -0
  262. edsl/scenarios/handlers/pdf.py +68 -0
  263. edsl/scenarios/handlers/png.py +39 -0
  264. edsl/scenarios/handlers/pptx.py +105 -0
  265. edsl/scenarios/handlers/py.py +294 -0
  266. edsl/scenarios/handlers/sql.py +313 -0
  267. edsl/scenarios/handlers/sqlite.py +149 -0
  268. edsl/scenarios/handlers/txt.py +33 -0
  269. edsl/scenarios/{ScenarioJoin.py → scenario_join.py} +131 -127
  270. edsl/scenarios/scenario_selector.py +156 -0
  271. edsl/shared.py +1 -1
  272. edsl/study/ObjectEntry.py +173 -173
  273. edsl/study/ProofOfWork.py +113 -113
  274. edsl/study/SnapShot.py +80 -80
  275. edsl/study/Study.py +521 -528
  276. edsl/study/__init__.py +4 -4
  277. edsl/surveys/ConstructDAG.py +92 -0
  278. edsl/surveys/DAG.py +148 -148
  279. edsl/surveys/EditSurvey.py +221 -0
  280. edsl/surveys/InstructionHandler.py +100 -0
  281. edsl/surveys/Memory.py +31 -31
  282. edsl/surveys/MemoryManagement.py +72 -0
  283. edsl/surveys/MemoryPlan.py +244 -244
  284. edsl/surveys/Rule.py +327 -326
  285. edsl/surveys/RuleCollection.py +385 -387
  286. edsl/surveys/RuleManager.py +172 -0
  287. edsl/surveys/Simulator.py +75 -0
  288. edsl/surveys/Survey.py +1280 -1801
  289. edsl/surveys/SurveyCSS.py +273 -261
  290. edsl/surveys/SurveyExportMixin.py +259 -259
  291. edsl/surveys/{SurveyFlowVisualizationMixin.py → SurveyFlowVisualization.py} +181 -179
  292. edsl/surveys/SurveyQualtricsImport.py +284 -284
  293. edsl/surveys/SurveyToApp.py +141 -0
  294. edsl/surveys/__init__.py +5 -3
  295. edsl/surveys/base.py +53 -53
  296. edsl/surveys/descriptors.py +60 -56
  297. edsl/surveys/instructions/ChangeInstruction.py +48 -49
  298. edsl/surveys/instructions/Instruction.py +56 -65
  299. edsl/surveys/instructions/InstructionCollection.py +82 -77
  300. edsl/templates/error_reporting/base.html +23 -23
  301. edsl/templates/error_reporting/exceptions_by_model.html +34 -34
  302. edsl/templates/error_reporting/exceptions_by_question_name.html +16 -16
  303. edsl/templates/error_reporting/exceptions_by_type.html +16 -16
  304. edsl/templates/error_reporting/interview_details.html +115 -115
  305. edsl/templates/error_reporting/interviews.html +19 -19
  306. edsl/templates/error_reporting/overview.html +4 -4
  307. edsl/templates/error_reporting/performance_plot.html +1 -1
  308. edsl/templates/error_reporting/report.css +73 -73
  309. edsl/templates/error_reporting/report.html +117 -117
  310. edsl/templates/error_reporting/report.js +25 -25
  311. edsl/test_h +1 -0
  312. edsl/tools/__init__.py +1 -1
  313. edsl/tools/clusters.py +192 -192
  314. edsl/tools/embeddings.py +27 -27
  315. edsl/tools/embeddings_plotting.py +118 -118
  316. edsl/tools/plotting.py +112 -112
  317. edsl/tools/summarize.py +18 -18
  318. edsl/utilities/PrettyList.py +56 -0
  319. edsl/utilities/SystemInfo.py +28 -28
  320. edsl/utilities/__init__.py +22 -22
  321. edsl/utilities/ast_utilities.py +25 -25
  322. edsl/utilities/data/Registry.py +6 -6
  323. edsl/utilities/data/__init__.py +1 -1
  324. edsl/utilities/data/scooter_results.json +1 -1
  325. edsl/utilities/decorators.py +77 -77
  326. edsl/utilities/gcp_bucket/cloud_storage.py +96 -96
  327. edsl/utilities/gcp_bucket/example.py +50 -0
  328. edsl/utilities/interface.py +627 -627
  329. edsl/utilities/is_notebook.py +18 -0
  330. edsl/utilities/is_valid_variable_name.py +11 -0
  331. edsl/utilities/naming_utilities.py +263 -263
  332. edsl/utilities/remove_edsl_version.py +24 -0
  333. edsl/utilities/repair_functions.py +28 -28
  334. edsl/utilities/restricted_python.py +70 -70
  335. edsl/utilities/utilities.py +436 -424
  336. {edsl-0.1.39.dev3.dist-info → edsl-0.1.39.dev4.dist-info}/LICENSE +21 -21
  337. {edsl-0.1.39.dev3.dist-info → edsl-0.1.39.dev4.dist-info}/METADATA +13 -11
  338. edsl-0.1.39.dev4.dist-info/RECORD +361 -0
  339. edsl/language_models/KeyLookup.py +0 -30
  340. edsl/language_models/registry.py +0 -190
  341. edsl/language_models/unused/ReplicateBase.py +0 -83
  342. edsl/results/ResultsDBMixin.py +0 -238
  343. edsl-0.1.39.dev3.dist-info/RECORD +0 -277
  344. {edsl-0.1.39.dev3.dist-info → edsl-0.1.39.dev4.dist-info}/WHEEL +0 -0
@@ -1,118 +1,118 @@
1
- import json
2
- import numpy as np
3
- from sklearn.manifold import TSNE
4
- from IPython.display import display_html
5
-
6
-
7
- def compute_tsne(embeddings, labels):
8
- embeddings_np = np.array(embeddings)
9
- tsne = TSNE(n_components=2, random_state=42)
10
- tsne_results = tsne.fit_transform(embeddings_np)
11
- data = [
12
- {
13
- "x": float(tsne_results[i, 0]),
14
- "y": float(tsne_results[i, 1]),
15
- "label": labels[i],
16
- }
17
- for i in range(len(labels))
18
- ]
19
- return data
20
-
21
-
22
- def plot_tsne_in_notebook(embeddings, labels):
23
- # Compute t-SNE
24
- data = compute_tsne(embeddings, labels)
25
-
26
- # Convert data to JSON
27
- data_json = json.dumps(data)
28
-
29
- # HTML content with embedded data
30
- html_content = f"""
31
- <!DOCTYPE html>
32
- <html lang="en">
33
- <head>
34
- <meta charset="UTF-8">
35
- <title>t-SNE Plot with D3.js</title>
36
- <script src="https://d3js.org/d3.v6.min.js"></script>
37
- <style>
38
- .tooltip {{
39
- position: absolute;
40
- text-align: center;
41
- width: 80px;
42
- height: 28px;
43
- padding: 2px;
44
- font: 12px sans-serif;
45
- background: lightsteelblue;
46
- border: 0px;
47
- border-radius: 8px;
48
- pointer-events: none;
49
- }}
50
- </style>
51
- </head>
52
- <body>
53
- <svg width="600" height="600"></svg>
54
-
55
- <script>
56
- // Embedded data
57
- const data = {data_json};
58
-
59
- const svg = d3.select("svg"),
60
- width = +svg.attr("width"),
61
- height = +svg.attr("height");
62
-
63
- // Set up scales
64
- const x = d3.scaleLinear()
65
- .domain(d3.extent(data, d => d.x))
66
- .range([0, width]);
67
-
68
- const y = d3.scaleLinear()
69
- .domain(d3.extent(data, d => d.y))
70
- .range([height, 0]);
71
-
72
- // Create tooltip
73
- const tooltip = d3.select("body").append("div")
74
- .attr("class", "tooltip")
75
- .style("opacity", 0);
76
-
77
- // Create circles for each point
78
- svg.selectAll("circle")
79
- .data(data)
80
- .enter().append("circle")
81
- .attr("cx", d => x(d.x))
82
- .attr("cy", d => y(d.y))
83
- .attr("r", 5)
84
- .style("fill", "steelblue")
85
- .on("mouseover", function(event, d) {{
86
- tooltip.transition()
87
- .duration(200)
88
- .style("opacity", .9);
89
- tooltip.html(d.label)
90
- .style("left", (event.pageX + 5) + "px")
91
- .style("top", (event.pageY - 28) + "px");
92
- }})
93
- .on("mouseout", function(d) {{
94
- tooltip.transition()
95
- .duration(500)
96
- .style("opacity", 0);
97
- }});
98
- </script>
99
- </body>
100
- </html>
101
- """
102
-
103
- # Write HTML content to a temporary file
104
- html_file = "tsne_plot.html"
105
- with open(html_file, "w") as file:
106
- file.write(html_content)
107
-
108
- # Display the HTML content in an iframe within a Jupyter notebook
109
- display_html(
110
- f'<iframe src="{html_file}" width="600" height="600"></iframe>', raw=True
111
- )
112
-
113
-
114
- # Example usage
115
- if __name__ == "__main__":
116
- embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]
117
- labels = ["String 1", "String 2", "String 3"]
118
- plot_tsne_in_notebook(embeddings, labels)
1
+ import json
2
+ import numpy as np
3
+ from sklearn.manifold import TSNE
4
+ from IPython.display import display_html
5
+
6
+
7
+ def compute_tsne(embeddings, labels):
8
+ embeddings_np = np.array(embeddings)
9
+ tsne = TSNE(n_components=2, random_state=42)
10
+ tsne_results = tsne.fit_transform(embeddings_np)
11
+ data = [
12
+ {
13
+ "x": float(tsne_results[i, 0]),
14
+ "y": float(tsne_results[i, 1]),
15
+ "label": labels[i],
16
+ }
17
+ for i in range(len(labels))
18
+ ]
19
+ return data
20
+
21
+
22
+ def plot_tsne_in_notebook(embeddings, labels):
23
+ # Compute t-SNE
24
+ data = compute_tsne(embeddings, labels)
25
+
26
+ # Convert data to JSON
27
+ data_json = json.dumps(data)
28
+
29
+ # HTML content with embedded data
30
+ html_content = f"""
31
+ <!DOCTYPE html>
32
+ <html lang="en">
33
+ <head>
34
+ <meta charset="UTF-8">
35
+ <title>t-SNE Plot with D3.js</title>
36
+ <script src="https://d3js.org/d3.v6.min.js"></script>
37
+ <style>
38
+ .tooltip {{
39
+ position: absolute;
40
+ text-align: center;
41
+ width: 80px;
42
+ height: 28px;
43
+ padding: 2px;
44
+ font: 12px sans-serif;
45
+ background: lightsteelblue;
46
+ border: 0px;
47
+ border-radius: 8px;
48
+ pointer-events: none;
49
+ }}
50
+ </style>
51
+ </head>
52
+ <body>
53
+ <svg width="600" height="600"></svg>
54
+
55
+ <script>
56
+ // Embedded data
57
+ const data = {data_json};
58
+
59
+ const svg = d3.select("svg"),
60
+ width = +svg.attr("width"),
61
+ height = +svg.attr("height");
62
+
63
+ // Set up scales
64
+ const x = d3.scaleLinear()
65
+ .domain(d3.extent(data, d => d.x))
66
+ .range([0, width]);
67
+
68
+ const y = d3.scaleLinear()
69
+ .domain(d3.extent(data, d => d.y))
70
+ .range([height, 0]);
71
+
72
+ // Create tooltip
73
+ const tooltip = d3.select("body").append("div")
74
+ .attr("class", "tooltip")
75
+ .style("opacity", 0);
76
+
77
+ // Create circles for each point
78
+ svg.selectAll("circle")
79
+ .data(data)
80
+ .enter().append("circle")
81
+ .attr("cx", d => x(d.x))
82
+ .attr("cy", d => y(d.y))
83
+ .attr("r", 5)
84
+ .style("fill", "steelblue")
85
+ .on("mouseover", function(event, d) {{
86
+ tooltip.transition()
87
+ .duration(200)
88
+ .style("opacity", .9);
89
+ tooltip.html(d.label)
90
+ .style("left", (event.pageX + 5) + "px")
91
+ .style("top", (event.pageY - 28) + "px");
92
+ }})
93
+ .on("mouseout", function(d) {{
94
+ tooltip.transition()
95
+ .duration(500)
96
+ .style("opacity", 0);
97
+ }});
98
+ </script>
99
+ </body>
100
+ </html>
101
+ """
102
+
103
+ # Write HTML content to a temporary file
104
+ html_file = "tsne_plot.html"
105
+ with open(html_file, "w") as file:
106
+ file.write(html_content)
107
+
108
+ # Display the HTML content in an iframe within a Jupyter notebook
109
+ display_html(
110
+ f'<iframe src="{html_file}" width="600" height="600"></iframe>', raw=True
111
+ )
112
+
113
+
114
+ # Example usage
115
+ if __name__ == "__main__":
116
+ embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]
117
+ labels = ["String 1", "String 2", "String 3"]
118
+ plot_tsne_in_notebook(embeddings, labels)
edsl/tools/plotting.py CHANGED
@@ -1,112 +1,112 @@
1
- from typing import Optional
2
-
3
-
4
- def count_query(field):
5
- return f"""SELECT
6
- {field},
7
- COUNT(*) as number
8
- FROM self
9
- GROUP BY {field}
10
- """
11
-
12
-
13
- def get_options(results, field):
14
- question_type = results.survey.get_question(field).question_type
15
- if question_type in ["multiple_choice", "checkbox"]:
16
- return results.select(f"{field}_question_options").first()
17
- else:
18
- return None
19
-
20
-
21
- def interpret_image(path, analysis):
22
- from edsl import QuestionFreeText
23
- from edsl import Model
24
- from edsl import Scenario
25
-
26
- s = Scenario.from_image(path)
27
- if isinstance(analysis, str):
28
- plot_question_texts = [analysis]
29
- elif isinstance(analysis, list):
30
- plot_question_texts = analysis
31
-
32
- scenario_list = s.replicate(len(plot_question_texts))
33
- scenario_list.add_list("plot_question_text", plot_question_texts)
34
-
35
- m = Model("gpt-4o")
36
- q = QuestionFreeText(
37
- question_text="{{ plot_question_text }}", question_name="interpretation"
38
- )
39
- results = q.by(m).by(scenario_list).run()
40
- return results.select("plot_question_text", "interpretation").print(
41
- format="rich",
42
- pretty_labels={
43
- "scenario.plot_question_text": "Question to the model",
44
- "answer.interpretation": "Model answer",
45
- },
46
- )
47
-
48
-
49
- def barchart(
50
- results,
51
- field: str,
52
- fetch_options=True,
53
- xlab: Optional[str] = None,
54
- ylab: Optional[str] = None,
55
- analysis: Optional[str] = None,
56
- format: str = "png",
57
- ):
58
- labels = ""
59
- if xlab:
60
- labels += f"+ xlab('{xlab}')"
61
- if ylab:
62
- labels += f"+ ylab('{ylab}')"
63
-
64
- if fetch_options:
65
- factor_orders = {field: get_options(results, field)}
66
- else:
67
- factor_orders = None
68
-
69
- plot = results.ggplot2(
70
- f"""ggplot(data = self, aes(x = {field}, y = number)) +
71
- geom_bar(stat = "identity") +
72
- theme_bw() +
73
- theme(axis.text.x = element_text(angle = 45, hjust = 1)) {labels}""",
74
- sql=count_query(field),
75
- factor_orders=factor_orders,
76
- format=format,
77
- filename=f"barchart_{field}.{format}",
78
- )
79
- if analysis:
80
- interpret_image(f"barchart_{field}.{format}", analysis)
81
-
82
- return plot
83
-
84
-
85
- def theme_plot(results, field, context, themes=None, progress_bar=False):
86
- _, themes = results.auto_theme(
87
- field=field, context=context, themes=themes, progress_bar=progress_bar
88
- )
89
-
90
- themes_query = f"""
91
- SELECT theme, COUNT(*) AS mentions
92
- FROM (
93
- SELECT json_each.value AS theme
94
- FROM self,
95
- json_each({ field }_themes)
96
- )
97
- GROUP BY theme
98
- HAVING theme <> 'Other'
99
- ORDER BY mentions DESC
100
- """
101
- themes = results.sql(themes_query, to_list=True)
102
-
103
- (
104
- results.filter(f"{field} != ''").ggplot2(
105
- """ggplot(data = self, aes(x = theme, y = mentions)) +
106
- geom_bar(stat = "identity") +
107
- coord_flip() +
108
- theme_bw()""",
109
- sql=themes_query,
110
- factor_orders={"theme": [t[0] for t in themes]},
111
- )
112
- )
1
+ from typing import Optional
2
+
3
+
4
+ def count_query(field):
5
+ return f"""SELECT
6
+ {field},
7
+ COUNT(*) as number
8
+ FROM self
9
+ GROUP BY {field}
10
+ """
11
+
12
+
13
+ def get_options(results, field):
14
+ question_type = results.survey._get_question_by_name(field).question_type
15
+ if question_type in ["multiple_choice", "checkbox"]:
16
+ return results.select(f"{field}_question_options").first()
17
+ else:
18
+ return None
19
+
20
+
21
+ def interpret_image(path, analysis):
22
+ from edsl import QuestionFreeText
23
+ from edsl import Model
24
+ from edsl import Scenario
25
+
26
+ s = Scenario.from_image(path)
27
+ if isinstance(analysis, str):
28
+ plot_question_texts = [analysis]
29
+ elif isinstance(analysis, list):
30
+ plot_question_texts = analysis
31
+
32
+ scenario_list = s.replicate(len(plot_question_texts))
33
+ scenario_list.add_list("plot_question_text", plot_question_texts)
34
+
35
+ m = Model("gpt-4o")
36
+ q = QuestionFreeText(
37
+ question_text="{{ plot_question_text }}", question_name="interpretation"
38
+ )
39
+ results = q.by(m).by(scenario_list).run()
40
+ return results.select("plot_question_text", "interpretation").print(
41
+ format="rich",
42
+ pretty_labels={
43
+ "scenario.plot_question_text": "Question to the model",
44
+ "answer.interpretation": "Model answer",
45
+ },
46
+ )
47
+
48
+
49
+ def barchart(
50
+ results,
51
+ field: str,
52
+ fetch_options=True,
53
+ xlab: Optional[str] = None,
54
+ ylab: Optional[str] = None,
55
+ analysis: Optional[str] = None,
56
+ format: str = "png",
57
+ ):
58
+ labels = ""
59
+ if xlab:
60
+ labels += f"+ xlab('{xlab}')"
61
+ if ylab:
62
+ labels += f"+ ylab('{ylab}')"
63
+
64
+ if fetch_options:
65
+ factor_orders = {field: get_options(results, field)}
66
+ else:
67
+ factor_orders = None
68
+
69
+ plot = results.ggplot2(
70
+ f"""ggplot(data = self, aes(x = {field}, y = number)) +
71
+ geom_bar(stat = "identity") +
72
+ theme_bw() +
73
+ theme(axis.text.x = element_text(angle = 45, hjust = 1)) {labels}""",
74
+ sql=count_query(field),
75
+ factor_orders=factor_orders,
76
+ format=format,
77
+ filename=f"barchart_{field}.{format}",
78
+ )
79
+ if analysis:
80
+ interpret_image(f"barchart_{field}.{format}", analysis)
81
+
82
+ return plot
83
+
84
+
85
+ def theme_plot(results, field, context, themes=None, progress_bar=False):
86
+ _, themes = results.auto_theme(
87
+ field=field, context=context, themes=themes, progress_bar=progress_bar
88
+ )
89
+
90
+ themes_query = f"""
91
+ SELECT theme, COUNT(*) AS mentions
92
+ FROM (
93
+ SELECT json_each.value AS theme
94
+ FROM self,
95
+ json_each({ field }_themes)
96
+ )
97
+ GROUP BY theme
98
+ HAVING theme <> 'Other'
99
+ ORDER BY mentions DESC
100
+ """
101
+ themes = results.sql(themes_query, to_list=True)
102
+
103
+ (
104
+ results.filter(f"{field} != ''").ggplot2(
105
+ """ggplot(data = self, aes(x = theme, y = mentions)) +
106
+ geom_bar(stat = "identity") +
107
+ coord_flip() +
108
+ theme_bw()""",
109
+ sql=themes_query,
110
+ factor_orders={"theme": [t[0] for t in themes]},
111
+ )
112
+ )
edsl/tools/summarize.py CHANGED
@@ -1,18 +1,18 @@
1
- from edsl import QuestionList, Scenario, Model
2
-
3
-
4
- def summarize(texts, seed_phrase, n_bullets, n_words, models=None):
5
- if models is None:
6
- models = Model()
7
- s = Scenario(
8
- text=texts, seed_phrase=seed_phrase, n_bullets=n_bullets, n_words=n_words
9
- ).expand("text")
10
- QuestionList(
11
- question_text="""
12
- I have the following TEXT EXAMPLE :
13
- {{ text_example_json }}
14
- Please summarize the main point of this EXAMPLE {{seed_phrase }} into {{ n_bullets }} bullet points, where
15
- each bullet point is a {{ n_words }} word phrase.
16
- """,
17
- question_name="summarize",
18
- ).by(s).by(models).run()
1
+ from edsl import QuestionList, Scenario, Model
2
+
3
+
4
+ def summarize(texts, seed_phrase, n_bullets, n_words, models=None):
5
+ if models is None:
6
+ models = Model()
7
+ s = Scenario(
8
+ text=texts, seed_phrase=seed_phrase, n_bullets=n_bullets, n_words=n_words
9
+ ).expand("text")
10
+ QuestionList(
11
+ question_text="""
12
+ I have the following TEXT EXAMPLE :
13
+ {{ text_example_json }}
14
+ Please summarize the main point of this EXAMPLE {{seed_phrase }} into {{ n_bullets }} bullet points, where
15
+ each bullet point is a {{ n_words }} word phrase.
16
+ """,
17
+ question_name="summarize",
18
+ ).by(s).by(models).run()
@@ -0,0 +1,56 @@
1
+ from collections import UserList
2
+ from edsl.results.Dataset import Dataset
3
+
4
+
5
+ class PrettyList(UserList):
6
+ def __init__(self, data=None, columns=None):
7
+ super().__init__(data)
8
+ self.columns = columns
9
+
10
+ def _repr_html_(self):
11
+ if isinstance(self[0], list) or isinstance(self[0], tuple):
12
+ num_cols = len(self[0])
13
+ else:
14
+ num_cols = 1
15
+
16
+ if self.columns:
17
+ columns = self.columns
18
+ else:
19
+ columns = list(range(num_cols))
20
+
21
+ d = {}
22
+ for column in columns:
23
+ d[column] = []
24
+
25
+ for row in self:
26
+ for index, column in enumerate(columns):
27
+ if isinstance(row, list) or isinstance(row, tuple):
28
+ d[column].append(row[index])
29
+ else:
30
+ d[column].append(row)
31
+ # raise ValueError(d)
32
+ return Dataset([{key: entry} for key, entry in d.items()])._repr_html_()
33
+
34
+ if num_cols > 1:
35
+ return (
36
+ "<pre><table>"
37
+ + "".join(["<th>" + str(column) + "</th>" for column in columns])
38
+ + "".join(
39
+ [
40
+ "<tr>"
41
+ + "".join(["<td>" + str(x) + "</td>" for x in row])
42
+ + "</tr>"
43
+ for row in self
44
+ ]
45
+ )
46
+ + "</table></pre>"
47
+ )
48
+ else:
49
+ return (
50
+ "<pre><table>"
51
+ + "".join(["<th>" + str(index) + "</th>" for index in columns])
52
+ + "".join(
53
+ ["<tr>" + "<td>" + str(row) + "</td>" + "</tr>" for row in self]
54
+ )
55
+ + "</table></pre>"
56
+ )
@@ -1,28 +1,28 @@
1
- """Module to store system information."""
2
-
3
- from dataclasses import dataclass
4
- import getpass
5
- import platform
6
- import pkg_resources
7
-
8
-
9
- @dataclass
10
- class SystemInfo:
11
- """Dataclass to store system information."""
12
-
13
- username: str
14
- system_info: str
15
- release_info: str
16
- package_name: str
17
- package_version: str
18
-
19
- def __init__(self, package_name: str):
20
- """Initialize the dataclass with system."""
21
- self.username = getpass.getuser()
22
- self.system_info = platform.system()
23
- self.release_info = platform.release()
24
- self.package_name = package_name
25
- try:
26
- self.package_version = pkg_resources.get_distribution(package_name).version
27
- except pkg_resources.DistributionNotFound:
28
- self.package_version = "Not installed"
1
+ """Module to store system information."""
2
+
3
+ from dataclasses import dataclass
4
+ import getpass
5
+ import platform
6
+ import pkg_resources
7
+
8
+
9
+ @dataclass
10
+ class SystemInfo:
11
+ """Dataclass to store system information."""
12
+
13
+ username: str
14
+ system_info: str
15
+ release_info: str
16
+ package_name: str
17
+ package_version: str
18
+
19
+ def __init__(self, package_name: str):
20
+ """Initialize the dataclass with system."""
21
+ self.username = getpass.getuser()
22
+ self.system_info = platform.system()
23
+ self.release_info = platform.release()
24
+ self.package_name = package_name
25
+ try:
26
+ self.package_version = pkg_resources.get_distribution(package_name).version
27
+ except pkg_resources.DistributionNotFound:
28
+ self.package_version = "Not installed"
@@ -1,22 +1,22 @@
1
- # from edsl.utilities.interface import (
2
- # print_dict_as_html_table,
3
- # print_dict_with_rich,
4
- # print_list_of_dicts_as_html_table,
5
- # print_table_with_rich,
6
- # print_public_methods_with_doc,
7
- # print_list_of_dicts_as_markdown_table,
8
- # )
9
-
10
- # from edsl.utilities.utilities import (
11
- # create_valid_var_name,
12
- # dict_to_html,
13
- # hash_value,
14
- # HTMLSnippet,
15
- # is_notebook,
16
- # is_gzipped,
17
- # is_valid_variable_name,
18
- # random_string,
19
- # repair_json,
20
- # shorten_string,
21
- # time_all_functions,
22
- # )
1
+ # from edsl.utilities.interface import (
2
+ # print_dict_as_html_table,
3
+ # print_dict_with_rich,
4
+ # print_list_of_dicts_as_html_table,
5
+ # print_table_with_rich,
6
+ # print_public_methods_with_doc,
7
+ # print_list_of_dicts_as_markdown_table,
8
+ # )
9
+
10
+ # from edsl.utilities.utilities import (
11
+ # create_valid_var_name,
12
+ # dict_to_html,
13
+ # hash_value,
14
+ # HTMLSnippet,
15
+ # is_notebook,
16
+ # is_gzipped,
17
+ # is_valid_variable_name,
18
+ # random_string,
19
+ # repair_json,
20
+ # shorten_string,
21
+ # time_all_functions,
22
+ # )