maxframe 1.0.0rc1__cp311-cp311-macosx_10_9_x86_64.whl → 1.0.0rc3__cp311-cp311-macosx_10_9_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of maxframe might be problematic. Click here for more details.
- maxframe/_utils.cpython-311-darwin.so +0 -0
- maxframe/codegen.py +3 -6
- maxframe/config/config.py +49 -10
- maxframe/config/validators.py +42 -11
- maxframe/conftest.py +15 -2
- maxframe/core/__init__.py +2 -13
- maxframe/core/entity/__init__.py +0 -4
- maxframe/core/entity/objects.py +46 -3
- maxframe/core/entity/output_types.py +0 -3
- maxframe/core/entity/tests/test_objects.py +43 -0
- maxframe/core/entity/tileables.py +5 -78
- maxframe/core/graph/__init__.py +2 -2
- maxframe/core/graph/builder/__init__.py +0 -1
- maxframe/core/graph/builder/base.py +5 -4
- maxframe/core/graph/builder/tileable.py +4 -4
- maxframe/core/graph/builder/utils.py +4 -8
- maxframe/core/graph/core.cpython-311-darwin.so +0 -0
- maxframe/core/graph/entity.py +9 -33
- maxframe/core/operator/__init__.py +2 -9
- maxframe/core/operator/base.py +3 -5
- maxframe/core/operator/objects.py +0 -9
- maxframe/core/operator/utils.py +55 -0
- maxframe/dataframe/__init__.py +1 -1
- maxframe/dataframe/arithmetic/around.py +5 -17
- maxframe/dataframe/arithmetic/core.py +15 -7
- maxframe/dataframe/arithmetic/docstring.py +5 -55
- maxframe/dataframe/arithmetic/tests/test_arithmetic.py +22 -0
- maxframe/dataframe/core.py +5 -5
- maxframe/dataframe/datasource/date_range.py +2 -2
- maxframe/dataframe/datasource/read_odps_query.py +7 -1
- maxframe/dataframe/datasource/read_odps_table.py +3 -2
- maxframe/dataframe/datasource/tests/test_datasource.py +14 -0
- maxframe/dataframe/datastore/to_odps.py +1 -1
- maxframe/dataframe/groupby/cum.py +0 -1
- maxframe/dataframe/groupby/tests/test_groupby.py +4 -0
- maxframe/dataframe/indexing/add_prefix_suffix.py +1 -1
- maxframe/dataframe/indexing/rename.py +3 -37
- maxframe/dataframe/indexing/sample.py +0 -1
- maxframe/dataframe/indexing/set_index.py +68 -1
- maxframe/dataframe/merge/merge.py +236 -2
- maxframe/dataframe/merge/tests/test_merge.py +123 -0
- maxframe/dataframe/misc/apply.py +3 -10
- maxframe/dataframe/misc/case_when.py +1 -1
- maxframe/dataframe/misc/describe.py +2 -2
- maxframe/dataframe/misc/drop_duplicates.py +4 -25
- maxframe/dataframe/misc/eval.py +4 -0
- maxframe/dataframe/misc/pct_change.py +1 -83
- maxframe/dataframe/misc/transform.py +1 -30
- maxframe/dataframe/misc/value_counts.py +4 -17
- maxframe/dataframe/missing/dropna.py +1 -1
- maxframe/dataframe/missing/fillna.py +5 -5
- maxframe/dataframe/operators.py +1 -17
- maxframe/dataframe/reduction/core.py +2 -2
- maxframe/dataframe/sort/sort_values.py +1 -11
- maxframe/dataframe/statistics/quantile.py +5 -17
- maxframe/dataframe/utils.py +4 -7
- maxframe/io/objects/__init__.py +24 -0
- maxframe/io/objects/core.py +140 -0
- maxframe/io/objects/tensor.py +76 -0
- maxframe/io/objects/tests/__init__.py +13 -0
- maxframe/io/objects/tests/test_object_io.py +97 -0
- maxframe/{odpsio → io/odpsio}/__init__.py +3 -1
- maxframe/{odpsio → io/odpsio}/arrow.py +12 -8
- maxframe/{odpsio → io/odpsio}/schema.py +15 -12
- maxframe/io/odpsio/tableio.py +702 -0
- maxframe/io/odpsio/tests/__init__.py +13 -0
- maxframe/{odpsio → io/odpsio}/tests/test_schema.py +19 -18
- maxframe/{odpsio → io/odpsio}/tests/test_tableio.py +50 -23
- maxframe/{odpsio → io/odpsio}/tests/test_volumeio.py +4 -6
- maxframe/io/odpsio/volumeio.py +57 -0
- maxframe/learn/contrib/xgboost/classifier.py +26 -2
- maxframe/learn/contrib/xgboost/core.py +87 -2
- maxframe/learn/contrib/xgboost/dmatrix.py +3 -6
- maxframe/learn/contrib/xgboost/predict.py +21 -7
- maxframe/learn/contrib/xgboost/regressor.py +3 -10
- maxframe/learn/contrib/xgboost/train.py +27 -17
- maxframe/{core/operator/fuse.py → learn/core.py} +7 -10
- maxframe/lib/mmh3.cpython-311-darwin.so +0 -0
- maxframe/protocol.py +41 -17
- maxframe/remote/core.py +4 -8
- maxframe/serialization/__init__.py +1 -0
- maxframe/serialization/core.cpython-311-darwin.so +0 -0
- maxframe/serialization/serializables/core.py +48 -9
- maxframe/tensor/__init__.py +69 -2
- maxframe/tensor/arithmetic/isclose.py +1 -0
- maxframe/tensor/arithmetic/tests/test_arithmetic.py +21 -17
- maxframe/tensor/core.py +5 -136
- maxframe/tensor/datasource/array.py +3 -0
- maxframe/tensor/datasource/full.py +1 -1
- maxframe/tensor/datasource/tests/test_datasource.py +1 -1
- maxframe/tensor/indexing/flatnonzero.py +1 -1
- maxframe/tensor/merge/__init__.py +2 -0
- maxframe/tensor/merge/concatenate.py +98 -0
- maxframe/tensor/merge/tests/test_merge.py +30 -1
- maxframe/tensor/merge/vstack.py +70 -0
- maxframe/tensor/{base → misc}/__init__.py +2 -0
- maxframe/tensor/{base → misc}/atleast_1d.py +0 -2
- maxframe/tensor/misc/atleast_2d.py +70 -0
- maxframe/tensor/misc/atleast_3d.py +85 -0
- maxframe/tensor/misc/tests/__init__.py +13 -0
- maxframe/tensor/{base → misc}/transpose.py +22 -18
- maxframe/tensor/{base → misc}/unique.py +2 -2
- maxframe/tensor/operators.py +1 -7
- maxframe/tensor/random/core.py +1 -1
- maxframe/tensor/reduction/count_nonzero.py +1 -0
- maxframe/tensor/reduction/mean.py +1 -0
- maxframe/tensor/reduction/nanmean.py +1 -0
- maxframe/tensor/reduction/nanvar.py +2 -0
- maxframe/tensor/reduction/tests/test_reduction.py +12 -1
- maxframe/tensor/reduction/var.py +2 -0
- maxframe/tensor/statistics/quantile.py +2 -2
- maxframe/tensor/utils.py +2 -22
- maxframe/tests/utils.py +11 -2
- maxframe/typing_.py +4 -1
- maxframe/udf.py +8 -9
- maxframe/utils.py +32 -70
- {maxframe-1.0.0rc1.dist-info → maxframe-1.0.0rc3.dist-info}/METADATA +25 -25
- {maxframe-1.0.0rc1.dist-info → maxframe-1.0.0rc3.dist-info}/RECORD +133 -123
- {maxframe-1.0.0rc1.dist-info → maxframe-1.0.0rc3.dist-info}/WHEEL +1 -1
- maxframe_client/fetcher.py +60 -68
- maxframe_client/session/graph.py +8 -2
- maxframe_client/session/odps.py +58 -22
- maxframe_client/tests/test_fetcher.py +21 -3
- maxframe_client/tests/test_session.py +27 -4
- maxframe/core/entity/chunks.py +0 -68
- maxframe/core/entity/fuse.py +0 -73
- maxframe/core/graph/builder/chunk.py +0 -430
- maxframe/odpsio/tableio.py +0 -322
- maxframe/odpsio/volumeio.py +0 -95
- /maxframe/{odpsio → core/entity}/tests/__init__.py +0 -0
- /maxframe/{tensor/base/tests → io}/__init__.py +0 -0
- /maxframe/{odpsio → io/odpsio}/tests/test_arrow.py +0 -0
- /maxframe/tensor/{base → misc}/astype.py +0 -0
- /maxframe/tensor/{base → misc}/broadcast_to.py +0 -0
- /maxframe/tensor/{base → misc}/ravel.py +0 -0
- /maxframe/tensor/{base/tests/test_base.py → misc/tests/test_misc.py} +0 -0
- /maxframe/tensor/{base → misc}/where.py +0 -0
- {maxframe-1.0.0rc1.dist-info → maxframe-1.0.0rc3.dist-info}/top_level.txt +0 -0
|
Binary file
|
maxframe/codegen.py
CHANGED
|
@@ -26,9 +26,9 @@ from odps.utils import camel_to_underline
|
|
|
26
26
|
from .core import OperatorType, Tileable, TileableGraph
|
|
27
27
|
from .core.operator import Fetch
|
|
28
28
|
from .extension import iter_extensions
|
|
29
|
+
from .io.odpsio import build_dataframe_table_meta
|
|
30
|
+
from .io.odpsio.schema import pandas_to_odps_schema
|
|
29
31
|
from .lib import wrapped_pickle as pickle
|
|
30
|
-
from .odpsio import build_dataframe_table_meta
|
|
31
|
-
from .odpsio.schema import pandas_to_odps_schema
|
|
32
32
|
from .protocol import DataFrameTableMeta, ResultInfo
|
|
33
33
|
from .serialization import PickleContainer
|
|
34
34
|
from .serialization.serializables import Serializable, StringField
|
|
@@ -205,12 +205,8 @@ class BigDagCodeContext(metaclass=abc.ABCMeta):
|
|
|
205
205
|
return self._session_id
|
|
206
206
|
|
|
207
207
|
def register_udf(self, udf: AbstractUDF):
|
|
208
|
-
from maxframe_framedriver.services.session import SessionManager
|
|
209
|
-
|
|
210
208
|
udf.session_id = self._session_id
|
|
211
209
|
self._udfs[udf.name] = udf
|
|
212
|
-
if self._session_id and SessionManager.initialized():
|
|
213
|
-
SessionManager.instance().register_udf(self._session_id, udf)
|
|
214
210
|
|
|
215
211
|
def get_udfs(self) -> List[AbstractUDF]:
|
|
216
212
|
return list(self._udfs.values())
|
|
@@ -510,6 +506,7 @@ class BigDagCodeGenerator(metaclass=abc.ABCMeta):
|
|
|
510
506
|
prefer_binary=pack.prefer_binary,
|
|
511
507
|
pre_release=pack.pre_release,
|
|
512
508
|
force_rebuild=pack.force_rebuild,
|
|
509
|
+
no_audit_wheel=pack.no_audit_wheel,
|
|
513
510
|
python_tag=python_tag,
|
|
514
511
|
is_production=is_production,
|
|
515
512
|
schedule_id=schedule_id,
|
maxframe/config/config.py
CHANGED
|
@@ -19,29 +19,40 @@ import warnings
|
|
|
19
19
|
from copy import deepcopy
|
|
20
20
|
from typing import Any, Dict, Optional, Union
|
|
21
21
|
|
|
22
|
+
from odps.lib import tzlocal
|
|
23
|
+
|
|
24
|
+
try:
|
|
25
|
+
from zoneinfo import available_timezones
|
|
26
|
+
except ImportError:
|
|
27
|
+
from pytz import all_timezones
|
|
28
|
+
|
|
29
|
+
available_timezones = lambda: all_timezones
|
|
30
|
+
|
|
22
31
|
from ..utils import get_python_tag
|
|
23
32
|
from .validators import (
|
|
24
33
|
ValidatorType,
|
|
25
34
|
all_validator,
|
|
26
|
-
any_validator,
|
|
27
35
|
is_bool,
|
|
28
36
|
is_dict,
|
|
29
37
|
is_in,
|
|
30
38
|
is_integer,
|
|
39
|
+
is_non_negative_integer,
|
|
31
40
|
is_null,
|
|
32
41
|
is_numeric,
|
|
33
42
|
is_string,
|
|
43
|
+
is_valid_cache_path,
|
|
34
44
|
)
|
|
35
45
|
|
|
36
46
|
_DEFAULT_REDIRECT_WARN = "Option {source} has been replaced by {target} and might be removed in a future release."
|
|
37
47
|
_DEFAULT_MAX_ALIVE_SECONDS = 3 * 24 * 3600
|
|
38
48
|
_DEFAULT_MAX_IDLE_SECONDS = 3600
|
|
39
49
|
_DEFAULT_SPE_OPERATION_TIMEOUT_SECONDS = 120
|
|
50
|
+
_DEFAULT_SPE_FAILURE_RETRY_TIMES = 5
|
|
40
51
|
_DEFAULT_UPLOAD_BATCH_SIZE = 4096
|
|
41
52
|
_DEFAULT_TEMP_LIFECYCLE = 1
|
|
42
53
|
_DEFAULT_TASK_START_TIMEOUT = 60
|
|
43
54
|
_DEFAULT_TASK_RESTART_TIMEOUT = 300
|
|
44
|
-
_DEFAULT_LOGVIEW_HOURS = 24 *
|
|
55
|
+
_DEFAULT_LOGVIEW_HOURS = 24 * 30
|
|
45
56
|
|
|
46
57
|
|
|
47
58
|
class OptionError(Exception):
|
|
@@ -297,13 +308,28 @@ class Config:
|
|
|
297
308
|
return {k: v for k, v in res.items() if k in self._remote_options}
|
|
298
309
|
|
|
299
310
|
|
|
311
|
+
def _get_legal_local_tz_name() -> Optional[str]:
|
|
312
|
+
"""Sometimes we may get illegal tz name from tzlocal.get_localzone()"""
|
|
313
|
+
tz_name = str(tzlocal.get_localzone())
|
|
314
|
+
if tz_name not in available_timezones():
|
|
315
|
+
return None
|
|
316
|
+
return tz_name
|
|
317
|
+
|
|
318
|
+
|
|
300
319
|
default_options = Config()
|
|
301
320
|
default_options.register_option(
|
|
302
321
|
"execution_mode", "trigger", validator=is_in(["trigger", "eager"])
|
|
303
322
|
)
|
|
323
|
+
default_options.register_option("use_common_table", False, validator=is_bool)
|
|
304
324
|
default_options.register_option(
|
|
305
325
|
"python_tag", get_python_tag(), validator=is_string, remote=True
|
|
306
326
|
)
|
|
327
|
+
default_options.register_option(
|
|
328
|
+
"local_timezone",
|
|
329
|
+
_get_legal_local_tz_name(),
|
|
330
|
+
validator=is_null | is_in(set(available_timezones())),
|
|
331
|
+
remote=True,
|
|
332
|
+
)
|
|
307
333
|
default_options.register_option(
|
|
308
334
|
"session.logview_hours", _DEFAULT_LOGVIEW_HOURS, validator=is_integer, remote=True
|
|
309
335
|
)
|
|
@@ -322,6 +348,17 @@ default_options.register_option("sql.settings", {}, validator=is_dict, remote=Tr
|
|
|
322
348
|
default_options.register_option("is_production", False, validator=is_bool, remote=True)
|
|
323
349
|
default_options.register_option("schedule_id", "", validator=is_string, remote=True)
|
|
324
350
|
|
|
351
|
+
default_options.register_option(
|
|
352
|
+
"service_role_arn", None, validator=is_null | is_string, remote=True
|
|
353
|
+
)
|
|
354
|
+
default_options.register_option(
|
|
355
|
+
"object_cache_url", None, validator=is_null | is_valid_cache_path, remote=True
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
default_options.register_option(
|
|
359
|
+
"chunk_size", None, validator=is_null | is_integer, remote=True
|
|
360
|
+
)
|
|
361
|
+
|
|
325
362
|
default_options.register_option(
|
|
326
363
|
"session.max_alive_seconds",
|
|
327
364
|
_DEFAULT_MAX_ALIVE_SECONDS,
|
|
@@ -340,9 +377,7 @@ default_options.register_option(
|
|
|
340
377
|
validator=is_integer,
|
|
341
378
|
)
|
|
342
379
|
default_options.register_option(
|
|
343
|
-
"session.table_lifecycle",
|
|
344
|
-
None,
|
|
345
|
-
validator=any_validator(is_null, is_integer),
|
|
380
|
+
"session.table_lifecycle", None, validator=is_null | is_integer
|
|
346
381
|
)
|
|
347
382
|
default_options.register_option(
|
|
348
383
|
"session.temp_table_lifecycle",
|
|
@@ -353,7 +388,7 @@ default_options.register_option(
|
|
|
353
388
|
default_options.register_option(
|
|
354
389
|
"session.subinstance_priority",
|
|
355
390
|
None,
|
|
356
|
-
validator=
|
|
391
|
+
validator=is_null | is_integer,
|
|
357
392
|
remote=True,
|
|
358
393
|
)
|
|
359
394
|
|
|
@@ -365,9 +400,7 @@ default_options.register_option(
|
|
|
365
400
|
default_options.register_option(
|
|
366
401
|
"optimize.head_optimize_threshold", 1000, validator=is_integer
|
|
367
402
|
)
|
|
368
|
-
default_options.register_option(
|
|
369
|
-
"show_progress", "auto", validator=any_validator(is_bool, is_string)
|
|
370
|
-
)
|
|
403
|
+
default_options.register_option("show_progress", "auto", validator=is_bool | is_string)
|
|
371
404
|
default_options.register_option(
|
|
372
405
|
"dag.settings", value=dict(), validator=is_dict, remote=True
|
|
373
406
|
)
|
|
@@ -378,7 +411,13 @@ default_options.register_option(
|
|
|
378
411
|
default_options.register_option(
|
|
379
412
|
"spe.operation_timeout_seconds",
|
|
380
413
|
_DEFAULT_SPE_OPERATION_TIMEOUT_SECONDS,
|
|
381
|
-
validator=
|
|
414
|
+
validator=is_non_negative_integer,
|
|
415
|
+
remote=True,
|
|
416
|
+
)
|
|
417
|
+
default_options.register_option(
|
|
418
|
+
"spe.failure_retry_times",
|
|
419
|
+
_DEFAULT_SPE_FAILURE_RETRY_TIMES,
|
|
420
|
+
validator=is_non_negative_integer,
|
|
382
421
|
remote=True,
|
|
383
422
|
)
|
|
384
423
|
|
maxframe/config/validators.py
CHANGED
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
from typing import Callable
|
|
16
|
+
from urllib.parse import urlparse
|
|
16
17
|
|
|
17
18
|
ValidatorType = Callable[..., bool]
|
|
18
19
|
|
|
@@ -32,21 +33,51 @@ def all_validator(*validators: ValidatorType):
|
|
|
32
33
|
return validate
|
|
33
34
|
|
|
34
35
|
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
36
|
+
class Validator:
|
|
37
|
+
def __init__(self, func: ValidatorType):
|
|
38
|
+
self._func = func
|
|
39
|
+
|
|
40
|
+
def __call__(self, arg) -> bool:
|
|
41
|
+
return self._func(arg)
|
|
42
|
+
|
|
43
|
+
def __or__(self, other):
|
|
44
|
+
return OrValidator(self, other)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class OrValidator(Validator):
|
|
48
|
+
def __init__(self, lhs: Validator, rhs: Validator):
|
|
49
|
+
super().__init__(lambda x: lhs(x) or rhs(x))
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
is_null = Validator(lambda x: x is None)
|
|
53
|
+
is_bool = Validator(lambda x: isinstance(x, bool))
|
|
54
|
+
is_float = Validator(lambda x: isinstance(x, float))
|
|
55
|
+
is_integer = Validator(lambda x: isinstance(x, int))
|
|
56
|
+
is_numeric = Validator(lambda x: isinstance(x, (int, float)))
|
|
57
|
+
is_string = Validator(lambda x: isinstance(x, str))
|
|
58
|
+
is_dict = Validator(lambda x: isinstance(x, dict))
|
|
59
|
+
is_positive_integer = Validator(lambda x: is_integer(x) and x > 0)
|
|
60
|
+
is_non_negative_integer = Validator(lambda x: is_integer(x) and x >= 0)
|
|
43
61
|
|
|
44
62
|
|
|
45
63
|
def is_in(vals):
|
|
46
|
-
|
|
47
|
-
return x in vals
|
|
64
|
+
return Validator(vals.__contains__)
|
|
48
65
|
|
|
49
|
-
|
|
66
|
+
|
|
67
|
+
def _is_valid_cache_path(path: str) -> bool:
|
|
68
|
+
"""
|
|
69
|
+
path should look like oss://oss_endpoint/oss_bucket/path
|
|
70
|
+
"""
|
|
71
|
+
parsed_url = urlparse(path)
|
|
72
|
+
return (
|
|
73
|
+
parsed_url.scheme == "oss"
|
|
74
|
+
and parsed_url.netloc
|
|
75
|
+
and parsed_url.path
|
|
76
|
+
and "/" in parsed_url.path
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
is_valid_cache_path = Validator(_is_valid_cache_path)
|
|
50
81
|
|
|
51
82
|
|
|
52
83
|
_invalid_char_in_yaml_str = {'"', "'", "\n", "\\"}
|
maxframe/conftest.py
CHANGED
|
@@ -19,6 +19,8 @@ from configparser import ConfigParser, NoOptionError
|
|
|
19
19
|
import pytest
|
|
20
20
|
from odps import ODPS
|
|
21
21
|
|
|
22
|
+
from .config import options
|
|
23
|
+
|
|
22
24
|
faulthandler.enable(all_threads=True)
|
|
23
25
|
_test_conf_file_name = os.path.join(
|
|
24
26
|
os.path.dirname(os.path.abspath(__file__)), "tests", "test.conf"
|
|
@@ -77,16 +79,23 @@ def odps_envs(test_config):
|
|
|
77
79
|
pass
|
|
78
80
|
|
|
79
81
|
|
|
80
|
-
@pytest.fixture
|
|
82
|
+
@pytest.fixture(scope="session")
|
|
81
83
|
def oss_config():
|
|
82
84
|
config = ConfigParser()
|
|
83
85
|
config.read(_test_conf_file_name)
|
|
84
86
|
|
|
87
|
+
old_role_arn = options.service_role_arn
|
|
88
|
+
old_cache_url = options.object_cache_url
|
|
89
|
+
|
|
85
90
|
try:
|
|
86
91
|
oss_access_id = config.get("oss", "access_id")
|
|
87
92
|
oss_secret_access_key = config.get("oss", "secret_access_key")
|
|
88
93
|
oss_bucket_name = config.get("oss", "bucket_name")
|
|
89
94
|
oss_endpoint = config.get("oss", "endpoint")
|
|
95
|
+
oss_rolearn = config.get("oss", "rolearn")
|
|
96
|
+
|
|
97
|
+
options.service_role_arn = oss_rolearn
|
|
98
|
+
options.object_cache_url = f"oss://{oss_endpoint}/{oss_bucket_name}"
|
|
90
99
|
|
|
91
100
|
config.oss_config = (
|
|
92
101
|
oss_access_id,
|
|
@@ -99,9 +108,13 @@ def oss_config():
|
|
|
99
108
|
|
|
100
109
|
auth = oss2.Auth(oss_access_id, oss_secret_access_key)
|
|
101
110
|
config.oss_bucket = oss2.Bucket(auth, oss_endpoint, oss_bucket_name)
|
|
102
|
-
|
|
111
|
+
config.oss_rolearn = oss_rolearn
|
|
112
|
+
yield config
|
|
103
113
|
except (ConfigParser.NoSectionError, ConfigParser.NoOptionError, ImportError):
|
|
104
114
|
return None
|
|
115
|
+
finally:
|
|
116
|
+
options.service_role_arn = old_role_arn
|
|
117
|
+
options.object_cache_url = old_cache_url
|
|
105
118
|
|
|
106
119
|
|
|
107
120
|
@pytest.fixture(autouse=True)
|
maxframe/core/__init__.py
CHANGED
|
@@ -14,20 +14,14 @@
|
|
|
14
14
|
|
|
15
15
|
# noinspection PyUnresolvedReferences
|
|
16
16
|
from ..typing_ import ChunkType, EntityType, OperatorType, TileableType
|
|
17
|
-
from .base import ExecutionError
|
|
17
|
+
from .base import Base, ExecutionError
|
|
18
18
|
from .entity import (
|
|
19
|
-
CHUNK_TYPE,
|
|
20
19
|
ENTITY_TYPE,
|
|
21
|
-
FUSE_CHUNK_TYPE,
|
|
22
20
|
OBJECT_TYPE,
|
|
23
21
|
TILEABLE_TYPE,
|
|
24
|
-
Chunk,
|
|
25
|
-
ChunkData,
|
|
26
22
|
Entity,
|
|
27
23
|
EntityData,
|
|
28
24
|
ExecutableTuple,
|
|
29
|
-
FuseChunk,
|
|
30
|
-
FuseChunkData,
|
|
31
25
|
HasShapeTileable,
|
|
32
26
|
HasShapeTileableData,
|
|
33
27
|
NotSupportTile,
|
|
@@ -40,23 +34,18 @@ from .entity import (
|
|
|
40
34
|
get_fetch_class,
|
|
41
35
|
get_output_types,
|
|
42
36
|
get_tileable_types,
|
|
43
|
-
register,
|
|
44
37
|
register_fetch_class,
|
|
45
38
|
register_output_types,
|
|
46
|
-
unregister,
|
|
47
39
|
)
|
|
48
40
|
|
|
49
41
|
# noinspection PyUnresolvedReferences
|
|
50
42
|
from .graph import (
|
|
51
43
|
DAG,
|
|
52
|
-
ChunkGraph,
|
|
53
|
-
ChunkGraphBuilder,
|
|
54
44
|
DirectedGraph,
|
|
55
45
|
GraphContainsCycleError,
|
|
56
46
|
GraphSerializer,
|
|
57
47
|
TileableGraph,
|
|
58
48
|
TileableGraphBuilder,
|
|
59
|
-
TileContext,
|
|
60
|
-
TileStatus,
|
|
61
49
|
)
|
|
62
50
|
from .mode import enter_mode, is_build_mode, is_eager_mode, is_kernel_mode
|
|
51
|
+
from .operator import build_fetch
|
maxframe/core/entity/__init__.py
CHANGED
|
@@ -12,10 +12,8 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from .chunks import CHUNK_TYPE, Chunk, ChunkData
|
|
16
15
|
from .core import ENTITY_TYPE, Entity, EntityData
|
|
17
16
|
from .executable import ExecutableTuple, _ExecuteAndFetchMixin
|
|
18
|
-
from .fuse import FUSE_CHUNK_TYPE, FuseChunk, FuseChunkData
|
|
19
17
|
from .objects import OBJECT_TYPE, Object, ObjectData
|
|
20
18
|
from .output_types import (
|
|
21
19
|
OutputType,
|
|
@@ -32,6 +30,4 @@ from .tileables import (
|
|
|
32
30
|
NotSupportTile,
|
|
33
31
|
Tileable,
|
|
34
32
|
TileableData,
|
|
35
|
-
register,
|
|
36
|
-
unregister,
|
|
37
33
|
)
|
maxframe/core/entity/objects.py
CHANGED
|
@@ -12,8 +12,10 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import Any, Dict
|
|
15
|
+
from typing import Any, Dict, Type
|
|
16
16
|
|
|
17
|
+
from ...serialization import load_type
|
|
18
|
+
from ...serialization.serializables import StringField
|
|
17
19
|
from .core import Entity
|
|
18
20
|
from .executable import _ToObjectMixin
|
|
19
21
|
from .tileables import TileableData
|
|
@@ -23,11 +25,44 @@ class ObjectData(TileableData, _ToObjectMixin):
|
|
|
23
25
|
__slots__ = ()
|
|
24
26
|
type_name = "Object"
|
|
25
27
|
# workaround for removed field since v0.1.0b5
|
|
26
|
-
# todo remove this when all versions below
|
|
28
|
+
# todo remove this when all versions below v1.0.0rc1 is eliminated
|
|
27
29
|
_legacy_deprecated_non_primitives = ["_chunks"]
|
|
30
|
+
_legacy_new_non_primitives = ["object_class"]
|
|
31
|
+
|
|
32
|
+
object_class = StringField("object_class", default=None)
|
|
33
|
+
|
|
34
|
+
@classmethod
|
|
35
|
+
def get_entity_class(cls) -> Type["Object"]:
|
|
36
|
+
if getattr(cls, "_entity_class", None) is not None:
|
|
37
|
+
return cls._entity_class
|
|
38
|
+
assert cls.__qualname__[-4:] == "Data"
|
|
39
|
+
target_class_name = cls.__module__ + "#" + cls.__qualname__[:-4]
|
|
40
|
+
cls._entity_class = load_type(target_class_name, Object)
|
|
41
|
+
return cls._entity_class
|
|
42
|
+
|
|
43
|
+
def __new__(cls, op=None, nsplits=None, **kw):
|
|
44
|
+
if cls is ObjectData:
|
|
45
|
+
obj_cls = kw.get("object_class")
|
|
46
|
+
if isinstance(obj_cls, str):
|
|
47
|
+
obj_cls = load_type(obj_cls, (Object, ObjectData))
|
|
48
|
+
if isinstance(obj_cls, type) and issubclass(obj_cls, Object):
|
|
49
|
+
obj_cls = obj_cls.get_data_class()
|
|
50
|
+
|
|
51
|
+
if obj_cls is not None and cls is not obj_cls:
|
|
52
|
+
return obj_cls(op=op, nsplits=nsplits, **kw)
|
|
53
|
+
return super().__new__(cls)
|
|
28
54
|
|
|
29
55
|
def __init__(self, op=None, nsplits=None, **kw):
|
|
56
|
+
obj_cls = kw.pop("object_class", None)
|
|
57
|
+
if isinstance(obj_cls, type):
|
|
58
|
+
if isinstance(obj_cls, type) and issubclass(obj_cls, Object):
|
|
59
|
+
obj_cls = obj_cls.get_data_class()
|
|
60
|
+
kw["object_class"] = obj_cls.__module__ + "#" + obj_cls.__qualname__
|
|
61
|
+
|
|
30
62
|
super().__init__(_op=op, _nsplits=nsplits, **kw)
|
|
63
|
+
if self.object_class is None and type(self) is not ObjectData:
|
|
64
|
+
cls = type(self)
|
|
65
|
+
self.object_class = cls.__module__ + "#" + cls.__qualname__
|
|
31
66
|
|
|
32
67
|
def __repr__(self):
|
|
33
68
|
return f"Object <op={type(self.op).__name__}, key={self.key}>"
|
|
@@ -35,7 +70,7 @@ class ObjectData(TileableData, _ToObjectMixin):
|
|
|
35
70
|
@property
|
|
36
71
|
def params(self):
|
|
37
72
|
# params return the properties which useful to rebuild a new tileable object
|
|
38
|
-
return dict()
|
|
73
|
+
return dict(object_class=self.object_class)
|
|
39
74
|
|
|
40
75
|
@params.setter
|
|
41
76
|
def params(self, new_params: Dict[str, Any]):
|
|
@@ -54,5 +89,13 @@ class Object(Entity, _ToObjectMixin):
|
|
|
54
89
|
_allow_data_type_ = (ObjectData,)
|
|
55
90
|
type_name = "Object"
|
|
56
91
|
|
|
92
|
+
@classmethod
|
|
93
|
+
def get_data_class(cls) -> Type[ObjectData]:
|
|
94
|
+
if getattr(cls, "_data_class", None) is not None:
|
|
95
|
+
return cls._data_class
|
|
96
|
+
target_class_name = cls.__module__ + "#" + cls.__qualname__ + "Data"
|
|
97
|
+
cls._data_class = load_type(target_class_name, ObjectData)
|
|
98
|
+
return cls._data_class
|
|
99
|
+
|
|
57
100
|
|
|
58
101
|
OBJECT_TYPE = (Object, ObjectData)
|
|
@@ -15,7 +15,6 @@
|
|
|
15
15
|
import functools
|
|
16
16
|
from enum import Enum
|
|
17
17
|
|
|
18
|
-
from .fuse import FUSE_CHUNK_TYPE
|
|
19
18
|
from .objects import OBJECT_TYPE
|
|
20
19
|
|
|
21
20
|
|
|
@@ -77,8 +76,6 @@ def get_output_types(*objs, unknown_as=None):
|
|
|
77
76
|
for obj in objs:
|
|
78
77
|
if obj is None:
|
|
79
78
|
continue
|
|
80
|
-
elif isinstance(obj, FUSE_CHUNK_TYPE):
|
|
81
|
-
obj = obj.chunk
|
|
82
79
|
|
|
83
80
|
try:
|
|
84
81
|
output_types.append(_get_output_type_by_cls(type(obj)))
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
# Copyright 1999-2024 Alibaba Group Holding Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from ..objects import Object, ObjectData
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class TestSubObjectData(ObjectData):
|
|
19
|
+
__test__ = False
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class TestSubObject(Object):
|
|
23
|
+
__test__ = False
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def test_object_init():
|
|
27
|
+
assert TestSubObjectData.get_entity_class() is TestSubObject
|
|
28
|
+
|
|
29
|
+
obj = ObjectData(
|
|
30
|
+
object_class=TestSubObjectData.__module__ + "#" + TestSubObjectData.__name__
|
|
31
|
+
)
|
|
32
|
+
assert isinstance(obj, TestSubObjectData)
|
|
33
|
+
|
|
34
|
+
obj = ObjectData(object_class=TestSubObjectData)
|
|
35
|
+
assert isinstance(obj, TestSubObjectData)
|
|
36
|
+
|
|
37
|
+
obj = ObjectData(
|
|
38
|
+
object_class=TestSubObject.__module__ + "#" + TestSubObject.__name__
|
|
39
|
+
)
|
|
40
|
+
assert isinstance(obj, TestSubObjectData)
|
|
41
|
+
|
|
42
|
+
obj = ObjectData(object_class=TestSubObject)
|
|
43
|
+
assert isinstance(obj, TestSubObjectData)
|
|
@@ -15,17 +15,15 @@
|
|
|
15
15
|
import builtins
|
|
16
16
|
import itertools
|
|
17
17
|
from operator import attrgetter
|
|
18
|
-
from typing import Callable, List
|
|
19
18
|
from weakref import WeakKeyDictionary, WeakSet
|
|
20
19
|
|
|
21
20
|
import numpy as np
|
|
22
21
|
|
|
23
22
|
from ...serialization.serializables import BoolField, FieldTypes, TupleField
|
|
24
|
-
from ...typing_ import
|
|
23
|
+
from ...typing_ import TileableType
|
|
25
24
|
from ...utils import on_deserialize_shape, on_serialize_nsplits, on_serialize_shape
|
|
26
25
|
from ..base import Base
|
|
27
26
|
from ..mode import enter_mode
|
|
28
|
-
from .chunks import Chunk
|
|
29
27
|
from .core import Entity, EntityData
|
|
30
28
|
from .executable import _ExecutableMixin
|
|
31
29
|
|
|
@@ -34,79 +32,6 @@ class NotSupportTile(Exception):
|
|
|
34
32
|
pass
|
|
35
33
|
|
|
36
34
|
|
|
37
|
-
class OperatorTilesHandler:
|
|
38
|
-
_handlers = dict()
|
|
39
|
-
|
|
40
|
-
@classmethod
|
|
41
|
-
def _get_op_cls(cls, op: OperatorType):
|
|
42
|
-
if isinstance(op, type):
|
|
43
|
-
return op
|
|
44
|
-
return type(op)
|
|
45
|
-
|
|
46
|
-
@classmethod
|
|
47
|
-
def register(
|
|
48
|
-
cls, op: OperatorType, tile_handler: Callable[[OperatorType], TileableType]
|
|
49
|
-
):
|
|
50
|
-
cls._handlers[cls._get_op_cls(op)] = tile_handler
|
|
51
|
-
|
|
52
|
-
@classmethod
|
|
53
|
-
def unregister(cls, op: OperatorType):
|
|
54
|
-
del cls._handlers[cls._get_op_cls(op)]
|
|
55
|
-
|
|
56
|
-
@classmethod
|
|
57
|
-
def get_handler(
|
|
58
|
-
cls, op: OperatorType
|
|
59
|
-
) -> Callable[[OperatorType], List[TileableType]]:
|
|
60
|
-
op_cls = cls._get_op_cls(op)
|
|
61
|
-
return cls._handlers.get(op_cls, op_cls.tile)
|
|
62
|
-
|
|
63
|
-
@classmethod
|
|
64
|
-
def _assign_to(
|
|
65
|
-
cls,
|
|
66
|
-
tile_after_tensor_datas: List["TileableData"],
|
|
67
|
-
tile_before_tensor_datas: List["TileableData"],
|
|
68
|
-
):
|
|
69
|
-
assert len(tile_after_tensor_datas) == len(tile_before_tensor_datas)
|
|
70
|
-
|
|
71
|
-
for tile_after_tensor_data, tile_before_tensor_data in zip(
|
|
72
|
-
tile_after_tensor_datas, tile_before_tensor_datas
|
|
73
|
-
):
|
|
74
|
-
if tile_before_tensor_data is None:
|
|
75
|
-
# garbage collected
|
|
76
|
-
continue
|
|
77
|
-
tile_after_tensor_data.copy_to(tile_before_tensor_data)
|
|
78
|
-
tile_before_tensor_data.op.outputs = tile_before_tensor_datas
|
|
79
|
-
|
|
80
|
-
@enter_mode(kernel=True)
|
|
81
|
-
def dispatch(self, op: OperatorType):
|
|
82
|
-
op_cls = self._get_op_cls(op)
|
|
83
|
-
tiled = None
|
|
84
|
-
cause = None
|
|
85
|
-
|
|
86
|
-
if op_cls in self._handlers:
|
|
87
|
-
tiled = self._handlers[op_cls](op)
|
|
88
|
-
else:
|
|
89
|
-
try:
|
|
90
|
-
tiled = op_cls.tile(op)
|
|
91
|
-
except NotImplementedError as ex:
|
|
92
|
-
cause = ex
|
|
93
|
-
for super_cls in op_cls.__mro__:
|
|
94
|
-
if super_cls in self._handlers:
|
|
95
|
-
h = self._handlers[op_cls] = self._handlers[super_cls]
|
|
96
|
-
tiled = h(op)
|
|
97
|
-
break
|
|
98
|
-
|
|
99
|
-
if tiled is not None:
|
|
100
|
-
return tiled if isinstance(tiled, list) else [tiled]
|
|
101
|
-
else:
|
|
102
|
-
raise NotImplementedError(f"{type(op)} does not support tile") from cause
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
handler = OperatorTilesHandler()
|
|
106
|
-
register = OperatorTilesHandler.register
|
|
107
|
-
unregister = OperatorTilesHandler.unregister
|
|
108
|
-
|
|
109
|
-
|
|
110
35
|
class _ChunksIndexer:
|
|
111
36
|
__slots__ = ("_tileable",)
|
|
112
37
|
|
|
@@ -231,7 +156,7 @@ entity_view_handler = EntityDataModificationHandler()
|
|
|
231
156
|
|
|
232
157
|
|
|
233
158
|
class TileableData(EntityData, _ExecutableMixin):
|
|
234
|
-
__slots__ = "_cix", "_entities", "_executed_sessions"
|
|
159
|
+
__slots__ = "_chunks", "_cix", "_entities", "_executed_sessions"
|
|
235
160
|
_no_copy_attrs_ = Base._no_copy_attrs_ | {"_cix"}
|
|
236
161
|
|
|
237
162
|
# optional fields
|
|
@@ -245,6 +170,8 @@ class TileableData(EntityData, _ExecutableMixin):
|
|
|
245
170
|
cache = BoolField("cache", default=False)
|
|
246
171
|
|
|
247
172
|
def __init__(self: TileableType, *args, **kwargs):
|
|
173
|
+
if kwargs.get("chunks") is not None:
|
|
174
|
+
self._chunks = kwargs.pop("chunks")
|
|
248
175
|
if kwargs.get("_nsplits", None) is not None:
|
|
249
176
|
kwargs["_nsplits"] = tuple(tuple(s) for s in kwargs["_nsplits"])
|
|
250
177
|
|
|
@@ -270,7 +197,7 @@ class TileableData(EntityData, _ExecutableMixin):
|
|
|
270
197
|
return tuple(map(len, self._nsplits))
|
|
271
198
|
|
|
272
199
|
@property
|
|
273
|
-
def chunks(self) ->
|
|
200
|
+
def chunks(self) -> list:
|
|
274
201
|
return getattr(self, "_chunks", None)
|
|
275
202
|
|
|
276
203
|
@property
|
maxframe/core/graph/__init__.py
CHANGED
|
@@ -12,6 +12,6 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from .builder import
|
|
15
|
+
from .builder import TileableGraphBuilder
|
|
16
16
|
from .core import DAG, DirectedGraph, GraphContainsCycleError
|
|
17
|
-
from .entity import
|
|
17
|
+
from .entity import EntityGraph, GraphSerializer, TileableGraph
|
|
@@ -14,10 +14,10 @@
|
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
from abc import ABC, abstractmethod
|
|
17
|
-
from typing import Generator, List, Set
|
|
17
|
+
from typing import Generator, List, Set
|
|
18
18
|
|
|
19
19
|
from ....typing_ import EntityType
|
|
20
|
-
from ..entity import
|
|
20
|
+
from ..entity import EntityGraph
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
def _default_inputs_selector(inputs: List[EntityType]) -> List[EntityType]:
|
|
@@ -43,7 +43,7 @@ class AbstractGraphBuilder(ABC):
|
|
|
43
43
|
|
|
44
44
|
def _add_nodes(
|
|
45
45
|
self,
|
|
46
|
-
graph:
|
|
46
|
+
graph: EntityGraph,
|
|
47
47
|
nodes: List[EntityType],
|
|
48
48
|
visited: Set,
|
|
49
49
|
):
|
|
@@ -75,7 +75,7 @@ class AbstractGraphBuilder(ABC):
|
|
|
75
75
|
nodes.append(out)
|
|
76
76
|
|
|
77
77
|
@abstractmethod
|
|
78
|
-
def build(self) -> Generator[
|
|
78
|
+
def build(self) -> Generator[EntityGraph, None, None]:
|
|
79
79
|
"""
|
|
80
80
|
Build a entity graph.
|
|
81
81
|
|
|
@@ -84,3 +84,4 @@ class AbstractGraphBuilder(ABC):
|
|
|
84
84
|
graph : EntityGraph
|
|
85
85
|
Entity graph.
|
|
86
86
|
"""
|
|
87
|
+
raise NotImplementedError
|