edsl 0.1.58__py3-none-any.whl → 0.1.60__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. edsl/__version__.py +1 -1
  2. edsl/agents/agent.py +23 -4
  3. edsl/agents/agent_list.py +36 -6
  4. edsl/base/data_transfer_models.py +5 -0
  5. edsl/base/enums.py +7 -2
  6. edsl/coop/coop.py +103 -1
  7. edsl/dataset/dataset.py +74 -0
  8. edsl/dataset/dataset_operations_mixin.py +69 -64
  9. edsl/inference_services/services/__init__.py +3 -1
  10. edsl/inference_services/services/open_ai_service_v2.py +243 -0
  11. edsl/inference_services/services/test_service.py +1 -1
  12. edsl/interviews/exception_tracking.py +66 -20
  13. edsl/invigilators/invigilators.py +5 -1
  14. edsl/invigilators/prompt_constructor.py +299 -136
  15. edsl/jobs/data_structures.py +3 -0
  16. edsl/jobs/html_table_job_logger.py +18 -1
  17. edsl/jobs/jobs_pricing_estimation.py +6 -2
  18. edsl/jobs/jobs_remote_inference_logger.py +2 -0
  19. edsl/jobs/remote_inference.py +34 -7
  20. edsl/key_management/key_lookup_builder.py +25 -3
  21. edsl/language_models/language_model.py +41 -3
  22. edsl/language_models/raw_response_handler.py +126 -7
  23. edsl/prompts/prompt.py +1 -0
  24. edsl/questions/question_list.py +76 -20
  25. edsl/results/result.py +37 -0
  26. edsl/results/results.py +9 -1
  27. edsl/scenarios/file_store.py +8 -12
  28. edsl/scenarios/scenario.py +50 -2
  29. edsl/scenarios/scenario_list.py +34 -12
  30. edsl/surveys/survey.py +4 -0
  31. edsl/tasks/task_history.py +180 -6
  32. edsl/utilities/wikipedia.py +194 -0
  33. {edsl-0.1.58.dist-info → edsl-0.1.60.dist-info}/METADATA +5 -4
  34. {edsl-0.1.58.dist-info → edsl-0.1.60.dist-info}/RECORD +37 -35
  35. {edsl-0.1.58.dist-info → edsl-0.1.60.dist-info}/LICENSE +0 -0
  36. {edsl-0.1.58.dist-info → edsl-0.1.60.dist-info}/WHEEL +0 -0
  37. {edsl-0.1.58.dist-info → edsl-0.1.60.dist-info}/entry_points.txt +0 -0
edsl/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.1.58"
1
+ __version__ = "0.1.60"
edsl/agents/agent.py CHANGED
@@ -426,6 +426,25 @@ class Agent(Base):
426
426
  self.traits_presentation_template = "Your traits: {{traits}}"
427
427
  self.set_traits_presentation_template = False
428
428
 
429
+
430
+ def drop(self, field_name: str) -> Agent:
431
+ """Drop a field from the agent.
432
+
433
+ Args:
434
+ field_name: The name of the field to drop.
435
+ """
436
+ d = self.to_dict()
437
+ if field_name in d['traits']:
438
+ d['traits'].pop(field_name)
439
+ elif field_name in d:
440
+ d.pop(field_name)
441
+ else:
442
+ raise AgentErrors((f"Field '{field_name}' not found in agent"
443
+ f"Available fields: {d.keys()}"
444
+ f"Available traits: {d['traits'].keys()}"
445
+ ))
446
+ return Agent.from_dict(d)
447
+
429
448
  def duplicate(self) -> Agent:
430
449
  """Create a deep copy of this agent with all its traits and capabilities.
431
450
 
@@ -1213,7 +1232,7 @@ class Agent(Base):
1213
1232
  """
1214
1233
  return dict_hash(self.to_dict(add_edsl_version=False))
1215
1234
 
