edsl 0.1.39.dev1__py3-none-any.whl → 0.1.39.dev3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +332 -332
- edsl/BaseDiff.py +260 -260
- edsl/TemplateLoader.py +24 -24
- edsl/__init__.py +49 -49
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +867 -867
- edsl/agents/AgentList.py +413 -413
- edsl/agents/Invigilator.py +233 -233
- edsl/agents/InvigilatorBase.py +270 -265
- edsl/agents/PromptConstructor.py +354 -354
- edsl/agents/__init__.py +3 -3
- edsl/agents/descriptors.py +99 -99
- edsl/agents/prompt_helpers.py +129 -129
- edsl/auto/AutoStudy.py +117 -117
- edsl/auto/StageBase.py +230 -230
- edsl/auto/StageGenerateSurvey.py +178 -178
- edsl/auto/StageLabelQuestions.py +125 -125
- edsl/auto/StagePersona.py +61 -61
- edsl/auto/StagePersonaDimensionValueRanges.py +88 -88
- edsl/auto/StagePersonaDimensionValues.py +74 -74
- edsl/auto/StagePersonaDimensions.py +69 -69
- edsl/auto/StageQuestions.py +73 -73
- edsl/auto/SurveyCreatorPipeline.py +21 -21
- edsl/auto/utilities.py +224 -224
- edsl/base/Base.py +279 -279
- edsl/config.py +157 -157
- edsl/conversation/Conversation.py +290 -290
- edsl/conversation/car_buying.py +58 -58
- edsl/conversation/chips.py +95 -95
- edsl/conversation/mug_negotiation.py +81 -81
- edsl/conversation/next_speaker_utilities.py +93 -93
- edsl/coop/PriceFetcher.py +54 -54
- edsl/coop/__init__.py +2 -2
- edsl/coop/coop.py +1028 -1028
- edsl/coop/utils.py +131 -131
- edsl/data/Cache.py +555 -555
- edsl/data/CacheEntry.py +233 -233
- edsl/data/CacheHandler.py +149 -149
- edsl/data/RemoteCacheSync.py +78 -78
- edsl/data/SQLiteDict.py +292 -292
- edsl/data/__init__.py +4 -4
- edsl/data/orm.py +10 -10
- edsl/data_transfer_models.py +73 -73
- edsl/enums.py +175 -175
- edsl/exceptions/BaseException.py +21 -21
- edsl/exceptions/__init__.py +54 -54
- edsl/exceptions/agents.py +42 -42
- edsl/exceptions/cache.py +5 -5
- edsl/exceptions/configuration.py +16 -16
- edsl/exceptions/coop.py +10 -10
- edsl/exceptions/data.py +14 -14
- edsl/exceptions/general.py +34 -34
- edsl/exceptions/jobs.py +33 -33
- edsl/exceptions/language_models.py +63 -63
- edsl/exceptions/prompts.py +15 -15
- edsl/exceptions/questions.py +91 -91
- edsl/exceptions/results.py +29 -29
- edsl/exceptions/scenarios.py +22 -22
- edsl/exceptions/surveys.py +37 -37
- edsl/inference_services/AnthropicService.py +87 -87
- edsl/inference_services/AwsBedrock.py +120 -120
- edsl/inference_services/AzureAI.py +217 -217
- edsl/inference_services/DeepInfraService.py +18 -18
- edsl/inference_services/GoogleService.py +148 -148
- edsl/inference_services/GroqService.py +20 -20
- edsl/inference_services/InferenceServiceABC.py +147 -147
- edsl/inference_services/InferenceServicesCollection.py +97 -97
- edsl/inference_services/MistralAIService.py +123 -123
- edsl/inference_services/OllamaService.py +18 -18
- edsl/inference_services/OpenAIService.py +224 -224
- edsl/inference_services/PerplexityService.py +163 -163
- edsl/inference_services/TestService.py +89 -89
- edsl/inference_services/TogetherAIService.py +170 -170
- edsl/inference_services/models_available_cache.py +118 -118
- edsl/inference_services/rate_limits_cache.py +25 -25
- edsl/inference_services/registry.py +41 -41
- edsl/inference_services/write_available.py +10 -10
- edsl/jobs/Answers.py +56 -56
- edsl/jobs/Jobs.py +898 -898
- edsl/jobs/JobsChecks.py +147 -147
- edsl/jobs/JobsPrompts.py +268 -268
- edsl/jobs/JobsRemoteInferenceHandler.py +239 -239
- edsl/jobs/__init__.py +1 -1
- edsl/jobs/buckets/BucketCollection.py +63 -63
- edsl/jobs/buckets/ModelBuckets.py +65 -65
- edsl/jobs/buckets/TokenBucket.py +251 -251
- edsl/jobs/interviews/Interview.py +661 -661
- edsl/jobs/interviews/InterviewExceptionCollection.py +99 -99
- edsl/jobs/interviews/InterviewExceptionEntry.py +186 -186
- edsl/jobs/interviews/InterviewStatistic.py +63 -63
- edsl/jobs/interviews/InterviewStatisticsCollection.py +25 -25
- edsl/jobs/interviews/InterviewStatusDictionary.py +78 -78
- edsl/jobs/interviews/InterviewStatusLog.py +92 -92
- edsl/jobs/interviews/ReportErrors.py +66 -66
- edsl/jobs/interviews/interview_status_enum.py +9 -9
- edsl/jobs/runners/JobsRunnerAsyncio.py +466 -466
- edsl/jobs/runners/JobsRunnerStatus.py +330 -330
- edsl/jobs/tasks/QuestionTaskCreator.py +242 -242
- edsl/jobs/tasks/TaskCreators.py +64 -64
- edsl/jobs/tasks/TaskHistory.py +450 -450
- edsl/jobs/tasks/TaskStatusLog.py +23 -23
- edsl/jobs/tasks/task_status_enum.py +163 -163
- edsl/jobs/tokens/InterviewTokenUsage.py +27 -27
- edsl/jobs/tokens/TokenUsage.py +34 -34
- edsl/language_models/KeyLookup.py +30 -30
- edsl/language_models/LanguageModel.py +668 -668
- edsl/language_models/ModelList.py +155 -155
- edsl/language_models/RegisterLanguageModelsMeta.py +184 -184
- edsl/language_models/__init__.py +3 -3
- edsl/language_models/fake_openai_call.py +15 -15
- edsl/language_models/fake_openai_service.py +61 -61
- edsl/language_models/registry.py +190 -190
- edsl/language_models/repair.py +156 -156
- edsl/language_models/unused/ReplicateBase.py +83 -83
- edsl/language_models/utilities.py +64 -64
- edsl/notebooks/Notebook.py +258 -258
- edsl/notebooks/__init__.py +1 -1
- edsl/prompts/Prompt.py +362 -362
- edsl/prompts/__init__.py +2 -2
- edsl/questions/AnswerValidatorMixin.py +289 -289
- edsl/questions/QuestionBase.py +664 -664
- edsl/questions/QuestionBaseGenMixin.py +161 -161
- edsl/questions/QuestionBasePromptsMixin.py +217 -217
- edsl/questions/QuestionBudget.py +227 -227
- edsl/questions/QuestionCheckBox.py +359 -359
- edsl/questions/QuestionExtract.py +182 -182
- edsl/questions/QuestionFreeText.py +114 -114
- edsl/questions/QuestionFunctional.py +166 -166
- edsl/questions/QuestionList.py +231 -231
- edsl/questions/QuestionMultipleChoice.py +286 -286
- edsl/questions/QuestionNumerical.py +153 -153
- edsl/questions/QuestionRank.py +324 -324
- edsl/questions/Quick.py +41 -41
- edsl/questions/RegisterQuestionsMeta.py +71 -71
- edsl/questions/ResponseValidatorABC.py +174 -174
- edsl/questions/SimpleAskMixin.py +73 -73
- edsl/questions/__init__.py +26 -26
- edsl/questions/compose_questions.py +98 -98
- edsl/questions/decorators.py +21 -21
- edsl/questions/derived/QuestionLikertFive.py +76 -76
- edsl/questions/derived/QuestionLinearScale.py +87 -87
- edsl/questions/derived/QuestionTopK.py +93 -93
- edsl/questions/derived/QuestionYesNo.py +82 -82
- edsl/questions/descriptors.py +413 -413
- edsl/questions/prompt_templates/question_budget.jinja +13 -13
- edsl/questions/prompt_templates/question_checkbox.jinja +32 -32
- edsl/questions/prompt_templates/question_extract.jinja +11 -11
- edsl/questions/prompt_templates/question_free_text.jinja +3 -3
- edsl/questions/prompt_templates/question_linear_scale.jinja +11 -11
- edsl/questions/prompt_templates/question_list.jinja +17 -17
- edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -33
- edsl/questions/prompt_templates/question_numerical.jinja +36 -36
- edsl/questions/question_registry.py +177 -177
- edsl/questions/settings.py +12 -12
- edsl/questions/templates/budget/answering_instructions.jinja +7 -7
- edsl/questions/templates/budget/question_presentation.jinja +7 -7
- edsl/questions/templates/checkbox/answering_instructions.jinja +10 -10
- edsl/questions/templates/checkbox/question_presentation.jinja +22 -22
- edsl/questions/templates/extract/answering_instructions.jinja +7 -7
- edsl/questions/templates/likert_five/answering_instructions.jinja +10 -10
- edsl/questions/templates/likert_five/question_presentation.jinja +11 -11
- edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -5
- edsl/questions/templates/linear_scale/question_presentation.jinja +5 -5
- edsl/questions/templates/list/answering_instructions.jinja +3 -3
- edsl/questions/templates/list/question_presentation.jinja +5 -5
- edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -9
- edsl/questions/templates/multiple_choice/question_presentation.jinja +11 -11
- edsl/questions/templates/numerical/answering_instructions.jinja +6 -6
- edsl/questions/templates/numerical/question_presentation.jinja +6 -6
- edsl/questions/templates/rank/answering_instructions.jinja +11 -11
- edsl/questions/templates/rank/question_presentation.jinja +15 -15
- edsl/questions/templates/top_k/answering_instructions.jinja +8 -8
- edsl/questions/templates/top_k/question_presentation.jinja +22 -22
- edsl/questions/templates/yes_no/answering_instructions.jinja +6 -6
- edsl/questions/templates/yes_no/question_presentation.jinja +11 -11
- edsl/results/CSSParameterizer.py +108 -108
- edsl/results/Dataset.py +424 -424
- edsl/results/DatasetExportMixin.py +731 -731
- edsl/results/DatasetTree.py +275 -275
- edsl/results/Result.py +465 -465
- edsl/results/Results.py +1165 -1165
- edsl/results/ResultsDBMixin.py +238 -238
- edsl/results/ResultsExportMixin.py +43 -43
- edsl/results/ResultsFetchMixin.py +33 -33
- edsl/results/ResultsGGMixin.py +121 -121
- edsl/results/ResultsToolsMixin.py +98 -98
- edsl/results/Selector.py +135 -135
- edsl/results/TableDisplay.py +198 -198
- edsl/results/__init__.py +2 -2
- edsl/results/table_display.css +77 -77
- edsl/results/tree_explore.py +115 -115
- edsl/scenarios/FileStore.py +632 -632
- edsl/scenarios/Scenario.py +601 -601
- edsl/scenarios/ScenarioHtmlMixin.py +64 -64
- edsl/scenarios/ScenarioJoin.py +127 -127
- edsl/scenarios/ScenarioList.py +1287 -1287
- edsl/scenarios/ScenarioListExportMixin.py +52 -52
- edsl/scenarios/ScenarioListPdfMixin.py +261 -261
- edsl/scenarios/__init__.py +4 -4
- edsl/shared.py +1 -1
- edsl/study/ObjectEntry.py +173 -173
- edsl/study/ProofOfWork.py +113 -113
- edsl/study/SnapShot.py +80 -80
- edsl/study/Study.py +528 -528
- edsl/study/__init__.py +4 -4
- edsl/surveys/DAG.py +148 -148
- edsl/surveys/Memory.py +31 -31
- edsl/surveys/MemoryPlan.py +244 -244
- edsl/surveys/Rule.py +326 -326
- edsl/surveys/RuleCollection.py +387 -387
- edsl/surveys/Survey.py +1801 -1801
- edsl/surveys/SurveyCSS.py +261 -261
- edsl/surveys/SurveyExportMixin.py +259 -259
- edsl/surveys/SurveyFlowVisualizationMixin.py +179 -179
- edsl/surveys/SurveyQualtricsImport.py +284 -284
- edsl/surveys/__init__.py +3 -3
- edsl/surveys/base.py +53 -53
- edsl/surveys/descriptors.py +56 -56
- edsl/surveys/instructions/ChangeInstruction.py +49 -49
- edsl/surveys/instructions/Instruction.py +65 -65
- edsl/surveys/instructions/InstructionCollection.py +77 -77
- edsl/templates/error_reporting/base.html +23 -23
- edsl/templates/error_reporting/exceptions_by_model.html +34 -34
- edsl/templates/error_reporting/exceptions_by_question_name.html +16 -16
- edsl/templates/error_reporting/exceptions_by_type.html +16 -16
- edsl/templates/error_reporting/interview_details.html +115 -115
- edsl/templates/error_reporting/interviews.html +19 -19
- edsl/templates/error_reporting/overview.html +4 -4
- edsl/templates/error_reporting/performance_plot.html +1 -1
- edsl/templates/error_reporting/report.css +73 -73
- edsl/templates/error_reporting/report.html +117 -117
- edsl/templates/error_reporting/report.js +25 -25
- edsl/tools/__init__.py +1 -1
- edsl/tools/clusters.py +192 -192
- edsl/tools/embeddings.py +27 -27
- edsl/tools/embeddings_plotting.py +118 -118
- edsl/tools/plotting.py +112 -112
- edsl/tools/summarize.py +18 -18
- edsl/utilities/SystemInfo.py +28 -28
- edsl/utilities/__init__.py +22 -22
- edsl/utilities/ast_utilities.py +25 -25
- edsl/utilities/data/Registry.py +6 -6
- edsl/utilities/data/__init__.py +1 -1
- edsl/utilities/data/scooter_results.json +1 -1
- edsl/utilities/decorators.py +77 -77
- edsl/utilities/gcp_bucket/cloud_storage.py +96 -96
- edsl/utilities/interface.py +627 -627
- edsl/utilities/naming_utilities.py +263 -263
- edsl/utilities/repair_functions.py +28 -28
- edsl/utilities/restricted_python.py +70 -70
- edsl/utilities/utilities.py +424 -424
- {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev3.dist-info}/LICENSE +21 -21
- {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev3.dist-info}/METADATA +1 -1
- edsl-0.1.39.dev3.dist-info/RECORD +277 -0
- edsl-0.1.39.dev1.dist-info/RECORD +0 -277
- {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev3.dist-info}/WHEEL +0 -0
@@ -1,64 +1,64 @@
|
|
1
|
-
import requests
|
2
|
-
from typing import Optional
|
3
|
-
from requests.adapters import HTTPAdapter
|
4
|
-
from requests.packages.urllib3.util.retry import Retry
|
5
|
-
|
6
|
-
|
7
|
-
class ScenarioHtmlMixin:
|
8
|
-
@classmethod
|
9
|
-
def from_html(cls, url: str, field_name: Optional[str] = None) -> "Scenario":
|
10
|
-
"""Create a scenario from HTML content.
|
11
|
-
|
12
|
-
:param html: The HTML content.
|
13
|
-
:param field_name: The name of the field containing the HTML content.
|
14
|
-
|
15
|
-
|
16
|
-
"""
|
17
|
-
html = cls.fetch_html(url)
|
18
|
-
text = cls.extract_text(html)
|
19
|
-
if not field_name:
|
20
|
-
field_name = "text"
|
21
|
-
return cls({"url": url, "html": html, field_name: text})
|
22
|
-
|
23
|
-
def fetch_html(url):
|
24
|
-
# Define the user-agent to mimic a browser
|
25
|
-
headers = {
|
26
|
-
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
27
|
-
}
|
28
|
-
|
29
|
-
# Create a session to manage cookies and retries
|
30
|
-
session = requests.Session()
|
31
|
-
retries = Retry(
|
32
|
-
total=5, backoff_factor=0.1, status_forcelist=[500, 502, 503, 504]
|
33
|
-
)
|
34
|
-
session.mount("http://", HTTPAdapter(max_retries=retries))
|
35
|
-
session.mount("https://", HTTPAdapter(max_retries=retries))
|
36
|
-
|
37
|
-
try:
|
38
|
-
# Make the request
|
39
|
-
response = session.get(url, headers=headers, timeout=10)
|
40
|
-
response.raise_for_status() # Raise an exception for HTTP errors
|
41
|
-
return response.text
|
42
|
-
except requests.exceptions.RequestException as e:
|
43
|
-
print(f"An error occurred: {e}")
|
44
|
-
return None
|
45
|
-
|
46
|
-
def extract_text(html):
|
47
|
-
# Extract text from HTML using BeautifulSoup
|
48
|
-
from bs4 import BeautifulSoup
|
49
|
-
|
50
|
-
soup = BeautifulSoup(html, "html.parser")
|
51
|
-
text = soup.get_text()
|
52
|
-
return text
|
53
|
-
|
54
|
-
|
55
|
-
if __name__ == "__main__":
|
56
|
-
# Usage example
|
57
|
-
url = "https://example.com"
|
58
|
-
html = ScenarioHtmlMixin.fetch_html(url)
|
59
|
-
if html:
|
60
|
-
print("Successfully fetched the HTML content.")
|
61
|
-
else:
|
62
|
-
print("Failed to fetch the HTML content.")
|
63
|
-
|
64
|
-
print(html)
|
1
|
+
import requests
|
2
|
+
from typing import Optional
|
3
|
+
from requests.adapters import HTTPAdapter
|
4
|
+
from requests.packages.urllib3.util.retry import Retry
|
5
|
+
|
6
|
+
|
7
|
+
class ScenarioHtmlMixin:
|
8
|
+
@classmethod
|
9
|
+
def from_html(cls, url: str, field_name: Optional[str] = None) -> "Scenario":
|
10
|
+
"""Create a scenario from HTML content.
|
11
|
+
|
12
|
+
:param html: The HTML content.
|
13
|
+
:param field_name: The name of the field containing the HTML content.
|
14
|
+
|
15
|
+
|
16
|
+
"""
|
17
|
+
html = cls.fetch_html(url)
|
18
|
+
text = cls.extract_text(html)
|
19
|
+
if not field_name:
|
20
|
+
field_name = "text"
|
21
|
+
return cls({"url": url, "html": html, field_name: text})
|
22
|
+
|
23
|
+
def fetch_html(url):
|
24
|
+
# Define the user-agent to mimic a browser
|
25
|
+
headers = {
|
26
|
+
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
27
|
+
}
|
28
|
+
|
29
|
+
# Create a session to manage cookies and retries
|
30
|
+
session = requests.Session()
|
31
|
+
retries = Retry(
|
32
|
+
total=5, backoff_factor=0.1, status_forcelist=[500, 502, 503, 504]
|
33
|
+
)
|
34
|
+
session.mount("http://", HTTPAdapter(max_retries=retries))
|
35
|
+
session.mount("https://", HTTPAdapter(max_retries=retries))
|
36
|
+
|
37
|
+
try:
|
38
|
+
# Make the request
|
39
|
+
response = session.get(url, headers=headers, timeout=10)
|
40
|
+
response.raise_for_status() # Raise an exception for HTTP errors
|
41
|
+
return response.text
|
42
|
+
except requests.exceptions.RequestException as e:
|
43
|
+
print(f"An error occurred: {e}")
|
44
|
+
return None
|
45
|
+
|
46
|
+
def extract_text(html):
|
47
|
+
# Extract text from HTML using BeautifulSoup
|
48
|
+
from bs4 import BeautifulSoup
|
49
|
+
|
50
|
+
soup = BeautifulSoup(html, "html.parser")
|
51
|
+
text = soup.get_text()
|
52
|
+
return text
|
53
|
+
|
54
|
+
|
55
|
+
if __name__ == "__main__":
|
56
|
+
# Usage example
|
57
|
+
url = "https://example.com"
|
58
|
+
html = ScenarioHtmlMixin.fetch_html(url)
|
59
|
+
if html:
|
60
|
+
print("Successfully fetched the HTML content.")
|
61
|
+
else:
|
62
|
+
print("Failed to fetch the HTML content.")
|
63
|
+
|
64
|
+
print(html)
|
edsl/scenarios/ScenarioJoin.py
CHANGED
@@ -1,127 +1,127 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
from typing import Union, TYPE_CHECKING
|
3
|
-
|
4
|
-
# if TYPE_CHECKING:
|
5
|
-
from edsl.scenarios.ScenarioList import ScenarioList
|
6
|
-
from edsl.scenarios.Scenario import Scenario
|
7
|
-
|
8
|
-
|
9
|
-
class ScenarioJoin:
|
10
|
-
"""Handles join operations between two ScenarioLists.
|
11
|
-
|
12
|
-
This class encapsulates all join-related logic, making it easier to maintain
|
13
|
-
and extend with other join types (inner, right, full) in the future.
|
14
|
-
"""
|
15
|
-
|
16
|
-
def __init__(self, left: "ScenarioList", right: "ScenarioList"):
|
17
|
-
"""Initialize join operation with two ScenarioLists.
|
18
|
-
|
19
|
-
Args:
|
20
|
-
left: The left ScenarioList
|
21
|
-
right: The right ScenarioList
|
22
|
-
"""
|
23
|
-
self.left = left
|
24
|
-
self.right = right
|
25
|
-
|
26
|
-
def left_join(self, by: Union[str, list[str]]) -> ScenarioList:
|
27
|
-
"""Perform a left join between the two ScenarioLists.
|
28
|
-
|
29
|
-
Args:
|
30
|
-
by: String or list of strings representing the key(s) to join on. Cannot be empty.
|
31
|
-
|
32
|
-
Returns:
|
33
|
-
A new ScenarioList containing the joined scenarios
|
34
|
-
|
35
|
-
Raises:
|
36
|
-
ValueError: If by is empty or if any join keys don't exist in both ScenarioLists
|
37
|
-
"""
|
38
|
-
self._validate_join_keys(by)
|
39
|
-
by_keys = [by] if isinstance(by, str) else by
|
40
|
-
|
41
|
-
other_dict = self._create_lookup_dict(self.right, by_keys)
|
42
|
-
all_keys = self._get_all_keys()
|
43
|
-
|
44
|
-
return ScenarioList(
|
45
|
-
self._create_joined_scenarios(by_keys, other_dict, all_keys)
|
46
|
-
)
|
47
|
-
|
48
|
-
def _validate_join_keys(self, by: Union[str, list[str]]) -> None:
|
49
|
-
"""Validate join keys exist in both ScenarioLists."""
|
50
|
-
if not by:
|
51
|
-
raise ValueError(
|
52
|
-
"Join keys cannot be empty. Please specify at least one key to join on."
|
53
|
-
)
|
54
|
-
|
55
|
-
by_keys = [by] if isinstance(by, str) else by
|
56
|
-
left_keys = set(next(iter(self.left)).keys()) if self.left else set()
|
57
|
-
right_keys = set(next(iter(self.right)).keys()) if self.right else set()
|
58
|
-
|
59
|
-
missing_left = set(by_keys) - left_keys
|
60
|
-
missing_right = set(by_keys) - right_keys
|
61
|
-
if missing_left or missing_right:
|
62
|
-
missing = missing_left | missing_right
|
63
|
-
raise ValueError(f"Join key(s) {missing} not found in both ScenarioLists")
|
64
|
-
|
65
|
-
@staticmethod
|
66
|
-
def _get_key_tuple(scenario: Scenario, keys: list[str]) -> tuple:
|
67
|
-
"""Create a tuple of values for the join keys."""
|
68
|
-
return tuple(scenario[k] for k in keys)
|
69
|
-
|
70
|
-
def _create_lookup_dict(self, scenarios: ScenarioList, by_keys: list[str]) -> dict:
|
71
|
-
"""Create a lookup dictionary for the right scenarios."""
|
72
|
-
return {
|
73
|
-
self._get_key_tuple(scenario, by_keys): scenario for scenario in scenarios
|
74
|
-
}
|
75
|
-
|
76
|
-
def _get_all_keys(self) -> set:
|
77
|
-
"""Get all unique keys from both ScenarioLists."""
|
78
|
-
all_keys = set()
|
79
|
-
for scenario in self.left:
|
80
|
-
all_keys.update(scenario.keys())
|
81
|
-
for scenario in self.right:
|
82
|
-
all_keys.update(scenario.keys())
|
83
|
-
return all_keys
|
84
|
-
|
85
|
-
def _create_joined_scenarios(
|
86
|
-
self, by_keys: list[str], other_dict: dict, all_keys: set
|
87
|
-
) -> list[Scenario]:
|
88
|
-
"""Create the joined scenarios."""
|
89
|
-
new_scenarios = []
|
90
|
-
|
91
|
-
for scenario in self.left:
|
92
|
-
new_scenario = {key: None for key in all_keys}
|
93
|
-
new_scenario.update(scenario)
|
94
|
-
|
95
|
-
key_tuple = self._get_key_tuple(scenario, by_keys)
|
96
|
-
if matching_scenario := other_dict.get(key_tuple):
|
97
|
-
self._handle_matching_scenario(
|
98
|
-
new_scenario, scenario, matching_scenario, by_keys
|
99
|
-
)
|
100
|
-
|
101
|
-
new_scenarios.append(Scenario(new_scenario))
|
102
|
-
|
103
|
-
return new_scenarios
|
104
|
-
|
105
|
-
def _handle_matching_scenario(
|
106
|
-
self,
|
107
|
-
new_scenario: dict,
|
108
|
-
left_scenario: Scenario,
|
109
|
-
right_scenario: Scenario,
|
110
|
-
by_keys: list[str],
|
111
|
-
) -> None:
|
112
|
-
"""Handle merging of matching scenarios and conflict warnings."""
|
113
|
-
overlapping_keys = set(left_scenario.keys()) & set(right_scenario.keys())
|
114
|
-
|
115
|
-
for key in overlapping_keys:
|
116
|
-
if key not in by_keys and left_scenario[key] != right_scenario[key]:
|
117
|
-
join_conditions = [f"{k}='{left_scenario[k]}'" for k in by_keys]
|
118
|
-
print(
|
119
|
-
f"Warning: Conflicting values for key '{key}' where "
|
120
|
-
f"{' AND '.join(join_conditions)}. "
|
121
|
-
f"Keeping left value: {left_scenario[key]} "
|
122
|
-
f"(discarding: {right_scenario[key]})"
|
123
|
-
)
|
124
|
-
|
125
|
-
# Only update with non-overlapping keys from matching scenario
|
126
|
-
new_keys = set(right_scenario.keys()) - set(left_scenario.keys())
|
127
|
-
new_scenario.update({k: right_scenario[k] for k in new_keys})
|
1
|
+
from __future__ import annotations
|
2
|
+
from typing import Union, TYPE_CHECKING
|
3
|
+
|
4
|
+
# if TYPE_CHECKING:
|
5
|
+
from edsl.scenarios.ScenarioList import ScenarioList
|
6
|
+
from edsl.scenarios.Scenario import Scenario
|
7
|
+
|
8
|
+
|
9
|
+
class ScenarioJoin:
|
10
|
+
"""Handles join operations between two ScenarioLists.
|
11
|
+
|
12
|
+
This class encapsulates all join-related logic, making it easier to maintain
|
13
|
+
and extend with other join types (inner, right, full) in the future.
|
14
|
+
"""
|
15
|
+
|
16
|
+
def __init__(self, left: "ScenarioList", right: "ScenarioList"):
|
17
|
+
"""Initialize join operation with two ScenarioLists.
|
18
|
+
|
19
|
+
Args:
|
20
|
+
left: The left ScenarioList
|
21
|
+
right: The right ScenarioList
|
22
|
+
"""
|
23
|
+
self.left = left
|
24
|
+
self.right = right
|
25
|
+
|
26
|
+
def left_join(self, by: Union[str, list[str]]) -> ScenarioList:
|
27
|
+
"""Perform a left join between the two ScenarioLists.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
by: String or list of strings representing the key(s) to join on. Cannot be empty.
|
31
|
+
|
32
|
+
Returns:
|
33
|
+
A new ScenarioList containing the joined scenarios
|
34
|
+
|
35
|
+
Raises:
|
36
|
+
ValueError: If by is empty or if any join keys don't exist in both ScenarioLists
|
37
|
+
"""
|
38
|
+
self._validate_join_keys(by)
|
39
|
+
by_keys = [by] if isinstance(by, str) else by
|
40
|
+
|
41
|
+
other_dict = self._create_lookup_dict(self.right, by_keys)
|
42
|
+
all_keys = self._get_all_keys()
|
43
|
+
|
44
|
+
return ScenarioList(
|
45
|
+
self._create_joined_scenarios(by_keys, other_dict, all_keys)
|
46
|
+
)
|
47
|
+
|
48
|
+
def _validate_join_keys(self, by: Union[str, list[str]]) -> None:
|
49
|
+
"""Validate join keys exist in both ScenarioLists."""
|
50
|
+
if not by:
|
51
|
+
raise ValueError(
|
52
|
+
"Join keys cannot be empty. Please specify at least one key to join on."
|
53
|
+
)
|
54
|
+
|
55
|
+
by_keys = [by] if isinstance(by, str) else by
|
56
|
+
left_keys = set(next(iter(self.left)).keys()) if self.left else set()
|
57
|
+
right_keys = set(next(iter(self.right)).keys()) if self.right else set()
|
58
|
+
|
59
|
+
missing_left = set(by_keys) - left_keys
|
60
|
+
missing_right = set(by_keys) - right_keys
|
61
|
+
if missing_left or missing_right:
|
62
|
+
missing = missing_left | missing_right
|
63
|
+
raise ValueError(f"Join key(s) {missing} not found in both ScenarioLists")
|
64
|
+
|
65
|
+
@staticmethod
|
66
|
+
def _get_key_tuple(scenario: Scenario, keys: list[str]) -> tuple:
|
67
|
+
"""Create a tuple of values for the join keys."""
|
68
|
+
return tuple(scenario[k] for k in keys)
|
69
|
+
|
70
|
+
def _create_lookup_dict(self, scenarios: ScenarioList, by_keys: list[str]) -> dict:
|
71
|
+
"""Create a lookup dictionary for the right scenarios."""
|
72
|
+
return {
|
73
|
+
self._get_key_tuple(scenario, by_keys): scenario for scenario in scenarios
|
74
|
+
}
|
75
|
+
|
76
|
+
def _get_all_keys(self) -> set:
|
77
|
+
"""Get all unique keys from both ScenarioLists."""
|
78
|
+
all_keys = set()
|
79
|
+
for scenario in self.left:
|
80
|
+
all_keys.update(scenario.keys())
|
81
|
+
for scenario in self.right:
|
82
|
+
all_keys.update(scenario.keys())
|
83
|
+
return all_keys
|
84
|
+
|
85
|
+
def _create_joined_scenarios(
|
86
|
+
self, by_keys: list[str], other_dict: dict, all_keys: set
|
87
|
+
) -> list[Scenario]:
|
88
|
+
"""Create the joined scenarios."""
|
89
|
+
new_scenarios = []
|
90
|
+
|
91
|
+
for scenario in self.left:
|
92
|
+
new_scenario = {key: None for key in all_keys}
|
93
|
+
new_scenario.update(scenario)
|
94
|
+
|
95
|
+
key_tuple = self._get_key_tuple(scenario, by_keys)
|
96
|
+
if matching_scenario := other_dict.get(key_tuple):
|
97
|
+
self._handle_matching_scenario(
|
98
|
+
new_scenario, scenario, matching_scenario, by_keys
|
99
|
+
)
|
100
|
+
|
101
|
+
new_scenarios.append(Scenario(new_scenario))
|
102
|
+
|
103
|
+
return new_scenarios
|
104
|
+
|
105
|
+
def _handle_matching_scenario(
|
106
|
+
self,
|
107
|
+
new_scenario: dict,
|
108
|
+
left_scenario: Scenario,
|
109
|
+
right_scenario: Scenario,
|
110
|
+
by_keys: list[str],
|
111
|
+
) -> None:
|
112
|
+
"""Handle merging of matching scenarios and conflict warnings."""
|
113
|
+
overlapping_keys = set(left_scenario.keys()) & set(right_scenario.keys())
|
114
|
+
|
115
|
+
for key in overlapping_keys:
|
116
|
+
if key not in by_keys and left_scenario[key] != right_scenario[key]:
|
117
|
+
join_conditions = [f"{k}='{left_scenario[k]}'" for k in by_keys]
|
118
|
+
print(
|
119
|
+
f"Warning: Conflicting values for key '{key}' where "
|
120
|
+
f"{' AND '.join(join_conditions)}. "
|
121
|
+
f"Keeping left value: {left_scenario[key]} "
|
122
|
+
f"(discarding: {right_scenario[key]})"
|
123
|
+
)
|
124
|
+
|
125
|
+
# Only update with non-overlapping keys from matching scenario
|
126
|
+
new_keys = set(right_scenario.keys()) - set(left_scenario.keys())
|
127
|
+
new_scenario.update({k: right_scenario[k] for k in new_keys})
|