edsl 0.1.39.dev3__py3-none-any.whl → 0.1.39.dev4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +413 -332
- edsl/BaseDiff.py +260 -260
- edsl/TemplateLoader.py +24 -24
- edsl/__init__.py +57 -49
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +1071 -867
- edsl/agents/AgentList.py +551 -413
- edsl/agents/Invigilator.py +284 -233
- edsl/agents/InvigilatorBase.py +257 -270
- edsl/agents/PromptConstructor.py +272 -354
- edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
- edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
- edsl/agents/__init__.py +2 -3
- edsl/agents/descriptors.py +99 -99
- edsl/agents/prompt_helpers.py +129 -129
- edsl/agents/question_option_processor.py +172 -0
- edsl/auto/AutoStudy.py +130 -117
- edsl/auto/StageBase.py +243 -230
- edsl/auto/StageGenerateSurvey.py +178 -178
- edsl/auto/StageLabelQuestions.py +125 -125
- edsl/auto/StagePersona.py +61 -61
- edsl/auto/StagePersonaDimensionValueRanges.py +88 -88
- edsl/auto/StagePersonaDimensionValues.py +74 -74
- edsl/auto/StagePersonaDimensions.py +69 -69
- edsl/auto/StageQuestions.py +74 -73
- edsl/auto/SurveyCreatorPipeline.py +21 -21
- edsl/auto/utilities.py +218 -224
- edsl/base/Base.py +279 -279
- edsl/config.py +177 -157
- edsl/conversation/Conversation.py +290 -290
- edsl/conversation/car_buying.py +59 -58
- edsl/conversation/chips.py +95 -95
- edsl/conversation/mug_negotiation.py +81 -81
- edsl/conversation/next_speaker_utilities.py +93 -93
- edsl/coop/CoopFunctionsMixin.py +15 -0
- edsl/coop/ExpectedParrotKeyHandler.py +125 -0
- edsl/coop/PriceFetcher.py +54 -54
- edsl/coop/__init__.py +2 -2
- edsl/coop/coop.py +1106 -1028
- edsl/coop/utils.py +131 -131
- edsl/data/Cache.py +573 -555
- edsl/data/CacheEntry.py +230 -233
- edsl/data/CacheHandler.py +168 -149
- edsl/data/RemoteCacheSync.py +186 -78
- edsl/data/SQLiteDict.py +292 -292
- edsl/data/__init__.py +5 -4
- edsl/data/hack.py +10 -0
- edsl/data/orm.py +10 -10
- edsl/data_transfer_models.py +74 -73
- edsl/enums.py +202 -175
- edsl/exceptions/BaseException.py +21 -21
- edsl/exceptions/__init__.py +54 -54
- edsl/exceptions/agents.py +54 -42
- edsl/exceptions/cache.py +5 -5
- edsl/exceptions/configuration.py +16 -16
- edsl/exceptions/coop.py +10 -10
- edsl/exceptions/data.py +14 -14
- edsl/exceptions/general.py +34 -34
- edsl/exceptions/inference_services.py +5 -0
- edsl/exceptions/jobs.py +33 -33
- edsl/exceptions/language_models.py +63 -63
- edsl/exceptions/prompts.py +15 -15
- edsl/exceptions/questions.py +109 -91
- edsl/exceptions/results.py +29 -29
- edsl/exceptions/scenarios.py +29 -22
- edsl/exceptions/surveys.py +37 -37
- edsl/inference_services/AnthropicService.py +106 -87
- edsl/inference_services/AvailableModelCacheHandler.py +184 -0
- edsl/inference_services/AvailableModelFetcher.py +215 -0
- edsl/inference_services/AwsBedrock.py +118 -120
- edsl/inference_services/AzureAI.py +215 -217
- edsl/inference_services/DeepInfraService.py +18 -18
- edsl/inference_services/GoogleService.py +143 -148
- edsl/inference_services/GroqService.py +20 -20
- edsl/inference_services/InferenceServiceABC.py +80 -147
- edsl/inference_services/InferenceServicesCollection.py +138 -97
- edsl/inference_services/MistralAIService.py +120 -123
- edsl/inference_services/OllamaService.py +18 -18
- edsl/inference_services/OpenAIService.py +236 -224
- edsl/inference_services/PerplexityService.py +160 -163
- edsl/inference_services/ServiceAvailability.py +135 -0
- edsl/inference_services/TestService.py +90 -89
- edsl/inference_services/TogetherAIService.py +172 -170
- edsl/inference_services/data_structures.py +134 -0
- edsl/inference_services/models_available_cache.py +118 -118
- edsl/inference_services/rate_limits_cache.py +25 -25
- edsl/inference_services/registry.py +41 -41
- edsl/inference_services/write_available.py +10 -10
- edsl/jobs/AnswerQuestionFunctionConstructor.py +223 -0
- edsl/jobs/Answers.py +43 -56
- edsl/jobs/FetchInvigilator.py +47 -0
- edsl/jobs/InterviewTaskManager.py +98 -0
- edsl/jobs/InterviewsConstructor.py +50 -0
- edsl/jobs/Jobs.py +823 -898
- edsl/jobs/JobsChecks.py +172 -147
- edsl/jobs/JobsComponentConstructor.py +189 -0
- edsl/jobs/JobsPrompts.py +270 -268
- edsl/jobs/JobsRemoteInferenceHandler.py +311 -239
- edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
- edsl/jobs/RequestTokenEstimator.py +30 -0
- edsl/jobs/__init__.py +1 -1
- edsl/jobs/async_interview_runner.py +138 -0
- edsl/jobs/buckets/BucketCollection.py +104 -63
- edsl/jobs/buckets/ModelBuckets.py +65 -65
- edsl/jobs/buckets/TokenBucket.py +283 -251
- edsl/jobs/buckets/TokenBucketAPI.py +211 -0
- edsl/jobs/buckets/TokenBucketClient.py +191 -0
- edsl/jobs/check_survey_scenario_compatibility.py +85 -0
- edsl/jobs/data_structures.py +120 -0
- edsl/jobs/decorators.py +35 -0
- edsl/jobs/interviews/Interview.py +396 -661
- edsl/jobs/interviews/InterviewExceptionCollection.py +99 -99
- edsl/jobs/interviews/InterviewExceptionEntry.py +186 -186
- edsl/jobs/interviews/InterviewStatistic.py +63 -63
- edsl/jobs/interviews/InterviewStatisticsCollection.py +25 -25
- edsl/jobs/interviews/InterviewStatusDictionary.py +78 -78
- edsl/jobs/interviews/InterviewStatusLog.py +92 -92
- edsl/jobs/interviews/ReportErrors.py +66 -66
- edsl/jobs/interviews/interview_status_enum.py +9 -9
- edsl/jobs/jobs_status_enums.py +9 -0
- edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
- edsl/jobs/results_exceptions_handler.py +98 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +151 -466
- edsl/jobs/runners/JobsRunnerStatus.py +297 -330
- edsl/jobs/tasks/QuestionTaskCreator.py +244 -242
- edsl/jobs/tasks/TaskCreators.py +64 -64
- edsl/jobs/tasks/TaskHistory.py +470 -450
- edsl/jobs/tasks/TaskStatusLog.py +23 -23
- edsl/jobs/tasks/task_status_enum.py +161 -163
- edsl/jobs/tokens/InterviewTokenUsage.py +27 -27
- edsl/jobs/tokens/TokenUsage.py +34 -34
- edsl/language_models/ComputeCost.py +63 -0
- edsl/language_models/LanguageModel.py +626 -668
- edsl/language_models/ModelList.py +164 -155
- edsl/language_models/PriceManager.py +127 -0
- edsl/language_models/RawResponseHandler.py +106 -0
- edsl/language_models/RegisterLanguageModelsMeta.py +184 -184
- edsl/language_models/ServiceDataSources.py +0 -0
- edsl/language_models/__init__.py +2 -3
- edsl/language_models/fake_openai_call.py +15 -15
- edsl/language_models/fake_openai_service.py +61 -61
- edsl/language_models/key_management/KeyLookup.py +63 -0
- edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
- edsl/language_models/key_management/KeyLookupCollection.py +38 -0
- edsl/language_models/key_management/__init__.py +0 -0
- edsl/language_models/key_management/models.py +131 -0
- edsl/language_models/model.py +256 -0
- edsl/language_models/repair.py +156 -156
- edsl/language_models/utilities.py +65 -64
- edsl/notebooks/Notebook.py +263 -258
- edsl/notebooks/NotebookToLaTeX.py +142 -0
- edsl/notebooks/__init__.py +1 -1
- edsl/prompts/Prompt.py +352 -362
- edsl/prompts/__init__.py +2 -2
- edsl/questions/ExceptionExplainer.py +77 -0
- edsl/questions/HTMLQuestion.py +103 -0
- edsl/questions/QuestionBase.py +518 -664
- edsl/questions/QuestionBasePromptsMixin.py +221 -217
- edsl/questions/QuestionBudget.py +227 -227
- edsl/questions/QuestionCheckBox.py +359 -359
- edsl/questions/QuestionExtract.py +180 -182
- edsl/questions/QuestionFreeText.py +113 -114
- edsl/questions/QuestionFunctional.py +166 -166
- edsl/questions/QuestionList.py +223 -231
- edsl/questions/QuestionMatrix.py +265 -0
- edsl/questions/QuestionMultipleChoice.py +330 -286
- edsl/questions/QuestionNumerical.py +151 -153
- edsl/questions/QuestionRank.py +314 -324
- edsl/questions/Quick.py +41 -41
- edsl/questions/SimpleAskMixin.py +74 -73
- edsl/questions/__init__.py +27 -26
- edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +334 -289
- edsl/questions/compose_questions.py +98 -98
- edsl/questions/data_structures.py +20 -0
- edsl/questions/decorators.py +21 -21
- edsl/questions/derived/QuestionLikertFive.py +76 -76
- edsl/questions/derived/QuestionLinearScale.py +90 -87
- edsl/questions/derived/QuestionTopK.py +93 -93
- edsl/questions/derived/QuestionYesNo.py +82 -82
- edsl/questions/descriptors.py +427 -413
- edsl/questions/loop_processor.py +149 -0
- edsl/questions/prompt_templates/question_budget.jinja +13 -13
- edsl/questions/prompt_templates/question_checkbox.jinja +32 -32
- edsl/questions/prompt_templates/question_extract.jinja +11 -11
- edsl/questions/prompt_templates/question_free_text.jinja +3 -3
- edsl/questions/prompt_templates/question_linear_scale.jinja +11 -11
- edsl/questions/prompt_templates/question_list.jinja +17 -17
- edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -33
- edsl/questions/prompt_templates/question_numerical.jinja +36 -36
- edsl/questions/{QuestionBaseGenMixin.py → question_base_gen_mixin.py} +168 -161
- edsl/questions/question_registry.py +177 -177
- edsl/questions/{RegisterQuestionsMeta.py → register_questions_meta.py} +71 -71
- edsl/questions/{ResponseValidatorABC.py → response_validator_abc.py} +188 -174
- edsl/questions/response_validator_factory.py +34 -0
- edsl/questions/settings.py +12 -12
- edsl/questions/templates/budget/answering_instructions.jinja +7 -7
- edsl/questions/templates/budget/question_presentation.jinja +7 -7
- edsl/questions/templates/checkbox/answering_instructions.jinja +10 -10
- edsl/questions/templates/checkbox/question_presentation.jinja +22 -22
- edsl/questions/templates/extract/answering_instructions.jinja +7 -7
- edsl/questions/templates/likert_five/answering_instructions.jinja +10 -10
- edsl/questions/templates/likert_five/question_presentation.jinja +11 -11
- edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -5
- edsl/questions/templates/linear_scale/question_presentation.jinja +5 -5
- edsl/questions/templates/list/answering_instructions.jinja +3 -3
- edsl/questions/templates/list/question_presentation.jinja +5 -5
- edsl/questions/templates/matrix/__init__.py +1 -0
- edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
- edsl/questions/templates/matrix/question_presentation.jinja +20 -0
- edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -9
- edsl/questions/templates/multiple_choice/question_presentation.jinja +11 -11
- edsl/questions/templates/numerical/answering_instructions.jinja +6 -6
- edsl/questions/templates/numerical/question_presentation.jinja +6 -6
- edsl/questions/templates/rank/answering_instructions.jinja +11 -11
- edsl/questions/templates/rank/question_presentation.jinja +15 -15
- edsl/questions/templates/top_k/answering_instructions.jinja +8 -8
- edsl/questions/templates/top_k/question_presentation.jinja +22 -22
- edsl/questions/templates/yes_no/answering_instructions.jinja +6 -6
- edsl/questions/templates/yes_no/question_presentation.jinja +11 -11
- edsl/results/CSSParameterizer.py +108 -108
- edsl/results/Dataset.py +587 -424
- edsl/results/DatasetExportMixin.py +594 -731
- edsl/results/DatasetTree.py +295 -275
- edsl/results/MarkdownToDocx.py +122 -0
- edsl/results/MarkdownToPDF.py +111 -0
- edsl/results/Result.py +557 -465
- edsl/results/Results.py +1183 -1165
- edsl/results/ResultsExportMixin.py +45 -43
- edsl/results/ResultsGGMixin.py +121 -121
- edsl/results/TableDisplay.py +125 -198
- edsl/results/TextEditor.py +50 -0
- edsl/results/__init__.py +2 -2
- edsl/results/file_exports.py +252 -0
- edsl/results/{ResultsFetchMixin.py → results_fetch_mixin.py} +33 -33
- edsl/results/{Selector.py → results_selector.py} +145 -135
- edsl/results/{ResultsToolsMixin.py → results_tools_mixin.py} +98 -98
- edsl/results/smart_objects.py +96 -0
- edsl/results/table_data_class.py +12 -0
- edsl/results/table_display.css +77 -77
- edsl/results/table_renderers.py +118 -0
- edsl/results/tree_explore.py +115 -115
- edsl/scenarios/ConstructDownloadLink.py +109 -0
- edsl/scenarios/DocumentChunker.py +102 -0
- edsl/scenarios/DocxScenario.py +16 -0
- edsl/scenarios/FileStore.py +511 -632
- edsl/scenarios/PdfExtractor.py +40 -0
- edsl/scenarios/Scenario.py +498 -601
- edsl/scenarios/ScenarioHtmlMixin.py +65 -64
- edsl/scenarios/ScenarioList.py +1458 -1287
- edsl/scenarios/ScenarioListExportMixin.py +45 -52
- edsl/scenarios/ScenarioListPdfMixin.py +239 -261
- edsl/scenarios/__init__.py +3 -4
- edsl/scenarios/directory_scanner.py +96 -0
- edsl/scenarios/file_methods.py +85 -0
- edsl/scenarios/handlers/__init__.py +13 -0
- edsl/scenarios/handlers/csv.py +38 -0
- edsl/scenarios/handlers/docx.py +76 -0
- edsl/scenarios/handlers/html.py +37 -0
- edsl/scenarios/handlers/json.py +111 -0
- edsl/scenarios/handlers/latex.py +5 -0
- edsl/scenarios/handlers/md.py +51 -0
- edsl/scenarios/handlers/pdf.py +68 -0
- edsl/scenarios/handlers/png.py +39 -0
- edsl/scenarios/handlers/pptx.py +105 -0
- edsl/scenarios/handlers/py.py +294 -0
- edsl/scenarios/handlers/sql.py +313 -0
- edsl/scenarios/handlers/sqlite.py +149 -0
- edsl/scenarios/handlers/txt.py +33 -0
- edsl/scenarios/{ScenarioJoin.py → scenario_join.py} +131 -127
- edsl/scenarios/scenario_selector.py +156 -0
- edsl/shared.py +1 -1
- edsl/study/ObjectEntry.py +173 -173
- edsl/study/ProofOfWork.py +113 -113
- edsl/study/SnapShot.py +80 -80
- edsl/study/Study.py +521 -528
- edsl/study/__init__.py +4 -4
- edsl/surveys/ConstructDAG.py +92 -0
- edsl/surveys/DAG.py +148 -148
- edsl/surveys/EditSurvey.py +221 -0
- edsl/surveys/InstructionHandler.py +100 -0
- edsl/surveys/Memory.py +31 -31
- edsl/surveys/MemoryManagement.py +72 -0
- edsl/surveys/MemoryPlan.py +244 -244
- edsl/surveys/Rule.py +327 -326
- edsl/surveys/RuleCollection.py +385 -387
- edsl/surveys/RuleManager.py +172 -0
- edsl/surveys/Simulator.py +75 -0
- edsl/surveys/Survey.py +1280 -1801
- edsl/surveys/SurveyCSS.py +273 -261
- edsl/surveys/SurveyExportMixin.py +259 -259
- edsl/surveys/{SurveyFlowVisualizationMixin.py → SurveyFlowVisualization.py} +181 -179
- edsl/surveys/SurveyQualtricsImport.py +284 -284
- edsl/surveys/SurveyToApp.py +141 -0
- edsl/surveys/__init__.py +5 -3
- edsl/surveys/base.py +53 -53
- edsl/surveys/descriptors.py +60 -56
- edsl/surveys/instructions/ChangeInstruction.py +48 -49
- edsl/surveys/instructions/Instruction.py +56 -65
- edsl/surveys/instructions/InstructionCollection.py +82 -77
- edsl/templates/error_reporting/base.html +23 -23
- edsl/templates/error_reporting/exceptions_by_model.html +34 -34
- edsl/templates/error_reporting/exceptions_by_question_name.html +16 -16
- edsl/templates/error_reporting/exceptions_by_type.html +16 -16
- edsl/templates/error_reporting/interview_details.html +115 -115
- edsl/templates/error_reporting/interviews.html +19 -19
- edsl/templates/error_reporting/overview.html +4 -4
- edsl/templates/error_reporting/performance_plot.html +1 -1
- edsl/templates/error_reporting/report.css +73 -73
- edsl/templates/error_reporting/report.html +117 -117
- edsl/templates/error_reporting/report.js +25 -25
- edsl/test_h +1 -0
- edsl/tools/__init__.py +1 -1
- edsl/tools/clusters.py +192 -192
- edsl/tools/embeddings.py +27 -27
- edsl/tools/embeddings_plotting.py +118 -118
- edsl/tools/plotting.py +112 -112
- edsl/tools/summarize.py +18 -18
- edsl/utilities/PrettyList.py +56 -0
- edsl/utilities/SystemInfo.py +28 -28
- edsl/utilities/__init__.py +22 -22
- edsl/utilities/ast_utilities.py +25 -25
- edsl/utilities/data/Registry.py +6 -6
- edsl/utilities/data/__init__.py +1 -1
- edsl/utilities/data/scooter_results.json +1 -1
- edsl/utilities/decorators.py +77 -77
- edsl/utilities/gcp_bucket/cloud_storage.py +96 -96
- edsl/utilities/gcp_bucket/example.py +50 -0
- edsl/utilities/interface.py +627 -627
- edsl/utilities/is_notebook.py +18 -0
- edsl/utilities/is_valid_variable_name.py +11 -0
- edsl/utilities/naming_utilities.py +263 -263
- edsl/utilities/remove_edsl_version.py +24 -0
- edsl/utilities/repair_functions.py +28 -28
- edsl/utilities/restricted_python.py +70 -70
- edsl/utilities/utilities.py +436 -424
- {edsl-0.1.39.dev3.dist-info → edsl-0.1.39.dev4.dist-info}/LICENSE +21 -21
- {edsl-0.1.39.dev3.dist-info → edsl-0.1.39.dev4.dist-info}/METADATA +13 -11
- edsl-0.1.39.dev4.dist-info/RECORD +361 -0
- edsl/language_models/KeyLookup.py +0 -30
- edsl/language_models/registry.py +0 -190
- edsl/language_models/unused/ReplicateBase.py +0 -83
- edsl/results/ResultsDBMixin.py +0 -238
- edsl-0.1.39.dev3.dist-info/RECORD +0 -277
- {edsl-0.1.39.dev3.dist-info → edsl-0.1.39.dev4.dist-info}/WHEEL +0 -0
edsl/tools/clusters.py
CHANGED
@@ -1,192 +1,192 @@
|
|
1
|
-
import json
|
2
|
-
import numpy as np
|
3
|
-
from sklearn.cluster import KMeans
|
4
|
-
from sklearn.manifold import TSNE
|
5
|
-
from sklearn.decomposition import PCA
|
6
|
-
from IPython.display import display_html
|
7
|
-
|
8
|
-
|
9
|
-
def compute_tsne(embeddings, cluster_labels, text_labels):
|
10
|
-
"""
|
11
|
-
Compute t-SNE on embedding vectors.
|
12
|
-
|
13
|
-
Parameters:
|
14
|
-
embeddings (np.ndarray): The embedding vectors.
|
15
|
-
cluster_labels (np.ndarray): Cluster labels for each embedding.
|
16
|
-
text_labels (list): Text labels for each embedding.
|
17
|
-
|
18
|
-
Returns:
|
19
|
-
list: List of dictionaries with x, y coordinates, cluster labels, and text labels.
|
20
|
-
"""
|
21
|
-
tsne = TSNE(n_components=2, random_state=42)
|
22
|
-
tsne_results = tsne.fit_transform(embeddings)
|
23
|
-
data = [
|
24
|
-
{
|
25
|
-
"x": float(tsne_results[i, 0]),
|
26
|
-
"y": float(tsne_results[i, 1]),
|
27
|
-
"cluster_label": str(cluster_labels[i]),
|
28
|
-
"text_label": text_labels[i],
|
29
|
-
}
|
30
|
-
for i in range(len(cluster_labels))
|
31
|
-
]
|
32
|
-
return data
|
33
|
-
|
34
|
-
|
35
|
-
def compute_pca(embeddings, cluster_labels, text_labels):
|
36
|
-
"""
|
37
|
-
Compute PCA on embedding vectors.
|
38
|
-
|
39
|
-
Parameters:
|
40
|
-
embeddings (np.ndarray): The embedding vectors.
|
41
|
-
cluster_labels (np.ndarray): Cluster labels for each embedding.
|
42
|
-
text_labels (list): Text labels for each embedding.
|
43
|
-
|
44
|
-
Returns:
|
45
|
-
list: List of dictionaries with x, y coordinates, cluster labels, and text labels.
|
46
|
-
"""
|
47
|
-
pca = PCA(n_components=2)
|
48
|
-
pca_results = pca.fit_transform(embeddings)
|
49
|
-
data = [
|
50
|
-
{
|
51
|
-
"x": float(pca_results[i, 0]),
|
52
|
-
"y": float(pca_results[i, 1]),
|
53
|
-
"cluster_label": str(cluster_labels[i]),
|
54
|
-
"text_label": text_labels[i],
|
55
|
-
}
|
56
|
-
for i in range(len(cluster_labels))
|
57
|
-
]
|
58
|
-
return data
|
59
|
-
|
60
|
-
|
61
|
-
def plot(embeddings, text_labels, n_clusters=5, method="tsne"):
|
62
|
-
"""
|
63
|
-
Perform k-means clustering and plot results in a Jupyter notebook using D3.js.
|
64
|
-
|
65
|
-
Parameters:
|
66
|
-
embeddings (np.ndarray): The embedding vectors.
|
67
|
-
text_labels (list): Text labels for each embedding.
|
68
|
-
n_clusters (int): The number of clusters to form.
|
69
|
-
method (str): The dimensionality reduction method to use ('tsne' or 'pca').
|
70
|
-
"""
|
71
|
-
# Perform k-means clustering
|
72
|
-
kmeans = KMeans(n_clusters=n_clusters, random_state=42)
|
73
|
-
cluster_labels = kmeans.fit_predict(embeddings)
|
74
|
-
|
75
|
-
# Compute dimensionality reduction
|
76
|
-
if method == "tsne":
|
77
|
-
data = compute_tsne(embeddings, cluster_labels, text_labels)
|
78
|
-
elif method == "pca":
|
79
|
-
data = compute_pca(embeddings, cluster_labels, text_labels)
|
80
|
-
else:
|
81
|
-
raise ValueError("Invalid method. Choose 'tsne' or 'pca'.")
|
82
|
-
|
83
|
-
# Convert data to JSON
|
84
|
-
data_json = json.dumps(data)
|
85
|
-
|
86
|
-
# HTML content with embedded data
|
87
|
-
html_content = f"""
|
88
|
-
<!DOCTYPE html>
|
89
|
-
<html lang="en">
|
90
|
-
<head>
|
91
|
-
<meta charset="UTF-8">
|
92
|
-
<title>t-SNE/PCA Plot with D3.js</title>
|
93
|
-
<script src="https://d3js.org/d3.v6.min.js"></script>
|
94
|
-
<style>
|
95
|
-
.tooltip {{
|
96
|
-
position: absolute;
|
97
|
-
text-align: center;
|
98
|
-
width: auto;
|
99
|
-
height: auto;
|
100
|
-
padding: 2px;
|
101
|
-
font: 12px sans-serif;
|
102
|
-
background: lightsteelblue;
|
103
|
-
border: 0px;
|
104
|
-
border-radius: 8px;
|
105
|
-
pointer-events: none;
|
106
|
-
}}
|
107
|
-
.dot {{
|
108
|
-
stroke: #000;
|
109
|
-
stroke-width: 0.5;
|
110
|
-
}}
|
111
|
-
</style>
|
112
|
-
</head>
|
113
|
-
<body>
|
114
|
-
<svg width="600" height="600"></svg>
|
115
|
-
|
116
|
-
<script>
|
117
|
-
// Embedded data
|
118
|
-
const data = {data_json};
|
119
|
-
|
120
|
-
const svg = d3.select("svg"),
|
121
|
-
width = +svg.attr("width"),
|
122
|
-
height = +svg.attr("height");
|
123
|
-
|
124
|
-
// Set up scales
|
125
|
-
const x = d3.scaleLinear()
|
126
|
-
.domain(d3.extent(data, d => d.x))
|
127
|
-
.range([0, width]);
|
128
|
-
|
129
|
-
const y = d3.scaleLinear()
|
130
|
-
.domain(d3.extent(data, d => d.y))
|
131
|
-
.range([height, 0]);
|
132
|
-
|
133
|
-
// Set up color scale
|
134
|
-
const color = d3.scaleOrdinal(d3.schemeCategory10);
|
135
|
-
|
136
|
-
// Create tooltip
|
137
|
-
const tooltip = d3.select("body").append("div")
|
138
|
-
.attr("class", "tooltip")
|
139
|
-
.style("opacity", 0);
|
140
|
-
|
141
|
-
// Create circles for each point
|
142
|
-
svg.selectAll("circle")
|
143
|
-
.data(data)
|
144
|
-
.enter().append("circle")
|
145
|
-
.attr("cx", d => x(d.x))
|
146
|
-
.attr("cy", d => y(d.y))
|
147
|
-
.attr("r", 5)
|
148
|
-
.attr("class", "dot")
|
149
|
-
.style("fill", d => color(d.cluster_label))
|
150
|
-
.on("mouseover", function(event, d) {{
|
151
|
-
tooltip.transition()
|
152
|
-
.duration(200)
|
153
|
-
.style("opacity", .9);
|
154
|
-
tooltip.html(d.text_label)
|
155
|
-
.style("left", (event.pageX + 5) + "px")
|
156
|
-
.style("top", (event.pageY - 28) + "px");
|
157
|
-
}})
|
158
|
-
.on("mouseout", function(d) {{
|
159
|
-
tooltip.transition()
|
160
|
-
.duration(500)
|
161
|
-
.style("opacity", 0);
|
162
|
-
}});
|
163
|
-
</script>
|
164
|
-
</body>
|
165
|
-
</html>
|
166
|
-
"""
|
167
|
-
|
168
|
-
# Write HTML content to a temporary file
|
169
|
-
html_file = "tsne_pca_plot.html"
|
170
|
-
with open(html_file, "w") as file:
|
171
|
-
file.write(html_content)
|
172
|
-
|
173
|
-
# Display the HTML content in an iframe within a Jupyter notebook
|
174
|
-
display_html(
|
175
|
-
f'<iframe src="{html_file}" width="600" height="600"></iframe>', raw=True
|
176
|
-
)
|
177
|
-
|
178
|
-
|
179
|
-
# Example usage
|
180
|
-
if __name__ == "__main__":
|
181
|
-
# Generate some sample data (embedding vectors)
|
182
|
-
np.random.seed(42)
|
183
|
-
embedding_vectors = np.random.rand(
|
184
|
-
100, 50
|
185
|
-
) # 100 samples with 50-dimensional embeddings
|
186
|
-
text_labels = [f"Text {i}" for i in range(100)] # Sample text labels
|
187
|
-
|
188
|
-
# Plot the clusters using t-SNE
|
189
|
-
plot(embedding_vectors, text_labels, n_clusters=5, method="tsne")
|
190
|
-
|
191
|
-
# Plot the clusters using PCA
|
192
|
-
plot(embedding_vectors, text_labels, n_clusters=5, method="pca")
|
1
|
+
import json
|
2
|
+
import numpy as np
|
3
|
+
from sklearn.cluster import KMeans
|
4
|
+
from sklearn.manifold import TSNE
|
5
|
+
from sklearn.decomposition import PCA
|
6
|
+
from IPython.display import display_html
|
7
|
+
|
8
|
+
|
9
|
+
def compute_tsne(embeddings, cluster_labels, text_labels):
|
10
|
+
"""
|
11
|
+
Compute t-SNE on embedding vectors.
|
12
|
+
|
13
|
+
Parameters:
|
14
|
+
embeddings (np.ndarray): The embedding vectors.
|
15
|
+
cluster_labels (np.ndarray): Cluster labels for each embedding.
|
16
|
+
text_labels (list): Text labels for each embedding.
|
17
|
+
|
18
|
+
Returns:
|
19
|
+
list: List of dictionaries with x, y coordinates, cluster labels, and text labels.
|
20
|
+
"""
|
21
|
+
tsne = TSNE(n_components=2, random_state=42)
|
22
|
+
tsne_results = tsne.fit_transform(embeddings)
|
23
|
+
data = [
|
24
|
+
{
|
25
|
+
"x": float(tsne_results[i, 0]),
|
26
|
+
"y": float(tsne_results[i, 1]),
|
27
|
+
"cluster_label": str(cluster_labels[i]),
|
28
|
+
"text_label": text_labels[i],
|
29
|
+
}
|
30
|
+
for i in range(len(cluster_labels))
|
31
|
+
]
|
32
|
+
return data
|
33
|
+
|
34
|
+
|
35
|
+
def compute_pca(embeddings, cluster_labels, text_labels):
|
36
|
+
"""
|
37
|
+
Compute PCA on embedding vectors.
|
38
|
+
|
39
|
+
Parameters:
|
40
|
+
embeddings (np.ndarray): The embedding vectors.
|
41
|
+
cluster_labels (np.ndarray): Cluster labels for each embedding.
|
42
|
+
text_labels (list): Text labels for each embedding.
|
43
|
+
|
44
|
+
Returns:
|
45
|
+
list: List of dictionaries with x, y coordinates, cluster labels, and text labels.
|
46
|
+
"""
|
47
|
+
pca = PCA(n_components=2)
|
48
|
+
pca_results = pca.fit_transform(embeddings)
|
49
|
+
data = [
|
50
|
+
{
|
51
|
+
"x": float(pca_results[i, 0]),
|
52
|
+
"y": float(pca_results[i, 1]),
|
53
|
+
"cluster_label": str(cluster_labels[i]),
|
54
|
+
"text_label": text_labels[i],
|
55
|
+
}
|
56
|
+
for i in range(len(cluster_labels))
|
57
|
+
]
|
58
|
+
return data
|
59
|
+
|
60
|
+
|
61
|
+
def plot(embeddings, text_labels, n_clusters=5, method="tsne"):
|
62
|
+
"""
|
63
|
+
Perform k-means clustering and plot results in a Jupyter notebook using D3.js.
|
64
|
+
|
65
|
+
Parameters:
|
66
|
+
embeddings (np.ndarray): The embedding vectors.
|
67
|
+
text_labels (list): Text labels for each embedding.
|
68
|
+
n_clusters (int): The number of clusters to form.
|
69
|
+
method (str): The dimensionality reduction method to use ('tsne' or 'pca').
|
70
|
+
"""
|
71
|
+
# Perform k-means clustering
|
72
|
+
kmeans = KMeans(n_clusters=n_clusters, random_state=42)
|
73
|
+
cluster_labels = kmeans.fit_predict(embeddings)
|
74
|
+
|
75
|
+
# Compute dimensionality reduction
|
76
|
+
if method == "tsne":
|
77
|
+
data = compute_tsne(embeddings, cluster_labels, text_labels)
|
78
|
+
elif method == "pca":
|
79
|
+
data = compute_pca(embeddings, cluster_labels, text_labels)
|
80
|
+
else:
|
81
|
+
raise ValueError("Invalid method. Choose 'tsne' or 'pca'.")
|
82
|
+
|
83
|
+
# Convert data to JSON
|
84
|
+
data_json = json.dumps(data)
|
85
|
+
|
86
|
+
# HTML content with embedded data
|
87
|
+
html_content = f"""
|
88
|
+
<!DOCTYPE html>
|
89
|
+
<html lang="en">
|
90
|
+
<head>
|
91
|
+
<meta charset="UTF-8">
|
92
|
+
<title>t-SNE/PCA Plot with D3.js</title>
|
93
|
+
<script src="https://d3js.org/d3.v6.min.js"></script>
|
94
|
+
<style>
|
95
|
+
.tooltip {{
|
96
|
+
position: absolute;
|
97
|
+
text-align: center;
|
98
|
+
width: auto;
|
99
|
+
height: auto;
|
100
|
+
padding: 2px;
|
101
|
+
font: 12px sans-serif;
|
102
|
+
background: lightsteelblue;
|
103
|
+
border: 0px;
|
104
|
+
border-radius: 8px;
|
105
|
+
pointer-events: none;
|
106
|
+
}}
|
107
|
+
.dot {{
|
108
|
+
stroke: #000;
|
109
|
+
stroke-width: 0.5;
|
110
|
+
}}
|
111
|
+
</style>
|
112
|
+
</head>
|
113
|
+
<body>
|
114
|
+
<svg width="600" height="600"></svg>
|
115
|
+
|
116
|
+
<script>
|
117
|
+
// Embedded data
|
118
|
+
const data = {data_json};
|
119
|
+
|
120
|
+
const svg = d3.select("svg"),
|
121
|
+
width = +svg.attr("width"),
|
122
|
+
height = +svg.attr("height");
|
123
|
+
|
124
|
+
// Set up scales
|
125
|
+
const x = d3.scaleLinear()
|
126
|
+
.domain(d3.extent(data, d => d.x))
|
127
|
+
.range([0, width]);
|
128
|
+
|
129
|
+
const y = d3.scaleLinear()
|
130
|
+
.domain(d3.extent(data, d => d.y))
|
131
|
+
.range([height, 0]);
|
132
|
+
|
133
|
+
// Set up color scale
|
134
|
+
const color = d3.scaleOrdinal(d3.schemeCategory10);
|
135
|
+
|
136
|
+
// Create tooltip
|
137
|
+
const tooltip = d3.select("body").append("div")
|
138
|
+
.attr("class", "tooltip")
|
139
|
+
.style("opacity", 0);
|
140
|
+
|
141
|
+
// Create circles for each point
|
142
|
+
svg.selectAll("circle")
|
143
|
+
.data(data)
|
144
|
+
.enter().append("circle")
|
145
|
+
.attr("cx", d => x(d.x))
|
146
|
+
.attr("cy", d => y(d.y))
|
147
|
+
.attr("r", 5)
|
148
|
+
.attr("class", "dot")
|
149
|
+
.style("fill", d => color(d.cluster_label))
|
150
|
+
.on("mouseover", function(event, d) {{
|
151
|
+
tooltip.transition()
|
152
|
+
.duration(200)
|
153
|
+
.style("opacity", .9);
|
154
|
+
tooltip.html(d.text_label)
|
155
|
+
.style("left", (event.pageX + 5) + "px")
|
156
|
+
.style("top", (event.pageY - 28) + "px");
|
157
|
+
}})
|
158
|
+
.on("mouseout", function(d) {{
|
159
|
+
tooltip.transition()
|
160
|
+
.duration(500)
|
161
|
+
.style("opacity", 0);
|
162
|
+
}});
|
163
|
+
</script>
|
164
|
+
</body>
|
165
|
+
</html>
|
166
|
+
"""
|
167
|
+
|
168
|
+
# Write HTML content to a temporary file
|
169
|
+
html_file = "tsne_pca_plot.html"
|
170
|
+
with open(html_file, "w") as file:
|
171
|
+
file.write(html_content)
|
172
|
+
|
173
|
+
# Display the HTML content in an iframe within a Jupyter notebook
|
174
|
+
display_html(
|
175
|
+
f'<iframe src="{html_file}" width="600" height="600"></iframe>', raw=True
|
176
|
+
)
|
177
|
+
|
178
|
+
|
179
|
+
# Example usage
|
180
|
+
if __name__ == "__main__":
|
181
|
+
# Generate some sample data (embedding vectors)
|
182
|
+
np.random.seed(42)
|
183
|
+
embedding_vectors = np.random.rand(
|
184
|
+
100, 50
|
185
|
+
) # 100 samples with 50-dimensional embeddings
|
186
|
+
text_labels = [f"Text {i}" for i in range(100)] # Sample text labels
|
187
|
+
|
188
|
+
# Plot the clusters using t-SNE
|
189
|
+
plot(embedding_vectors, text_labels, n_clusters=5, method="tsne")
|
190
|
+
|
191
|
+
# Plot the clusters using PCA
|
192
|
+
plot(embedding_vectors, text_labels, n_clusters=5, method="pca")
|
edsl/tools/embeddings.py
CHANGED
@@ -1,27 +1,27 @@
|
|
1
|
-
import openai
|
2
|
-
from dotenv import load_dotenv
|
3
|
-
import os
|
4
|
-
|
5
|
-
# Load environment variables from .env file
|
6
|
-
load_dotenv()
|
7
|
-
|
8
|
-
# Get the OpenAI API key from the environment variable
|
9
|
-
openai.api_key = os.getenv("OPENAI_API_KEY")
|
10
|
-
|
11
|
-
from openai import Client
|
12
|
-
|
13
|
-
|
14
|
-
def get_embeddings(texts):
|
15
|
-
"""
|
16
|
-
Get embeddings for a list of texts using OpenAI API.
|
17
|
-
|
18
|
-
Args:
|
19
|
-
texts (list of str): List of strings to get embeddings for.
|
20
|
-
|
21
|
-
Returns:
|
22
|
-
list of list of float: List of embeddings.
|
23
|
-
"""
|
24
|
-
client = Client()
|
25
|
-
response = client.embeddings.create(input=texts, model="text-embedding-ada-002")
|
26
|
-
embeddings = [item.embedding for item in response.data]
|
27
|
-
return embeddings
|
1
|
+
import openai
|
2
|
+
from dotenv import load_dotenv
|
3
|
+
import os
|
4
|
+
|
5
|
+
# Load environment variables from .env file
|
6
|
+
load_dotenv()
|
7
|
+
|
8
|
+
# Get the OpenAI API key from the environment variable
|
9
|
+
openai.api_key = os.getenv("OPENAI_API_KEY")
|
10
|
+
|
11
|
+
from openai import Client
|
12
|
+
|
13
|
+
|
14
|
+
def get_embeddings(texts):
|
15
|
+
"""
|
16
|
+
Get embeddings for a list of texts using OpenAI API.
|
17
|
+
|
18
|
+
Args:
|
19
|
+
texts (list of str): List of strings to get embeddings for.
|
20
|
+
|
21
|
+
Returns:
|
22
|
+
list of list of float: List of embeddings.
|
23
|
+
"""
|
24
|
+
client = Client()
|
25
|
+
response = client.embeddings.create(input=texts, model="text-embedding-ada-002")
|
26
|
+
embeddings = [item.embedding for item in response.data]
|
27
|
+
return embeddings
|