lionagi 0.15.4__py3-none-any.whl → 0.15.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -16,17 +16,23 @@ from collections.abc import (
16
16
  )
17
17
  from functools import wraps
18
18
  from pathlib import Path
19
- from typing import Any, Generic, TypeVar
19
+ from typing import Any, ClassVar, Generic, Literal, TypeVar
20
20
 
21
21
  import pandas as pd
22
- from pydantic import Field
22
+ from pydantic import Field, field_serializer, field_validator, model_validator
23
23
  from pydantic.fields import FieldInfo
24
24
  from pydapter import Adaptable, AsyncAdaptable
25
- from typing_extensions import Self, override
25
+ from typing_extensions import Self, deprecated, override
26
26
 
27
- from lionagi._errors import ItemExistsError, ItemNotFoundError
27
+ from lionagi._errors import ItemExistsError, ItemNotFoundError, ValidationError
28
28
  from lionagi.libs.concurrency import Lock as ConcurrencyLock
29
- from lionagi.utils import UNDEFINED, is_same_dtype, to_list
29
+ from lionagi.utils import (
30
+ UNDEFINED,
31
+ is_same_dtype,
32
+ is_union_type,
33
+ to_list,
34
+ union_members,
35
+ )
30
36
 
31
37
  from .._concepts import Observable
32
38
  from .element import ID, Collective, E, Element, IDType, validate_order
@@ -35,6 +41,8 @@ from .progression import Progression
35
41
  D = TypeVar("D")
36
42
  T = TypeVar("T", bound=E)
37
43
 
44
+ _ADAPATER_REGISTERED = False
45
+
38
46
 
39
47
  def synchronized(func: Callable):
40
48
  @wraps(func)
@@ -54,6 +62,125 @@ def async_synchronized(func: Callable):
54
62
  return wrapper
55
63
 
56
64
 
65
+ def _validate_item_type(value, /) -> set[type[T]] | None:
66
+ if value is None:
67
+ return None
68
+
69
+ value = to_list_type(value)
70
+ out = set()
71
+
72
+ from lionagi.utils import import_module
73
+
74
+ for i in value:
75
+ subcls = i
76
+ if isinstance(i, str):
77
+ try:
78
+ mod, imp = i.rsplit(".", 1)
79
+ subcls = import_module(mod, import_name=imp)
80
+ except Exception as e:
81
+ raise ValidationError.from_value(
82
+ i,
83
+ expected="A subclass of Observable.",
84
+ cause=e,
85
+ ) from e
86
+ if isinstance(subcls, type):
87
+ if is_union_type(subcls):
88
+ members = union_members(subcls)
89
+ for m in members:
90
+ if not issubclass(m, Observable):
91
+ raise ValidationError.from_value(
92
+ m, expected="A subclass of Observable."
93
+ )
94
+ out.add(m)
95
+ elif not issubclass(subcls, Observable):
96
+ raise ValidationError.from_value(
97
+ subcls, expected="A subclass of Observable."
98
+ )
99
+ else:
100
+ out.add(subcls)
101
+ else:
102
+ raise ValidationError.from_value(
103
+ i, expected="A subclass of Observable."
104
+ )
105
+
106
+ if len(value) != len(set(value)):
107
+ raise ValidationError("Detected duplicated item types in item_type.")
108
+
109
+ if len(value) > 0:
110
+ return out
111
+
112
+
113
+ def _validate_progression(
114
+ value: Any, collections: dict[IDType, T], /
115
+ ) -> Progression:
116
+ if not value:
117
+ return Progression(order=list(collections.keys()))
118
+
119
+ prog = None
120
+ if isinstance(value, dict):
121
+ try:
122
+ prog = Progression.from_dict(value)
123
+ value = list(prog)
124
+ except Exception:
125
+ # If we can't create Progression from dict, try to extract order field
126
+ value = to_list_type(value.get("order", []))
127
+ elif isinstance(value, Progression):
128
+ prog = value
129
+ value = list(prog)
130
+ else:
131
+ value = to_list_type(value)
132
+
133
+ value_set = set(value)
134
+ if len(value_set) != len(value):
135
+ raise ValueError("There are duplicate elements in the order")
136
+ if len(value_set) != len(collections.keys()):
137
+ raise ValueError(
138
+ "The length of the order does not match the length of the pile"
139
+ )
140
+
141
+ for i in value_set:
142
+ if ID.get_id(i) not in collections.keys():
143
+ raise ValueError(
144
+ f"The order does not match the pile. {i} not found"
145
+ )
146
+ return prog or Progression(order=value)
147
+
148
+
149
+ def _validate_collections(
150
+ value: Any, item_type: set | None, strict_type: bool, /
151
+ ) -> dict[str, T]:
152
+ if not value:
153
+ return {}
154
+
155
+ value = to_list_type(value)
156
+
157
+ result = {}
158
+ for i in value:
159
+ if isinstance(i, dict):
160
+ i = Element.from_dict(i)
161
+
162
+ if item_type:
163
+ if strict_type:
164
+ if type(i) not in item_type:
165
+ raise ValidationError.from_value(
166
+ i,
167
+ expected=f"One of {item_type}, no subclasses allowed.",
168
+ )
169
+ else:
170
+ if not any(issubclass(type(i), t) for t in item_type):
171
+ raise ValidationError.from_value(
172
+ i,
173
+ expected=f"One of {item_type} or the subclasses",
174
+ )
175
+ else:
176
+ if not isinstance(i, Observable):
177
+ raise ValueError(f"Invalid pile item {i}")
178
+
179
+ result[i.id] = i
180
+
181
+ return result
182
+
183
+
57
184
  class Pile(Element, Collective[T], Generic[T], Adaptable, AsyncAdaptable):
