edsl 0.1.58__py3-none-any.whl → 0.1.59__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__version__.py +1 -1
- edsl/agents/agent.py +23 -4
- edsl/agents/agent_list.py +36 -6
- edsl/coop/coop.py +103 -1
- edsl/dataset/dataset.py +74 -0
- edsl/dataset/dataset_operations_mixin.py +67 -62
- edsl/inference_services/services/test_service.py +1 -1
- edsl/interviews/exception_tracking.py +66 -20
- edsl/invigilators/invigilators.py +5 -1
- edsl/invigilators/prompt_constructor.py +299 -136
- edsl/jobs/html_table_job_logger.py +18 -1
- edsl/jobs/jobs_pricing_estimation.py +6 -2
- edsl/jobs/jobs_remote_inference_logger.py +2 -0
- edsl/jobs/remote_inference.py +34 -7
- edsl/language_models/language_model.py +39 -2
- edsl/prompts/prompt.py +1 -0
- edsl/questions/question_list.py +76 -20
- edsl/results/results.py +8 -1
- edsl/scenarios/file_store.py +8 -12
- edsl/scenarios/scenario.py +50 -2
- edsl/scenarios/scenario_list.py +34 -12
- edsl/surveys/survey.py +4 -0
- edsl/tasks/task_history.py +180 -6
- edsl/utilities/wikipedia.py +194 -0
- {edsl-0.1.58.dist-info → edsl-0.1.59.dist-info}/METADATA +4 -3
- {edsl-0.1.58.dist-info → edsl-0.1.59.dist-info}/RECORD +29 -28
- {edsl-0.1.58.dist-info → edsl-0.1.59.dist-info}/LICENSE +0 -0
- {edsl-0.1.58.dist-info → edsl-0.1.59.dist-info}/WHEEL +0 -0
- {edsl-0.1.58.dist-info → edsl-0.1.59.dist-info}/entry_points.txt +0 -0
edsl/__version__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.1.
|
1
|
+
__version__ = "0.1.59"
|
edsl/agents/agent.py
CHANGED
@@ -426,6 +426,25 @@ class Agent(Base):
|
|
426
426
|
self.traits_presentation_template = "Your traits: {{traits}}"
|
427
427
|
self.set_traits_presentation_template = False
|
428
428
|
|
429
|
+
|
430
|
+
def drop(self, field_name: str) -> Agent:
|
431
|
+
"""Drop a field from the agent.
|
432
|
+
|
433
|
+
Args:
|
434
|
+
field_name: The name of the field to drop.
|
435
|
+
"""
|
436
|
+
d = self.to_dict()
|
437
|
+
if field_name in d['traits']:
|
438
|
+
d['traits'].pop(field_name)
|
439
|
+
elif field_name in d:
|
440
|
+
d.pop(field_name)
|
441
|
+
else:
|
442
|
+
raise AgentErrors((f"Field '{field_name}' not found in agent"
|
443
|
+
f"Available fields: {d.keys()}"
|
444
|
+
f"Available traits: {d['traits'].keys()}"
|
445
|
+
))
|
446
|
+
return Agent.from_dict(d)
|
447
|
+
|
429
448
|
def duplicate(self) -> Agent:
|
430
449
|
"""Create a deep copy of this agent with all its traits and capabilities.
|
431
450
|
|
@@ -1213,7 +1232,7 @@ class Agent(Base):
|
|
1213
1232
|
"""
|
1214
1233
|
return dict_hash(self.to_dict(add_edsl_version=False))
|
1215
1234
|
|
1216
|
-
def to_dict(self, add_edsl_version=True) -> dict[str, Union[dict, bool]]:
|
1235
|
+
def to_dict(self, add_edsl_version=True, full_dict=False) -> dict[str, Union[dict, bool]]:
|
1217
1236
|
"""Serialize to a dictionary with EDSL info.
|
1218
1237
|
|
1219
1238
|
Example usage:
|
@@ -1230,11 +1249,11 @@ class Agent(Base):
|
|
1230
1249
|
d["traits"] = copy.deepcopy(dict(self._traits))
|
1231
1250
|
if self.name:
|
1232
1251
|
d["name"] = self.name
|
1233
|
-
if self.set_instructions:
|
1252
|
+
if self.set_instructions or full_dict:
|
1234
1253
|
d["instruction"] = self.instruction
|
1235
|
-
if self.set_traits_presentation_template:
|
1254
|
+
if self.set_traits_presentation_template or full_dict:
|
1236
1255
|
d["traits_presentation_template"] = self.traits_presentation_template
|
1237
|
-
if self.codebook:
|
1256
|
+
if self.codebook or full_dict:
|
1238
1257
|
d["codebook"] = self.codebook
|
1239
1258
|
if add_edsl_version:
|
1240
1259
|
from edsl import __version__
|
edsl/agents/agent_list.py
CHANGED
@@ -47,13 +47,13 @@ class AgentList(UserList, Base, AgentListOperationsMixin):
|
|
47
47
|
with methods for filtering, transforming, and analyzing collections of agents.
|
48
48
|
|
49
49
|
|
50
|
-
>>> AgentList.example().to_scenario_list()
|
51
|
-
ScenarioList([Scenario({'
|
52
|
-
|
50
|
+
>>> AgentList.example().to_scenario_list().drop('age')
|
51
|
+
ScenarioList([Scenario({'hair': 'brown', 'height': 5.5}), Scenario({'hair': 'brown', 'height': 5.5})])
|
52
|
+
|
53
53
|
>>> AgentList.example().to_dataset()
|
54
54
|
Dataset([{'age': [22, 22]}, {'hair': ['brown', 'brown']}, {'height': [5.5, 5.5]}])
|
55
55
|
|
56
|
-
>>> AgentList.example().to_pandas()
|
56
|
+
>>> AgentList.example().select('age', 'hair', 'height').to_pandas()
|
57
57
|
age hair height
|
58
58
|
0 22 brown 5.5
|
59
59
|
1 22 brown 5.5
|
@@ -91,6 +91,28 @@ class AgentList(UserList, Base, AgentListOperationsMixin):
|
|
91
91
|
if codebook is not None:
|
92
92
|
self.set_codebook(codebook)
|
93
93
|
|
94
|
+
def set_instruction(self, instruction: str) -> None:
|
95
|
+
"""Set the instruction for all agents in the list.
|
96
|
+
|
97
|
+
Args:
|
98
|
+
instruction: The instruction to set.
|
99
|
+
"""
|
100
|
+
for agent in self.data:
|
101
|
+
agent.instruction = instruction
|
102
|
+
|
103
|
+
return None
|
104
|
+
|
105
|
+
def set_traits_presentation_template(self, traits_presentation_template: str) -> None:
|
106
|
+
"""Set the traits presentation template for all agents in the list.
|
107
|
+
|
108
|
+
Args:
|
109
|
+
traits_presentation_template: The traits presentation template to set.
|
110
|
+
"""
|
111
|
+
for agent in self.data:
|
112
|
+
agent.traits_presentation_template = traits_presentation_template
|
113
|
+
|
114
|
+
return None
|
115
|
+
|
94
116
|
def shuffle(self, seed: Optional[str] = None) -> AgentList:
|
95
117
|
"""Randomly shuffle the agents in place.
|
96
118
|
|
@@ -119,6 +141,14 @@ class AgentList(UserList, Base, AgentListOperationsMixin):
|
|
119
141
|
if seed:
|
120
142
|
random.seed(seed)
|
121
143
|
return AgentList(random.sample(self.data, n))
|
144
|
+
|
145
|
+
def drop(self, field_name: str) -> AgentList:
|
146
|
+
"""Drop a field from the AgentList.
|
147
|
+
|
148
|
+
Args:
|
149
|
+
field_name: The name of the field to drop.
|
150
|
+
"""
|
151
|
+
return AgentList([a.drop(field_name) for a in self.data])
|
122
152
|
|
123
153
|
def duplicate(self) -> AgentList:
|
124
154
|
"""Create a deep copy of the AgentList.
|
@@ -478,7 +508,7 @@ class AgentList(UserList, Base, AgentListOperationsMixin):
|
|
478
508
|
>>> al.to_dataset()
|
479
509
|
Dataset([{'age': [22, 22]}, {'hair': ['brown', 'brown']}, {'height': [5.5, 5.5]}])
|
480
510
|
>>> al.to_dataset(traits_only=False) # doctest: +NORMALIZE_WHITESPACE
|
481
|
-
Dataset([{'age': [22, 22]}, {'hair': ['brown', 'brown']}, {'height': [5.5, 5.5]}, {'agent_parameters': [{'instruction': 'You are answering questions as if you were a human. Do not break character.', 'agent_name': None}, {'instruction': 'You are answering questions as if you were a human. Do not break character.', 'agent_name': None}]}])
|
511
|
+
Dataset([{'age': [22, 22]}, {'hair': ['brown', 'brown']}, {'height': [5.5, 5.5]}, {'agent_parameters': [{'instruction': 'You are answering questions as if you were a human. Do not break character.', 'agent_name': None, 'traits_presentation_template': 'Your traits: {{traits}}'}, {'instruction': 'You are answering questions as if you were a human. Do not break character.', 'agent_name': None, 'traits_presentation_template': 'Your traits: {{traits}}'}]}])
|
482
512
|
"""
|
483
513
|
from ..dataset import Dataset
|
484
514
|
|
@@ -495,7 +525,7 @@ class AgentList(UserList, Base, AgentListOperationsMixin):
|
|
495
525
|
data[trait_key].append(agent.traits.get(trait_key, None))
|
496
526
|
if not traits_only:
|
497
527
|
data["agent_parameters"].append(
|
498
|
-
{"instruction": agent.instruction, "agent_name": agent.name}
|
528
|
+
{"instruction": agent.instruction, "agent_name": agent.name, "traits_presentation_template": agent.traits_presentation_template}
|
499
529
|
)
|
500
530
|
return Dataset([{key: entry} for key, entry in data.items()])
|
501
531
|
|
edsl/coop/coop.py
CHANGED
@@ -2,6 +2,7 @@ import aiohttp
|
|
2
2
|
import base64
|
3
3
|
import json
|
4
4
|
import requests
|
5
|
+
import time
|
5
6
|
|
6
7
|
from typing import Any, Optional, Union, Literal, List, TypedDict, TYPE_CHECKING
|
7
8
|
from uuid import UUID
|
@@ -13,7 +14,9 @@ from ..caching import CacheEntry
|
|
13
14
|
|
14
15
|
if TYPE_CHECKING:
|
15
16
|
from ..jobs import Jobs
|
17
|
+
from ..scenarios import ScenarioList
|
16
18
|
from ..surveys import Survey
|
19
|
+
from ..results import Results
|
17
20
|
|
18
21
|
from .exceptions import (
|
19
22
|
CoopInvalidURLError,
|
@@ -567,6 +570,7 @@ class Coop(CoopFunctionsMixin):
|
|
567
570
|
json.dumps(
|
568
571
|
object_dict,
|
569
572
|
default=self._json_handle_none,
|
573
|
+
allow_nan=False,
|
570
574
|
)
|
571
575
|
if object_type != "scenario"
|
572
576
|
else ""
|
@@ -585,6 +589,7 @@ class Coop(CoopFunctionsMixin):
|
|
585
589
|
json_data = json.dumps(
|
586
590
|
object_dict,
|
587
591
|
default=self._json_handle_none,
|
592
|
+
allow_nan=False,
|
588
593
|
)
|
589
594
|
headers = {"Content-Type": "application/json"}
|
590
595
|
if response_json.get("upload_signed_url"):
|
@@ -928,6 +933,7 @@ class Coop(CoopFunctionsMixin):
|
|
928
933
|
json.dumps(
|
929
934
|
value.to_dict(),
|
930
935
|
default=self._json_handle_none,
|
936
|
+
allow_nan=False,
|
931
937
|
)
|
932
938
|
if value
|
933
939
|
else None
|
@@ -1385,12 +1391,108 @@ class Coop(CoopFunctionsMixin):
|
|
1385
1391
|
self._resolve_server_response(response)
|
1386
1392
|
response_json = response.json()
|
1387
1393
|
return {
|
1388
|
-
"
|
1394
|
+
"project_name": response_json.get("project_name"),
|
1389
1395
|
"uuid": response_json.get("uuid"),
|
1390
1396
|
"admin_url": f"{self.url}/home/projects/{response_json.get('uuid')}",
|
1391
1397
|
"respondent_url": f"{self.url}/respond/{response_json.get('uuid')}",
|
1392
1398
|
}
|
1393
1399
|
|
1400
|
+
def get_project(
|
1401
|
+
self,
|
1402
|
+
project_uuid: str,
|
1403
|
+
) -> dict:
|
1404
|
+
"""
|
1405
|
+
Get a project from Coop.
|
1406
|
+
"""
|
1407
|
+
response = self._send_server_request(
|
1408
|
+
uri=f"api/v0/projects/{project_uuid}",
|
1409
|
+
method="GET",
|
1410
|
+
)
|
1411
|
+
self._resolve_server_response(response)
|
1412
|
+
response_json = response.json()
|
1413
|
+
return {
|
1414
|
+
"project_name": response_json.get("project_name"),
|
1415
|
+
"project_job_uuids": response_json.get("job_uuids"),
|
1416
|
+
}
|
1417
|
+
|
1418
|
+
def get_project_human_responses(
|
1419
|
+
self,
|
1420
|
+
project_uuid: str,
|
1421
|
+
) -> Union["Results", "ScenarioList"]:
|
1422
|
+
"""
|
1423
|
+
Return a Results object with the human responses for a project.
|
1424
|
+
|
1425
|
+
If generating the Results object fails, a ScenarioList will be returned instead.
|
1426
|
+
"""
|
1427
|
+
from ..agents import Agent, AgentList
|
1428
|
+
from ..caching import Cache
|
1429
|
+
from ..language_models import Model
|
1430
|
+
from ..scenarios import Scenario, ScenarioList
|
1431
|
+
from ..surveys import Survey
|
1432
|
+
|
1433
|
+
response = self._send_server_request(
|
1434
|
+
uri=f"api/v0/projects/{project_uuid}/human-responses",
|
1435
|
+
method="GET",
|
1436
|
+
)
|
1437
|
+
self._resolve_server_response(response)
|
1438
|
+
response_json = response.json()
|
1439
|
+
human_responses = response_json.get("human_responses", [])
|
1440
|
+
|
1441
|
+
try:
|
1442
|
+
agent_list = AgentList()
|
1443
|
+
|
1444
|
+
for response in human_responses:
|
1445
|
+
response_uuid = response.get("response_uuid")
|
1446
|
+
if response_uuid is None:
|
1447
|
+
raise RuntimeError(
|
1448
|
+
"One of your responses is missing a unique identifier."
|
1449
|
+
)
|
1450
|
+
|
1451
|
+
response_dict = json.loads(response.get("response_json_string"))
|
1452
|
+
|
1453
|
+
a = Agent(name=response_uuid, instruction="")
|
1454
|
+
|
1455
|
+
def create_answer_function(response_data):
|
1456
|
+
def f(self, question, scenario):
|
1457
|
+
return response_data.get(question.question_name, None)
|
1458
|
+
|
1459
|
+
return f
|
1460
|
+
|
1461
|
+
a.add_direct_question_answering_method(
|
1462
|
+
create_answer_function(response_dict)
|
1463
|
+
)
|
1464
|
+
agent_list.append(a)
|
1465
|
+
|
1466
|
+
survey_json_string = response_json.get("survey_json_string")
|
1467
|
+
survey = Survey.from_dict(json.loads(survey_json_string))
|
1468
|
+
|
1469
|
+
model = Model("test")
|
1470
|
+
results = (
|
1471
|
+
survey.by(agent_list)
|
1472
|
+
.by(model)
|
1473
|
+
.run(
|
1474
|
+
cache=Cache(),
|
1475
|
+
disable_remote_cache=True,
|
1476
|
+
disable_remote_inference=True,
|
1477
|
+
print_exceptions=False,
|
1478
|
+
)
|
1479
|
+
)
|
1480
|
+
return results
|
1481
|
+
except Exception:
|
1482
|
+
human_response_scenarios = []
|
1483
|
+
for response in human_responses:
|
1484
|
+
response_uuid = response.get("response_uuid")
|
1485
|
+
if response_uuid is None:
|
1486
|
+
raise RuntimeError(
|
1487
|
+
"One of your responses is missing a unique identifier."
|
1488
|
+
)
|
1489
|
+
|
1490
|
+
response_dict = json.loads(response.get("response_json_string"))
|
1491
|
+
response_dict["agent_name"] = response_uuid
|
1492
|
+
scenario = Scenario(response_dict)
|
1493
|
+
human_response_scenarios.append(scenario)
|
1494
|
+
return ScenarioList(human_response_scenarios)
|
1495
|
+
|
1394
1496
|
def __repr__(self):
|
1395
1497
|
"""Return a string representation of the client."""
|
1396
1498
|
return f"Client(api_key='{self.api_key}', url='{self.url}')"
|
edsl/dataset/dataset.py
CHANGED
@@ -93,6 +93,38 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
|
|
93
93
|
"""
|
94
94
|
_, values = list(self.data[0].items())[0]
|
95
95
|
return len(values)
|
96
|
+
|
97
|
+
def drop(self, field_name):
|
98
|
+
"""
|
99
|
+
Returns a new Dataset with the specified field removed.
|
100
|
+
|
101
|
+
Args:
|
102
|
+
field_name (str): The name of the field to remove.
|
103
|
+
|
104
|
+
Returns:
|
105
|
+
Dataset: A new Dataset instance without the specified field.
|
106
|
+
|
107
|
+
Raises:
|
108
|
+
KeyError: If the field_name doesn't exist in the dataset.
|
109
|
+
|
110
|
+
Examples:
|
111
|
+
>>> from .dataset import Dataset
|
112
|
+
>>> d = Dataset([{'a': [1, 2, 3]}, {'b': [4, 5, 6]}])
|
113
|
+
>>> d.drop('a')
|
114
|
+
Dataset([{'b': [4, 5, 6]}])
|
115
|
+
|
116
|
+
>>> # Testing drop with nonexistent field raises DatasetKeyError - tested in unit tests
|
117
|
+
"""
|
118
|
+
from .dataset import Dataset
|
119
|
+
|
120
|
+
# Check if field exists in the dataset
|
121
|
+
if field_name not in self.relevant_columns():
|
122
|
+
raise DatasetKeyError(f"Field '{field_name}' not found in dataset")
|
123
|
+
|
124
|
+
# Create a new dataset without the specified field
|
125
|
+
new_data = [entry for entry in self.data if field_name not in entry]
|
126
|
+
return Dataset(new_data)
|
127
|
+
|
96
128
|
|
97
129
|
def tail(self, n: int = 5) -> Dataset:
|
98
130
|
"""Return the last n observations in the dataset.
|
@@ -1054,6 +1086,48 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
|
|
1054
1086
|
|
1055
1087
|
return Dataset(new_data)
|
1056
1088
|
|
1089
|
+
def unique(self) -> "Dataset":
|
1090
|
+
"""Return a new dataset with only unique observations.
|
1091
|
+
|
1092
|
+
Examples:
|
1093
|
+
>>> d = Dataset([{'a': [1, 2, 2, 3]}, {'b': [4, 5, 5, 6]}])
|
1094
|
+
>>> d.unique().data
|
1095
|
+
[{'a': [1, 2, 3]}, {'b': [4, 5, 6]}]
|
1096
|
+
|
1097
|
+
>>> d = Dataset([{'x': ['a', 'a', 'b']}, {'y': [1, 1, 2]}])
|
1098
|
+
>>> d.unique().data
|
1099
|
+
[{'x': ['a', 'b']}, {'y': [1, 2]}]
|
1100
|
+
"""
|
1101
|
+
# Get all column names and values
|
1102
|
+
headers, data = self._tabular()
|
1103
|
+
|
1104
|
+
# Create a list of unique rows
|
1105
|
+
unique_rows = []
|
1106
|
+
seen = set()
|
1107
|
+
|
1108
|
+
for row in data:
|
1109
|
+
# Convert the row to a hashable representation for comparison
|
1110
|
+
# We need to handle potential unhashable types
|
1111
|
+
try:
|
1112
|
+
row_key = tuple(map(lambda x: str(x) if isinstance(x, (list, dict)) else x, row))
|
1113
|
+
if row_key not in seen:
|
1114
|
+
seen.add(row_key)
|
1115
|
+
unique_rows.append(row)
|
1116
|
+
except:
|
1117
|
+
# Fallback for complex objects: compare based on string representation
|
1118
|
+
row_str = str(row)
|
1119
|
+
if row_str not in seen:
|
1120
|
+
seen.add(row_str)
|
1121
|
+
unique_rows.append(row)
|
1122
|
+
|
1123
|
+
# Create a new dataset with unique combinations
|
1124
|
+
new_data = []
|
1125
|
+
for i, header in enumerate(headers):
|
1126
|
+
values = [row[i] for row in unique_rows]
|
1127
|
+
new_data.append({header: values})
|
1128
|
+
|
1129
|
+
return Dataset(new_data)
|
1130
|
+
|
1057
1131
|
|
1058
1132
|
if __name__ == "__main__":
|
1059
1133
|
import doctest
|
@@ -1070,7 +1070,6 @@ class DataOperationsBase:
|
|
1070
1070
|
- All dictionaries in the field must have compatible structures
|
1071
1071
|
- If a dictionary is missing a key, the corresponding value will be None
|
1072
1072
|
- Non-dictionary values in the field will cause a warning
|
1073
|
-
|
1074
1073
|
Examples:
|
1075
1074
|
>>> from edsl.dataset import Dataset
|
1076
1075
|
|
@@ -1086,48 +1085,85 @@ class DataOperationsBase:
|
|
1086
1085
|
>>> d = Dataset([{'a': [{'a': 1, 'b': 2}]}, {'c': [5]}])
|
1087
1086
|
>>> d.flatten('a', keep_original=True)
|
1088
1087
|
Dataset([{'a': [{'a': 1, 'b': 2}]}, {'c': [5]}, {'a.a': [1]}, {'a.b': [2]}])
|
1088
|
+
|
1089
|
+
# Can also use unambiguous unprefixed field name
|
1090
|
+
>>> result = Dataset([{'answer.pros_cons': [{'pros': ['Safety'], 'cons': ['Cost']}]}]).flatten('pros_cons')
|
1091
|
+
>>> sorted(result.keys()) == ['answer.pros_cons.cons', 'answer.pros_cons.pros']
|
1092
|
+
True
|
1093
|
+
>>> sorted(result.to_dicts()[0].items()) == sorted({'cons': ['Cost'], 'pros': ['Safety']}.items())
|
1094
|
+
True
|
1089
1095
|
"""
|
1090
1096
|
from ..dataset import Dataset
|
1091
1097
|
|
1092
1098
|
# Ensure the dataset isn't empty
|
1093
1099
|
if not self.data:
|
1094
1100
|
return self.copy()
|
1095
|
-
|
1096
|
-
#
|
1097
|
-
matching_entries = []
|
1098
|
-
for entry in self.data:
|
1099
|
-
col_name = next(iter(entry.keys()))
|
1100
|
-
if field == col_name or (
|
1101
|
-
"." in col_name
|
1102
|
-
and (col_name.endswith("." + field) or col_name.startswith(field + "."))
|
1103
|
-
):
|
1104
|
-
matching_entries.append(entry)
|
1105
|
-
|
1106
|
-
# Check if the field is ambiguous
|
1107
|
-
if len(matching_entries) > 1:
|
1108
|
-
matching_cols = [next(iter(entry.keys())) for entry in matching_entries]
|
1109
|
-
from .exceptions import DatasetValueError
|
1110
|
-
|
1111
|
-
raise DatasetValueError(
|
1112
|
-
f"Ambiguous field name '{field}'. It matches multiple columns: {matching_cols}. "
|
1113
|
-
f"Please specify the full column name to flatten."
|
1114
|
-
)
|
1115
|
-
|
1116
|
-
# Get the number of observations
|
1117
|
-
num_observations = self.num_observations()
|
1118
|
-
|
1119
|
-
# Find the column to flatten
|
1101
|
+
|
1102
|
+
# First try direct match with the exact field name
|
1120
1103
|
field_entry = None
|
1121
1104
|
for entry in self.data:
|
1122
|
-
|
1105
|
+
col_name = next(iter(entry.keys()))
|
1106
|
+
if field == col_name:
|
1123
1107
|
field_entry = entry
|
1124
1108
|
break
|
1109
|
+
|
1110
|
+
# If not found, try to match by unprefixed name
|
1111
|
+
if field_entry is None:
|
1112
|
+
# Find any columns that have field as their unprefixed name
|
1113
|
+
candidates = []
|
1114
|
+
for entry in self.data:
|
1115
|
+
col_name = next(iter(entry.keys()))
|
1116
|
+
if '.' in col_name:
|
1117
|
+
prefix, col_field = col_name.split('.', 1)
|
1118
|
+
if col_field == field:
|
1119
|
+
candidates.append(entry)
|
1120
|
+
|
1121
|
+
# If we found exactly one match by unprefixed name, use it
|
1122
|
+
if len(candidates) == 1:
|
1123
|
+
field_entry = candidates[0]
|
1124
|
+
# If we found multiple matches, it's ambiguous
|
1125
|
+
elif len(candidates) > 1:
|
1126
|
+
matching_cols = [next(iter(entry.keys())) for entry in candidates]
|
1127
|
+
from .exceptions import DatasetValueError
|
1128
|
+
raise DatasetValueError(
|
1129
|
+
f"Ambiguous field name '{field}'. It matches multiple columns: {matching_cols}. "
|
1130
|
+
f"Please specify the full column name to flatten."
|
1131
|
+
)
|
1132
|
+
# If no candidates by unprefixed name, check partial matches
|
1133
|
+
else:
|
1134
|
+
partial_matches = []
|
1135
|
+
for entry in self.data:
|
1136
|
+
col_name = next(iter(entry.keys()))
|
1137
|
+
if '.' in col_name and (
|
1138
|
+
col_name.endswith('.' + field) or
|
1139
|
+
col_name.startswith(field + '.')
|
1140
|
+
):
|
1141
|
+
partial_matches.append(entry)
|
1142
|
+
|
1143
|
+
# If we found exactly one partial match, use it
|
1144
|
+
if len(partial_matches) == 1:
|
1145
|
+
field_entry = partial_matches[0]
|
1146
|
+
# If we found multiple partial matches, it's ambiguous
|
1147
|
+
elif len(partial_matches) > 1:
|
1148
|
+
matching_cols = [next(iter(entry.keys())) for entry in partial_matches]
|
1149
|
+
from .exceptions import DatasetValueError
|
1150
|
+
raise DatasetValueError(
|
1151
|
+
f"Ambiguous field name '{field}'. It matches multiple columns: {matching_cols}. "
|
1152
|
+
f"Please specify the full column name to flatten."
|
1153
|
+
)
|
1154
|
+
|
1155
|
+
# Get the number of observations
|
1156
|
+
num_observations = self.num_observations()
|
1125
1157
|
|
1158
|
+
# If we still haven't found the field, it's not in the dataset
|
1126
1159
|
if field_entry is None:
|
1127
1160
|
warnings.warn(
|
1128
1161
|
f"Field '{field}' not found in dataset, returning original dataset"
|
1129
1162
|
)
|
1130
1163
|
return self.copy()
|
1164
|
+
|
1165
|
+
# Get the actual field name as it appears in the data
|
1166
|
+
actual_field = next(iter(field_entry.keys()))
|
1131
1167
|
|
1132
1168
|
# Create new dictionary for flattened data
|
1133
1169
|
flattened_data = []
|
@@ -1135,14 +1171,14 @@ class DataOperationsBase:
|
|
1135
1171
|
# Copy all existing columns except the one we're flattening (if keep_original is False)
|
1136
1172
|
for entry in self.data:
|
1137
1173
|
col_name = next(iter(entry.keys()))
|
1138
|
-
if col_name !=
|
1174
|
+
if col_name != actual_field or keep_original:
|
1139
1175
|
flattened_data.append(entry.copy())
|
1140
1176
|
|
1141
1177
|
# Get field data and make sure it's valid
|
1142
|
-
field_values = field_entry[
|
1178
|
+
field_values = field_entry[actual_field]
|
1143
1179
|
if not all(isinstance(item, dict) for item in field_values if item is not None):
|
1144
1180
|
warnings.warn(
|
1145
|
-
f"Field '{
|
1181
|
+
f"Field '{actual_field}' contains non-dictionary values that cannot be flattened"
|
1146
1182
|
)
|
1147
1183
|
return self.copy()
|
1148
1184
|
|
@@ -1162,7 +1198,7 @@ class DataOperationsBase:
|
|
1162
1198
|
new_values.append(value)
|
1163
1199
|
|
1164
1200
|
# Add this as a new column
|
1165
|
-
flattened_data.append({f"{
|
1201
|
+
flattened_data.append({f"{actual_field}.{key}": new_values})
|
1166
1202
|
|
1167
1203
|
# Return a new Dataset with the flattened data
|
1168
1204
|
return Dataset(flattened_data)
|
@@ -1244,37 +1280,6 @@ class DataOperationsBase:
|
|
1244
1280
|
|
1245
1281
|
return result
|
1246
1282
|
|
1247
|
-
def drop(self, field_name):
|
1248
|
-
"""
|
1249
|
-
Returns a new Dataset with the specified field removed.
|
1250
|
-
|
1251
|
-
Args:
|
1252
|
-
field_name (str): The name of the field to remove.
|
1253
|
-
|
1254
|
-
Returns:
|
1255
|
-
Dataset: A new Dataset instance without the specified field.
|
1256
|
-
|
1257
|
-
Raises:
|
1258
|
-
KeyError: If the field_name doesn't exist in the dataset.
|
1259
|
-
|
1260
|
-
Examples:
|
1261
|
-
>>> from .dataset import Dataset
|
1262
|
-
>>> d = Dataset([{'a': [1, 2, 3]}, {'b': [4, 5, 6]}])
|
1263
|
-
>>> d.drop('a')
|
1264
|
-
Dataset([{'b': [4, 5, 6]}])
|
1265
|
-
|
1266
|
-
>>> # Testing drop with nonexistent field raises DatasetKeyError - tested in unit tests
|
1267
|
-
"""
|
1268
|
-
from .dataset import Dataset
|
1269
|
-
|
1270
|
-
# Check if field exists in the dataset
|
1271
|
-
if field_name not in self.relevant_columns():
|
1272
|
-
raise DatasetKeyError(f"Field '{field_name}' not found in dataset")
|
1273
|
-
|
1274
|
-
# Create a new dataset without the specified field
|
1275
|
-
new_data = [entry for entry in self.data if field_name not in entry]
|
1276
|
-
return Dataset(new_data)
|
1277
|
-
|
1278
1283
|
def remove_prefix(self):
|
1279
1284
|
"""Returns a new Dataset with the prefix removed from all column names.
|
1280
1285
|
|