maxframe 1.0.0rc4__cp310-cp310-win_amd64.whl → 1.1.1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of maxframe might be problematic. Click here for more details.
- maxframe/_utils.cp310-win_amd64.pyd +0 -0
- maxframe/config/__init__.py +1 -1
- maxframe/config/config.py +26 -0
- maxframe/config/tests/test_config.py +20 -1
- maxframe/conftest.py +17 -4
- maxframe/core/graph/core.cp310-win_amd64.pyd +0 -0
- maxframe/core/operator/base.py +2 -0
- maxframe/dataframe/arithmetic/tests/test_arithmetic.py +17 -16
- maxframe/dataframe/core.py +24 -2
- maxframe/dataframe/datasource/read_odps_query.py +65 -35
- maxframe/dataframe/datasource/read_odps_table.py +4 -2
- maxframe/dataframe/datasource/tests/test_datasource.py +59 -7
- maxframe/dataframe/extensions/__init__.py +5 -0
- maxframe/dataframe/extensions/apply_chunk.py +649 -0
- maxframe/dataframe/extensions/flatjson.py +131 -0
- maxframe/dataframe/extensions/flatmap.py +28 -40
- maxframe/dataframe/extensions/reshuffle.py +1 -1
- maxframe/dataframe/extensions/tests/test_apply_chunk.py +186 -0
- maxframe/dataframe/extensions/tests/test_extensions.py +46 -2
- maxframe/dataframe/groupby/__init__.py +1 -0
- maxframe/dataframe/groupby/aggregation.py +1 -0
- maxframe/dataframe/groupby/apply.py +9 -1
- maxframe/dataframe/groupby/core.py +1 -1
- maxframe/dataframe/groupby/fill.py +4 -1
- maxframe/dataframe/groupby/getitem.py +6 -0
- maxframe/dataframe/groupby/tests/test_groupby.py +1 -1
- maxframe/dataframe/groupby/transform.py +8 -2
- maxframe/dataframe/indexing/loc.py +6 -4
- maxframe/dataframe/merge/__init__.py +9 -1
- maxframe/dataframe/merge/concat.py +41 -31
- maxframe/dataframe/merge/merge.py +1 -1
- maxframe/dataframe/merge/tests/test_merge.py +3 -1
- maxframe/dataframe/misc/apply.py +3 -0
- maxframe/dataframe/misc/drop_duplicates.py +5 -1
- maxframe/dataframe/misc/map.py +3 -1
- maxframe/dataframe/misc/tests/test_misc.py +24 -2
- maxframe/dataframe/misc/transform.py +22 -13
- maxframe/dataframe/reduction/__init__.py +3 -0
- maxframe/dataframe/reduction/aggregation.py +1 -0
- maxframe/dataframe/reduction/median.py +56 -0
- maxframe/dataframe/reduction/tests/test_reduction.py +17 -7
- maxframe/dataframe/statistics/quantile.py +8 -2
- maxframe/dataframe/statistics/tests/test_statistics.py +4 -4
- maxframe/dataframe/tests/test_utils.py +60 -0
- maxframe/dataframe/utils.py +110 -7
- maxframe/dataframe/window/expanding.py +5 -3
- maxframe/dataframe/window/tests/test_expanding.py +2 -2
- maxframe/io/objects/tests/test_object_io.py +39 -12
- maxframe/io/odpsio/__init__.py +1 -1
- maxframe/io/odpsio/arrow.py +51 -2
- maxframe/io/odpsio/schema.py +23 -5
- maxframe/io/odpsio/tableio.py +80 -124
- maxframe/io/odpsio/tests/test_schema.py +40 -0
- maxframe/io/odpsio/tests/test_tableio.py +5 -5
- maxframe/io/odpsio/tests/test_volumeio.py +35 -11
- maxframe/io/odpsio/volumeio.py +27 -3
- maxframe/learn/contrib/__init__.py +3 -2
- maxframe/learn/contrib/llm/__init__.py +16 -0
- maxframe/learn/contrib/llm/core.py +54 -0
- maxframe/learn/contrib/llm/models/__init__.py +14 -0
- maxframe/learn/contrib/llm/models/dashscope.py +73 -0
- maxframe/learn/contrib/llm/multi_modal.py +42 -0
- maxframe/learn/contrib/llm/text.py +42 -0
- maxframe/lib/mmh3.cp310-win_amd64.pyd +0 -0
- maxframe/lib/sparse/tests/test_sparse.py +15 -15
- maxframe/opcodes.py +7 -1
- maxframe/serialization/core.cp310-win_amd64.pyd +0 -0
- maxframe/serialization/core.pyx +13 -1
- maxframe/serialization/pandas.py +50 -20
- maxframe/serialization/serializables/core.py +70 -15
- maxframe/serialization/serializables/field_type.py +4 -1
- maxframe/serialization/serializables/tests/test_serializable.py +12 -2
- maxframe/serialization/tests/test_serial.py +2 -1
- maxframe/tensor/__init__.py +19 -7
- maxframe/tensor/merge/vstack.py +1 -1
- maxframe/tests/utils.py +16 -0
- maxframe/udf.py +27 -0
- maxframe/utils.py +42 -8
- {maxframe-1.0.0rc4.dist-info → maxframe-1.1.1.dist-info}/METADATA +4 -4
- {maxframe-1.0.0rc4.dist-info → maxframe-1.1.1.dist-info}/RECORD +88 -77
- {maxframe-1.0.0rc4.dist-info → maxframe-1.1.1.dist-info}/WHEEL +1 -1
- maxframe_client/clients/framedriver.py +4 -1
- maxframe_client/fetcher.py +23 -8
- maxframe_client/session/odps.py +40 -11
- maxframe_client/session/task.py +6 -25
- maxframe_client/session/tests/test_task.py +35 -6
- maxframe_client/tests/test_session.py +30 -10
- {maxframe-1.0.0rc4.dist-info → maxframe-1.1.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
# Copyright 1999-2024 Alibaba Group Holding Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
from typing import Any, Dict
|
|
15
|
+
|
|
16
|
+
from ..... import opcodes
|
|
17
|
+
from .....serialization.serializables.core import Serializable
|
|
18
|
+
from .....serialization.serializables.field import StringField
|
|
19
|
+
from ..core import LLMOperator
|
|
20
|
+
from ..multi_modal import MultiModalLLM
|
|
21
|
+
from ..text import TextLLM
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class DashScopeLLMMixin(Serializable):
|
|
25
|
+
__slots__ = ()
|
|
26
|
+
|
|
27
|
+
_not_supported_params = {"stream", "incremental_output"}
|
|
28
|
+
|
|
29
|
+
def validate_params(self, params: Dict[str, Any]):
|
|
30
|
+
for k in params.keys():
|
|
31
|
+
if k in self._not_supported_params:
|
|
32
|
+
raise ValueError(f"{k} is not supported")
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class DashScopeTextLLM(TextLLM, DashScopeLLMMixin):
|
|
36
|
+
api_key_resource = StringField("api_key_resource", default=None)
|
|
37
|
+
|
|
38
|
+
def generate(
|
|
39
|
+
self,
|
|
40
|
+
data,
|
|
41
|
+
prompt_template: Dict[str, Any],
|
|
42
|
+
params: Dict[str, Any] = None,
|
|
43
|
+
):
|
|
44
|
+
return DashScopeTextGenerationOperator(
|
|
45
|
+
model=self,
|
|
46
|
+
prompt_template=prompt_template,
|
|
47
|
+
params=params,
|
|
48
|
+
)(data)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class DashScopeMultiModalLLM(MultiModalLLM, DashScopeLLMMixin):
|
|
52
|
+
api_key_resource = StringField("api_key_resource", default=None)
|
|
53
|
+
|
|
54
|
+
def generate(
|
|
55
|
+
self,
|
|
56
|
+
data,
|
|
57
|
+
prompt_template: Dict[str, Any],
|
|
58
|
+
params: Dict[str, Any] = None,
|
|
59
|
+
):
|
|
60
|
+
# TODO add precheck here
|
|
61
|
+
return DashScopeMultiModalGenerationOperator(
|
|
62
|
+
model=self,
|
|
63
|
+
prompt_template=prompt_template,
|
|
64
|
+
params=params,
|
|
65
|
+
)(data)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class DashScopeTextGenerationOperator(LLMOperator):
|
|
69
|
+
_op_type_ = opcodes.DASHSCOPE_TEXT_GENERATION
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class DashScopeMultiModalGenerationOperator(LLMOperator):
|
|
73
|
+
_op_type_ = opcodes.DASHSCOPE_MULTI_MODAL_GENERATION
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
# Copyright 1999-2024 Alibaba Group Holding Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
from typing import Any, Dict
|
|
15
|
+
|
|
16
|
+
from ....dataframe.core import DATAFRAME_TYPE, SERIES_TYPE
|
|
17
|
+
from .core import LLM
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class MultiModalLLM(LLM):
|
|
21
|
+
def generate(
|
|
22
|
+
self,
|
|
23
|
+
data,
|
|
24
|
+
prompt_template: Dict[str, Any],
|
|
25
|
+
params: Dict[str, Any] = None,
|
|
26
|
+
):
|
|
27
|
+
raise NotImplementedError
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def generate(
|
|
31
|
+
data,
|
|
32
|
+
model: MultiModalLLM,
|
|
33
|
+
prompt_template: Dict[str, Any],
|
|
34
|
+
params: Dict[str, Any] = None,
|
|
35
|
+
):
|
|
36
|
+
if not isinstance(data, DATAFRAME_TYPE) and not isinstance(data, SERIES_TYPE):
|
|
37
|
+
raise ValueError("data must be a maxframe dataframe or series object")
|
|
38
|
+
if not isinstance(model, MultiModalLLM):
|
|
39
|
+
raise ValueError("model must be a MultiModalLLM object")
|
|
40
|
+
params = params if params is not None else dict()
|
|
41
|
+
model.validate_params(params)
|
|
42
|
+
return model.generate(data, prompt_template, params)
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
# Copyright 1999-2024 Alibaba Group Holding Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
from typing import Any, Dict
|
|
15
|
+
|
|
16
|
+
from ....dataframe.core import DATAFRAME_TYPE, SERIES_TYPE
|
|
17
|
+
from .core import LLM
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class TextLLM(LLM):
|
|
21
|
+
def generate(
|
|
22
|
+
self,
|
|
23
|
+
data,
|
|
24
|
+
prompt_template: Dict[str, Any],
|
|
25
|
+
params: Dict[str, Any] = None,
|
|
26
|
+
):
|
|
27
|
+
raise NotImplementedError
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def generate(
|
|
31
|
+
data,
|
|
32
|
+
model: TextLLM,
|
|
33
|
+
prompt_template: Dict[str, Any],
|
|
34
|
+
params: Dict[str, Any] = None,
|
|
35
|
+
):
|
|
36
|
+
if not isinstance(data, DATAFRAME_TYPE) and not isinstance(data, SERIES_TYPE):
|
|
37
|
+
raise ValueError("data must be a maxframe dataframe or series object")
|
|
38
|
+
if not isinstance(model, TextLLM):
|
|
39
|
+
raise ValueError("model must be a TextLLM object")
|
|
40
|
+
params = params if params is not None else dict()
|
|
41
|
+
model.validate_params(params)
|
|
42
|
+
return model.generate(data, prompt_template, params)
|
|
Binary file
|
|
@@ -55,13 +55,13 @@ def test_sparse_creation():
|
|
|
55
55
|
s = SparseNDArray(s1_data)
|
|
56
56
|
assert s.ndim == 2
|
|
57
57
|
assert isinstance(s, SparseMatrix)
|
|
58
|
-
assert_array_equal(s.toarray(), s1_data.
|
|
59
|
-
assert_array_equal(s.todense(), s1_data.
|
|
58
|
+
assert_array_equal(s.toarray(), s1_data.toarray())
|
|
59
|
+
assert_array_equal(s.todense(), s1_data.toarray())
|
|
60
60
|
|
|
61
61
|
ss = pickle.loads(pickle.dumps(s))
|
|
62
62
|
assert s == ss
|
|
63
|
-
assert_array_equal(ss.toarray(), s1_data.
|
|
64
|
-
assert_array_equal(ss.todense(), s1_data.
|
|
63
|
+
assert_array_equal(ss.toarray(), s1_data.toarray())
|
|
64
|
+
assert_array_equal(ss.todense(), s1_data.toarray())
|
|
65
65
|
|
|
66
66
|
v = SparseNDArray(v1, shape=(3,))
|
|
67
67
|
assert s.ndim
|
|
@@ -331,12 +331,12 @@ def test_sparse_dot():
|
|
|
331
331
|
|
|
332
332
|
assert_array_equal(mls.dot(s1, v1_s), s1.dot(v1_data))
|
|
333
333
|
assert_array_equal(mls.dot(s2, v1_s), s2.dot(v1_data))
|
|
334
|
-
assert_array_equal(mls.dot(v2_s, s1), v2_data.dot(s1_data.
|
|
335
|
-
assert_array_equal(mls.dot(v2_s, s2), v2_data.dot(s2_data.
|
|
334
|
+
assert_array_equal(mls.dot(v2_s, s1), v2_data.dot(s1_data.toarray()))
|
|
335
|
+
assert_array_equal(mls.dot(v2_s, s2), v2_data.dot(s2_data.toarray()))
|
|
336
336
|
assert_array_equal(mls.dot(v1_s, v1_s), v1_data.dot(v1_data), almost=True)
|
|
337
337
|
assert_array_equal(mls.dot(v2_s, v2_s), v2_data.dot(v2_data), almost=True)
|
|
338
338
|
|
|
339
|
-
assert_array_equal(mls.dot(v2_s, s1, sparse=False), v2_data.dot(s1_data.
|
|
339
|
+
assert_array_equal(mls.dot(v2_s, s1, sparse=False), v2_data.dot(s1_data.toarray()))
|
|
340
340
|
assert_array_equal(mls.dot(v1_s, v1_s, sparse=False), v1_data.dot(v1_data))
|
|
341
341
|
|
|
342
342
|
|
|
@@ -390,7 +390,7 @@ def test_sparse_fill_diagonal():
|
|
|
390
390
|
arr = SparseNDArray(s1)
|
|
391
391
|
arr.fill_diagonal(3)
|
|
392
392
|
|
|
393
|
-
expected = s1.copy().
|
|
393
|
+
expected = s1.copy().toarray()
|
|
394
394
|
np.fill_diagonal(expected, 3)
|
|
395
395
|
|
|
396
396
|
np.testing.assert_array_equal(arr.toarray(), expected)
|
|
@@ -399,7 +399,7 @@ def test_sparse_fill_diagonal():
|
|
|
399
399
|
arr = SparseNDArray(s1)
|
|
400
400
|
arr.fill_diagonal(3, wrap=True)
|
|
401
401
|
|
|
402
|
-
expected = s1.copy().
|
|
402
|
+
expected = s1.copy().toarray()
|
|
403
403
|
np.fill_diagonal(expected, 3, wrap=True)
|
|
404
404
|
|
|
405
405
|
np.testing.assert_array_equal(arr.toarray(), expected)
|
|
@@ -408,7 +408,7 @@ def test_sparse_fill_diagonal():
|
|
|
408
408
|
arr = SparseNDArray(s1)
|
|
409
409
|
arr.fill_diagonal([1, 2, 3])
|
|
410
410
|
|
|
411
|
-
expected = s1.copy().
|
|
411
|
+
expected = s1.copy().toarray()
|
|
412
412
|
np.fill_diagonal(expected, [1, 2, 3])
|
|
413
413
|
|
|
414
414
|
np.testing.assert_array_equal(arr.toarray(), expected)
|
|
@@ -417,7 +417,7 @@ def test_sparse_fill_diagonal():
|
|
|
417
417
|
arr = SparseNDArray(s1)
|
|
418
418
|
arr.fill_diagonal([1, 2, 3], wrap=True)
|
|
419
419
|
|
|
420
|
-
expected = s1.copy().
|
|
420
|
+
expected = s1.copy().toarray()
|
|
421
421
|
np.fill_diagonal(expected, [1, 2, 3], wrap=True)
|
|
422
422
|
|
|
423
423
|
np.testing.assert_array_equal(arr.toarray(), expected)
|
|
@@ -427,7 +427,7 @@ def test_sparse_fill_diagonal():
|
|
|
427
427
|
arr = SparseNDArray(s1)
|
|
428
428
|
arr.fill_diagonal(val)
|
|
429
429
|
|
|
430
|
-
expected = s1.copy().
|
|
430
|
+
expected = s1.copy().toarray()
|
|
431
431
|
np.fill_diagonal(expected, val)
|
|
432
432
|
|
|
433
433
|
np.testing.assert_array_equal(arr.toarray(), expected)
|
|
@@ -437,7 +437,7 @@ def test_sparse_fill_diagonal():
|
|
|
437
437
|
arr = SparseNDArray(s1)
|
|
438
438
|
arr.fill_diagonal(val, wrap=True)
|
|
439
439
|
|
|
440
|
-
expected = s1.copy().
|
|
440
|
+
expected = s1.copy().toarray()
|
|
441
441
|
np.fill_diagonal(expected, val, wrap=True)
|
|
442
442
|
|
|
443
443
|
np.testing.assert_array_equal(arr.toarray(), expected)
|
|
@@ -447,7 +447,7 @@ def test_sparse_fill_diagonal():
|
|
|
447
447
|
arr = SparseNDArray(s1)
|
|
448
448
|
arr.fill_diagonal(val)
|
|
449
449
|
|
|
450
|
-
expected = s1.copy().
|
|
450
|
+
expected = s1.copy().toarray()
|
|
451
451
|
np.fill_diagonal(expected, val)
|
|
452
452
|
|
|
453
453
|
np.testing.assert_array_equal(arr.toarray(), expected)
|
|
@@ -457,7 +457,7 @@ def test_sparse_fill_diagonal():
|
|
|
457
457
|
arr = SparseNDArray(s1)
|
|
458
458
|
arr.fill_diagonal(val, wrap=True)
|
|
459
459
|
|
|
460
|
-
expected = s1.copy().
|
|
460
|
+
expected = s1.copy().toarray()
|
|
461
461
|
np.fill_diagonal(expected, val, wrap=True)
|
|
462
462
|
|
|
463
463
|
np.testing.assert_array_equal(arr.toarray(), expected)
|
maxframe/opcodes.py
CHANGED
|
@@ -270,6 +270,7 @@ KURTOSIS = 351
|
|
|
270
270
|
SEM = 352
|
|
271
271
|
STR_CONCAT = 353
|
|
272
272
|
MAD = 354
|
|
273
|
+
MEDIAN = 355
|
|
273
274
|
|
|
274
275
|
# tensor operator
|
|
275
276
|
RESHAPE = 401
|
|
@@ -377,7 +378,6 @@ DROP_DUPLICATES = 728
|
|
|
377
378
|
MELT = 729
|
|
378
379
|
RENAME = 731
|
|
379
380
|
INSERT = 732
|
|
380
|
-
MAP_CHUNK = 733
|
|
381
381
|
CARTESIAN_CHUNK = 734
|
|
382
382
|
EXPLODE = 735
|
|
383
383
|
REPLACE = 736
|
|
@@ -392,6 +392,10 @@ PIVOT_TABLE = 744
|
|
|
392
392
|
|
|
393
393
|
FUSE = 801
|
|
394
394
|
|
|
395
|
+
# LLM
|
|
396
|
+
DASHSCOPE_TEXT_GENERATION = 810
|
|
397
|
+
DASHSCOPE_MULTI_MODAL_GENERATION = 811
|
|
398
|
+
|
|
395
399
|
# table like input for tensor
|
|
396
400
|
TABLE_COO = 1003
|
|
397
401
|
# store tensor as coo format
|
|
@@ -569,6 +573,8 @@ CHOLESKY_FUSE = 999988
|
|
|
569
573
|
# MaxFrame-dedicated functions
|
|
570
574
|
DATAFRAME_RESHUFFLE = 10001
|
|
571
575
|
FLATMAP = 10002
|
|
576
|
+
FLATJSON = 10003
|
|
577
|
+
APPLY_CHUNK = 10004
|
|
572
578
|
|
|
573
579
|
# MaxFrame internal operators
|
|
574
580
|
DATAFRAME_PROJECTION_SAME_INDEX_MERGE = 100001
|
|
Binary file
|
maxframe/serialization/core.pyx
CHANGED
|
@@ -37,7 +37,7 @@ from .._utils import NamedType
|
|
|
37
37
|
from .._utils cimport TypeDispatcher
|
|
38
38
|
|
|
39
39
|
from ..lib import wrapped_pickle as pickle
|
|
40
|
-
from ..utils import arrow_type_from_str
|
|
40
|
+
from ..utils import NoDefault, arrow_type_from_str, no_default
|
|
41
41
|
|
|
42
42
|
try:
|
|
43
43
|
from pandas import ArrowDtype
|
|
@@ -94,6 +94,7 @@ cdef:
|
|
|
94
94
|
int COMPLEX_SERIALIZER = 12
|
|
95
95
|
int SLICE_SERIALIZER = 13
|
|
96
96
|
int REGEX_SERIALIZER = 14
|
|
97
|
+
int NO_DEFAULT_SERIALIZER = 15
|
|
97
98
|
int PLACEHOLDER_SERIALIZER = 4096
|
|
98
99
|
|
|
99
100
|
|
|
@@ -803,6 +804,16 @@ cdef class RegexSerializer(Serializer):
|
|
|
803
804
|
return re.compile((<bytes>(subs[0])).decode(), serialized[0])
|
|
804
805
|
|
|
805
806
|
|
|
807
|
+
cdef class NoDefaultSerializer(Serializer):
|
|
808
|
+
serializer_id = NO_DEFAULT_SERIALIZER
|
|
809
|
+
|
|
810
|
+
cpdef serial(self, object obj, dict context):
|
|
811
|
+
return [], [], True
|
|
812
|
+
|
|
813
|
+
cpdef deserial(self, list obj, dict context, list subs):
|
|
814
|
+
return no_default
|
|
815
|
+
|
|
816
|
+
|
|
806
817
|
cdef class Placeholder:
|
|
807
818
|
"""
|
|
808
819
|
Placeholder object to reduce duplicated serialization
|
|
@@ -857,6 +868,7 @@ DtypeSerializer.register(ExtensionDtype)
|
|
|
857
868
|
ComplexSerializer.register(complex)
|
|
858
869
|
SliceSerializer.register(slice)
|
|
859
870
|
RegexSerializer.register(re.Pattern)
|
|
871
|
+
NoDefaultSerializer.register(NoDefault)
|
|
860
872
|
PlaceholderSerializer.register(Placeholder)
|
|
861
873
|
|
|
862
874
|
|
maxframe/serialization/pandas.py
CHANGED
|
@@ -134,8 +134,10 @@ class ArraySerializer(Serializer):
|
|
|
134
134
|
data_parts = [obj.tolist()]
|
|
135
135
|
else:
|
|
136
136
|
data_parts = [obj.to_numpy().tolist()]
|
|
137
|
-
|
|
137
|
+
elif hasattr(obj, "_data"):
|
|
138
138
|
data_parts = [getattr(obj, "_data")]
|
|
139
|
+
else:
|
|
140
|
+
data_parts = [getattr(obj, "_pa_array")]
|
|
139
141
|
return [ser_type], [dtype] + data_parts, False
|
|
140
142
|
|
|
141
143
|
def deserial(self, serialized: List, context: Dict, subs: List):
|
|
@@ -155,38 +157,66 @@ class PdTimestampSerializer(Serializer):
|
|
|
155
157
|
else:
|
|
156
158
|
zone_info = []
|
|
157
159
|
ts = obj.to_pydatetime().timestamp()
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
)
|
|
160
|
+
elements = [int(ts), obj.microsecond, obj.nanosecond]
|
|
161
|
+
if hasattr(obj, "unit"):
|
|
162
|
+
elements.append(str(obj.unit))
|
|
163
|
+
return elements, zone_info, bool(zone_info)
|
|
163
164
|
|
|
164
165
|
def deserial(self, serialized: List, context: Dict, subs: List):
|
|
165
166
|
if subs:
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
167
|
+
pydt = datetime.datetime.utcfromtimestamp(serialized[0])
|
|
168
|
+
kwargs = {
|
|
169
|
+
"year": pydt.year,
|
|
170
|
+
"month": pydt.month,
|
|
171
|
+
"day": pydt.day,
|
|
172
|
+
"hour": pydt.hour,
|
|
173
|
+
"minute": pydt.minute,
|
|
174
|
+
"second": pydt.second,
|
|
175
|
+
"microsecond": serialized[1],
|
|
176
|
+
"nanosecond": serialized[2],
|
|
177
|
+
"tzinfo": datetime.timezone.utc,
|
|
178
|
+
}
|
|
179
|
+
if len(serialized) > 3:
|
|
180
|
+
kwargs["unit"] = serialized[3]
|
|
181
|
+
val = pd.Timestamp(**kwargs).tz_convert(subs[0])
|
|
170
182
|
else:
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
183
|
+
pydt = datetime.datetime.fromtimestamp(serialized[0])
|
|
184
|
+
kwargs = {
|
|
185
|
+
"year": pydt.year,
|
|
186
|
+
"month": pydt.month,
|
|
187
|
+
"day": pydt.day,
|
|
188
|
+
"hour": pydt.hour,
|
|
189
|
+
"minute": pydt.minute,
|
|
190
|
+
"second": pydt.second,
|
|
191
|
+
"microsecond": serialized[1],
|
|
192
|
+
"nanosecond": serialized[2],
|
|
193
|
+
}
|
|
194
|
+
if len(serialized) >= 4:
|
|
195
|
+
kwargs["unit"] = serialized[3]
|
|
196
|
+
val = pd.Timestamp(**kwargs)
|
|
174
197
|
return val
|
|
175
198
|
|
|
176
199
|
|
|
177
200
|
class PdTimedeltaSerializer(Serializer):
|
|
178
201
|
def serial(self, obj: pd.Timedelta, context: Dict):
|
|
179
|
-
|
|
202
|
+
elements = [int(obj.seconds), obj.microseconds, obj.nanoseconds, obj.days]
|
|
203
|
+
if hasattr(obj, "unit"):
|
|
204
|
+
elements.append(str(obj.unit))
|
|
205
|
+
return elements, [], True
|
|
180
206
|
|
|
181
207
|
def deserial(self, serialized: List, context: Dict, subs: List):
|
|
182
208
|
days = 0 if len(serialized) < 4 else serialized[3]
|
|
209
|
+
unit = None if len(serialized) < 5 else serialized[4]
|
|
183
210
|
seconds, microseconds, nanoseconds = serialized[:3]
|
|
184
|
-
|
|
185
|
-
days
|
|
186
|
-
seconds
|
|
187
|
-
microseconds
|
|
188
|
-
nanoseconds
|
|
189
|
-
|
|
211
|
+
kwargs = {
|
|
212
|
+
"days": days,
|
|
213
|
+
"seconds": seconds,
|
|
214
|
+
"microseconds": microseconds,
|
|
215
|
+
"nanoseconds": nanoseconds,
|
|
216
|
+
}
|
|
217
|
+
if unit is not None:
|
|
218
|
+
kwargs["unit"] = unit
|
|
219
|
+
return pd.Timedelta(**kwargs)
|
|
190
220
|
|
|
191
221
|
|
|
192
222
|
class NoDefaultSerializer(Serializer):
|
|
@@ -13,12 +13,13 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
import weakref
|
|
16
|
-
from collections import
|
|
16
|
+
from collections import OrderedDict
|
|
17
17
|
from typing import Any, Dict, List, Optional, Tuple, Type
|
|
18
18
|
|
|
19
19
|
import msgpack
|
|
20
20
|
|
|
21
21
|
from ...lib.mmh3 import hash
|
|
22
|
+
from ...utils import no_default
|
|
22
23
|
from ..core import Placeholder, Serializer, buffered, load_type
|
|
23
24
|
from .field import Field
|
|
24
25
|
from .field_type import DictType, ListType, PrimitiveFieldType, TupleType
|
|
@@ -97,14 +98,18 @@ class SerializableMeta(type):
|
|
|
97
98
|
non_primitive_fields.append(v)
|
|
98
99
|
|
|
99
100
|
# count number of fields for every base class
|
|
100
|
-
cls_to_primitive_field_count =
|
|
101
|
-
cls_to_non_primitive_field_count =
|
|
101
|
+
cls_to_primitive_field_count = OrderedDict()
|
|
102
|
+
cls_to_non_primitive_field_count = OrderedDict()
|
|
102
103
|
for field_name in field_order:
|
|
103
104
|
cls_hash = field_to_cls_hash[field_name]
|
|
104
105
|
if field_name in primitive_field_names:
|
|
105
|
-
cls_to_primitive_field_count[cls_hash]
|
|
106
|
+
cls_to_primitive_field_count[cls_hash] = (
|
|
107
|
+
cls_to_primitive_field_count.get(cls_hash, 0) + 1
|
|
108
|
+
)
|
|
106
109
|
else:
|
|
107
|
-
cls_to_non_primitive_field_count[cls_hash]
|
|
110
|
+
cls_to_non_primitive_field_count[cls_hash] = (
|
|
111
|
+
cls_to_non_primitive_field_count.get(cls_hash, 0) + 1
|
|
112
|
+
)
|
|
108
113
|
|
|
109
114
|
slots = set(properties.pop("__slots__", set()))
|
|
110
115
|
slots.update(properties_field_slot_names)
|
|
@@ -119,9 +124,11 @@ class SerializableMeta(type):
|
|
|
119
124
|
properties["_FIELD_ORDER"] = field_order
|
|
120
125
|
properties["_FIELD_TO_NAME_HASH"] = field_to_cls_hash
|
|
121
126
|
properties["_PRIMITIVE_FIELDS"] = primitive_fields
|
|
122
|
-
properties["_CLS_TO_PRIMITIVE_FIELD_COUNT"] =
|
|
127
|
+
properties["_CLS_TO_PRIMITIVE_FIELD_COUNT"] = OrderedDict(
|
|
128
|
+
cls_to_primitive_field_count
|
|
129
|
+
)
|
|
123
130
|
properties["_NON_PRIMITIVE_FIELDS"] = non_primitive_fields
|
|
124
|
-
properties["_CLS_TO_NON_PRIMITIVE_FIELD_COUNT"] =
|
|
131
|
+
properties["_CLS_TO_NON_PRIMITIVE_FIELD_COUNT"] = OrderedDict(
|
|
125
132
|
cls_to_non_primitive_field_count
|
|
126
133
|
)
|
|
127
134
|
properties["__slots__"] = tuple(slots)
|
|
@@ -211,6 +218,22 @@ class _NoFieldValue:
|
|
|
211
218
|
_no_field_value = _NoFieldValue()
|
|
212
219
|
|
|
213
220
|
|
|
221
|
+
def _to_primitive_placeholder(v: Any) -> Any:
|
|
222
|
+
if v is _no_field_value or v is no_default:
|
|
223
|
+
return {}
|
|
224
|
+
return v
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def _restore_primitive_placeholder(v: Any) -> Any:
|
|
228
|
+
if type(v) is dict:
|
|
229
|
+
if v == {}:
|
|
230
|
+
return _no_field_value
|
|
231
|
+
else:
|
|
232
|
+
return v
|
|
233
|
+
else:
|
|
234
|
+
return v
|
|
235
|
+
|
|
236
|
+
|
|
214
237
|
class SerializableSerializer(Serializer):
|
|
215
238
|
"""
|
|
216
239
|
Leverage DictSerializer to perform serde.
|
|
@@ -241,9 +264,7 @@ class SerializableSerializer(Serializer):
|
|
|
241
264
|
else:
|
|
242
265
|
primitive_vals = self._get_field_values(obj, obj._PRIMITIVE_FIELDS)
|
|
243
266
|
# replace _no_field_value as {} to make them msgpack-serializable
|
|
244
|
-
primitive_vals = [
|
|
245
|
-
v if v is not _no_field_value else {} for v in primitive_vals
|
|
246
|
-
]
|
|
267
|
+
primitive_vals = [_to_primitive_placeholder(v) for v in primitive_vals]
|
|
247
268
|
if obj._cache_primitive_serial:
|
|
248
269
|
primitive_vals = msgpack.dumps(primitive_vals)
|
|
249
270
|
_primitive_serial_cache[obj] = primitive_vals
|
|
@@ -281,21 +302,51 @@ class SerializableSerializer(Serializer):
|
|
|
281
302
|
else:
|
|
282
303
|
field.set(obj, value)
|
|
283
304
|
|
|
305
|
+
@classmethod
|
|
306
|
+
def _prune_server_fields(
|
|
307
|
+
cls,
|
|
308
|
+
client_cls_to_field_count: Optional[Dict[int, int]],
|
|
309
|
+
server_cls_to_field_count: Dict[int, int],
|
|
310
|
+
server_fields: list,
|
|
311
|
+
) -> list:
|
|
312
|
+
if not client_cls_to_field_count: # pragma: no cover
|
|
313
|
+
# todo remove this branch when all versions below v0.1.0b5 is eliminated
|
|
314
|
+
return server_fields
|
|
315
|
+
if set(client_cls_to_field_count.keys()) == set(
|
|
316
|
+
server_cls_to_field_count.keys()
|
|
317
|
+
):
|
|
318
|
+
return server_fields
|
|
319
|
+
ret_server_fields = []
|
|
320
|
+
server_pos = 0
|
|
321
|
+
for cls_hash, count in server_cls_to_field_count.items():
|
|
322
|
+
if cls_hash in client_cls_to_field_count:
|
|
323
|
+
ret_server_fields.extend(server_fields[server_pos : server_pos + count])
|
|
324
|
+
server_pos += count
|
|
325
|
+
return ret_server_fields
|
|
326
|
+
|
|
284
327
|
@classmethod
|
|
285
328
|
def _set_field_values(
|
|
286
329
|
cls,
|
|
287
330
|
obj: Serializable,
|
|
288
331
|
values: List[Any],
|
|
289
|
-
client_cls_to_field_count: Optional[Dict[
|
|
332
|
+
client_cls_to_field_count: Optional[Dict[int, int]],
|
|
290
333
|
is_primitive: bool = True,
|
|
291
334
|
):
|
|
292
335
|
obj_class = type(obj)
|
|
293
336
|
if is_primitive:
|
|
294
337
|
server_cls_to_field_count = obj_class._CLS_TO_PRIMITIVE_FIELD_COUNT
|
|
295
|
-
server_fields =
|
|
338
|
+
server_fields = cls._prune_server_fields(
|
|
339
|
+
client_cls_to_field_count,
|
|
340
|
+
server_cls_to_field_count,
|
|
341
|
+
obj_class._PRIMITIVE_FIELDS,
|
|
342
|
+
)
|
|
296
343
|
else:
|
|
297
344
|
server_cls_to_field_count = obj_class._CLS_TO_NON_PRIMITIVE_FIELD_COUNT
|
|
298
|
-
server_fields =
|
|
345
|
+
server_fields = cls._prune_server_fields(
|
|
346
|
+
client_cls_to_field_count,
|
|
347
|
+
server_cls_to_field_count,
|
|
348
|
+
obj_class._NON_PRIMITIVE_FIELDS,
|
|
349
|
+
)
|
|
299
350
|
|
|
300
351
|
legacy_to_new_hash = {
|
|
301
352
|
c._LEGACY_NAME_HASH: c._NAME_HASH
|
|
@@ -311,7 +362,9 @@ class SerializableSerializer(Serializer):
|
|
|
311
362
|
cls_fields = server_fields[server_field_num : field_num + count]
|
|
312
363
|
cls_values = values[field_num : field_num + count]
|
|
313
364
|
for field, value in zip(cls_fields, cls_values):
|
|
314
|
-
if
|
|
365
|
+
if is_primitive:
|
|
366
|
+
value = _restore_primitive_placeholder(value)
|
|
367
|
+
if not is_primitive or value is not _no_field_value:
|
|
315
368
|
cls._set_field_value(obj, field, value)
|
|
316
369
|
field_num += count
|
|
317
370
|
try:
|
|
@@ -356,7 +409,9 @@ class SerializableSerializer(Serializer):
|
|
|
356
409
|
server_fields + deprecated_fields, key=lambda f: f.name
|
|
357
410
|
)
|
|
358
411
|
for field, value in zip(server_fields, values):
|
|
359
|
-
if
|
|
412
|
+
if is_primitive:
|
|
413
|
+
value = _restore_primitive_placeholder(value)
|
|
414
|
+
if not is_primitive or value is not _no_field_value:
|
|
360
415
|
try:
|
|
361
416
|
cls._set_field_value(obj, field, value)
|
|
362
417
|
except AttributeError: # pragma: no cover
|
|
@@ -46,6 +46,9 @@ class PrimitiveType(Enum):
|
|
|
46
46
|
complex128 = 25
|
|
47
47
|
|
|
48
48
|
|
|
49
|
+
_np_unicode = np.unicode_ if hasattr(np, "unicode_") else np.str_
|
|
50
|
+
|
|
51
|
+
|
|
49
52
|
_primitive_type_to_valid_types = {
|
|
50
53
|
PrimitiveType.bool: (bool, np.bool_),
|
|
51
54
|
PrimitiveType.int8: (int, np.int8),
|
|
@@ -60,7 +63,7 @@ _primitive_type_to_valid_types = {
|
|
|
60
63
|
PrimitiveType.float32: (float, np.float32),
|
|
61
64
|
PrimitiveType.float64: (float, np.float64),
|
|
62
65
|
PrimitiveType.bytes: (bytes, np.bytes_),
|
|
63
|
-
PrimitiveType.string: (str,
|
|
66
|
+
PrimitiveType.string: (str, _np_unicode),
|
|
64
67
|
PrimitiveType.complex64: (complex, np.complex64),
|
|
65
68
|
PrimitiveType.complex128: (complex, np.complex128),
|
|
66
69
|
}
|
|
@@ -21,6 +21,7 @@ import pytest
|
|
|
21
21
|
|
|
22
22
|
from ....core import EntityData
|
|
23
23
|
from ....lib.wrapped_pickle import switch_unpickle
|
|
24
|
+
from ....utils import no_default
|
|
24
25
|
from ... import deserialize, serialize
|
|
25
26
|
from .. import (
|
|
26
27
|
AnyField,
|
|
@@ -143,6 +144,7 @@ class MySerializable(Serializable):
|
|
|
143
144
|
oneof1_val=f"{__name__}.MySerializable",
|
|
144
145
|
oneof2_val=MySimpleSerializable,
|
|
145
146
|
)
|
|
147
|
+
_no_default_val = Float64Field("no_default_val", default=no_default)
|
|
146
148
|
|
|
147
149
|
|
|
148
150
|
@pytest.mark.parametrize("set_is_ci", [False, True], indirect=True)
|
|
@@ -187,6 +189,7 @@ def test_serializable(set_is_ci):
|
|
|
187
189
|
_dict_val={"a": b"bytes_value"},
|
|
188
190
|
_ref_val=MySerializable(),
|
|
189
191
|
_oneof_val=MySerializable(_id="2"),
|
|
192
|
+
_no_default_val=no_default,
|
|
190
193
|
)
|
|
191
194
|
|
|
192
195
|
header, buffers = serialize(my_serializable)
|
|
@@ -218,7 +221,10 @@ def test_compatible_serializable(set_is_ci):
|
|
|
218
221
|
_ref_val = ReferenceField("ref_val", "MySimpleSerializable")
|
|
219
222
|
_dict_val = DictField("dict_val")
|
|
220
223
|
|
|
221
|
-
class
|
|
224
|
+
class MyMidSerializable(MySimpleSerializable):
|
|
225
|
+
_i_bool_val = Int64Field("i_bool_val", default=True)
|
|
226
|
+
|
|
227
|
+
class MySubSerializable(MyMidSerializable):
|
|
222
228
|
_m_int_val = Int64Field("m_int_val", default=250)
|
|
223
229
|
_m_str_val = StringField("m_str_val", default="SUB_STR")
|
|
224
230
|
|
|
@@ -234,7 +240,11 @@ def _assert_serializable_eq(my_serializable, my_serializable2):
|
|
|
234
240
|
if not hasattr(my_serializable, field.name):
|
|
235
241
|
continue
|
|
236
242
|
expect_value = getattr(my_serializable, field_name)
|
|
237
|
-
|
|
243
|
+
if expect_value is no_default:
|
|
244
|
+
assert not hasattr(my_serializable2, field.name)
|
|
245
|
+
continue
|
|
246
|
+
else:
|
|
247
|
+
actual_value = getattr(my_serializable2, field_name)
|
|
238
248
|
if isinstance(expect_value, np.ndarray):
|
|
239
249
|
np.testing.assert_array_equal(expect_value, actual_value)
|
|
240
250
|
elif isinstance(expect_value, pd.DataFrame):
|