58
185
  """Thread-safe async-compatible, ordered collection of elements.
59
186
 
@@ -79,7 +206,6 @@ class Pile(Element, Collective[T], Generic[T], Adaptable, AsyncAdaptable):
79
206
  progression: Progression = Field(
80
207
  default_factory=Progression,
81
208
  description="Progression specifying the order of items in the pile.",
82
- exclude=True,
83
209
  )
84
210
  strict_type: bool = Field(
85
211
  default=False,
@@ -87,6 +213,13 @@ class Pile(Element, Collective[T], Generic[T], Adaptable, AsyncAdaptable):
87
213
  frozen=True,
88
214
  )
89
215
 
216
+ _EXTRA_FIELDS: ClassVar[set[str]] = {
217
+ "collections",
218
+ "item_type",
219
+ "progression",
220
+ "strict_type",
221
+ }
222
+
90
223
  def __pydantic_extra__(self) -> dict[str, FieldInfo]:
91
224
  return {
92
225
  "_lock": Field(default_factory=threading.Lock),
@@ -96,6 +229,29 @@ class Pile(Element, Collective[T], Generic[T], Adaptable, AsyncAdaptable):
96
229
  def __pydantic_private__(self) -> dict[str, FieldInfo]:
97
230
  return self.__pydantic_extra__()
98
231
 
232
+ @classmethod
233
+ def _validate_before(cls, data: dict[str, Any]) -> dict[str, Any]:
234
+ item_type = _validate_item_type(data.get("item_type"))
235
+ strict_type = data.get("strict_type", False)
236
+ collections = _validate_collections(
237
+ data.get("collections"), item_type, strict_type
238
+ )
239
+ progression = None
240
+ if "order" in data:
241
+ progression = _validate_progression(data["order"], collections)
242
+ else:
243
+ progression = _validate_progression(
244
+ data.get("progression"), collections
245
+ )
246
+
247
+ return {
248
+ "collections": collections,
249
+ "item_type": item_type,
250
+ "progression": progression,
251
+ "strict_type": strict_type,
252
+ **{k: v for k, v in data.items() if k not in cls._EXTRA_FIELDS},
253
+ }
254
+
99
255
  @override
100
256
  def __init__(
101
257
  self,
@@ -113,22 +269,33 @@ class Pile(Element, Collective[T], Generic[T], Adaptable, AsyncAdaptable):
113
269
  order: Initial order of items (as Progression).
114
270
  strict_type: If True, enforce strict type checking.
115
271
  """