1216
- def to_dict(self, add_edsl_version=True) -> dict[str, Union[dict, bool]]:
1235
+ def to_dict(self, add_edsl_version=True, full_dict=False) -> dict[str, Union[dict, bool]]:
1217
1236
  """Serialize to a dictionary with EDSL info.
1218
1237
 
1219
1238
  Example usage:
@@ -1230,11 +1249,11 @@ class Agent(Base):
1230
1249
  d["traits"] = copy.deepcopy(dict(self._traits))
1231
1250
  if self.name:
1232
1251
  d["name"] = self.name
1233
- if self.set_instructions:
1252
+ if self.set_instructions or full_dict:
1234
1253
  d["instruction"] = self.instruction
1235
- if self.set_traits_presentation_template:
1254
+ if self.set_traits_presentation_template or full_dict:
1236
1255
  d["traits_presentation_template"] = self.traits_presentation_template
1237
- if self.codebook:
1256
+ if self.codebook or full_dict:
1238
1257
  d["codebook"] = self.codebook
1239
1258
  if add_edsl_version:
1240
1259
  from edsl import __version__
edsl/agents/agent_list.py CHANGED
@@ -47,13 +47,13 @@ class AgentList(UserList, Base, AgentListOperationsMixin):
47
47
  with methods for filtering, transforming, and analyzing collections of agents.
48
48
 
49
49
 
50
- >>> AgentList.example().to_scenario_list()
51
- ScenarioList([Scenario({'age': 22, 'hair': 'brown', 'height': 5.5}), Scenario({'age': 22, 'hair': 'brown', 'height': 5.5})])
52
-
50
+ >>> AgentList.example().to_scenario_list().drop('age')
51
+ ScenarioList([Scenario({'hair': 'brown', 'height': 5.5}), Scenario({'hair': 'brown', 'height': 5.5})])
52
+
53
53
  >>> AgentList.example().to_dataset()
54
54
  Dataset([{'age': [22, 22]}, {'hair': ['brown', 'brown']}, {'height': [5.5, 5.5]}])
55
55
 
56
- >>> AgentList.example().to_pandas()
56
+ >>> AgentList.example().select('age', 'hair', 'height').to_pandas()
57
57
  age hair height
58
58
  0 22 brown 5.5
59
59
  1 22 brown 5.5
@@ -91,6 +91,28 @@ class AgentList(UserList, Base, AgentListOperationsMixin):
91
91
  if codebook is not None:
92
92
  self.set_codebook(codebook)
93
93
 
94
+ def set_instruction(self, instruction: str) -> None:
95
+ """Set the instruction for all agents in the list.
96
+
97
+ Args:
98
+ instruction: The instruction to set.
99
+ """
100
+ for agent in self.data:
101
+ agent.instruction = instruction
102
+
103
+ return None
104
+
105
+ def set_traits_presentation_template(self, traits_presentation_template: str) -> None:
106
+ """Set the traits presentation template for all agents in the list.
107
+
108
+ Args:
109
+ traits_presentation_template: The traits presentation template to set.
110
+ """
111
+ for agent in self.data:
112
+ agent.traits_presentation_template = traits_presentation_template
113
+
114
+ return None
115
+
94
116
  def shuffle(self, seed: Optional[str] = None) -> AgentList:
95
117
  """Randomly shuffle the agents in place.
96
118
 
@@ -119,6 +141,14 @@ class AgentList(UserList, Base, AgentListOperationsMixin):
119
141
  if seed:
120
142
  random.seed(seed)
121
143
  return AgentList(random.sample(self.data, n))
144
+
145
+ def drop(self, field_name: str) -> AgentList:
146
+ """Drop a field from the AgentList.
147
+
148
+ Args:
149
+ field_name: The name of the field to drop.
150
+ """
151
+ return AgentList([a.drop(field_name) for a in self.data])
122
152
 
123
153
  def duplicate(self) -> AgentList:
124
154
  """Create a deep copy of the AgentList.
@@ -478,7 +508,7 @@ class AgentList(UserList, Base, AgentListOperationsMixin):
478
508
  >>> al.to_dataset()
479
509
  Dataset([{'age': [22, 22]}, {'hair': ['brown', 'brown']}, {'height': [5.5, 5.5]}])
480
510
  >>> al.to_dataset(traits_only=False) # doctest: +NORMALIZE_WHITESPACE
481
- Dataset([{'age': [22, 22]}, {'hair': ['brown', 'brown']}, {'height': [5.5, 5.5]}, {'agent_parameters': [{'instruction': 'You are answering questions as if you were a human. Do not break character.', 'agent_name': None}, {'instruction': 'You are answering questions as if you were a human. Do not break character.', 'agent_name': None}]}])
511
+ Dataset([{'age': [22, 22]}, {'hair': ['brown', 'brown']}, {'height': [5.5, 5.5]}, {'agent_parameters': [{'instruction': 'You are answering questions as if you were a human. Do not break character.', 'agent_name': None, 'traits_presentation_template': 'Your traits: {{traits}}'}, {'instruction': 'You are answering questions as if you were a human. Do not break character.', 'agent_name': None, 'traits_presentation_template': 'Your traits: {{traits}}'}]}])
482
512
  """
483
513
  from ..dataset import Dataset
484
514
 
@@ -495,7 +525,7 @@ class AgentList(UserList, Base, AgentListOperationsMixin):
495
525
  data[trait_key].append(agent.traits.get(trait_key, None))
496
526
  if not traits_only:
497
527
  data["agent_parameters"].append(
498
- {"instruction": agent.instruction, "agent_name": agent.name}
528
+ {"instruction": agent.instruction, "agent_name": agent.name, "traits_presentation_template": agent.traits_presentation_template}
499
529
  )
500
530
  return Dataset([{key: entry} for key, entry in data.items()])
501
531
 
@@ -17,6 +17,7 @@ class EDSLOutput(NamedTuple):
17
17
  answer: Any
18
18
  generated_tokens: str
19
19
  comment: Optional[str] = None
20
+ reasoning_summary: Optional[Any] = None
20
21
 
21
22
 
22
23
  class ModelResponse(NamedTuple):
@@ -49,6 +50,7 @@ class EDSLResultObjectInput(NamedTuple):
49
50
  cache_key: str
50
51
  answer: Any
51
52
  comment: str
53
+ reasoning_summary: Optional[Any] = None
52
54
  validated: bool = False
53
55
  exception_occurred: Exception = None
54
56
  input_tokens: Optional[int] = None
@@ -96,12 +98,15 @@ class Answers(UserDict):
96
98
  answer = response.answer
97
99
  comment = response.comment
98
100
  generated_tokens = response.generated_tokens
101
+ reasoning_summary = response.reasoning_summary
99
102
  # record the answer
100
103
  if generated_tokens:
101
104
  self[question.question_name + "_generated_tokens"] = generated_tokens
102
105
  self[question.question_name] = answer
103
106
  if comment:
104
107
  self[question.question_name + "_comment"] = comment
108
+ if reasoning_summary:
109
+ self[question.question_name + "_reasoning_summary"] = reasoning_summary
105
110
 
106
111
  def replace_missing_answers_with_none(self, survey: "Survey") -> None:
107
112
  """Replace missing answers with None. Answers can be missing if the agent skips a question."""
edsl/base/enums.py CHANGED
@@ -57,6 +57,7 @@ class InferenceServiceType(EnumWithChecks):
57
57
  DEEP_INFRA = "deep_infra"
58
58
  REPLICATE = "replicate"
59
59
  OPENAI = "openai"
60
+ OPENAI_V2 = "openai_v2"
60
61
  GOOGLE = "google"
61
62
  TEST = "test"
62
63
  ANTHROPIC = "anthropic"
