maxframe 0.1.0b4__cp38-cp38-win32.whl → 1.0.0__cp38-cp38-win32.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of maxframe might be problematic. Click here for more details.
- maxframe/__init__.py +1 -0
- maxframe/_utils.cp38-win32.pyd +0 -0
- maxframe/codegen.py +56 -5
- maxframe/config/config.py +78 -10
- maxframe/config/validators.py +42 -11
- maxframe/conftest.py +58 -14
- maxframe/core/__init__.py +2 -16
- maxframe/core/entity/__init__.py +1 -12
- maxframe/core/entity/executable.py +1 -1
- maxframe/core/entity/objects.py +46 -45
- maxframe/core/entity/output_types.py +0 -3
- maxframe/core/entity/tests/test_objects.py +43 -0
- maxframe/core/entity/tileables.py +5 -78
- maxframe/core/graph/__init__.py +2 -2
- maxframe/core/graph/builder/__init__.py +0 -1
- maxframe/core/graph/builder/base.py +5 -4
- maxframe/core/graph/builder/tileable.py +4 -4
- maxframe/core/graph/builder/utils.py +4 -8
- maxframe/core/graph/core.cp38-win32.pyd +0 -0
- maxframe/core/graph/core.pyx +4 -4
- maxframe/core/graph/entity.py +9 -33
- maxframe/core/operator/__init__.py +2 -9
- maxframe/core/operator/base.py +3 -5
- maxframe/core/operator/objects.py +0 -9
- maxframe/core/operator/utils.py +55 -0
- maxframe/dataframe/__init__.py +2 -1
- maxframe/dataframe/arithmetic/around.py +5 -17
- maxframe/dataframe/arithmetic/core.py +15 -7
- maxframe/dataframe/arithmetic/docstring.py +7 -33
- maxframe/dataframe/arithmetic/equal.py +4 -2
- maxframe/dataframe/arithmetic/greater.py +4 -2
- maxframe/dataframe/arithmetic/greater_equal.py +4 -2
- maxframe/dataframe/arithmetic/less.py +2 -2
- maxframe/dataframe/arithmetic/less_equal.py +4 -2
- maxframe/dataframe/arithmetic/not_equal.py +4 -2
- maxframe/dataframe/arithmetic/tests/test_arithmetic.py +39 -16
- maxframe/dataframe/core.py +58 -12
- maxframe/dataframe/datasource/date_range.py +2 -2
- maxframe/dataframe/datasource/read_odps_query.py +120 -24
- maxframe/dataframe/datasource/read_odps_table.py +9 -4
- maxframe/dataframe/datasource/tests/test_datasource.py +103 -8
- maxframe/dataframe/datastore/tests/test_to_odps.py +48 -0
- maxframe/dataframe/datastore/to_odps.py +28 -0
- maxframe/dataframe/extensions/__init__.py +5 -0
- maxframe/dataframe/extensions/flatjson.py +131 -0
- maxframe/dataframe/extensions/flatmap.py +317 -0
- maxframe/dataframe/extensions/reshuffle.py +1 -1
- maxframe/dataframe/extensions/tests/test_extensions.py +108 -3
- maxframe/dataframe/groupby/core.py +1 -1
- maxframe/dataframe/groupby/cum.py +0 -1
- maxframe/dataframe/groupby/fill.py +4 -1
- maxframe/dataframe/groupby/getitem.py +6 -0
- maxframe/dataframe/groupby/tests/test_groupby.py +5 -1
- maxframe/dataframe/groupby/transform.py +5 -1
- maxframe/dataframe/indexing/align.py +1 -1
- maxframe/dataframe/indexing/loc.py +6 -4
- maxframe/dataframe/indexing/rename.py +5 -28
- maxframe/dataframe/indexing/sample.py +0 -1
- maxframe/dataframe/indexing/set_index.py +68 -1
- maxframe/dataframe/initializer.py +11 -1
- maxframe/dataframe/merge/__init__.py +9 -1
- maxframe/dataframe/merge/concat.py +41 -31
- maxframe/dataframe/merge/merge.py +237 -3
- maxframe/dataframe/merge/tests/test_merge.py +126 -1
- maxframe/dataframe/misc/__init__.py +4 -0
- maxframe/dataframe/misc/apply.py +6 -11
- maxframe/dataframe/misc/case_when.py +141 -0
- maxframe/dataframe/misc/describe.py +2 -2
- maxframe/dataframe/misc/drop_duplicates.py +8 -8
- maxframe/dataframe/misc/eval.py +4 -0
- maxframe/dataframe/misc/memory_usage.py +2 -2
- maxframe/dataframe/misc/pct_change.py +1 -83
- maxframe/dataframe/misc/pivot_table.py +262 -0
- maxframe/dataframe/misc/tests/test_misc.py +93 -1
- maxframe/dataframe/misc/transform.py +1 -30
- maxframe/dataframe/misc/value_counts.py +4 -17
- maxframe/dataframe/missing/dropna.py +1 -1
- maxframe/dataframe/missing/fillna.py +5 -5
- maxframe/dataframe/operators.py +1 -17
- maxframe/dataframe/plotting/core.py +2 -2
- maxframe/dataframe/reduction/core.py +4 -3
- maxframe/dataframe/reduction/tests/test_reduction.py +2 -4
- maxframe/dataframe/sort/sort_values.py +1 -11
- maxframe/dataframe/statistics/corr.py +3 -3
- maxframe/dataframe/statistics/quantile.py +13 -19
- maxframe/dataframe/statistics/tests/test_statistics.py +4 -4
- maxframe/dataframe/tests/test_initializer.py +33 -2
- maxframe/dataframe/utils.py +33 -11
- maxframe/dataframe/window/expanding.py +5 -3
- maxframe/dataframe/window/tests/test_expanding.py +2 -2
- maxframe/errors.py +13 -0
- maxframe/extension.py +12 -0
- maxframe/io/__init__.py +13 -0
- maxframe/io/objects/__init__.py +24 -0
- maxframe/io/objects/core.py +140 -0
- maxframe/io/objects/tensor.py +76 -0
- maxframe/io/objects/tests/__init__.py +13 -0
- maxframe/io/objects/tests/test_object_io.py +97 -0
- maxframe/{odpsio → io/odpsio}/__init__.py +3 -1
- maxframe/{odpsio → io/odpsio}/arrow.py +43 -12
- maxframe/{odpsio → io/odpsio}/schema.py +38 -16
- maxframe/io/odpsio/tableio.py +719 -0
- maxframe/io/odpsio/tests/__init__.py +13 -0
- maxframe/{odpsio → io/odpsio}/tests/test_schema.py +75 -33
- maxframe/{odpsio → io/odpsio}/tests/test_tableio.py +50 -23
- maxframe/{odpsio → io/odpsio}/tests/test_volumeio.py +4 -6
- maxframe/io/odpsio/volumeio.py +63 -0
- maxframe/learn/contrib/__init__.py +3 -1
- maxframe/learn/contrib/graph/__init__.py +15 -0
- maxframe/learn/contrib/graph/connected_components.py +215 -0
- maxframe/learn/contrib/graph/tests/__init__.py +13 -0
- maxframe/learn/contrib/graph/tests/test_connected_components.py +53 -0
- maxframe/learn/contrib/llm/__init__.py +16 -0
- maxframe/learn/contrib/llm/core.py +54 -0
- maxframe/learn/contrib/llm/models/__init__.py +14 -0
- maxframe/learn/contrib/llm/models/dashscope.py +73 -0
- maxframe/learn/contrib/llm/multi_modal.py +42 -0
- maxframe/learn/contrib/llm/text.py +42 -0
- maxframe/learn/contrib/utils.py +52 -0
- maxframe/learn/contrib/xgboost/__init__.py +26 -0
- maxframe/learn/contrib/xgboost/classifier.py +110 -0
- maxframe/learn/contrib/xgboost/core.py +241 -0
- maxframe/learn/contrib/xgboost/dmatrix.py +147 -0
- maxframe/learn/contrib/xgboost/predict.py +121 -0
- maxframe/learn/contrib/xgboost/regressor.py +71 -0
- maxframe/learn/contrib/xgboost/tests/__init__.py +13 -0
- maxframe/learn/contrib/xgboost/tests/test_core.py +43 -0
- maxframe/learn/contrib/xgboost/train.py +132 -0
- maxframe/{core/operator/fuse.py → learn/core.py} +7 -10
- maxframe/learn/utils/__init__.py +15 -0
- maxframe/learn/utils/core.py +29 -0
- maxframe/lib/mmh3.cp38-win32.pyd +0 -0
- maxframe/lib/mmh3.pyi +43 -0
- maxframe/lib/sparse/tests/test_sparse.py +15 -15
- maxframe/lib/wrapped_pickle.py +2 -1
- maxframe/opcodes.py +11 -0
- maxframe/protocol.py +154 -27
- maxframe/remote/core.py +4 -8
- maxframe/serialization/__init__.py +1 -0
- maxframe/serialization/core.cp38-win32.pyd +0 -0
- maxframe/serialization/core.pxd +3 -0
- maxframe/serialization/core.pyi +64 -0
- maxframe/serialization/core.pyx +67 -26
- maxframe/serialization/exception.py +1 -1
- maxframe/serialization/pandas.py +52 -17
- maxframe/serialization/serializables/core.py +180 -15
- maxframe/serialization/serializables/field_type.py +4 -1
- maxframe/serialization/serializables/tests/test_serializable.py +54 -5
- maxframe/serialization/tests/test_serial.py +2 -1
- maxframe/session.py +37 -2
- maxframe/tensor/__init__.py +81 -2
- maxframe/tensor/arithmetic/isclose.py +1 -0
- maxframe/tensor/arithmetic/tests/test_arithmetic.py +22 -18
- maxframe/tensor/core.py +5 -136
- maxframe/tensor/datasource/array.py +7 -2
- maxframe/tensor/datasource/full.py +1 -1
- maxframe/tensor/datasource/scalar.py +1 -1
- maxframe/tensor/datasource/tests/test_datasource.py +1 -1
- maxframe/tensor/indexing/flatnonzero.py +1 -1
- maxframe/tensor/indexing/getitem.py +2 -0
- maxframe/tensor/merge/__init__.py +2 -0
- maxframe/tensor/merge/concatenate.py +101 -0
- maxframe/tensor/merge/tests/test_merge.py +30 -1
- maxframe/tensor/merge/vstack.py +74 -0
- maxframe/tensor/{base → misc}/__init__.py +4 -0
- maxframe/tensor/misc/atleast_1d.py +72 -0
- maxframe/tensor/misc/atleast_2d.py +70 -0
- maxframe/tensor/misc/atleast_3d.py +85 -0
- maxframe/tensor/misc/tests/__init__.py +13 -0
- maxframe/tensor/{base → misc}/transpose.py +22 -18
- maxframe/tensor/misc/unique.py +205 -0
- maxframe/tensor/operators.py +1 -7
- maxframe/tensor/random/core.py +1 -1
- maxframe/tensor/reduction/count_nonzero.py +2 -1
- maxframe/tensor/reduction/mean.py +1 -0
- maxframe/tensor/reduction/nanmean.py +1 -0
- maxframe/tensor/reduction/nanvar.py +2 -0
- maxframe/tensor/reduction/tests/test_reduction.py +12 -1
- maxframe/tensor/reduction/var.py +2 -0
- maxframe/tensor/statistics/quantile.py +2 -2
- maxframe/tensor/utils.py +2 -22
- maxframe/tests/test_protocol.py +34 -0
- maxframe/tests/test_utils.py +0 -12
- maxframe/tests/utils.py +17 -2
- maxframe/typing_.py +4 -1
- maxframe/udf.py +62 -3
- maxframe/utils.py +112 -86
- {maxframe-0.1.0b4.dist-info → maxframe-1.0.0.dist-info}/METADATA +25 -25
- {maxframe-0.1.0b4.dist-info → maxframe-1.0.0.dist-info}/RECORD +208 -167
- {maxframe-0.1.0b4.dist-info → maxframe-1.0.0.dist-info}/WHEEL +1 -1
- maxframe_client/__init__.py +0 -1
- maxframe_client/clients/framedriver.py +4 -1
- maxframe_client/fetcher.py +123 -54
- maxframe_client/session/consts.py +3 -0
- maxframe_client/session/graph.py +8 -2
- maxframe_client/session/odps.py +223 -40
- maxframe_client/session/task.py +108 -80
- maxframe_client/tests/test_fetcher.py +21 -3
- maxframe_client/tests/test_session.py +136 -8
- maxframe/core/entity/chunks.py +0 -68
- maxframe/core/entity/fuse.py +0 -73
- maxframe/core/graph/builder/chunk.py +0 -430
- maxframe/odpsio/tableio.py +0 -300
- maxframe/odpsio/volumeio.py +0 -95
- maxframe_client/clients/spe.py +0 -104
- /maxframe/{odpsio → core/entity}/tests/__init__.py +0 -0
- /maxframe/{tensor/base → dataframe/datastore}/tests/__init__.py +0 -0
- /maxframe/{odpsio → io/odpsio}/tests/test_arrow.py +0 -0
- /maxframe/tensor/{base → misc}/astype.py +0 -0
- /maxframe/tensor/{base → misc}/broadcast_to.py +0 -0
- /maxframe/tensor/{base → misc}/ravel.py +0 -0
- /maxframe/tensor/{base/tests/test_base.py → misc/tests/test_misc.py} +0 -0
- /maxframe/tensor/{base → misc}/where.py +0 -0
- {maxframe-0.1.0b4.dist-info → maxframe-1.0.0.dist-info}/top_level.txt +0 -0
maxframe/serialization/pandas.py
CHANGED
|
@@ -134,8 +134,10 @@ class ArraySerializer(Serializer):
|
|
|
134
134
|
data_parts = [obj.tolist()]
|
|
135
135
|
else:
|
|
136
136
|
data_parts = [obj.to_numpy().tolist()]
|
|
137
|
-
|
|
137
|
+
elif hasattr(obj, "_data"):
|
|
138
138
|
data_parts = [getattr(obj, "_data")]
|
|
139
|
+
else:
|
|
140
|
+
data_parts = [getattr(obj, "_pa_array")]
|
|
139
141
|
return [ser_type], [dtype] + data_parts, False
|
|
140
142
|
|
|
141
143
|
def deserial(self, serialized: List, context: Dict, subs: List):
|
|
@@ -155,33 +157,66 @@ class PdTimestampSerializer(Serializer):
|
|
|
155
157
|
else:
|
|
156
158
|
zone_info = []
|
|
157
159
|
ts = obj.to_pydatetime().timestamp()
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
)
|
|
160
|
+
elements = [int(ts), obj.microsecond, obj.nanosecond]
|
|
161
|
+
if hasattr(obj, "unit"):
|
|
162
|
+
elements.append(str(obj.unit))
|
|
163
|
+
return elements, zone_info, bool(zone_info)
|
|
163
164
|
|
|
164
165
|
def deserial(self, serialized: List, context: Dict, subs: List):
|
|
165
166
|
if subs:
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
167
|
+
pydt = datetime.datetime.utcfromtimestamp(serialized[0])
|
|
168
|
+
kwargs = {
|
|
169
|
+
"year": pydt.year,
|
|
170
|
+
"month": pydt.month,
|
|
171
|
+
"day": pydt.day,
|
|
172
|
+
"hour": pydt.hour,
|
|
173
|
+
"minute": pydt.minute,
|
|
174
|
+
"second": pydt.second,
|
|
175
|
+
"microsecond": serialized[1],
|
|
176
|
+
"nanosecond": serialized[2],
|
|
177
|
+
"tzinfo": datetime.timezone.utc,
|
|
178
|
+
}
|
|
179
|
+
if len(serialized) > 3:
|
|
180
|
+
kwargs["unit"] = serialized[3]
|
|
181
|
+
val = pd.Timestamp(**kwargs).tz_convert(subs[0])
|
|
170
182
|
else:
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
183
|
+
pydt = datetime.datetime.fromtimestamp(serialized[0])
|
|
184
|
+
kwargs = {
|
|
185
|
+
"year": pydt.year,
|
|
186
|
+
"month": pydt.month,
|
|
187
|
+
"day": pydt.day,
|
|
188
|
+
"hour": pydt.hour,
|
|
189
|
+
"minute": pydt.minute,
|
|
190
|
+
"second": pydt.second,
|
|
191
|
+
"microsecond": serialized[1],
|
|
192
|
+
"nanosecond": serialized[2],
|
|
193
|
+
}
|
|
194
|
+
if len(serialized) >= 4:
|
|
195
|
+
kwargs["unit"] = serialized[3]
|
|
196
|
+
val = pd.Timestamp(**kwargs)
|
|
174
197
|
return val
|
|
175
198
|
|
|
176
199
|
|
|
177
200
|
class PdTimedeltaSerializer(Serializer):
|
|
178
201
|
def serial(self, obj: pd.Timedelta, context: Dict):
|
|
179
|
-
|
|
202
|
+
elements = [int(obj.seconds), obj.microseconds, obj.nanoseconds, obj.days]
|
|
203
|
+
if hasattr(obj, "unit"):
|
|
204
|
+
elements.append(str(obj.unit))
|
|
205
|
+
return elements, [], True
|
|
180
206
|
|
|
181
207
|
def deserial(self, serialized: List, context: Dict, subs: List):
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
208
|
+
days = 0 if len(serialized) < 4 else serialized[3]
|
|
209
|
+
unit = None if len(serialized) < 5 else serialized[4]
|
|
210
|
+
seconds, microseconds, nanoseconds = serialized[:3]
|
|
211
|
+
kwargs = {
|
|
212
|
+
"days": days,
|
|
213
|
+
"seconds": seconds,
|
|
214
|
+
"microseconds": microseconds,
|
|
215
|
+
"nanoseconds": nanoseconds,
|
|
216
|
+
}
|
|
217
|
+
if unit is not None:
|
|
218
|
+
kwargs["unit"] = unit
|
|
219
|
+
return pd.Timedelta(**kwargs)
|
|
185
220
|
|
|
186
221
|
|
|
187
222
|
class NoDefaultSerializer(Serializer):
|
|
@@ -12,12 +12,14 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
import operator
|
|
16
15
|
import weakref
|
|
17
|
-
from
|
|
16
|
+
from collections import defaultdict
|
|
17
|
+
from typing import Any, Dict, List, Optional, Tuple, Type
|
|
18
18
|
|
|
19
19
|
import msgpack
|
|
20
20
|
|
|
21
|
+
from ...lib.mmh3 import hash
|
|
22
|
+
from ...utils import no_default
|
|
21
23
|
from ..core import Placeholder, Serializer, buffered, load_type
|
|
22
24
|
from .field import Field
|
|
23
25
|
from .field_type import DictType, ListType, PrimitiveFieldType, TupleType
|
|
@@ -50,11 +52,19 @@ def _is_field_primitive_compound(field: Field):
|
|
|
50
52
|
class SerializableMeta(type):
|
|
51
53
|
def __new__(mcs, name: str, bases: Tuple[Type], properties: Dict):
|
|
52
54
|
# All the fields including misc fields.
|
|
55
|
+
legacy_name_hash = hash(f"{properties.get('__module__')}.{name}")
|
|
56
|
+
name_hash = hash(
|
|
57
|
+
f"{properties.get('__module__')}.{properties.get('__qualname__')}"
|
|
58
|
+
)
|
|
53
59
|
all_fields = dict()
|
|
60
|
+
# mapping field names to base classes
|
|
61
|
+
field_to_cls_hash = dict()
|
|
54
62
|
|
|
55
63
|
for base in bases:
|
|
56
|
-
if hasattr(base, "_FIELDS"):
|
|
57
|
-
|
|
64
|
+
if not hasattr(base, "_FIELDS"):
|
|
65
|
+
continue
|
|
66
|
+
all_fields.update(base._FIELDS)
|
|
67
|
+
field_to_cls_hash.update(base._FIELD_TO_NAME_HASH)
|
|
58
68
|
|
|
59
69
|
properties_without_fields = {}
|
|
60
70
|
properties_field_slot_names = []
|
|
@@ -64,6 +74,8 @@ class SerializableMeta(type):
|
|
|
64
74
|
continue
|
|
65
75
|
|
|
66
76
|
field = all_fields.get(k)
|
|
77
|
+
# record the field for the class being created
|
|
78
|
+
field_to_cls_hash[k] = name_hash
|
|
67
79
|
if field is None:
|
|
68
80
|
properties_field_slot_names.append(k)
|
|
69
81
|
else:
|
|
@@ -75,23 +87,44 @@ class SerializableMeta(type):
|
|
|
75
87
|
|
|
76
88
|
# Make field order deterministic to serialize it as list instead of dict.
|
|
77
89
|
field_order = list(all_fields)
|
|
78
|
-
all_fields = dict(sorted(all_fields.items(), key=operator.itemgetter(0)))
|
|
79
90
|
primitive_fields = []
|
|
91
|
+
primitive_field_names = set()
|
|
80
92
|
non_primitive_fields = []
|
|
81
|
-
for v in all_fields.
|
|
93
|
+
for field_name, v in all_fields.items():
|
|
82
94
|
if _is_field_primitive_compound(v):
|
|
83
95
|
primitive_fields.append(v)
|
|
96
|
+
primitive_field_names.add(field_name)
|
|
84
97
|
else:
|
|
85
98
|
non_primitive_fields.append(v)
|
|
86
99
|
|
|
100
|
+
# count number of fields for every base class
|
|
101
|
+
cls_to_primitive_field_count = defaultdict(lambda: 0)
|
|
102
|
+
cls_to_non_primitive_field_count = defaultdict(lambda: 0)
|
|
103
|
+
for field_name in field_order:
|
|
104
|
+
cls_hash = field_to_cls_hash[field_name]
|
|
105
|
+
if field_name in primitive_field_names:
|
|
106
|
+
cls_to_primitive_field_count[cls_hash] += 1
|
|
107
|
+
else:
|
|
108
|
+
cls_to_non_primitive_field_count[cls_hash] += 1
|
|
109
|
+
|
|
87
110
|
slots = set(properties.pop("__slots__", set()))
|
|
88
111
|
slots.update(properties_field_slot_names)
|
|
89
112
|
|
|
90
113
|
properties = properties_without_fields
|
|
114
|
+
|
|
115
|
+
# todo remove this prop when all versions below v1.0.0rc1 is eliminated
|
|
116
|
+
properties["_LEGACY_NAME_HASH"] = legacy_name_hash
|
|
117
|
+
|
|
118
|
+
properties["_NAME_HASH"] = name_hash
|
|
91
119
|
properties["_FIELDS"] = all_fields
|
|
92
120
|
properties["_FIELD_ORDER"] = field_order
|
|
121
|
+
properties["_FIELD_TO_NAME_HASH"] = field_to_cls_hash
|
|
93
122
|
properties["_PRIMITIVE_FIELDS"] = primitive_fields
|
|
123
|
+
properties["_CLS_TO_PRIMITIVE_FIELD_COUNT"] = dict(cls_to_primitive_field_count)
|
|
94
124
|
properties["_NON_PRIMITIVE_FIELDS"] = non_primitive_fields
|
|
125
|
+
properties["_CLS_TO_NON_PRIMITIVE_FIELD_COUNT"] = dict(
|
|
126
|
+
cls_to_non_primitive_field_count
|
|
127
|
+
)
|
|
95
128
|
properties["__slots__"] = tuple(slots)
|
|
96
129
|
|
|
97
130
|
clz = type.__new__(mcs, name, bases, properties)
|
|
@@ -114,10 +147,14 @@ class Serializable(metaclass=SerializableMeta):
|
|
|
114
147
|
_cache_primitive_serial = False
|
|
115
148
|
_ignore_non_existing_keys = False
|
|
116
149
|
|
|
150
|
+
_NAME_HASH: int
|
|
117
151
|
_FIELDS: Dict[str, Field]
|
|
118
152
|
_FIELD_ORDER: List[str]
|
|
153
|
+
_FIELD_TO_NAME_HASH: Dict[str, int]
|
|
119
154
|
_PRIMITIVE_FIELDS: List[str]
|
|
155
|
+
_CLS_TO_PRIMITIVE_FIELD_COUNT: Dict[int, int]
|
|
120
156
|
_NON_PRIMITIVE_FIELDS: List[str]
|
|
157
|
+
_CLS_TO_NON_PRIMITIVE_FIELD_COUNT: Dict[int, int]
|
|
121
158
|
|
|
122
159
|
def __init__(self, *args, **kwargs):
|
|
123
160
|
fields = self._FIELDS
|
|
@@ -175,11 +212,31 @@ class _NoFieldValue:
|
|
|
175
212
|
_no_field_value = _NoFieldValue()
|
|
176
213
|
|
|
177
214
|
|
|
215
|
+
def _to_primitive_placeholder(v: Any) -> Any:
|
|
216
|
+
if v is _no_field_value or v is no_default:
|
|
217
|
+
return {}
|
|
218
|
+
return v
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def _restore_primitive_placeholder(v: Any) -> Any:
|
|
222
|
+
if type(v) is dict:
|
|
223
|
+
if v == {}:
|
|
224
|
+
return _no_field_value
|
|
225
|
+
else:
|
|
226
|
+
return v
|
|
227
|
+
else:
|
|
228
|
+
return v
|
|
229
|
+
|
|
230
|
+
|
|
178
231
|
class SerializableSerializer(Serializer):
|
|
179
232
|
"""
|
|
180
233
|
Leverage DictSerializer to perform serde.
|
|
181
234
|
"""
|
|
182
235
|
|
|
236
|
+
@classmethod
|
|
237
|
+
def _get_obj_field_count_key(cls, obj: Serializable, legacy: bool = False):
|
|
238
|
+
return f"FC_{obj._NAME_HASH if not legacy else obj._LEGACY_NAME_HASH}"
|
|
239
|
+
|
|
183
240
|
@classmethod
|
|
184
241
|
def _get_field_values(cls, obj: Serializable, fields):
|
|
185
242
|
values = []
|
|
@@ -201,15 +258,25 @@ class SerializableSerializer(Serializer):
|
|
|
201
258
|
else:
|
|
202
259
|
primitive_vals = self._get_field_values(obj, obj._PRIMITIVE_FIELDS)
|
|
203
260
|
# replace _no_field_value as {} to make them msgpack-serializable
|
|
204
|
-
primitive_vals = [
|
|
205
|
-
v if v is not _no_field_value else {} for v in primitive_vals
|
|
206
|
-
]
|
|
261
|
+
primitive_vals = [_to_primitive_placeholder(v) for v in primitive_vals]
|
|
207
262
|
if obj._cache_primitive_serial:
|
|
208
263
|
primitive_vals = msgpack.dumps(primitive_vals)
|
|
209
264
|
_primitive_serial_cache[obj] = primitive_vals
|
|
210
265
|
|
|
211
266
|
compound_vals = self._get_field_values(obj, obj._NON_PRIMITIVE_FIELDS)
|
|
212
267
|
cls_module = f"{type(obj).__module__}#{type(obj).__qualname__}"
|
|
268
|
+
|
|
269
|
+
field_count_key = self._get_obj_field_count_key(obj)
|
|
270
|
+
if not self.is_public_data_exist(context, field_count_key):
|
|
271
|
+
# store field distribution for current Serializable
|
|
272
|
+
counts = [
|
|
273
|
+
list(obj._CLS_TO_PRIMITIVE_FIELD_COUNT.items()),
|
|
274
|
+
list(obj._CLS_TO_NON_PRIMITIVE_FIELD_COUNT.items()),
|
|
275
|
+
]
|
|
276
|
+
field_count_data = msgpack.dumps(counts)
|
|
277
|
+
self.put_public_data(
|
|
278
|
+
context, self._get_obj_field_count_key(obj), field_count_data
|
|
279
|
+
)
|
|
213
280
|
return [cls_module, primitive_vals], [compound_vals], False
|
|
214
281
|
|
|
215
282
|
@staticmethod
|
|
@@ -229,6 +296,92 @@ class SerializableSerializer(Serializer):
|
|
|
229
296
|
else:
|
|
230
297
|
field.set(obj, value)
|
|
231
298
|
|
|
299
|
+
@classmethod
|
|
300
|
+
def _set_field_values(
|
|
301
|
+
cls,
|
|
302
|
+
obj: Serializable,
|
|
303
|
+
values: List[Any],
|
|
304
|
+
client_cls_to_field_count: Optional[Dict[str, int]],
|
|
305
|
+
is_primitive: bool = True,
|
|
306
|
+
):
|
|
307
|
+
obj_class = type(obj)
|
|
308
|
+
if is_primitive:
|
|
309
|
+
server_cls_to_field_count = obj_class._CLS_TO_PRIMITIVE_FIELD_COUNT
|
|
310
|
+
server_fields = obj_class._PRIMITIVE_FIELDS
|
|
311
|
+
else:
|
|
312
|
+
server_cls_to_field_count = obj_class._CLS_TO_NON_PRIMITIVE_FIELD_COUNT
|
|
313
|
+
server_fields = obj_class._NON_PRIMITIVE_FIELDS
|
|
314
|
+
|
|
315
|
+
legacy_to_new_hash = {
|
|
316
|
+
c._LEGACY_NAME_HASH: c._NAME_HASH
|
|
317
|
+
for c in obj_class.__mro__
|
|
318
|
+
if hasattr(c, "_NAME_HASH") and c._LEGACY_NAME_HASH != c._NAME_HASH
|
|
319
|
+
}
|
|
320
|
+
|
|
321
|
+
if client_cls_to_field_count:
|
|
322
|
+
field_num, server_field_num = 0, 0
|
|
323
|
+
for cls_hash, count in client_cls_to_field_count.items():
|
|
324
|
+
# cut values and fields given field distribution
|
|
325
|
+
# at client and server end
|
|
326
|
+
cls_fields = server_fields[server_field_num : field_num + count]
|
|
327
|
+
cls_values = values[field_num : field_num + count]
|
|
328
|
+
for field, value in zip(cls_fields, cls_values):
|
|
329
|
+
if is_primitive:
|
|
330
|
+
value = _restore_primitive_placeholder(value)
|
|
331
|
+
if not is_primitive or value is not _no_field_value:
|
|
332
|
+
cls._set_field_value(obj, field, value)
|
|
333
|
+
field_num += count
|
|
334
|
+
try:
|
|
335
|
+
server_field_num += server_cls_to_field_count[cls_hash]
|
|
336
|
+
except KeyError:
|
|
337
|
+
try:
|
|
338
|
+
# todo remove this fallback when all
|
|
339
|
+
# versions below v1.0.0rc1 is eliminated
|
|
340
|
+
server_field_num += server_cls_to_field_count[
|
|
341
|
+
legacy_to_new_hash[cls_hash]
|
|
342
|
+
]
|
|
343
|
+
except KeyError:
|
|
344
|
+
# it is possible that certain type of field does not exist
|
|
345
|
+
# at server side
|
|
346
|
+
pass
|
|
347
|
+
else:
|
|
348
|
+
# handle legacy serialization style, with all fields sorted by name
|
|
349
|
+
# todo remove this branch when all versions below v0.1.0b5 is eliminated
|
|
350
|
+
from .field import AnyField
|
|
351
|
+
|
|
352
|
+
if is_primitive:
|
|
353
|
+
new_field_attr = "_legacy_new_primitives"
|
|
354
|
+
deprecated_field_attr = "_legacy_deprecated_primitives"
|
|
355
|
+
else:
|
|
356
|
+
new_field_attr = "_legacy_new_non_primitives"
|
|
357
|
+
deprecated_field_attr = "_legacy_deprecated_non_primitives"
|
|
358
|
+
|
|
359
|
+
# remove fields added on later releases
|
|
360
|
+
new_names = set(getattr(obj_class, new_field_attr, None) or [])
|
|
361
|
+
server_fields = [f for f in server_fields if f.name not in new_names]
|
|
362
|
+
|
|
363
|
+
# fill fields deprecated on later releases
|
|
364
|
+
deprecated_fields = []
|
|
365
|
+
deprecated_names = set()
|
|
366
|
+
if hasattr(obj_class, deprecated_field_attr):
|
|
367
|
+
deprecated_names = set(getattr(obj_class, deprecated_field_attr))
|
|
368
|
+
for field_name in deprecated_names:
|
|
369
|
+
field = AnyField(tag=field_name)
|
|
370
|
+
field.name = field_name
|
|
371
|
+
deprecated_fields.append(field)
|
|
372
|
+
server_fields = sorted(
|
|
373
|
+
server_fields + deprecated_fields, key=lambda f: f.name
|
|
374
|
+
)
|
|
375
|
+
for field, value in zip(server_fields, values):
|
|
376
|
+
if is_primitive:
|
|
377
|
+
value = _restore_primitive_placeholder(value)
|
|
378
|
+
if not is_primitive or value is not _no_field_value:
|
|
379
|
+
try:
|
|
380
|
+
cls._set_field_value(obj, field, value)
|
|
381
|
+
except AttributeError: # pragma: no cover
|
|
382
|
+
if field.name not in deprecated_names:
|
|
383
|
+
raise
|
|
384
|
+
|
|
232
385
|
def deserial(self, serialized: List, context: Dict, subs: List) -> Serializable:
|
|
233
386
|
obj_class_name, primitives = serialized
|
|
234
387
|
obj_class = load_type(obj_class_name, Serializable)
|
|
@@ -238,14 +391,26 @@ class SerializableSerializer(Serializer):
|
|
|
238
391
|
|
|
239
392
|
obj = obj_class.__new__(obj_class)
|
|
240
393
|
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
394
|
+
field_count_data = self.get_public_data(
|
|
395
|
+
context, self._get_obj_field_count_key(obj)
|
|
396
|
+
)
|
|
397
|
+
if field_count_data is None:
|
|
398
|
+
# todo remove this fallback when all
|
|
399
|
+
# versions below v1.0.0rc1 is eliminated
|
|
400
|
+
field_count_data = self.get_public_data(
|
|
401
|
+
context, self._get_obj_field_count_key(obj, legacy=True)
|
|
402
|
+
)
|
|
403
|
+
if field_count_data is not None:
|
|
404
|
+
cls_to_prim_key, cls_to_non_prim_key = msgpack.loads(field_count_data)
|
|
405
|
+
cls_to_prim_key = dict(cls_to_prim_key)
|
|
406
|
+
cls_to_non_prim_key = dict(cls_to_non_prim_key)
|
|
407
|
+
else:
|
|
408
|
+
cls_to_prim_key, cls_to_non_prim_key = None, None
|
|
245
409
|
|
|
410
|
+
if primitives:
|
|
411
|
+
self._set_field_values(obj, primitives, cls_to_prim_key, True)
|
|
246
412
|
if obj_class._NON_PRIMITIVE_FIELDS:
|
|
247
|
-
|
|
248
|
-
self._set_field_value(obj, field, value)
|
|
413
|
+
self._set_field_values(obj, subs[0], cls_to_non_prim_key, False)
|
|
249
414
|
obj.__on_deserialize__()
|
|
250
415
|
return obj
|
|
251
416
|
|
|
@@ -46,6 +46,9 @@ class PrimitiveType(Enum):
|
|
|
46
46
|
complex128 = 25
|
|
47
47
|
|
|
48
48
|
|
|
49
|
+
_np_unicode = np.unicode_ if hasattr(np, "unicode_") else np.str_
|
|
50
|
+
|
|
51
|
+
|
|
49
52
|
_primitive_type_to_valid_types = {
|
|
50
53
|
PrimitiveType.bool: (bool, np.bool_),
|
|
51
54
|
PrimitiveType.int8: (int, np.int8),
|
|
@@ -60,7 +63,7 @@ _primitive_type_to_valid_types = {
|
|
|
60
63
|
PrimitiveType.float32: (float, np.float32),
|
|
61
64
|
PrimitiveType.float64: (float, np.float64),
|
|
62
65
|
PrimitiveType.bytes: (bytes, np.bytes_),
|
|
63
|
-
PrimitiveType.string: (str,
|
|
66
|
+
PrimitiveType.string: (str, _np_unicode),
|
|
64
67
|
PrimitiveType.complex64: (complex, np.complex64),
|
|
65
68
|
PrimitiveType.complex128: (complex, np.complex128),
|
|
66
69
|
}
|
|
@@ -21,6 +21,7 @@ import pytest
|
|
|
21
21
|
|
|
22
22
|
from ....core import EntityData
|
|
23
23
|
from ....lib.wrapped_pickle import switch_unpickle
|
|
24
|
+
from ....utils import no_default
|
|
24
25
|
from ... import deserialize, serialize
|
|
25
26
|
from .. import (
|
|
26
27
|
AnyField,
|
|
@@ -94,6 +95,11 @@ class MySimpleSerializable(Serializable):
|
|
|
94
95
|
_ref_val = ReferenceField("ref_val", "MySimpleSerializable")
|
|
95
96
|
|
|
96
97
|
|
|
98
|
+
class MySubSerializable(MySimpleSerializable):
|
|
99
|
+
_m_int_val = Int64Field("m_int_val", default=250)
|
|
100
|
+
_m_str_val = StringField("m_str_val", default="SUB_STR")
|
|
101
|
+
|
|
102
|
+
|
|
97
103
|
class MySerializable(Serializable):
|
|
98
104
|
_id = IdentityField("id")
|
|
99
105
|
_any_val = AnyField("any_val")
|
|
@@ -138,10 +144,11 @@ class MySerializable(Serializable):
|
|
|
138
144
|
oneof1_val=f"{__name__}.MySerializable",
|
|
139
145
|
oneof2_val=MySimpleSerializable,
|
|
140
146
|
)
|
|
147
|
+
_no_default_val = Float64Field("no_default_val", default=no_default)
|
|
141
148
|
|
|
142
149
|
|
|
143
150
|
@pytest.mark.parametrize("set_is_ci", [False, True], indirect=True)
|
|
144
|
-
@switch_unpickle
|
|
151
|
+
@switch_unpickle(forbidden=False)
|
|
145
152
|
def test_serializable(set_is_ci):
|
|
146
153
|
my_serializable = MySerializable(
|
|
147
154
|
_id="1",
|
|
@@ -165,7 +172,9 @@ def test_serializable(set_is_ci):
|
|
|
165
172
|
_key_val=MyHasKey("aaa"),
|
|
166
173
|
_ndarray_val=np.random.rand(4, 3),
|
|
167
174
|
_datetime64_val=pd.Timestamp(123),
|
|
168
|
-
_timedelta64_val=pd.Timedelta(
|
|
175
|
+
_timedelta64_val=pd.Timedelta(
|
|
176
|
+
days=1, seconds=123, microseconds=345, nanoseconds=132
|
|
177
|
+
),
|
|
169
178
|
_datatype_val=np.dtype(np.int32),
|
|
170
179
|
_index_val=pd.Index([1, 2]),
|
|
171
180
|
_series_val=pd.Series(["a", "b"]),
|
|
@@ -180,6 +189,7 @@ def test_serializable(set_is_ci):
|
|
|
180
189
|
_dict_val={"a": b"bytes_value"},
|
|
181
190
|
_ref_val=MySerializable(),
|
|
182
191
|
_oneof_val=MySerializable(_id="2"),
|
|
192
|
+
_no_default_val=no_default,
|
|
183
193
|
)
|
|
184
194
|
|
|
185
195
|
header, buffers = serialize(my_serializable)
|
|
@@ -187,12 +197,51 @@ def test_serializable(set_is_ci):
|
|
|
187
197
|
_assert_serializable_eq(my_serializable, my_serializable2)
|
|
188
198
|
|
|
189
199
|
|
|
200
|
+
@pytest.mark.parametrize("set_is_ci", [False, True], indirect=True)
|
|
201
|
+
@switch_unpickle
|
|
202
|
+
def test_compatible_serializable(set_is_ci):
|
|
203
|
+
global MySimpleSerializable, MySubSerializable
|
|
204
|
+
|
|
205
|
+
old_base, old_sub = MySimpleSerializable, MySubSerializable
|
|
206
|
+
|
|
207
|
+
try:
|
|
208
|
+
my_sub_serializable = MySubSerializable(
|
|
209
|
+
_id="id_val",
|
|
210
|
+
_list_val=["abcd", "wxyz"],
|
|
211
|
+
_ref_val=MyHasKey(),
|
|
212
|
+
_m_int_val=3412,
|
|
213
|
+
_m_str_val="dfgghj",
|
|
214
|
+
)
|
|
215
|
+
header, buffers = serialize(my_sub_serializable)
|
|
216
|
+
|
|
217
|
+
class MySimpleSerializable(Serializable):
|
|
218
|
+
_id = IdentityField("id")
|
|
219
|
+
_int_val = Int64Field("int_val", default=1000)
|
|
220
|
+
_list_val = ListField("list_val", default_factory=list)
|
|
221
|
+
_ref_val = ReferenceField("ref_val", "MySimpleSerializable")
|
|
222
|
+
_dict_val = DictField("dict_val")
|
|
223
|
+
|
|
224
|
+
class MySubSerializable(MySimpleSerializable):
|
|
225
|
+
_m_int_val = Int64Field("m_int_val", default=250)
|
|
226
|
+
_m_str_val = StringField("m_str_val", default="SUB_STR")
|
|
227
|
+
|
|
228
|
+
my_sub_serializable2 = deserialize(header, buffers)
|
|
229
|
+
assert type(my_sub_serializable) is not type(my_sub_serializable2)
|
|
230
|
+
_assert_serializable_eq(my_sub_serializable, my_sub_serializable2)
|
|
231
|
+
finally:
|
|
232
|
+
MySimpleSerializable, MySubSerializable = old_base, old_sub
|
|
233
|
+
|
|
234
|
+
|
|
190
235
|
def _assert_serializable_eq(my_serializable, my_serializable2):
|
|
191
236
|
for field_name, field in my_serializable._FIELDS.items():
|
|
192
|
-
if not hasattr(my_serializable, field.
|
|
237
|
+
if not hasattr(my_serializable, field.name):
|
|
193
238
|
continue
|
|
194
239
|
expect_value = getattr(my_serializable, field_name)
|
|
195
|
-
|
|
240
|
+
if expect_value is no_default:
|
|
241
|
+
assert not hasattr(my_serializable2, field.name)
|
|
242
|
+
continue
|
|
243
|
+
else:
|
|
244
|
+
actual_value = getattr(my_serializable2, field_name)
|
|
196
245
|
if isinstance(expect_value, np.ndarray):
|
|
197
246
|
np.testing.assert_array_equal(expect_value, actual_value)
|
|
198
247
|
elif isinstance(expect_value, pd.DataFrame):
|
|
@@ -208,7 +257,7 @@ def _assert_serializable_eq(my_serializable, my_serializable2):
|
|
|
208
257
|
elif callable(expect_value):
|
|
209
258
|
assert expect_value(1) == actual_value(1)
|
|
210
259
|
else:
|
|
211
|
-
assert expect_value == actual_value
|
|
260
|
+
assert expect_value == actual_value, f"Field {field_name}"
|
|
212
261
|
|
|
213
262
|
|
|
214
263
|
@pytest.mark.parametrize("set_is_ci", [True], indirect=True)
|
|
@@ -42,7 +42,7 @@ except ImportError:
|
|
|
42
42
|
from ...lib.sparse import SparseMatrix
|
|
43
43
|
from ...lib.wrapped_pickle import switch_unpickle
|
|
44
44
|
from ...tests.utils import require_cudf, require_cupy
|
|
45
|
-
from ...utils import lazy_import
|
|
45
|
+
from ...utils import lazy_import, no_default
|
|
46
46
|
from .. import (
|
|
47
47
|
PickleContainer,
|
|
48
48
|
RemoteException,
|
|
@@ -90,6 +90,7 @@ class CustomNamedTuple(NamedTuple):
|
|
|
90
90
|
pd.Timedelta(102.234154131),
|
|
91
91
|
{"abc": 5.6, "def": [3.4], "gh": None, "ijk": {}},
|
|
92
92
|
OrderedDict([("abcd", 5.6)]),
|
|
93
|
+
no_default,
|
|
93
94
|
],
|
|
94
95
|
)
|
|
95
96
|
@switch_unpickle
|
maxframe/session.py
CHANGED
|
@@ -150,6 +150,10 @@ class AbstractSession(ABC):
|
|
|
150
150
|
def session_id(self):
|
|
151
151
|
return self._session_id
|
|
152
152
|
|
|
153
|
+
@property
|
|
154
|
+
def closed(self) -> bool:
|
|
155
|
+
return self._closed
|
|
156
|
+
|
|
153
157
|
def __eq__(self, other):
|
|
154
158
|
return (
|
|
155
159
|
isinstance(other, AbstractSession)
|
|
@@ -365,6 +369,15 @@ class AbstractAsyncSession(AbstractSession, metaclass=ABCMeta):
|
|
|
365
369
|
Stop server.
|
|
366
370
|
"""
|
|
367
371
|
|
|
372
|
+
@abstractmethod
|
|
373
|
+
async def get_logview_address(self, hours=None) -> Optional[str]:
|
|
374
|
+
"""
|
|
375
|
+
Get Logview address
|
|
376
|
+
Returns
|
|
377
|
+
-------
|
|
378
|
+
Logview address
|
|
379
|
+
"""
|
|
380
|
+
|
|
368
381
|
def close(self):
|
|
369
382
|
asyncio.run(self.destroy())
|
|
370
383
|
|
|
@@ -549,6 +562,15 @@ class AbstractSyncSession(AbstractSession, metaclass=ABCMeta):
|
|
|
549
562
|
|
|
550
563
|
return fetch(tileables, self, offsets=offsets, sizes=sizes)
|
|
551
564
|
|
|
565
|
+
@abstractmethod
|
|
566
|
+
def get_logview_address(self, hours=None) -> Optional[str]:
|
|
567
|
+
"""
|
|
568
|
+
Get logview address
|
|
569
|
+
Returns
|
|
570
|
+
-------
|
|
571
|
+
logview address
|
|
572
|
+
"""
|
|
573
|
+
|
|
552
574
|
|
|
553
575
|
def _delegate_to_isolated_session(func: Union[Callable, Coroutine]):
|
|
554
576
|
if asyncio.iscoroutinefunction(func):
|
|
@@ -728,6 +750,11 @@ class AsyncSession(AbstractAsyncSession):
|
|
|
728
750
|
await asyncio.wrap_future(asyncio.run_coroutine_threadsafe(coro, self._loop))
|
|
729
751
|
stop_isolation()
|
|
730
752
|
|
|
753
|
+
@implements(AbstractAsyncSession.get_logview_address)
|
|
754
|
+
@_delegate_to_isolated_session
|
|
755
|
+
async def get_logview_address(self, hours=None) -> Optional[str]:
|
|
756
|
+
pass # pragma: no cover
|
|
757
|
+
|
|
731
758
|
|
|
732
759
|
class ProgressBar:
|
|
733
760
|
def __init__(self, show_progress):
|
|
@@ -949,6 +976,11 @@ class SyncSession(AbstractSyncSession):
|
|
|
949
976
|
def get_cluster_versions(self) -> List[str]:
|
|
950
977
|
pass # pragma: no cover
|
|
951
978
|
|
|
979
|
+
@implements(AbstractSyncSession.get_logview_address)
|
|
980
|
+
@_delegate_to_isolated_session
|
|
981
|
+
def get_logview_address(self, hours=None) -> Optional[str]:
|
|
982
|
+
pass # pragma: no cover
|
|
983
|
+
|
|
952
984
|
def destroy(self):
|
|
953
985
|
coro = self._isolated_session.destroy()
|
|
954
986
|
asyncio.run_coroutine_threadsafe(coro, self._loop).result()
|
|
@@ -1255,9 +1287,12 @@ def get_default_or_create(**kwargs):
|
|
|
1255
1287
|
if session is None:
|
|
1256
1288
|
# no session attached, try to create one
|
|
1257
1289
|
warnings.warn(warning_msg)
|
|
1258
|
-
|
|
1259
|
-
|
|
1290
|
+
odps_entry = (
|
|
1291
|
+
kwargs.pop("odps_entry", None)
|
|
1292
|
+
or ODPS.from_global()
|
|
1293
|
+
or ODPS.from_environments()
|
|
1260
1294
|
)
|
|
1295
|
+
session = new_session(odps_entry=odps_entry, **kwargs)
|
|
1261
1296
|
session.as_default()
|
|
1262
1297
|
if isinstance(session, IsolatedAsyncSession):
|
|
1263
1298
|
session = SyncSession.from_isolated_session(session)
|