116
- _config = {}
117
- if "id" in kwargs:
118
- _config["id"] = kwargs["id"]
119
- if "created_at" in kwargs:
120
- _config["created_at"] = kwargs["created_at"]
121
-
122
- super().__init__(strict_type=strict_type, **_config)
123
- self.item_type = self._validate_item_type(item_type)
124
-
125
- if isinstance(collections, list) and is_same_dtype(collections, dict):
126
- collections = [Element.from_dict(i) for i in collections]
127
-
128
- self.collections = self._validate_pile(
129
- collections or kwargs.get("collections", {})
272
+ data = Pile._validate_before(
273
+ {
274
+ "collections": collections,
275
+ "item_type": item_type,
276
+ "progression": order,
277
+ "strict_type": strict_type,
278
+ **kwargs,
279
+ }
130
280
  )
131
- self.progression = self._validate_order(order)
281
+ super().__init__(**data)
282
+
283
+ @field_serializer("collections")
284
+ def _serialize_collections(
285
+ self, v: dict[IDType, T]
286
+ ) -> list[dict[str, Any]]:
287
+ return [i.to_dict() for i in v.values()]
288
+
289
+ @field_serializer("progression")
290
+ def _serialize_progression(self, v: Progression) -> dict[str, Any]:
291
+ return v.to_dict()
292
+
293
+ @field_serializer("item_type")
294
+ def _serialize_item_type(self, v: set[type[T]] | None) -> list[str] | None:
295
+ """Serialize item_type to a list of class names."""
296
+ if v is None:
297
+ return None
298
+ return [c.class_name(full=True) for c in v]
132
299
 
133
300
  # Sync Interface methods
134
301
  @override
@@ -147,11 +314,9 @@ class Pile(Element, Collective[T], Generic[T], Adaptable, AsyncAdaptable):
147
314
  A new Pile instance created from the provided data.
148
315
 
149
316
  Raises:
150
- ValueError: If the dictionary format is invalid.
317
+ ValidationError: If the dictionary format is invalid.
151
318
  """
152
- items = data.pop("collections", [])
153
- items = [Element.from_dict(i) for i in items]
154
- return cls(collections=items, **data)
319
+ return cls(**data)
155
320
 
156
321
  def __setitem__(
157
322
  self,
@@ -204,22 +369,32 @@ class Pile(Element, Collective[T], Generic[T], Adaptable, AsyncAdaptable):
204
369
  Raises:
205
370
  ValueError: If item not found.
206
371
  """
207
- self._remove(item)
372
+ if isinstance(item, int | slice):
373
+ raise TypeError(
374
+ "Invalid item type for remove, should be ID or Item(s)"
375
+ )
376
+ if item in self:
377
+ self.pop(item)
378
+ return
379
+ raise ItemNotFoundError(f"{item}")
208
380
 
209
- def include(
210
- self,
211
- item: ID.ItemSeq | ID.Item,
212
- /,
213
- ) -> None:
381
+ def include(self, item: ID.ItemSeq | ID.Item, /) -> None:
214
382
  """Include item(s) if not present.
215
383
 
216
384
  Args:
217
385
  item: Item(s) to include.
218
-
219
- Raises:
220
- TypeError: If item type not allowed.
221
386
  """
222
- self._include(item)
387
+ item_dict = _validate_collections(
388
+ item, self.item_type, self.strict_type
389
+ )
390
+
391
+ item_order = []
392
+ for i in item_dict.keys():
393
+ if i not in self.progression:
394
+ item_order.append(i)
395
+
396
+ self.progression.append(item_order)
397
+ self.collections.update(item_dict)
223
398
 
224
399
  def exclude(
225
400
  self,
@@ -231,7 +406,13 @@ class Pile(Element, Collective[T], Generic[T], Adaptable, AsyncAdaptable):
231
406
  Args:
232
407
  item: Item(s) to exclude.
233
408
  """
234
- self._exclude(item)
409
+ item = to_list_type(item)
410
+ exclude_list = []
411
+ for i in item:
412
+ if i in self:
413
+ exclude_list.append(i)
414
+ if exclude_list:
415
+ self.pop(exclude_list)
235
416
 
236
417
  @synchronized
237
418
  def clear(self) -> None:
@@ -251,7 +432,12 @@ class Pile(Element, Collective[T], Generic[T], Adaptable, AsyncAdaptable):
251
432
  Raises:
252
433
  TypeError: If item types not allowed.
253
434
  """
254
- self._update(other)
435
+ others = _validate_collections(other, self.item_type, self.strict_type)
436
+ for i in others.keys():
437
+ if i in self.collections:
438
+ self.collections[i] = others[i]
439
+ else:
440
+ self.include(others[i])
255
441
 
256
442
  @synchronized
257
443
  def insert(self, index: int, item: T, /) -> None:
@@ -370,7 +556,9 @@ class Pile(Element, Collective[T], Generic[T], Adaptable, AsyncAdaptable):
370
556
  raise TypeError(
371
557
  f"Invalid type for Pile operation. expected <Pile>, got {type(other)}"
372
558
  )
373
- other = self._validate_pile(list(other))
559
+ other = _validate_collections(
560
+ list(other), self.item_type, self.strict_type
561
+ )
374
562
  self.include(other)
375
563
  return self
376
564
 
@@ -526,7 +714,7 @@ class Pile(Element, Collective[T], Generic[T], Adaptable, AsyncAdaptable):
526
714
  /,
527
715
  ) -> None:
528
716
  """Async remove item."""
529
- self._remove(item)
717
+ self.remove(item)
530
718
 
531
719
  @async_synchronized
532
720
  async def ainclude(
@@ -535,7 +723,7 @@ class Pile(Element, Collective[T], Generic[T], Adaptable, AsyncAdaptable):
535
723
  /,
536
724
  ) -> None:
537
725
  """Async include item(s)."""
538
- self._include(item)
726
+ self.include(item)
539
727
  if item not in self:
540
728
  raise TypeError(f"Item {item} is not of allowed types")
541
729
 
@@ -546,7 +734,7 @@ class Pile(Element, Collective[T], Generic[T], Adaptable, AsyncAdaptable):
546
734
  /,
547
735
  ) -> None:
548
736
  """Async exclude item(s)."""
549
- self._exclude(item)
737
+ self.exclude(item)
550
738
 
551
739
  @async_synchronized
552
740
  async def aclear(self) -> None:
@@ -560,7 +748,7 @@ class Pile(Element, Collective[T], Generic[T], Adaptable, AsyncAdaptable):
560
748
  /,
561
749
  ) -> None:
562
750
  """Async update with items."""
563
- self._update(other)
751
+ self.update(other)
564
752
 
565
753
  @async_synchronized
566
754
  async def aget(
@@ -635,7 +823,9 @@ class Pile(Element, Collective[T], Generic[T], Adaptable, AsyncAdaptable):
635
823
  key: ID.Ref | ID.RefSeq | int | slice,
636
824
  item: ID.Item | ID.ItemSeq,
637
825
  ) -> None:
638
- item_dict = self._validate_pile(item)
826
+ item_dict = _validate_collections(
827
+ item, self.item_type, self.strict_type
828
+ )
639
829
 
640
830
  item_order = []
641
831
  for i in item_dict.keys():
@@ -745,128 +935,14 @@ class Pile(Element, Collective[T], Generic[T], Adaptable, AsyncAdaptable):
745
935
  raise ItemNotFoundError(f"Item not found. Error: {e}")
746
936
  return default
747
937
 
748
- def _remove(self, item: ID.Ref | ID.RefSeq):
749
- if isinstance(item, int | slice):
750
- raise TypeError(
751
- "Invalid item type for remove, should be ID or Item(s)"
752
- )
753
- if item in self:
754
- self.pop(item)
755
- return
756
- raise ItemNotFoundError(f"{item}")
757
-
758
- def _include(self, item: ID.ItemSeq | ID.Item):
759
- item_dict = self._validate_pile(item)
760
-
761
- item_order = []
762
- for i in item_dict.keys():
763
- if i not in self.progression:
764
- item_order.append(i)
765
-
766
- self.progression.append(item_order)
767
- self.collections.update(item_dict)
768
-
769
- def _exclude(self, item: ID.Ref | ID.RefSeq):
770
- item = to_list_type(item)
771
- exclude_list = []
772
- for i in item:
773
- if i in self:
774
- exclude_list.append(i)
775
- if exclude_list:
776
- self.pop(exclude_list)
777
-
778
938
  def _clear(self) -> None:
779
939
  self.collections.clear()
780
940
  self.progression.clear()
781
941
 