@@ -77,6 +78,7 @@ InferenceServiceLiteral = Literal[
77
78
  "deep_infra",
78
79
  "replicate",
79
80
  "openai",
81
+ "openai_v2",
80
82
  "google",
81
83
  "test",
82
84
  "anthropic",
@@ -93,6 +95,7 @@ InferenceServiceLiteral = Literal[
93
95
  available_models_urls = {
94
96
  "anthropic": "https://docs.anthropic.com/en/docs/about-claude/models",
95
97
  "openai": "https://platform.openai.com/docs/models/gp",
98
+ "openai_v2": "https://platform.openai.com/docs/models/gp",
96
99
  "groq": "https://console.groq.com/docs/models",
97
100
  "google": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models",
98
101
  }
@@ -102,6 +105,7 @@ service_to_api_keyname = {
102
105
  InferenceServiceType.DEEP_INFRA.value: "DEEP_INFRA_API_KEY",
103
106
  InferenceServiceType.REPLICATE.value: "TBD",
104
107
  InferenceServiceType.OPENAI.value: "OPENAI_API_KEY",
108
+ InferenceServiceType.OPENAI_V2.value: "OPENAI_API_KEY",
105
109
  InferenceServiceType.GOOGLE.value: "GOOGLE_API_KEY",
106
110
  InferenceServiceType.TEST.value: "TBD",
107
111
  InferenceServiceType.ANTHROPIC.value: "ANTHROPIC_API_KEY",
@@ -135,7 +139,7 @@ class TokenPricing:
135
139
  and self.prompt_token_price == other.prompt_token_price
136
140
  and self.completion_token_price == other.completion_token_price
137
141
  )
138
-
142
+
139
143
  @classmethod
140
144
  def example(cls) -> "TokenPricing":
141
145
  """Return an example TokenPricing object."""
@@ -145,6 +149,7 @@ class TokenPricing:
145
149
  completion_token_price_per_k=0.03,
146
150
  )
147
151
 
152
+
148
153
  pricing = {
149
154
  "dbrx-instruct": TokenPricing(
150
155
  model_name="dbrx-instruct",
@@ -212,4 +217,4 @@ def get_token_pricing(model_name):
212
217
  model_name=model_name,
213
218
  prompt_token_price_per_k=0.0,
214
219
  completion_token_price_per_k=0.0,
215
- )
220
+ )
edsl/coop/coop.py CHANGED
@@ -2,6 +2,7 @@ import aiohttp
2
2
  import base64
3
3
  import json
4
4
  import requests
5
+ import time
5
6
 
6
7
  from typing import Any, Optional, Union, Literal, List, TypedDict, TYPE_CHECKING
7
8
  from uuid import UUID
@@ -13,7 +14,9 @@ from ..caching import CacheEntry
13
14
 
14
15
  if TYPE_CHECKING:
15
16
  from ..jobs import Jobs
17
+ from ..scenarios import ScenarioList
16
18
  from ..surveys import Survey
19
+ from ..results import Results
17
20
 
18
21
  from .exceptions import (
19
22
  CoopInvalidURLError,
@@ -567,6 +570,7 @@ class Coop(CoopFunctionsMixin):
567
570
  json.dumps(
568
571
  object_dict,
569
572
  default=self._json_handle_none,
573
+ allow_nan=False,
570
574
  )
571
575
  if object_type != "scenario"
572
576
  else ""
@@ -585,6 +589,7 @@ class Coop(CoopFunctionsMixin):
585
589
  json_data = json.dumps(
586
590
  object_dict,
587
591
  default=self._json_handle_none,
592
+ allow_nan=False,
588
593
  )
589
594
  headers = {"Content-Type": "application/json"}
590
595
  if response_json.get("upload_signed_url"):
@@ -928,6 +933,7 @@ class Coop(CoopFunctionsMixin):
928
933
  json.dumps(
929
934
  value.to_dict(),
930
935
  default=self._json_handle_none,
936
+ allow_nan=False,
931
937
  )
932
938
  if value
933
939
  else None
@@ -1385,12 +1391,108 @@ class Coop(CoopFunctionsMixin):
1385
1391
  self._resolve_server_response(response)
1386
1392
  response_json = response.json()
1387
1393
  return {
1388
- "name": response_json.get("project_name"),
1394
+ "project_name": response_json.get("project_name"),
1389
1395
  "uuid": response_json.get("uuid"),
1390
1396
  "admin_url": f"{self.url}/home/projects/{response_json.get('uuid')}",
1391
1397
  "respondent_url": f"{self.url}/respond/{response_json.get('uuid')}",
1392
1398
  }
1393
1399
 
1400
+ def get_project(
1401
+ self,
1402
+ project_uuid: str,
1403
+ ) -> dict:
1404
+ """
1405
+ Get a project from Coop.
1406
+ """
1407
+ response = self._send_server_request(
1408
+ uri=f"api/v0/projects/{project_uuid}",
1409
+ method="GET",
1410
+ )
1411
+ self._resolve_server_response(response)
1412
+ response_json = response.json()
1413
+ return {
1414
+ "project_name": response_json.get("project_name"),
1415
+ "project_job_uuids": response_json.get("job_uuids"),
1416
+ }
1417
+
1418
+ def get_project_human_responses(
1419
+ self,
1420
+ project_uuid: str,
1421
+ ) -> Union["Results", "ScenarioList"]:
1422
+ """
1423
+ Return a Results object with the human responses for a project.
1424
+
1425
+ If generating the Results object fails, a ScenarioList will be returned instead.
1426
+ """
1427
+ from ..agents import Agent, AgentList
1428
+ from ..caching import Cache
1429
+ from ..language_models import Model
1430
+ from ..scenarios import Scenario, ScenarioList
1431
+ from ..surveys import Survey
1432
+
1433
+ response = self._send_server_request(
1434
+ uri=f"api/v0/projects/{project_uuid}/human-responses",
1435
+ method="GET",
1436
+ )
1437
+ self._resolve_server_response(response)
1438
+ response_json = response.json()
1439
+ human_responses = response_json.get("human_responses", [])
1440
+
1441
+ try:
1442
+ agent_list = AgentList()
1443
+
1444
+ for response in human_responses:
1445
+ response_uuid = response.get("response_uuid")
1446
+ if response_uuid is None:
1447
+ raise RuntimeError(
1448
+ "One of your responses is missing a unique identifier."
1449
+ )
1450
+
1451
+ response_dict = json.loads(response.get("response_json_string"))
1452
+
1453
+ a = Agent(name=response_uuid, instruction="")
1454
+
1455
+ def create_answer_function(response_data):
1456
+ def f(self, question, scenario):
1457
+ return response_data.get(question.question_name, None)
1458
+
1459
+ return f
1460
+
1461
+ a.add_direct_question_answering_method(
1462
+ create_answer_function(response_dict)
1463
+ )
1464
+ agent_list.append(a)
1465
+
1466
+ survey_json_string = response_json.get("survey_json_string")
1467
+ survey = Survey.from_dict(json.loads(survey_json_string))
1468
+
1469
+ model = Model("test")
1470
+ results = (
1471
+ survey.by(agent_list)
1472
+ .by(model)
1473
+ .run(
1474
+ cache=Cache(),
1475
+ disable_remote_cache=True,
1476
+ disable_remote_inference=True,
1477
+ print_exceptions=False,
1478
+ )
1479
+ )
1480
+ return results
1481
+ except Exception:
1482
+ human_response_scenarios = []
1483
+ for response in human_responses:
1484
+ response_uuid = response.get("response_uuid")
1485
+ if response_uuid is None:
1486
+ raise RuntimeError(
1487
+ "One of your responses is missing a unique identifier."
1488
+ )
1489
+
1490
+ response_dict = json.loads(response.get("response_json_string"))
1491
+ response_dict["agent_name"] = response_uuid
1492
+ scenario = Scenario(response_dict)
1493
+ human_response_scenarios.append(scenario)
1494
+ return ScenarioList(human_response_scenarios)
1495
+
1394
1496
  def __repr__(self):
1395
1497
  """Return a string representation of the client."""
1396
1498
  return f"Client(api_key='{self.api_key}', url='{self.url}')"
edsl/dataset/dataset.py CHANGED
@@ -93,6 +93,38 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
93
93
  """
94
94
  _, values = list(self.data[0].items())[0]
95
95
  return len(values)
96
+
97
+ def drop(self, field_name):
98
+ """
99
+ Returns a new Dataset with the specified field removed.
100
+
101
+ Args:
102
+ field_name (str): The name of the field to remove.
103
+
104
+ Returns:
105
+ Dataset: A new Dataset instance without the specified field.
106
+
107
+ Raises:
108
+ KeyError: If the field_name doesn't exist in the dataset.
109
+
110
+ Examples:
111
+ >>> from .dataset import Dataset
112
+ >>> d = Dataset([{'a': [1, 2, 3]}, {'b': [4, 5, 6]}])
113
+ >>> d.drop('a')
114
+ Dataset([{'b': [4, 5, 6]}])
115
+
116
+ >>> # Testing drop with nonexistent field raises DatasetKeyError - tested in unit tests
117
+ """
118
+ from .dataset import Dataset
119
+
120
+ # Check if field exists in the dataset
121
+ if field_name not in self.relevant_columns():
122
+ raise DatasetKeyError(f"Field '{field_name}' not found in dataset")
123
+
124
+ # Create a new dataset without the specified field
125
+ new_data = [entry for entry in self.data if field_name not in entry]
126
+ return Dataset(new_data)
127
+
96
128
 
97
129
  def tail(self, n: int = 5) -> Dataset:
98
130
  """Return the last n observations in the dataset.
@@ -1054,6 +1086,48 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
1054
1086
 
1055
1087
  return Dataset(new_data)
1056
1088
 
1089
+ def unique(self) -> "Dataset":
1090
+ """Return a new dataset with only unique observations.
1091
+
1092
+ Examples:
1093
+ >>> d = Dataset([{'a': [1, 2, 2, 3]}, {'b': [4, 5, 5, 6]}])
1094
+ >>> d.unique().data
1095
+ [{'a': [1, 2, 3]}, {'b': [4, 5, 6]}]
1096
+
1097
+ >>> d = Dataset([{'x': ['a', 'a', 'b']}, {'y': [1, 1, 2]}])
1098
+ >>> d.unique().data
1099
+ [{'x': ['a', 'b']}, {'y': [1, 2]}]
1100
+ """
1101
+ # Get all column names and values
1102
+ headers, data = self._tabular()
1103
+
1104
+ # Create a list of unique rows
1105
+ unique_rows = []
1106
+ seen = set()
1107
+
1108
+ for row in data:
1109
+ # Convert the row to a hashable representation for comparison
1110
+ # We need to handle potential unhashable types
1111
+ try:
1112
+ row_key = tuple(map(lambda x: str(x) if isinstance(x, (list, dict)) else x, row))
1113
+ if row_key not in seen:
1114
+ seen.add(row_key)
1115
+ unique_rows.append(row)
1116
+ except:
1117
+ # Fallback for complex objects: compare based on string representation
1118
+ row_str = str(row)
1119
+ if row_str not in seen:
1120
+ seen.add(row_str)
1121
+ unique_rows.append(row)
1122
+
1123
+ # Create a new dataset with unique combinations
1124
+ new_data = []
1125
+ for i, header in enumerate(headers):
1126
+ values = [row[i] for row in unique_rows]
1127
+ new_data.append({header: values})
1128
+
1129
+ return Dataset(new_data)
1130
+
1057
1131
 
1058
1132
  if __name__ == "__main__":
1059
1133
  import doctest
@@ -357,7 +357,7 @@ class DataOperationsBase:
357
357
  4
358
358
  >>> engine = Results.example()._db(shape = "long")
359
359
  >>> len(engine.execute(text("SELECT * FROM self")).fetchall())
360
- 204
360
+ 212
361
361
  """
362
362
  # Import needed for database connection
363
363
  from sqlalchemy import create_engine
@@ -442,7 +442,7 @@ class DataOperationsBase:
442
442
 
443
443
  # Using long format
444
444
  >>> len(r.sql("SELECT * FROM self", shape="long"))
445
- 204
445
+ 212
446
446
  """
447
447
  import pandas as pd
448
448
 
@@ -1070,7 +1070,6 @@ class DataOperationsBase:
1070
1070
  - All dictionaries in the field must have compatible structures
1071
1071
  - If a dictionary is missing a key, the corresponding value will be None
1072
1072
  - Non-dictionary values in the field will cause a warning
1073
-
1074
1073
  Examples:
1075
1074
  >>> from edsl.dataset import Dataset
1076
1075
 
@@ -1086,48 +1085,85 @@ class DataOperationsBase:
1086
1085
  >>> d = Dataset([{'a': [{'a': 1, 'b': 2}]}, {'c': [5]}])
1087
1086
  >>> d.flatten('a', keep_original=True)
1088
1087
  Dataset([{'a': [{'a': 1, 'b': 2}]}, {'c': [5]}, {'a.a': [1]}, {'a.b': [2]}])
1088
+
1089
+ # Can also use unambiguous unprefixed field name
1090
+ >>> result = Dataset([{'answer.pros_cons': [{'pros': ['Safety'], 'cons': ['Cost']}]}]).flatten('pros_cons')
1091
+ >>> sorted(result.keys()) == ['answer.pros_cons.cons', 'answer.pros_cons.pros']
1092
+ True
1093
+ >>> sorted(result.to_dicts()[0].items()) == sorted({'cons': ['Cost'], 'pros': ['Safety']}.items())
1094
+ True
1089
1095
  """
1090
1096
  from ..dataset import Dataset
1091
1097
 
1092
1098
  # Ensure the dataset isn't empty
1093
1099
  if not self.data:
1094
1100
  return self.copy()
1095
-
1096
- # Find all columns that contain the field
1097
- matching_entries = []
1098
- for entry in self.data:
1099
- col_name = next(iter(entry.keys()))
1100
- if field == col_name or (
1101
- "." in col_name
1102
- and (col_name.endswith("." + field) or col_name.startswith(field + "."))
1103
- ):
1104
- matching_entries.append(entry)
1105
-
1106
- # Check if the field is ambiguous
1107
- if len(matching_entries) > 1:
1108
- matching_cols = [next(iter(entry.keys())) for entry in matching_entries]
1109
- from .exceptions import DatasetValueError
1110
-
1111
- raise DatasetValueError(
1112
- f"Ambiguous field name '{field}'. It matches multiple columns: {matching_cols}. "
1113
- f"Please specify the full column name to flatten."
1114
- )
1115
-
1116
- # Get the number of observations
1117
- num_observations = self.num_observations()
1118
-
1119
- # Find the column to flatten
1101
+
1102
+ # First try direct match with the exact field name
1120
1103
  field_entry = None
1121
1104
  for entry in self.data:
1122
- if field in entry:
1105
+ col_name = next(iter(entry.keys()))
1106
+ if field == col_name:
1123
1107
  field_entry = entry
1124
1108
  break
1109
+
1110
+ # If not found, try to match by unprefixed name
1111
+ if field_entry is None:
1112
+ # Find any columns that have field as their unprefixed name
1113
+ candidates = []
1114
+ for entry in self.data:
1115
+ col_name = next(iter(entry.keys()))
1116
+ if '.' in col_name:
1117
+ prefix, col_field = col_name.split('.', 1)
1118
+ if col_field == field:
1119
+ candidates.append(entry)
1120
+
1121
+ # If we found exactly one match by unprefixed name, use it
1122
+ if len(candidates) == 1:
1123
+ field_entry = candidates[0]
1124
+ # If we found multiple matches, it's ambiguous
1125
+ elif len(candidates) > 1:
1126
+ matching_cols = [next(iter(entry.keys())) for entry in candidates]
1127
+ from .exceptions import DatasetValueError
1128
+ raise DatasetValueError(
1129
+ f"Ambiguous field name '{field}'. It matches multiple columns: {matching_cols}. "
1130
+ f"Please specify the full column name to flatten."
1131
+ )
1132
+ # If no candidates by unprefixed name, check partial matches
1133
+ else:
1134
+ partial_matches = []
1135
+ for entry in self.data:
1136
+ col_name = next(iter(entry.keys()))
1137
+ if '.' in col_name and (
1138
+ col_name.endswith('.' + field) or
1139
+ col_name.startswith(field + '.')
1140
+ ):
1141
+ partial_matches.append(entry)
1142
+
1143
+ # If we found exactly one partial match, use it
1144
+ if len(partial_matches) == 1:
1145
+ field_entry = partial_matches[0]
1146
+ # If we found multiple partial matches, it's ambiguous
1147
+ elif len(partial_matches) > 1:
1148
+ matching_cols = [next(iter(entry.keys())) for entry in partial_matches]
1149
+ from .exceptions import DatasetValueError
1150
+ raise DatasetValueError(
1151
+ f"Ambiguous field name '{field}'. It matches multiple columns: {matching_cols}. "
1152
+ f"Please specify the full column name to flatten."
1153
+ )
1154
+
1155
+ # Get the number of observations
1156
+ num_observations = self.num_observations()
1125
1157
 
1158
+ # If we still haven't found the field, it's not in the dataset
1126
1159
  if field_entry is None:
1127
1160
  warnings.warn(
1128
1161
  f"Field '{field}' not found in dataset, returning original dataset"
1129
1162
  )
1130
1163
  return self.copy()
1164
+
1165
+ # Get the actual field name as it appears in the data
1166
+ actual_field = next(iter(field_entry.keys()))
1131
1167
 
1132
1168
  # Create new dictionary for flattened data
1133
1169
  flattened_data = []
@@ -1135,14 +1171,14 @@ class DataOperationsBase:
1135
1171
  # Copy all existing columns except the one we're flattening (if keep_original is False)
1136
1172
  for entry in self.data:
1137
1173
  col_name = next(iter(entry.keys()))
1138
- if col_name != field or keep_original:
1174
+ if col_name != actual_field or keep_original:
1139
1175
  flattened_data.append(entry.copy())
1140
1176
 
1141
1177
  # Get field data and make sure it's valid
1142
- field_values = field_entry[field]
1178
+ field_values = field_entry[actual_field]
1143
1179
  if not all(isinstance(item, dict) for item in field_values if item is not None):
1144
1180
  warnings.warn(
1145
- f"Field '{field}' contains non-dictionary values that cannot be flattened"
1181
+ f"Field '{actual_field}' contains non-dictionary values that cannot be flattened"
1146
1182
  )
1147
1183
  return self.copy()
1148
1184
 
@@ -1162,7 +1198,7 @@ class DataOperationsBase:
1162
1198
  new_values.append(value)
1163
1199
 
1164
1200
  # Add this as a new column
1165
- flattened_data.append({f"{field}.{key}": new_values})
1201
+ flattened_data.append({f"{actual_field}.{key}": new_values})
1166
1202
 
1167
1203
  # Return a new Dataset with the flattened data
1168
1204
  return Dataset(flattened_data)
@@ -1244,37 +1280,6 @@ class DataOperationsBase:
1244
1280
 
1245
1281
  return result
1246
1282
 
1247
- def drop(self, field_name):
1248
- """
1249
- Returns a new Dataset with the specified field removed.
1250
-
1251
- Args:
1252
- field_name (str): The name of the field to remove.
1253
-
1254
- Returns:
1255
- Dataset: A new Dataset instance without the specified field.
1256
-
1257
- Raises:
1258
- KeyError: If the field_name doesn't exist in the dataset.
1259
-
1260
- Examples:
1261
- >>> from .dataset import Dataset
1262
- >>> d = Dataset([{'a': [1, 2, 3]}, {'b': [4, 5, 6]}])
1263
- >>> d.drop('a')
1264
- Dataset([{'b': [4, 5, 6]}])
1265
-
1266
- >>> # Testing drop with nonexistent field raises DatasetKeyError - tested in unit tests
1267
- """
1268
- from .dataset import Dataset
1269
-
1270
- # Check if field exists in the dataset
1271
- if field_name not in self.relevant_columns():
1272
- raise DatasetKeyError(f"Field '{field_name}' not found in dataset")
1273
-
1274
- # Create a new dataset without the specified field
1275
- new_data = [entry for entry in self.data if field_name not in entry]
1276
- return Dataset(new_data)
1277
-
1278
1283
  def remove_prefix(self):
1279
1284
  """Returns a new Dataset with the prefix removed from all column names.
1280
1285
 
@@ -8,6 +8,7 @@ from .groq_service import GroqService
8
8
  from .mistral_ai_service import MistralAIService
9
9
  from .ollama_service import OllamaService
10
10
  from .open_ai_service import OpenAIService
11
+ from .open_ai_service_v2 import OpenAIServiceV2
11
12
  from .perplexity_service import PerplexityService
12
13
  from .test_service import TestService
13
14
  from .together_ai_service import TogetherAIService
@@ -24,8 +25,9 @@ __all__ = [
24
25
  "MistralAIService",
25
26
  "OllamaService",
26
27
  "OpenAIService",
28
+ "OpenAIServiceV2",
27
29
  "PerplexityService",
28
30
  "TestService",
29
31
  "TogetherAIService",
30
32
  "XAIService",
31
- ]
33
+ ]