maxframe 1.0.0rc4__cp38-cp38-win32.whl → 1.1.1__cp38-cp38-win32.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of maxframe might be problematic. Click here for more details.

Files changed (88) hide show
  1. maxframe/_utils.cp38-win32.pyd +0 -0
  2. maxframe/config/__init__.py +1 -1
  3. maxframe/config/config.py +26 -0
  4. maxframe/config/tests/test_config.py +20 -1
  5. maxframe/conftest.py +17 -4
  6. maxframe/core/graph/core.cp38-win32.pyd +0 -0
  7. maxframe/core/operator/base.py +2 -0
  8. maxframe/dataframe/arithmetic/tests/test_arithmetic.py +17 -16
  9. maxframe/dataframe/core.py +24 -2
  10. maxframe/dataframe/datasource/read_odps_query.py +65 -35
  11. maxframe/dataframe/datasource/read_odps_table.py +4 -2
  12. maxframe/dataframe/datasource/tests/test_datasource.py +59 -7
  13. maxframe/dataframe/extensions/__init__.py +5 -0
  14. maxframe/dataframe/extensions/apply_chunk.py +649 -0
  15. maxframe/dataframe/extensions/flatjson.py +131 -0
  16. maxframe/dataframe/extensions/flatmap.py +28 -40
  17. maxframe/dataframe/extensions/reshuffle.py +1 -1
  18. maxframe/dataframe/extensions/tests/test_apply_chunk.py +186 -0
  19. maxframe/dataframe/extensions/tests/test_extensions.py +46 -2
  20. maxframe/dataframe/groupby/__init__.py +1 -0
  21. maxframe/dataframe/groupby/aggregation.py +1 -0
  22. maxframe/dataframe/groupby/apply.py +9 -1
  23. maxframe/dataframe/groupby/core.py +1 -1
  24. maxframe/dataframe/groupby/fill.py +4 -1
  25. maxframe/dataframe/groupby/getitem.py +6 -0
  26. maxframe/dataframe/groupby/tests/test_groupby.py +1 -1
  27. maxframe/dataframe/groupby/transform.py +8 -2
  28. maxframe/dataframe/indexing/loc.py +6 -4
  29. maxframe/dataframe/merge/__init__.py +9 -1
  30. maxframe/dataframe/merge/concat.py +41 -31
  31. maxframe/dataframe/merge/merge.py +1 -1
  32. maxframe/dataframe/merge/tests/test_merge.py +3 -1
  33. maxframe/dataframe/misc/apply.py +3 -0
  34. maxframe/dataframe/misc/drop_duplicates.py +5 -1
  35. maxframe/dataframe/misc/map.py +3 -1
  36. maxframe/dataframe/misc/tests/test_misc.py +24 -2
  37. maxframe/dataframe/misc/transform.py +22 -13
  38. maxframe/dataframe/reduction/__init__.py +3 -0
  39. maxframe/dataframe/reduction/aggregation.py +1 -0
  40. maxframe/dataframe/reduction/median.py +56 -0
  41. maxframe/dataframe/reduction/tests/test_reduction.py +17 -7
  42. maxframe/dataframe/statistics/quantile.py +8 -2
  43. maxframe/dataframe/statistics/tests/test_statistics.py +4 -4
  44. maxframe/dataframe/tests/test_utils.py +60 -0
  45. maxframe/dataframe/utils.py +110 -7
  46. maxframe/dataframe/window/expanding.py +5 -3
  47. maxframe/dataframe/window/tests/test_expanding.py +2 -2
  48. maxframe/io/objects/tests/test_object_io.py +39 -12
  49. maxframe/io/odpsio/__init__.py +1 -1
  50. maxframe/io/odpsio/arrow.py +51 -2
  51. maxframe/io/odpsio/schema.py +23 -5
  52. maxframe/io/odpsio/tableio.py +80 -124
  53. maxframe/io/odpsio/tests/test_schema.py +40 -0
  54. maxframe/io/odpsio/tests/test_tableio.py +5 -5
  55. maxframe/io/odpsio/tests/test_volumeio.py +35 -11
  56. maxframe/io/odpsio/volumeio.py +27 -3
  57. maxframe/learn/contrib/__init__.py +3 -2
  58. maxframe/learn/contrib/llm/__init__.py +16 -0
  59. maxframe/learn/contrib/llm/core.py +54 -0
  60. maxframe/learn/contrib/llm/models/__init__.py +14 -0
  61. maxframe/learn/contrib/llm/models/dashscope.py +73 -0
  62. maxframe/learn/contrib/llm/multi_modal.py +42 -0
  63. maxframe/learn/contrib/llm/text.py +42 -0
  64. maxframe/lib/mmh3.cp38-win32.pyd +0 -0
  65. maxframe/lib/sparse/tests/test_sparse.py +15 -15
  66. maxframe/opcodes.py +7 -1
  67. maxframe/serialization/core.cp38-win32.pyd +0 -0
  68. maxframe/serialization/core.pyx +13 -1
  69. maxframe/serialization/pandas.py +50 -20
  70. maxframe/serialization/serializables/core.py +70 -15
  71. maxframe/serialization/serializables/field_type.py +4 -1
  72. maxframe/serialization/serializables/tests/test_serializable.py +12 -2
  73. maxframe/serialization/tests/test_serial.py +2 -1
  74. maxframe/tensor/__init__.py +19 -7
  75. maxframe/tensor/merge/vstack.py +1 -1
  76. maxframe/tests/utils.py +16 -0
  77. maxframe/udf.py +27 -0
  78. maxframe/utils.py +42 -8
  79. {maxframe-1.0.0rc4.dist-info → maxframe-1.1.1.dist-info}/METADATA +2 -2
  80. {maxframe-1.0.0rc4.dist-info → maxframe-1.1.1.dist-info}/RECORD +88 -77
  81. {maxframe-1.0.0rc4.dist-info → maxframe-1.1.1.dist-info}/WHEEL +1 -1
  82. maxframe_client/clients/framedriver.py +4 -1
  83. maxframe_client/fetcher.py +23 -8
  84. maxframe_client/session/odps.py +40 -11
  85. maxframe_client/session/task.py +6 -25
  86. maxframe_client/session/tests/test_task.py +35 -6
  87. maxframe_client/tests/test_session.py +30 -10
  88. {maxframe-1.0.0rc4.dist-info → maxframe-1.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,73 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict
15
+
16
+ from ..... import opcodes
17
+ from .....serialization.serializables.core import Serializable
18
+ from .....serialization.serializables.field import StringField
19
+ from ..core import LLMOperator
20
+ from ..multi_modal import MultiModalLLM
21
+ from ..text import TextLLM
22
+
23
+
24
+ class DashScopeLLMMixin(Serializable):
25
+ __slots__ = ()
26
+
27
+ _not_supported_params = {"stream", "incremental_output"}
28
+
29
+ def validate_params(self, params: Dict[str, Any]):
30
+ for k in params.keys():
31
+ if k in self._not_supported_params:
32
+ raise ValueError(f"{k} is not supported")
33
+
34
+
35
+ class DashScopeTextLLM(TextLLM, DashScopeLLMMixin):
36
+ api_key_resource = StringField("api_key_resource", default=None)
37
+
38
+ def generate(
39
+ self,
40
+ data,
41
+ prompt_template: Dict[str, Any],
42
+ params: Dict[str, Any] = None,
43
+ ):
44
+ return DashScopeTextGenerationOperator(
45
+ model=self,
46
+ prompt_template=prompt_template,
47
+ params=params,
48
+ )(data)
49
+
50
+
51
+ class DashScopeMultiModalLLM(MultiModalLLM, DashScopeLLMMixin):
52
+ api_key_resource = StringField("api_key_resource", default=None)
53
+
54
+ def generate(
55
+ self,
56
+ data,
57
+ prompt_template: Dict[str, Any],
58
+ params: Dict[str, Any] = None,
59
+ ):
60
+ # TODO add precheck here
61
+ return DashScopeMultiModalGenerationOperator(
62
+ model=self,
63
+ prompt_template=prompt_template,
64
+ params=params,
65
+ )(data)
66
+
67
+
68
+ class DashScopeTextGenerationOperator(LLMOperator):
69
+ _op_type_ = opcodes.DASHSCOPE_TEXT_GENERATION
70
+
71
+
72
+ class DashScopeMultiModalGenerationOperator(LLMOperator):
73
+ _op_type_ = opcodes.DASHSCOPE_MULTI_MODAL_GENERATION
@@ -0,0 +1,42 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict
15
+
16
+ from ....dataframe.core import DATAFRAME_TYPE, SERIES_TYPE
17
+ from .core import LLM
18
+
19
+
20
+ class MultiModalLLM(LLM):
21
+ def generate(
22
+ self,
23
+ data,
24
+ prompt_template: Dict[str, Any],
25
+ params: Dict[str, Any] = None,
26
+ ):
27
+ raise NotImplementedError
28
+
29
+
30
+ def generate(
31
+ data,
32
+ model: MultiModalLLM,
33
+ prompt_template: Dict[str, Any],
34
+ params: Dict[str, Any] = None,
35
+ ):
36
+ if not isinstance(data, DATAFRAME_TYPE) and not isinstance(data, SERIES_TYPE):
37
+ raise ValueError("data must be a maxframe dataframe or series object")
38
+ if not isinstance(model, MultiModalLLM):
39
+ raise ValueError("model must be a MultiModalLLM object")
40
+ params = params if params is not None else dict()
41
+ model.validate_params(params)
42
+ return model.generate(data, prompt_template, params)
@@ -0,0 +1,42 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict
15
+
16
+ from ....dataframe.core import DATAFRAME_TYPE, SERIES_TYPE
17
+ from .core import LLM
18
+
19
+
20
+ class TextLLM(LLM):
21
+ def generate(
22
+ self,
23
+ data,
24
+ prompt_template: Dict[str, Any],
25
+ params: Dict[str, Any] = None,
26
+ ):
27
+ raise NotImplementedError
28
+
29
+
30
+ def generate(
31
+ data,
32
+ model: TextLLM,
33
+ prompt_template: Dict[str, Any],
34
+ params: Dict[str, Any] = None,
35
+ ):
36
+ if not isinstance(data, DATAFRAME_TYPE) and not isinstance(data, SERIES_TYPE):
37
+ raise ValueError("data must be a maxframe dataframe or series object")
38
+ if not isinstance(model, TextLLM):
39
+ raise ValueError("model must be a TextLLM object")
40
+ params = params if params is not None else dict()
41
+ model.validate_params(params)
42
+ return model.generate(data, prompt_template, params)
Binary file
@@ -55,13 +55,13 @@ def test_sparse_creation():
55
55
  s = SparseNDArray(s1_data)
56
56
  assert s.ndim == 2
57
57
  assert isinstance(s, SparseMatrix)
58
- assert_array_equal(s.toarray(), s1_data.A)
59
- assert_array_equal(s.todense(), s1_data.A)
58
+ assert_array_equal(s.toarray(), s1_data.toarray())
59
+ assert_array_equal(s.todense(), s1_data.toarray())
60
60
 
61
61
  ss = pickle.loads(pickle.dumps(s))
62
62
  assert s == ss
63
- assert_array_equal(ss.toarray(), s1_data.A)
64
- assert_array_equal(ss.todense(), s1_data.A)
63
+ assert_array_equal(ss.toarray(), s1_data.toarray())
64
+ assert_array_equal(ss.todense(), s1_data.toarray())
65
65
 
66
66
  v = SparseNDArray(v1, shape=(3,))
67
67
  assert s.ndim
@@ -331,12 +331,12 @@ def test_sparse_dot():
331
331
 
332
332
  assert_array_equal(mls.dot(s1, v1_s), s1.dot(v1_data))
333
333
  assert_array_equal(mls.dot(s2, v1_s), s2.dot(v1_data))
334
- assert_array_equal(mls.dot(v2_s, s1), v2_data.dot(s1_data.A))
335
- assert_array_equal(mls.dot(v2_s, s2), v2_data.dot(s2_data.A))
334
+ assert_array_equal(mls.dot(v2_s, s1), v2_data.dot(s1_data.toarray()))
335
+ assert_array_equal(mls.dot(v2_s, s2), v2_data.dot(s2_data.toarray()))
336
336
  assert_array_equal(mls.dot(v1_s, v1_s), v1_data.dot(v1_data), almost=True)
337
337
  assert_array_equal(mls.dot(v2_s, v2_s), v2_data.dot(v2_data), almost=True)
338
338
 
339
- assert_array_equal(mls.dot(v2_s, s1, sparse=False), v2_data.dot(s1_data.A))
339
+ assert_array_equal(mls.dot(v2_s, s1, sparse=False), v2_data.dot(s1_data.toarray()))
340
340
  assert_array_equal(mls.dot(v1_s, v1_s, sparse=False), v1_data.dot(v1_data))
341
341
 
342
342
 
@@ -390,7 +390,7 @@ def test_sparse_fill_diagonal():
390
390
  arr = SparseNDArray(s1)
391
391
  arr.fill_diagonal(3)
392
392
 
393
- expected = s1.copy().A
393
+ expected = s1.copy().toarray()
394
394
  np.fill_diagonal(expected, 3)
395
395
 
396
396
  np.testing.assert_array_equal(arr.toarray(), expected)
@@ -399,7 +399,7 @@ def test_sparse_fill_diagonal():
399
399
  arr = SparseNDArray(s1)
400
400
  arr.fill_diagonal(3, wrap=True)
401
401
 
402
- expected = s1.copy().A
402
+ expected = s1.copy().toarray()
403
403
  np.fill_diagonal(expected, 3, wrap=True)
404
404
 
405
405
  np.testing.assert_array_equal(arr.toarray(), expected)
@@ -408,7 +408,7 @@ def test_sparse_fill_diagonal():
408
408
  arr = SparseNDArray(s1)
409
409
  arr.fill_diagonal([1, 2, 3])
410
410
 
411
- expected = s1.copy().A
411
+ expected = s1.copy().toarray()
412
412
  np.fill_diagonal(expected, [1, 2, 3])
413
413
 
414
414
  np.testing.assert_array_equal(arr.toarray(), expected)
@@ -417,7 +417,7 @@ def test_sparse_fill_diagonal():
417
417
  arr = SparseNDArray(s1)
418
418
  arr.fill_diagonal([1, 2, 3], wrap=True)
419
419
 
420
- expected = s1.copy().A
420
+ expected = s1.copy().toarray()
421
421
  np.fill_diagonal(expected, [1, 2, 3], wrap=True)
422
422
 
423
423
  np.testing.assert_array_equal(arr.toarray(), expected)
@@ -427,7 +427,7 @@ def test_sparse_fill_diagonal():
427
427
  arr = SparseNDArray(s1)
428
428
  arr.fill_diagonal(val)
429
429
 
430
- expected = s1.copy().A
430
+ expected = s1.copy().toarray()
431
431
  np.fill_diagonal(expected, val)
432
432
 
433
433
  np.testing.assert_array_equal(arr.toarray(), expected)
@@ -437,7 +437,7 @@ def test_sparse_fill_diagonal():
437
437
  arr = SparseNDArray(s1)
438
438
  arr.fill_diagonal(val, wrap=True)
439
439
 
440
- expected = s1.copy().A
440
+ expected = s1.copy().toarray()
441
441
  np.fill_diagonal(expected, val, wrap=True)
442
442
 
443
443
  np.testing.assert_array_equal(arr.toarray(), expected)
@@ -447,7 +447,7 @@ def test_sparse_fill_diagonal():
447
447
  arr = SparseNDArray(s1)
448
448
  arr.fill_diagonal(val)
449
449
 
450
- expected = s1.copy().A
450
+ expected = s1.copy().toarray()
451
451
  np.fill_diagonal(expected, val)
452
452
 
453
453
  np.testing.assert_array_equal(arr.toarray(), expected)
@@ -457,7 +457,7 @@ def test_sparse_fill_diagonal():
457
457
  arr = SparseNDArray(s1)
458
458
  arr.fill_diagonal(val, wrap=True)
459
459
 
460
- expected = s1.copy().A
460
+ expected = s1.copy().toarray()
461
461
  np.fill_diagonal(expected, val, wrap=True)
462
462
 
463
463
  np.testing.assert_array_equal(arr.toarray(), expected)
maxframe/opcodes.py CHANGED
@@ -270,6 +270,7 @@ KURTOSIS = 351
270
270
  SEM = 352
271
271
  STR_CONCAT = 353
272
272
  MAD = 354
273
+ MEDIAN = 355
273
274
 
274
275
  # tensor operator
275
276
  RESHAPE = 401
@@ -377,7 +378,6 @@ DROP_DUPLICATES = 728
377
378
  MELT = 729
378
379
  RENAME = 731
379
380
  INSERT = 732
380
- MAP_CHUNK = 733
381
381
  CARTESIAN_CHUNK = 734
382
382
  EXPLODE = 735
383
383
  REPLACE = 736
@@ -392,6 +392,10 @@ PIVOT_TABLE = 744
392
392
 
393
393
  FUSE = 801
394
394
 
395
+ # LLM
396
+ DASHSCOPE_TEXT_GENERATION = 810
397
+ DASHSCOPE_MULTI_MODAL_GENERATION = 811
398
+
395
399
  # table like input for tensor
396
400
  TABLE_COO = 1003
397
401
  # store tensor as coo format
@@ -569,6 +573,8 @@ CHOLESKY_FUSE = 999988
569
573
  # MaxFrame-dedicated functions
570
574
  DATAFRAME_RESHUFFLE = 10001
571
575
  FLATMAP = 10002
576
+ FLATJSON = 10003
577
+ APPLY_CHUNK = 10004
572
578
 
573
579
  # MaxFrame internal operators
574
580
  DATAFRAME_PROJECTION_SAME_INDEX_MERGE = 100001
Binary file
@@ -37,7 +37,7 @@ from .._utils import NamedType
37
37
  from .._utils cimport TypeDispatcher
38
38
 
39
39
  from ..lib import wrapped_pickle as pickle
40
- from ..utils import arrow_type_from_str
40
+ from ..utils import NoDefault, arrow_type_from_str, no_default
41
41
 
42
42
  try:
43
43
  from pandas import ArrowDtype
@@ -94,6 +94,7 @@ cdef:
94
94
  int COMPLEX_SERIALIZER = 12
95
95
  int SLICE_SERIALIZER = 13
96
96
  int REGEX_SERIALIZER = 14
97
+ int NO_DEFAULT_SERIALIZER = 15
97
98
  int PLACEHOLDER_SERIALIZER = 4096
98
99
 
99
100
 
@@ -803,6 +804,16 @@ cdef class RegexSerializer(Serializer):
803
804
  return re.compile((<bytes>(subs[0])).decode(), serialized[0])
804
805
 
805
806
 
807
+ cdef class NoDefaultSerializer(Serializer):
808
+ serializer_id = NO_DEFAULT_SERIALIZER
809
+
810
+ cpdef serial(self, object obj, dict context):
811
+ return [], [], True
812
+
813
+ cpdef deserial(self, list obj, dict context, list subs):
814
+ return no_default
815
+
816
+
806
817
  cdef class Placeholder:
807
818
  """
808
819
  Placeholder object to reduce duplicated serialization
@@ -857,6 +868,7 @@ DtypeSerializer.register(ExtensionDtype)
857
868
  ComplexSerializer.register(complex)
858
869
  SliceSerializer.register(slice)
859
870
  RegexSerializer.register(re.Pattern)
871
+ NoDefaultSerializer.register(NoDefault)
860
872
  PlaceholderSerializer.register(Placeholder)
861
873
 
862
874
 
@@ -134,8 +134,10 @@ class ArraySerializer(Serializer):
134
134
  data_parts = [obj.tolist()]
135
135
  else:
136
136
  data_parts = [obj.to_numpy().tolist()]
137
- else:
137
+ elif hasattr(obj, "_data"):
138
138
  data_parts = [getattr(obj, "_data")]
139
+ else:
140
+ data_parts = [getattr(obj, "_pa_array")]
139
141
  return [ser_type], [dtype] + data_parts, False
140
142
 
141
143
  def deserial(self, serialized: List, context: Dict, subs: List):
@@ -155,38 +157,66 @@ class PdTimestampSerializer(Serializer):
155
157
  else:
156
158
  zone_info = []
157
159
  ts = obj.to_pydatetime().timestamp()
158
- return (
159
- [int(ts), obj.microsecond, obj.nanosecond],
160
- zone_info,
161
- bool(zone_info),
162
- )
160
+ elements = [int(ts), obj.microsecond, obj.nanosecond]
161
+ if hasattr(obj, "unit"):
162
+ elements.append(str(obj.unit))
163
+ return elements, zone_info, bool(zone_info)
163
164
 
164
165
  def deserial(self, serialized: List, context: Dict, subs: List):
165
166
  if subs:
166
- val = pd.Timestamp.utcfromtimestamp(serialized[0]).replace(
167
- microsecond=serialized[1], nanosecond=serialized[2]
168
- )
169
- val = val.replace(tzinfo=datetime.timezone.utc).tz_convert(subs[0])
167
+ pydt = datetime.datetime.utcfromtimestamp(serialized[0])
168
+ kwargs = {
169
+ "year": pydt.year,
170
+ "month": pydt.month,
171
+ "day": pydt.day,
172
+ "hour": pydt.hour,
173
+ "minute": pydt.minute,
174
+ "second": pydt.second,
175
+ "microsecond": serialized[1],
176
+ "nanosecond": serialized[2],
177
+ "tzinfo": datetime.timezone.utc,
178
+ }
179
+ if len(serialized) > 3:
180
+ kwargs["unit"] = serialized[3]
181
+ val = pd.Timestamp(**kwargs).tz_convert(subs[0])
170
182
  else:
171
- val = pd.Timestamp.fromtimestamp(serialized[0]).replace(
172
- microsecond=serialized[1], nanosecond=serialized[2]
173
- )
183
+ pydt = datetime.datetime.fromtimestamp(serialized[0])
184
+ kwargs = {
185
+ "year": pydt.year,
186
+ "month": pydt.month,
187
+ "day": pydt.day,
188
+ "hour": pydt.hour,
189
+ "minute": pydt.minute,
190
+ "second": pydt.second,
191
+ "microsecond": serialized[1],
192
+ "nanosecond": serialized[2],
193
+ }
194
+ if len(serialized) >= 4:
195
+ kwargs["unit"] = serialized[3]
196
+ val = pd.Timestamp(**kwargs)
174
197
  return val
175
198
 
176
199
 
177
200
  class PdTimedeltaSerializer(Serializer):
178
201
  def serial(self, obj: pd.Timedelta, context: Dict):
179
- return [int(obj.seconds), obj.microseconds, obj.nanoseconds, obj.days], [], True
202
+ elements = [int(obj.seconds), obj.microseconds, obj.nanoseconds, obj.days]
203
+ if hasattr(obj, "unit"):
204
+ elements.append(str(obj.unit))
205
+ return elements, [], True
180
206
 
181
207
  def deserial(self, serialized: List, context: Dict, subs: List):
182
208
  days = 0 if len(serialized) < 4 else serialized[3]
209
+ unit = None if len(serialized) < 5 else serialized[4]
183
210
  seconds, microseconds, nanoseconds = serialized[:3]
184
- return pd.Timedelta(
185
- days=days,
186
- seconds=seconds,
187
- microseconds=microseconds,
188
- nanoseconds=nanoseconds,
189
- )
211
+ kwargs = {
212
+ "days": days,
213
+ "seconds": seconds,
214
+ "microseconds": microseconds,
215
+ "nanoseconds": nanoseconds,
216
+ }
217
+ if unit is not None:
218
+ kwargs["unit"] = unit
219
+ return pd.Timedelta(**kwargs)
190
220
 
191
221
 
192
222
  class NoDefaultSerializer(Serializer):
@@ -13,12 +13,13 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import weakref
16
- from collections import defaultdict
16
+ from collections import OrderedDict
17
17
  from typing import Any, Dict, List, Optional, Tuple, Type
18
18
 
19
19
  import msgpack
20
20
 
21
21
  from ...lib.mmh3 import hash
22
+ from ...utils import no_default
22
23
  from ..core import Placeholder, Serializer, buffered, load_type
23
24
  from .field import Field
24
25
  from .field_type import DictType, ListType, PrimitiveFieldType, TupleType
@@ -97,14 +98,18 @@ class SerializableMeta(type):
97
98
  non_primitive_fields.append(v)
98
99
 
99
100
  # count number of fields for every base class
100
- cls_to_primitive_field_count = defaultdict(lambda: 0)
101
- cls_to_non_primitive_field_count = defaultdict(lambda: 0)
101
+ cls_to_primitive_field_count = OrderedDict()
102
+ cls_to_non_primitive_field_count = OrderedDict()
102
103
  for field_name in field_order:
103
104
  cls_hash = field_to_cls_hash[field_name]
104
105
  if field_name in primitive_field_names:
105
- cls_to_primitive_field_count[cls_hash] += 1
106
+ cls_to_primitive_field_count[cls_hash] = (
107
+ cls_to_primitive_field_count.get(cls_hash, 0) + 1
108
+ )
106
109
  else:
107
- cls_to_non_primitive_field_count[cls_hash] += 1
110
+ cls_to_non_primitive_field_count[cls_hash] = (
111
+ cls_to_non_primitive_field_count.get(cls_hash, 0) + 1
112
+ )
108
113
 
109
114
  slots = set(properties.pop("__slots__", set()))
110
115
  slots.update(properties_field_slot_names)
@@ -119,9 +124,11 @@ class SerializableMeta(type):
119
124
  properties["_FIELD_ORDER"] = field_order
120
125
  properties["_FIELD_TO_NAME_HASH"] = field_to_cls_hash
121
126
  properties["_PRIMITIVE_FIELDS"] = primitive_fields
122
- properties["_CLS_TO_PRIMITIVE_FIELD_COUNT"] = dict(cls_to_primitive_field_count)
127
+ properties["_CLS_TO_PRIMITIVE_FIELD_COUNT"] = OrderedDict(
128
+ cls_to_primitive_field_count
129
+ )
123
130
  properties["_NON_PRIMITIVE_FIELDS"] = non_primitive_fields
124
- properties["_CLS_TO_NON_PRIMITIVE_FIELD_COUNT"] = dict(
131
+ properties["_CLS_TO_NON_PRIMITIVE_FIELD_COUNT"] = OrderedDict(
125
132
  cls_to_non_primitive_field_count
126
133
  )
127
134
  properties["__slots__"] = tuple(slots)
@@ -211,6 +218,22 @@ class _NoFieldValue:
211
218
  _no_field_value = _NoFieldValue()
212
219
 
213
220
 
221
+ def _to_primitive_placeholder(v: Any) -> Any:
222
+ if v is _no_field_value or v is no_default:
223
+ return {}
224
+ return v
225
+
226
+
227
+ def _restore_primitive_placeholder(v: Any) -> Any:
228
+ if type(v) is dict:
229
+ if v == {}:
230
+ return _no_field_value
231
+ else:
232
+ return v
233
+ else:
234
+ return v
235
+
236
+
214
237
  class SerializableSerializer(Serializer):
215
238
  """
216
239
  Leverage DictSerializer to perform serde.
@@ -241,9 +264,7 @@ class SerializableSerializer(Serializer):
241
264
  else:
242
265
  primitive_vals = self._get_field_values(obj, obj._PRIMITIVE_FIELDS)
243
266
  # replace _no_field_value as {} to make them msgpack-serializable
244
- primitive_vals = [
245
- v if v is not _no_field_value else {} for v in primitive_vals
246
- ]
267
+ primitive_vals = [_to_primitive_placeholder(v) for v in primitive_vals]
247
268
  if obj._cache_primitive_serial:
248
269
  primitive_vals = msgpack.dumps(primitive_vals)
249
270
  _primitive_serial_cache[obj] = primitive_vals
@@ -281,21 +302,51 @@ class SerializableSerializer(Serializer):
281
302
  else:
282
303
  field.set(obj, value)
283
304
 
305
+ @classmethod
306
+ def _prune_server_fields(
307
+ cls,
308
+ client_cls_to_field_count: Optional[Dict[int, int]],
309
+ server_cls_to_field_count: Dict[int, int],
310
+ server_fields: list,
311
+ ) -> list:
312
+ if not client_cls_to_field_count: # pragma: no cover
313
+ # todo remove this branch when all versions below v0.1.0b5 is eliminated
314
+ return server_fields
315
+ if set(client_cls_to_field_count.keys()) == set(
316
+ server_cls_to_field_count.keys()
317
+ ):
318
+ return server_fields
319
+ ret_server_fields = []
320
+ server_pos = 0
321
+ for cls_hash, count in server_cls_to_field_count.items():
322
+ if cls_hash in client_cls_to_field_count:
323
+ ret_server_fields.extend(server_fields[server_pos : server_pos + count])
324
+ server_pos += count
325
+ return ret_server_fields
326
+
284
327
  @classmethod
285
328
  def _set_field_values(
286
329
  cls,
287
330
  obj: Serializable,
288
331
  values: List[Any],
289
- client_cls_to_field_count: Optional[Dict[str, int]],
332
+ client_cls_to_field_count: Optional[Dict[int, int]],
290
333
  is_primitive: bool = True,
291
334
  ):
292
335
  obj_class = type(obj)
293
336
  if is_primitive:
294
337
  server_cls_to_field_count = obj_class._CLS_TO_PRIMITIVE_FIELD_COUNT
295
- server_fields = obj_class._PRIMITIVE_FIELDS
338
+ server_fields = cls._prune_server_fields(
339
+ client_cls_to_field_count,
340
+ server_cls_to_field_count,
341
+ obj_class._PRIMITIVE_FIELDS,
342
+ )
296
343
  else:
297
344
  server_cls_to_field_count = obj_class._CLS_TO_NON_PRIMITIVE_FIELD_COUNT
298
- server_fields = obj_class._NON_PRIMITIVE_FIELDS
345
+ server_fields = cls._prune_server_fields(
346
+ client_cls_to_field_count,
347
+ server_cls_to_field_count,
348
+ obj_class._NON_PRIMITIVE_FIELDS,
349
+ )
299
350
 
300
351
  legacy_to_new_hash = {
301
352
  c._LEGACY_NAME_HASH: c._NAME_HASH
@@ -311,7 +362,9 @@ class SerializableSerializer(Serializer):
311
362
  cls_fields = server_fields[server_field_num : field_num + count]
312
363
  cls_values = values[field_num : field_num + count]
313
364
  for field, value in zip(cls_fields, cls_values):
314
- if not is_primitive or value != {}:
365
+ if is_primitive:
366
+ value = _restore_primitive_placeholder(value)
367
+ if not is_primitive or value is not _no_field_value:
315
368
  cls._set_field_value(obj, field, value)
316
369
  field_num += count
317
370
  try:
@@ -356,7 +409,9 @@ class SerializableSerializer(Serializer):
356
409
  server_fields + deprecated_fields, key=lambda f: f.name
357
410
  )
358
411
  for field, value in zip(server_fields, values):
359
- if not is_primitive or value != {}:
412
+ if is_primitive:
413
+ value = _restore_primitive_placeholder(value)
414
+ if not is_primitive or value is not _no_field_value:
360
415
  try:
361
416
  cls._set_field_value(obj, field, value)
362
417
  except AttributeError: # pragma: no cover
@@ -46,6 +46,9 @@ class PrimitiveType(Enum):
46
46
  complex128 = 25
47
47
 
48
48
 
49
+ _np_unicode = np.unicode_ if hasattr(np, "unicode_") else np.str_
50
+
51
+
49
52
  _primitive_type_to_valid_types = {
50
53
  PrimitiveType.bool: (bool, np.bool_),
51
54
  PrimitiveType.int8: (int, np.int8),
@@ -60,7 +63,7 @@ _primitive_type_to_valid_types = {
60
63
  PrimitiveType.float32: (float, np.float32),
61
64
  PrimitiveType.float64: (float, np.float64),
62
65
  PrimitiveType.bytes: (bytes, np.bytes_),
63
- PrimitiveType.string: (str, np.unicode_),
66
+ PrimitiveType.string: (str, _np_unicode),
64
67
  PrimitiveType.complex64: (complex, np.complex64),
65
68
  PrimitiveType.complex128: (complex, np.complex128),
66
69
  }
@@ -21,6 +21,7 @@ import pytest
21
21
 
22
22
  from ....core import EntityData
23
23
  from ....lib.wrapped_pickle import switch_unpickle
24
+ from ....utils import no_default
24
25
  from ... import deserialize, serialize
25
26
  from .. import (
26
27
  AnyField,
@@ -143,6 +144,7 @@ class MySerializable(Serializable):
143
144
  oneof1_val=f"{__name__}.MySerializable",
144
145
  oneof2_val=MySimpleSerializable,
145
146
  )
147
+ _no_default_val = Float64Field("no_default_val", default=no_default)
146
148
 
147
149
 
148
150
  @pytest.mark.parametrize("set_is_ci", [False, True], indirect=True)
@@ -187,6 +189,7 @@ def test_serializable(set_is_ci):
187
189
  _dict_val={"a": b"bytes_value"},
188
190
  _ref_val=MySerializable(),
189
191
  _oneof_val=MySerializable(_id="2"),
192
+ _no_default_val=no_default,
190
193
  )
191
194
 
192
195
  header, buffers = serialize(my_serializable)
@@ -218,7 +221,10 @@ def test_compatible_serializable(set_is_ci):
218
221
  _ref_val = ReferenceField("ref_val", "MySimpleSerializable")
219
222
  _dict_val = DictField("dict_val")
220
223
 
221
- class MySubSerializable(MySimpleSerializable):
224
+ class MyMidSerializable(MySimpleSerializable):
225
+ _i_bool_val = Int64Field("i_bool_val", default=True)
226
+
227
+ class MySubSerializable(MyMidSerializable):
222
228
  _m_int_val = Int64Field("m_int_val", default=250)
223
229
  _m_str_val = StringField("m_str_val", default="SUB_STR")
224
230
 
@@ -234,7 +240,11 @@ def _assert_serializable_eq(my_serializable, my_serializable2):
234
240
  if not hasattr(my_serializable, field.name):
235
241
  continue
236
242
  expect_value = getattr(my_serializable, field_name)
237
- actual_value = getattr(my_serializable2, field_name)
243
+ if expect_value is no_default:
244
+ assert not hasattr(my_serializable2, field.name)
245
+ continue
246
+ else:
247
+ actual_value = getattr(my_serializable2, field_name)
238
248
  if isinstance(expect_value, np.ndarray):
239
249
  np.testing.assert_array_equal(expect_value, actual_value)
240
250
  elif isinstance(expect_value, pd.DataFrame):