782
- def _update(self, other: ID.ItemSeq | ID.Item):
783
- others = self._validate_pile(other)
784
- for i in others.keys():
785
- if i in self.collections:
786
- self.collections[i] = others[i]
787
- else:
788
- self.include(others[i])
789
-
790
- def _validate_item_type(self, value) -> set[type[T]] | None:
791
- if value is None:
792
- return None
793
-
794
- value = to_list_type(value)
795
-
796
- for i in value:
797
- if not issubclass(i, Observable):
798
- raise TypeError(
799
- f"Item type must be a subclass of Observable. Got {i}"
800
- )
801
-
802
- if len(value) != len(set(value)):
803
- raise ValueError(
804
- "Detected duplicated item types in item_type.",
805
- )
806
-
807
- if len(value) > 0:
808
- return set(value)
809
-
810
- def _validate_pile(self, value: Any) -> dict[str, T]:
811
- if not value:
812
- return {}
813
-
814
- value = to_list_type(value)
815
-
816
- result = {}
817
- for i in value:
818
- if isinstance(i, dict):
819
- i = Element.from_dict(i)
820
-
821
- if self.item_type:
822
- if self.strict_type:
823
- if type(i) not in self.item_type:
824
- raise TypeError(
825
- f"Invalid item type in pile. Expected {self.item_type}",
826
- )
827
- else:
828
- if not any(issubclass(type(i), t) for t in self.item_type):
829
- raise TypeError(
830
- "Invalid item type in pile. Expected "
831
- f"{self.item_type} or the subclasses",
832
- )
833
- else:
834
- if not isinstance(i, Observable):
835
- raise ValueError(f"Invalid pile item {i}")
836
-
837
- result[i.id] = i
838
-
839
- return result
840
-
841
- def _validate_order(self, value: Any) -> Progression:
842
- if not value:
843
- return self.progression.__class__(
844
- order=list(self.collections.keys())
845
- )
846
-
847
- if isinstance(value, Progression):
848
- value = list(value)
849
- else:
850
- value = to_list_type(value)
851
-
852
- value_set = set(value)
853
- if len(value_set) != len(value):
854
- raise ValueError("There are duplicate elements in the order")
855
- if len(value_set) != len(self.collections.keys()):
856
- raise ValueError(
857
- "The length of the order does not match the length of the pile"
858
- )
859
-
860
- for i in value_set:
861
- if ID.get_id(i) not in self.collections.keys():
862
- raise ValueError(
863
- f"The order does not match the pile. {i} not found"
864
- )
865
-
866
- return self.progression.__class__(order=value)
867
-
868
942
  def _insert(self, index: int, item: ID.Item):
869
- item_dict = self._validate_pile(item)
943
+ item_dict = _validate_collections(
944
+ item, self.item_type, self.strict_type
945
+ )
870
946
 
871
947
  item_order = []
872
948
  for i in item_dict.keys():
@@ -876,24 +952,6 @@ class Pile(Element, Collective[T], Generic[T], Adaptable, AsyncAdaptable):
876
952
  self.progression.insert(index, item_order)
877
953
  self.collections.update(item_dict)
878
954
 
879
- def to_dict(self) -> dict[str, Any]:
880
- """Convert pile to dictionary, properly handling collections."""
881
- # Get base dict from parent class
882
- dict_ = super().to_dict()
883
-
884
- # Manually serialize collections
885
- collections_list = []
886
- for item in self.collections.values():
887
- if hasattr(item, "to_dict"):
888
- collections_list.append(item.to_dict())
889
- elif hasattr(item, "model_dump"):
890
- collections_list.append(item.model_dump())
891
- else:
892
- collections_list.append(str(item))
893
-
894
- dict_["collections"] = collections_list
895
- return dict_
896
-
897
955
  class AsyncPileIterator:
898
956
  def __init__(self, pile: Pile):
899
957
  self.pile = pile
@@ -930,58 +988,86 @@ class Pile(Element, Collective[T], Generic[T], Adaptable, AsyncAdaptable):
930
988
  is_same_dtype(self.collections.values())
931
989
  )
932
990
 
933
- def adapt_to(self, obj_key: str, many=False, **kwargs: Any) -> Any:
934
- kwargs["adapt_meth"] = "to_dict"
935
- return super().adapt_to(obj_key, many=many, **kwargs)
991
+ def adapt_to(self, obj_key: str, many=False, **kw: Any) -> Any:
992
+ """Adapt to another format.
993
+
994
+ Args:
995
+ obj_key: Key indicating the format (e.g., 'json', 'csv').
996
+ many: If True, interpret to receive list of items in the collection.
997
+ **kw: Additional keyword arguments for adaptation.
998
+
999
+ Example:
1000
+ >>> str_ = pile.adapt_to('json')
1001
+ >>> df = pile.adapt_to('pd.DataFrame', many=True)
1002
+ >>> csv_str = pile.adapt_to('csv', many=True)
1003
+
1004
+ Pile built-in with `json`, `csv`, `pd.DataFrame` adapters. You can add more
1005
+ from pydapter, such as `qdrant`, `neo4j`, `postgres`, etc.
1006
+ please visit https://khive-ai.github.io/pydapter/ for more details.
1007
+ """
1008
+ kw["adapt_meth"] = "to_dict"
1009
+ return super().adapt_to(obj_key=obj_key, many=many, **kw)
936
1010
 
937
1011
  @classmethod
938
- def adapt_from(cls, obj: Any, obj_key: str, many=False, **kwargs: Any):
939
- """Create from another format."""
940
- kwargs["adapt_meth"] = "from_dict"
941
- return super().adapt_from(obj, obj_key, many=many, **kwargs)
1012
+ def adapt_from(cls, obj: Any, obj_key: str, many=False, **kw: Any):
1013
+ """Create from another format.
1014
+
1015
+ Args:
1016
+ obj: Object to adapt from.
1017
+ obj_key: Key indicating the format (e.g., 'json', 'csv').
1018
+ many: If True, interpret to receive list of items in the collection.
1019
+ **kw: Additional keyword arguments for adaptation.
1020
+
1021
+ Example:
1022
+ >>> pile = Pile.adapt_from(str_, 'json')
1023
+ >>> pile = Pile.adapt_from(df, 'pd.DataFrame', many=True)
1024
+ Pile built-in with `json`, `csv`, `pd.DataFrame` adapters. You can add more
1025
+ from pydapter, such as `qdrant`, `neo4j`, `postgres`, etc.
1026
+ please visit https://khive-ai.github.io/pydapter/ for more details.
1027
+ """
1028
+ kw["adapt_meth"] = "from_dict"
1029
+ return super().adapt_from(obj, obj_key, many=many, **kw)
942
1030
 
943
- async def adapt_to_async(
944
- self, obj_key: str, many=False, **kwargs: Any
945
- ) -> Any:
946
- kwargs["adapt_meth"] = "to_dict"
947
- return await super().adapt_to_async(obj_key, many=many, **kwargs)
1031
+ async def adapt_to_async(self, obj_key: str, many=False, **kw: Any) -> Any:
1032
+ """Asynchronously adapt to another format."""
1033
+ kw["adapt_meth"] = "to_dict"
1034
+ return await super().adapt_to_async(obj_key=obj_key, many=many, **kw)
948
1035
 
949
1036
  @classmethod
950
1037
  async def adapt_from_async(
951
- cls, obj: Any, obj_key: str, many=False, **kwargs: Any
1038
+ cls, obj: Any, obj_key: str, many=False, **kw: Any
952
1039
  ):
953
- kwargs["adapt_meth"] = "from_dict"
954
- return await super().adapt_from_async(
955
- obj, obj_key, many=many, **kwargs
956
- )
1040
+ """Asynchronously create from another format."""
1041
+ kw["adapt_meth"] = "from_dict"
1042
+ return await super().adapt_from_async(obj, obj_key, many=many, **kw)
957
1043
 
958
1044
  def to_df(
959
- self,
960
- columns: list[str] | None = None,
961
- **kwargs: Any,
1045
+ self, columns: list[str] | None = None, **kw: Any
962
1046
  ) -> pd.DataFrame:
963
1047
  """Convert to DataFrame."""
964
1048
  from pydapter.extras.pandas_ import DataFrameAdapter
965
1049
 
966
1050
  df = DataFrameAdapter.to_obj(
967
- list(self.collections.values()), adapt_meth="to_dict", **kwargs
1051
+ list(self.collections.values()), adapt_meth="to_dict", **kw
968
1052
  )
969
1053
  if columns:
970
1054
  return df[columns]
971
1055
  return df
972
1056
 
973
- def to_csv_file(self, fp: str | Path, **kwargs: Any) -> None:
1057
+ @deprecated(
1058
+ "to_csv_file is deprecated, use `pile.dump(fp, 'csv')` instead"
1059
+ )
1060
+ def to_csv_file(self, fp: str | Path, **kw: Any) -> None:
974
1061
  """Save to CSV file."""
975
- from pydapter.adapters import CsvAdapter
976
-
977
- csv_str = CsvAdapter.to_obj(
978
- list(self.collections.values()), adapt_meth="to_dict", **kwargs
979
- )
1062
+ csv_str = self.adapt_to("csv", many=True, **kw)
980
1063
  with open(fp, "w") as f:
981
1064
  f.write(csv_str)
982
1065
 
1066
+ @deprecated(
1067
+ "to_json_file is deprecated, use `pile.dump(fp, 'json')` instead"
1068
+ )
983
1069
  def to_json_file(
984
- self, fp: str | Path, mode: str = "w", many: bool = False, **kwargs
1070
+ self, fp: str | Path, mode: str = "w", many: bool = False, **kw
985
1071
  ):
986
1072
  """Export collection to JSON file.
987
1073
 
@@ -991,14 +1077,61 @@ class Pile(Element, Collective[T], Generic[T], Adaptable, AsyncAdaptable):
991
1077
  mode: File mode ('w' for write, 'a' for append).
992
1078
  **kwargs: Additional arguments for json.dump() or DataFrame.to_json().
993
1079
  """
994
- from pydapter.adapters import JsonAdapter
995
-
996
- json_str = JsonAdapter.to_obj(
997
- self, many=many, adapt_meth="to_dict", **kwargs
998
- )
1080
+ json_str = self.adapt_to("json", many=many, **kw)
999
1081
  with open(fp, mode) as f:
1000
1082
  f.write(json_str)
1001
1083
 
1084
+ def dump(
1085
+ self,
1086
+ fp: str | Path | None,
1087
+ obj_key: Literal["json", "csv", "parquet"] = "json",
1088
+ *,
1089
+ mode: Literal["w", "a"] = "w",
1090
+ clear=False,
1091
+ **kw,
1092
+ ) -> None:
1093
+ """Export collection to file in specified format.
1094
+
1095
+ Args:
1096
+ fp: File path or buffer to write to. If None, returns string.
1097
+ Cannot be None if obj_key is 'parquet'.
1098
+ obj_key: Format to export ('json', 'csv', 'parquet').
1099
+ mode: File mode ('w' for write, 'a' for append).
1100
+ clear: If True, clear the collection after export.
1101
+ **kw: Additional arguments for the export method, pandas kwargs
1102
+ """
1103
+ df = self.to_df()
1104
+ match obj_key:
1105
+ case "parquet":
1106
+ df.to_parquet(fp, engine="pyarrow", index=False, **kw)
1107
+ case "json":
1108
+ out = df.to_json(
1109
+ fp, orient="records", lines=True, mode=mode, **kw
1110
+ )
1111
+ return out if out is not None else None
1112
+ case "csv":
1113
+ out = df.to_csv(fp, index=False, mode=mode, **kw)
1114
+ return out if out is not None else None
1115
+ case _:
1116
+ raise ValueError(
1117
+ f"Unsupported obj_key: {obj_key}. Supported keys are 'json', 'csv', 'parquet'."
1118
+ )
1119
+
1120
+ if clear:
1121
+ self.clear()
1122
+
1123
+ @async_synchronized
1124
+ async def adump(
1125
+ self,
1126
+ fp: str | Path,
1127
+ *,
1128
+ obj_key: Literal["json", "csv", "parquet"] = "json",
1129
+ mode: Literal["w", "a"] = "w",
1130
+ clear=False,
1131
+ **kw,
1132
+ ) -> None:
1133
+ return self.dump(fp, obj_key=obj_key, mode=mode, clear=clear, **kw)
1134
+
1002
1135
 
1003
1136
  def to_list_type(value: Any, /) -> list[Any]:
1004
1137
  """Convert input to a list format"""
@@ -1017,4 +1150,17 @@ def to_list_type(value: Any, /) -> list[Any]:
1017
1150
  return [value]
1018
1151
 
1019
1152
 
1153
+ if not _ADAPATER_REGISTERED:
1154
+ from pydapter.adapters import CsvAdapter, JsonAdapter
1155
+ from pydapter.extras.pandas_ import DataFrameAdapter
1156
+
1157
+ Pile.register_adapter(CsvAdapter)
1158
+ Pile.register_adapter(JsonAdapter)
1159
+ Pile.register_adapter(DataFrameAdapter)
1160
+ _ADAPATER_REGISTERED = True
1161
+
1162
+ Pile = Pile
1163
+
1164
+ __all__ = ("Pile",)
1165
+
1020
1166
  # File: lionagi/protocols/generic/pile.py