maxframe 1.0.0rc2__cp311-cp311-win_amd64.whl → 1.0.0rc4__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of maxframe might be problematic. Click here for more details.

Files changed (134) hide show
  1. maxframe/_utils.cp311-win_amd64.pyd +0 -0
  2. maxframe/codegen.py +4 -2
  3. maxframe/config/config.py +28 -9
  4. maxframe/config/validators.py +42 -12
  5. maxframe/conftest.py +56 -14
  6. maxframe/core/__init__.py +2 -13
  7. maxframe/core/entity/__init__.py +0 -4
  8. maxframe/core/entity/executable.py +1 -1
  9. maxframe/core/entity/objects.py +45 -2
  10. maxframe/core/entity/output_types.py +0 -3
  11. maxframe/core/entity/tests/test_objects.py +43 -0
  12. maxframe/core/entity/tileables.py +5 -78
  13. maxframe/core/graph/__init__.py +2 -2
  14. maxframe/core/graph/builder/__init__.py +0 -1
  15. maxframe/core/graph/builder/base.py +5 -4
  16. maxframe/core/graph/builder/tileable.py +4 -4
  17. maxframe/core/graph/builder/utils.py +4 -8
  18. maxframe/core/graph/core.cp311-win_amd64.pyd +0 -0
  19. maxframe/core/graph/entity.py +9 -33
  20. maxframe/core/operator/__init__.py +2 -9
  21. maxframe/core/operator/base.py +3 -5
  22. maxframe/core/operator/objects.py +0 -9
  23. maxframe/core/operator/utils.py +55 -0
  24. maxframe/dataframe/arithmetic/docstring.py +26 -2
  25. maxframe/dataframe/arithmetic/equal.py +4 -2
  26. maxframe/dataframe/arithmetic/greater.py +4 -2
  27. maxframe/dataframe/arithmetic/greater_equal.py +4 -2
  28. maxframe/dataframe/arithmetic/less.py +2 -2
  29. maxframe/dataframe/arithmetic/less_equal.py +4 -2
  30. maxframe/dataframe/arithmetic/not_equal.py +4 -2
  31. maxframe/dataframe/core.py +2 -0
  32. maxframe/dataframe/datasource/read_odps_query.py +67 -8
  33. maxframe/dataframe/datasource/read_odps_table.py +4 -2
  34. maxframe/dataframe/datasource/tests/test_datasource.py +35 -6
  35. maxframe/dataframe/datastore/to_odps.py +8 -1
  36. maxframe/dataframe/extensions/__init__.py +3 -0
  37. maxframe/dataframe/extensions/flatmap.py +326 -0
  38. maxframe/dataframe/extensions/tests/test_extensions.py +62 -1
  39. maxframe/dataframe/indexing/add_prefix_suffix.py +1 -1
  40. maxframe/dataframe/indexing/rename.py +11 -0
  41. maxframe/dataframe/initializer.py +11 -1
  42. maxframe/dataframe/misc/drop_duplicates.py +18 -1
  43. maxframe/dataframe/operators.py +1 -17
  44. maxframe/dataframe/reduction/core.py +2 -2
  45. maxframe/dataframe/tests/test_initializer.py +33 -2
  46. maxframe/io/objects/__init__.py +24 -0
  47. maxframe/io/objects/core.py +140 -0
  48. maxframe/io/objects/tensor.py +76 -0
  49. maxframe/io/objects/tests/__init__.py +13 -0
  50. maxframe/io/objects/tests/test_object_io.py +97 -0
  51. maxframe/{odpsio → io/odpsio}/__init__.py +2 -0
  52. maxframe/{odpsio → io/odpsio}/arrow.py +4 -4
  53. maxframe/{odpsio → io/odpsio}/schema.py +10 -8
  54. maxframe/{odpsio → io/odpsio}/tableio.py +50 -38
  55. maxframe/io/odpsio/tests/__init__.py +13 -0
  56. maxframe/{odpsio → io/odpsio}/tests/test_schema.py +3 -7
  57. maxframe/{odpsio → io/odpsio}/tests/test_tableio.py +3 -3
  58. maxframe/{odpsio → io/odpsio}/tests/test_volumeio.py +4 -6
  59. maxframe/io/odpsio/volumeio.py +63 -0
  60. maxframe/learn/contrib/__init__.py +2 -1
  61. maxframe/learn/contrib/graph/__init__.py +15 -0
  62. maxframe/learn/contrib/graph/connected_components.py +215 -0
  63. maxframe/learn/contrib/graph/tests/__init__.py +13 -0
  64. maxframe/learn/contrib/graph/tests/test_connected_components.py +53 -0
  65. maxframe/learn/contrib/xgboost/classifier.py +26 -2
  66. maxframe/learn/contrib/xgboost/core.py +87 -2
  67. maxframe/learn/contrib/xgboost/dmatrix.py +1 -4
  68. maxframe/learn/contrib/xgboost/predict.py +27 -44
  69. maxframe/learn/contrib/xgboost/regressor.py +3 -10
  70. maxframe/learn/contrib/xgboost/train.py +27 -16
  71. maxframe/{core/operator/fuse.py → learn/core.py} +7 -10
  72. maxframe/lib/mmh3.cp311-win_amd64.pyd +0 -0
  73. maxframe/opcodes.py +3 -0
  74. maxframe/protocol.py +7 -16
  75. maxframe/remote/core.py +4 -8
  76. maxframe/serialization/__init__.py +1 -0
  77. maxframe/serialization/core.cp311-win_amd64.pyd +0 -0
  78. maxframe/session.py +9 -2
  79. maxframe/tensor/__init__.py +10 -2
  80. maxframe/tensor/arithmetic/isclose.py +1 -0
  81. maxframe/tensor/arithmetic/tests/test_arithmetic.py +21 -17
  82. maxframe/tensor/core.py +5 -136
  83. maxframe/tensor/datasource/array.py +3 -0
  84. maxframe/tensor/datasource/full.py +1 -1
  85. maxframe/tensor/datasource/tests/test_datasource.py +1 -1
  86. maxframe/tensor/indexing/flatnonzero.py +1 -1
  87. maxframe/tensor/indexing/getitem.py +2 -0
  88. maxframe/tensor/merge/__init__.py +2 -0
  89. maxframe/tensor/merge/concatenate.py +101 -0
  90. maxframe/tensor/merge/tests/test_merge.py +30 -1
  91. maxframe/tensor/merge/vstack.py +74 -0
  92. maxframe/tensor/{base → misc}/__init__.py +2 -0
  93. maxframe/tensor/{base → misc}/atleast_1d.py +0 -2
  94. maxframe/tensor/misc/atleast_2d.py +70 -0
  95. maxframe/tensor/misc/atleast_3d.py +85 -0
  96. maxframe/tensor/misc/tests/__init__.py +13 -0
  97. maxframe/tensor/{base → misc}/transpose.py +22 -18
  98. maxframe/tensor/operators.py +1 -7
  99. maxframe/tensor/random/core.py +1 -1
  100. maxframe/tensor/reduction/count_nonzero.py +1 -0
  101. maxframe/tensor/reduction/mean.py +1 -0
  102. maxframe/tensor/reduction/nanmean.py +1 -0
  103. maxframe/tensor/reduction/nanvar.py +2 -0
  104. maxframe/tensor/reduction/tests/test_reduction.py +12 -1
  105. maxframe/tensor/reduction/var.py +2 -0
  106. maxframe/tensor/utils.py +2 -22
  107. maxframe/typing_.py +4 -1
  108. maxframe/udf.py +8 -9
  109. maxframe/utils.py +49 -73
  110. maxframe-1.0.0rc4.dist-info/METADATA +104 -0
  111. {maxframe-1.0.0rc2.dist-info → maxframe-1.0.0rc4.dist-info}/RECORD +129 -114
  112. {maxframe-1.0.0rc2.dist-info → maxframe-1.0.0rc4.dist-info}/WHEEL +1 -1
  113. maxframe_client/fetcher.py +33 -50
  114. maxframe_client/session/consts.py +3 -0
  115. maxframe_client/session/graph.py +8 -2
  116. maxframe_client/session/odps.py +134 -27
  117. maxframe_client/session/task.py +58 -20
  118. maxframe_client/tests/test_fetcher.py +1 -1
  119. maxframe_client/tests/test_session.py +27 -3
  120. maxframe/core/entity/chunks.py +0 -68
  121. maxframe/core/entity/fuse.py +0 -73
  122. maxframe/core/graph/builder/chunk.py +0 -430
  123. maxframe/odpsio/volumeio.py +0 -95
  124. maxframe-1.0.0rc2.dist-info/METADATA +0 -177
  125. /maxframe/{odpsio → core/entity}/tests/__init__.py +0 -0
  126. /maxframe/{tensor/base/tests → io}/__init__.py +0 -0
  127. /maxframe/{odpsio → io/odpsio}/tests/test_arrow.py +0 -0
  128. /maxframe/tensor/{base → misc}/astype.py +0 -0
  129. /maxframe/tensor/{base → misc}/broadcast_to.py +0 -0
  130. /maxframe/tensor/{base → misc}/ravel.py +0 -0
  131. /maxframe/tensor/{base/tests/test_base.py → misc/tests/test_misc.py} +0 -0
  132. /maxframe/tensor/{base → misc}/unique.py +0 -0
  133. /maxframe/tensor/{base → misc}/where.py +0 -0
  134. {maxframe-1.0.0rc2.dist-info → maxframe-1.0.0rc4.dist-info}/top_level.txt +0 -0
Binary file
maxframe/codegen.py CHANGED
@@ -26,9 +26,9 @@ from odps.utils import camel_to_underline
26
26
  from .core import OperatorType, Tileable, TileableGraph
27
27
  from .core.operator import Fetch
28
28
  from .extension import iter_extensions
29
+ from .io.odpsio import build_dataframe_table_meta
30
+ from .io.odpsio.schema import pandas_to_odps_schema
29
31
  from .lib import wrapped_pickle as pickle
30
- from .odpsio import build_dataframe_table_meta
31
- from .odpsio.schema import pandas_to_odps_schema
32
32
  from .protocol import DataFrameTableMeta, ResultInfo
33
33
  from .serialization import PickleContainer
34
34
  from .serialization.serializables import Serializable, StringField
@@ -347,6 +347,7 @@ BUILTIN_ENGINE_SPE = "SPE"
347
347
  BUILTIN_ENGINE_MCSQL = "MCSQL"
348
348
 
349
349
  FAST_RANGE_INDEX_ENABLED = "codegen.fast_range_index_enabled"
350
+ ROW_NUMBER_WINDOW_INDEX_ENABLED = "codegen.row_number_window_index_enabled"
350
351
 
351
352
 
352
353
  class BigDagCodeGenerator(metaclass=abc.ABCMeta):
@@ -506,6 +507,7 @@ class BigDagCodeGenerator(metaclass=abc.ABCMeta):
506
507
  prefer_binary=pack.prefer_binary,
507
508
  pre_release=pack.pre_release,
508
509
  force_rebuild=pack.force_rebuild,
510
+ no_audit_wheel=pack.no_audit_wheel,
509
511
  python_tag=python_tag,
510
512
  is_production=is_production,
511
513
  schedule_id=schedule_id,
maxframe/config/config.py CHANGED
@@ -32,7 +32,6 @@ from ..utils import get_python_tag
32
32
  from .validators import (
33
33
  ValidatorType,
34
34
  all_validator,
35
- any_validator,
36
35
  is_bool,
37
36
  is_dict,
38
37
  is_in,
@@ -41,6 +40,7 @@ from .validators import (
41
40
  is_null,
42
41
  is_numeric,
43
42
  is_string,
43
+ is_valid_cache_path,
44
44
  )
45
45
 
46
46
  _DEFAULT_REDIRECT_WARN = "Option {source} has been replaced by {target} and might be removed in a future release."
@@ -327,7 +327,7 @@ default_options.register_option(
327
327
  default_options.register_option(
328
328
  "local_timezone",
329
329
  _get_legal_local_tz_name(),
330
- validator=any_validator(is_null, is_in(set(available_timezones()))),
330
+ validator=is_null | is_in(set(available_timezones())),
331
331
  remote=True,
332
332
  )
333
333
  default_options.register_option(
@@ -343,11 +343,25 @@ default_options.register_option("sql.enable_mcqa", True, validator=is_bool, remo
343
343
  default_options.register_option(
344
344
  "sql.generate_comments", True, validator=is_bool, remote=True
345
345
  )
346
+ default_options.register_option(
347
+ "sql.auto_use_common_image", True, validator=is_bool, remote=True
348
+ )
346
349
  default_options.register_option("sql.settings", {}, validator=is_dict, remote=True)
347
350
 
348
351
  default_options.register_option("is_production", False, validator=is_bool, remote=True)
349
352
  default_options.register_option("schedule_id", "", validator=is_string, remote=True)
350
353
 
354
+ default_options.register_option(
355
+ "service_role_arn", None, validator=is_null | is_string, remote=True
356
+ )
357
+ default_options.register_option(
358
+ "object_cache_url", None, validator=is_null | is_valid_cache_path, remote=True
359
+ )
360
+
361
+ default_options.register_option(
362
+ "chunk_size", None, validator=is_null | is_integer, remote=True
363
+ )
364
+
351
365
  default_options.register_option(
352
366
  "session.max_alive_seconds",
353
367
  _DEFAULT_MAX_ALIVE_SECONDS,
@@ -360,15 +374,22 @@ default_options.register_option(
360
374
  validator=is_numeric,
361
375
  remote=True,
362
376
  )
377
+ default_options.register_option(
378
+ "session.quota_name", None, validator=is_null | is_string, remote=True
379
+ )
380
+ default_options.register_option(
381
+ "session.enable_schema", None, validator=is_null | is_bool, remote=True
382
+ )
383
+ default_options.register_option(
384
+ "session.default_schema", None, validator=is_null | is_string, remote=True
385
+ )
363
386
  default_options.register_option(
364
387
  "session.upload_batch_size",
365
388
  _DEFAULT_UPLOAD_BATCH_SIZE,
366
389
  validator=is_integer,
367
390
  )
368
391
  default_options.register_option(
369
- "session.table_lifecycle",
370
- None,
371
- validator=any_validator(is_null, is_integer),
392
+ "session.table_lifecycle", None, validator=is_null | is_integer, remote=True
372
393
  )
373
394
  default_options.register_option(
374
395
  "session.temp_table_lifecycle",
@@ -379,7 +400,7 @@ default_options.register_option(
379
400
  default_options.register_option(
380
401
  "session.subinstance_priority",
381
402
  None,
382
- validator=any_validator(is_null, is_integer),
403
+ validator=is_null | is_integer,
383
404
  remote=True,
384
405
  )
385
406
 
@@ -391,9 +412,7 @@ default_options.register_option(
391
412
  default_options.register_option(
392
413
  "optimize.head_optimize_threshold", 1000, validator=is_integer
393
414
  )
394
- default_options.register_option(
395
- "show_progress", "auto", validator=any_validator(is_bool, is_string)
396
- )
415
+ default_options.register_option("show_progress", "auto", validator=is_bool | is_string)
397
416
  default_options.register_option(
398
417
  "dag.settings", value=dict(), validator=is_dict, remote=True
399
418
  )
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from typing import Callable
16
+ from urllib.parse import urlparse
16
17
 
17
18
  ValidatorType = Callable[..., bool]
18
19
 
@@ -32,22 +33,51 @@ def all_validator(*validators: ValidatorType):
32
33
  return validate
33
34
 
34
35
 
35
- is_null = lambda x: x is None
36
- is_bool = lambda x: isinstance(x, bool)
37
- is_float = lambda x: isinstance(x, float)
38
- is_integer = lambda x: isinstance(x, int)
39
- is_numeric = lambda x: isinstance(x, (int, float))
40
- is_string = lambda x: isinstance(x, str)
41
- is_dict = lambda x: isinstance(x, dict)
42
- is_positive_integer = lambda x: is_integer(x) and x > 0
43
- is_non_negative_integer = lambda x: is_integer(x) and x >= 0
36
+ class Validator:
37
+ def __init__(self, func: ValidatorType):
38
+ self._func = func
39
+
40
+ def __call__(self, arg) -> bool:
41
+ return self._func(arg)
42
+
43
+ def __or__(self, other):
44
+ return OrValidator(self, other)
45
+
46
+
47
+ class OrValidator(Validator):
48
+ def __init__(self, lhs: Validator, rhs: Validator):
49
+ super().__init__(lambda x: lhs(x) or rhs(x))
50
+
51
+
52
+ is_null = Validator(lambda x: x is None)
53
+ is_bool = Validator(lambda x: isinstance(x, bool))
54
+ is_float = Validator(lambda x: isinstance(x, float))
55
+ is_integer = Validator(lambda x: isinstance(x, int))
56
+ is_numeric = Validator(lambda x: isinstance(x, (int, float)))
57
+ is_string = Validator(lambda x: isinstance(x, str))
58
+ is_dict = Validator(lambda x: isinstance(x, dict))
59
+ is_positive_integer = Validator(lambda x: is_integer(x) and x > 0)
60
+ is_non_negative_integer = Validator(lambda x: is_integer(x) and x >= 0)
44
61
 
45
62
 
46
63
  def is_in(vals):
47
- def validate(x):
48
- return x in vals
64
+ return Validator(vals.__contains__)
49
65
 
50
- return validate
66
+
67
+ def _is_valid_cache_path(path: str) -> bool:
68
+ """
69
+ path should look like oss://oss_endpoint/oss_bucket/path
70
+ """
71
+ parsed_url = urlparse(path)
72
+ return (
73
+ parsed_url.scheme == "oss"
74
+ and parsed_url.netloc
75
+ and parsed_url.path
76
+ and "/" in parsed_url.path
77
+ )
78
+
79
+
80
+ is_valid_cache_path = Validator(_is_valid_cache_path)
51
81
 
52
82
 
53
83
  _invalid_char_in_yaml_str = {'"', "'", "\n", "\\"}
maxframe/conftest.py CHANGED
@@ -14,10 +14,13 @@
14
14
 
15
15
  import faulthandler
16
16
  import os
17
- from configparser import ConfigParser, NoOptionError
17
+ from configparser import ConfigParser, NoOptionError, NoSectionError
18
18
 
19
19
  import pytest
20
20
  from odps import ODPS
21
+ from odps.accounts import BearerTokenAccount
22
+
23
+ from .config import options
21
24
 
22
25
  faulthandler.enable(all_threads=True)
23
26
  _test_conf_file_name = os.path.join(
@@ -32,12 +35,23 @@ def test_config():
32
35
  return config
33
36
 
34
37
 
35
- @pytest.fixture(scope="session", autouse=True)
36
- def odps_envs(test_config):
37
- access_id = test_config.get("odps", "access_id")
38
- secret_access_key = test_config.get("odps", "secret_access_key")
39
- project = test_config.get("odps", "project")
40
- endpoint = test_config.get("odps", "endpoint")
38
+ def _get_odps_env(test_config: ConfigParser, section_name: str) -> ODPS:
39
+ try:
40
+ access_id = test_config.get(section_name, "access_id")
41
+ except NoOptionError:
42
+ access_id = test_config.get("odps", "access_id")
43
+ try:
44
+ secret_access_key = test_config.get(section_name, "secret_access_key")
45
+ except NoOptionError:
46
+ secret_access_key = test_config.get("odps", "secret_access_key")
47
+ try:
48
+ project = test_config.get(section_name, "project")
49
+ except NoOptionError:
50
+ project = test_config.get("odps", "project")
51
+ try:
52
+ endpoint = test_config.get(section_name, "endpoint")
53
+ except NoOptionError:
54
+ endpoint = test_config.get("odps", "endpoint")
41
55
  try:
42
56
  tunnel_endpoint = test_config.get("odps", "tunnel_endpoint")
43
57
  except NoOptionError:
@@ -53,12 +67,31 @@ def odps_envs(test_config):
53
67
  ],
54
68
  }
55
69
  token = entry.get_project().generate_auth_token(policy, "bearer", 5)
70
+ return ODPS(
71
+ account=BearerTokenAccount(token, 5),
72
+ project=project,
73
+ endpoint=endpoint,
74
+ tunnel_endpoint=tunnel_endpoint,
75
+ )
56
76
 
57
- os.environ["ODPS_BEARER_TOKEN"] = token
58
- os.environ["ODPS_PROJECT_NAME"] = project
59
- os.environ["ODPS_ENDPOINT"] = endpoint
60
- if tunnel_endpoint:
61
- os.environ["ODPS_TUNNEL_ENDPOINT"] = tunnel_endpoint
77
+
78
+ @pytest.fixture(scope="session")
79
+ def odps_with_schema(test_config):
80
+ try:
81
+ return _get_odps_env(test_config, "odps_with_schema")
82
+ except NoSectionError:
83
+ pytest.skip("Need to specify odps_with_schema section in test.conf")
84
+
85
+
86
+ @pytest.fixture(scope="session", autouse=True)
87
+ def odps_envs(test_config):
88
+ entry = _get_odps_env(test_config, "odps")
89
+
90
+ os.environ["ODPS_BEARER_TOKEN"] = entry.account.token
91
+ os.environ["ODPS_PROJECT_NAME"] = entry.project
92
+ os.environ["ODPS_ENDPOINT"] = entry.endpoint
93
+ if entry.tunnel_endpoint:
94
+ os.environ["ODPS_TUNNEL_ENDPOINT"] = entry.tunnel_endpoint
62
95
 
63
96
  try:
64
97
  yield
@@ -77,11 +110,14 @@ def odps_envs(test_config):
77
110
  pass
78
111
 
79
112
 
80
- @pytest.fixture
113
+ @pytest.fixture(scope="session")
81
114
  def oss_config():
82
115
  config = ConfigParser()
83
116
  config.read(_test_conf_file_name)
84
117
 
118
+ old_role_arn = options.service_role_arn
119
+ old_cache_url = options.object_cache_url
120
+
85
121
  try:
86
122
  oss_access_id = config.get("oss", "access_id")
87
123
  oss_secret_access_key = config.get("oss", "secret_access_key")
@@ -89,6 +125,9 @@ def oss_config():
89
125
  oss_endpoint = config.get("oss", "endpoint")
90
126
  oss_rolearn = config.get("oss", "rolearn")
91
127
 
128
+ options.service_role_arn = oss_rolearn
129
+ options.object_cache_url = f"oss://{oss_endpoint}/{oss_bucket_name}"
130
+
92
131
  config.oss_config = (
93
132
  oss_access_id,
94
133
  oss_secret_access_key,
@@ -101,9 +140,12 @@ def oss_config():
101
140
  auth = oss2.Auth(oss_access_id, oss_secret_access_key)
102
141
  config.oss_bucket = oss2.Bucket(auth, oss_endpoint, oss_bucket_name)
103
142
  config.oss_rolearn = oss_rolearn
104
- return config
143
+ yield config
105
144
  except (ConfigParser.NoSectionError, ConfigParser.NoOptionError, ImportError):
106
145
  return None
146
+ finally:
147
+ options.service_role_arn = old_role_arn
148
+ options.object_cache_url = old_cache_url
107
149
 
108
150
 
109
151
  @pytest.fixture(autouse=True)
maxframe/core/__init__.py CHANGED
@@ -14,20 +14,14 @@
14
14
 
15
15
  # noinspection PyUnresolvedReferences
16
16
  from ..typing_ import ChunkType, EntityType, OperatorType, TileableType
17
- from .base import ExecutionError
17
+ from .base import Base, ExecutionError
18
18
  from .entity import (
19
- CHUNK_TYPE,
20
19
  ENTITY_TYPE,
21
- FUSE_CHUNK_TYPE,
22
20
  OBJECT_TYPE,
23
21
  TILEABLE_TYPE,
24
- Chunk,
25
- ChunkData,
26
22
  Entity,
27
23
  EntityData,
28
24
  ExecutableTuple,
29
- FuseChunk,
30
- FuseChunkData,
31
25
  HasShapeTileable,
32
26
  HasShapeTileableData,
33
27
  NotSupportTile,
@@ -40,23 +34,18 @@ from .entity import (
40
34
  get_fetch_class,
41
35
  get_output_types,
42
36
  get_tileable_types,
43
- register,
44
37
  register_fetch_class,
45
38
  register_output_types,
46
- unregister,
47
39
  )
48
40
 
49
41
  # noinspection PyUnresolvedReferences
50
42
  from .graph import (
51
43
  DAG,
52
- ChunkGraph,
53
- ChunkGraphBuilder,
54
44
  DirectedGraph,
55
45
  GraphContainsCycleError,
56
46
  GraphSerializer,
57
47
  TileableGraph,
58
48
  TileableGraphBuilder,
59
- TileContext,
60
- TileStatus,
61
49
  )
62
50
  from .mode import enter_mode, is_build_mode, is_eager_mode, is_kernel_mode
51
+ from .operator import build_fetch
@@ -12,10 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .chunks import CHUNK_TYPE, Chunk, ChunkData
16
15
  from .core import ENTITY_TYPE, Entity, EntityData
17
16
  from .executable import ExecutableTuple, _ExecuteAndFetchMixin
18
- from .fuse import FUSE_CHUNK_TYPE, FuseChunk, FuseChunkData
19
17
  from .objects import OBJECT_TYPE, Object, ObjectData
20
18
  from .output_types import (
21
19
  OutputType,
@@ -32,6 +30,4 @@ from .tileables import (
32
30
  NotSupportTile,
33
31
  Tileable,
34
32
  TileableData,
35
- register,
36
- unregister,
37
33
  )
@@ -46,7 +46,7 @@ class DecrefRunner:
46
46
  break
47
47
 
48
48
  session = session_ref()
49
- if session is None:
49
+ if session is None or session.closed:
50
50
  fut.set_result(None)
51
51
  continue
52
52
  try:
@@ -12,8 +12,10 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict
15
+ from typing import Any, Dict, Type
16
16
 
17
+ from ...serialization import load_type
18
+ from ...serialization.serializables import StringField
17
19
  from .core import Entity
18
20
  from .executable import _ToObjectMixin
19
21
  from .tileables import TileableData
@@ -25,9 +27,42 @@ class ObjectData(TileableData, _ToObjectMixin):
25
27
  # workaround for removed field since v0.1.0b5
26
28
  # todo remove this when all versions below v1.0.0rc1 is eliminated
27
29
  _legacy_deprecated_non_primitives = ["_chunks"]
30
+ _legacy_new_non_primitives = ["object_class"]
31
+
32
+ object_class = StringField("object_class", default=None)
33
+
34
+ @classmethod
35
+ def get_entity_class(cls) -> Type["Object"]:
36
+ if getattr(cls, "_entity_class", None) is not None:
37
+ return cls._entity_class
38
+ assert cls.__qualname__[-4:] == "Data"
39
+ target_class_name = cls.__module__ + "#" + cls.__qualname__[:-4]
40
+ cls._entity_class = load_type(target_class_name, Object)
41
+ return cls._entity_class
42
+
43
+ def __new__(cls, op=None, nsplits=None, **kw):
44
+ if cls is ObjectData:
45
+ obj_cls = kw.get("object_class")
46
+ if isinstance(obj_cls, str):
47
+ obj_cls = load_type(obj_cls, (Object, ObjectData))
48
+ if isinstance(obj_cls, type) and issubclass(obj_cls, Object):
49
+ obj_cls = obj_cls.get_data_class()
50
+
51
+ if obj_cls is not None and cls is not obj_cls:
52
+ return obj_cls(op=op, nsplits=nsplits, **kw)
53
+ return super().__new__(cls)
28
54
 
29
55
  def __init__(self, op=None, nsplits=None, **kw):
56
+ obj_cls = kw.pop("object_class", None)
57
+ if isinstance(obj_cls, type):
58
+ if isinstance(obj_cls, type) and issubclass(obj_cls, Object):
59
+ obj_cls = obj_cls.get_data_class()
60
+ kw["object_class"] = obj_cls.__module__ + "#" + obj_cls.__qualname__
61
+
30
62
  super().__init__(_op=op, _nsplits=nsplits, **kw)
63
+ if self.object_class is None and type(self) is not ObjectData:
64
+ cls = type(self)
65
+ self.object_class = cls.__module__ + "#" + cls.__qualname__
31
66
 
32
67
  def __repr__(self):
33
68
  return f"Object <op={type(self.op).__name__}, key={self.key}>"
@@ -35,7 +70,7 @@ class ObjectData(TileableData, _ToObjectMixin):
35
70
  @property
36
71
  def params(self):
37
72
  # params return the properties which useful to rebuild a new tileable object
38
- return dict()
73
+ return dict(object_class=self.object_class)
39
74
 
40
75
  @params.setter
41
76
  def params(self, new_params: Dict[str, Any]):
@@ -54,5 +89,13 @@ class Object(Entity, _ToObjectMixin):
54
89
  _allow_data_type_ = (ObjectData,)
55
90
  type_name = "Object"
56
91
 
92
+ @classmethod
93
+ def get_data_class(cls) -> Type[ObjectData]:
94
+ if getattr(cls, "_data_class", None) is not None:
95
+ return cls._data_class
96
+ target_class_name = cls.__module__ + "#" + cls.__qualname__ + "Data"
97
+ cls._data_class = load_type(target_class_name, ObjectData)
98
+ return cls._data_class
99
+
57
100
 
58
101
  OBJECT_TYPE = (Object, ObjectData)
@@ -15,7 +15,6 @@
15
15
  import functools
16
16
  from enum import Enum
17
17
 
18
- from .fuse import FUSE_CHUNK_TYPE
19
18
  from .objects import OBJECT_TYPE
20
19
 
21
20
 
@@ -77,8 +76,6 @@ def get_output_types(*objs, unknown_as=None):
77
76
  for obj in objs:
78
77
  if obj is None:
79
78
  continue
80
- elif isinstance(obj, FUSE_CHUNK_TYPE):
81
- obj = obj.chunk
82
79
 
83
80
  try:
84
81
  output_types.append(_get_output_type_by_cls(type(obj)))
@@ -0,0 +1,43 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ..objects import Object, ObjectData
16
+
17
+
18
+ class TestSubObjectData(ObjectData):
19
+ __test__ = False
20
+
21
+
22
+ class TestSubObject(Object):
23
+ __test__ = False
24
+
25
+
26
+ def test_object_init():
27
+ assert TestSubObjectData.get_entity_class() is TestSubObject
28
+
29
+ obj = ObjectData(
30
+ object_class=TestSubObjectData.__module__ + "#" + TestSubObjectData.__name__
31
+ )
32
+ assert isinstance(obj, TestSubObjectData)
33
+
34
+ obj = ObjectData(object_class=TestSubObjectData)
35
+ assert isinstance(obj, TestSubObjectData)
36
+
37
+ obj = ObjectData(
38
+ object_class=TestSubObject.__module__ + "#" + TestSubObject.__name__
39
+ )
40
+ assert isinstance(obj, TestSubObjectData)
41
+
42
+ obj = ObjectData(object_class=TestSubObject)
43
+ assert isinstance(obj, TestSubObjectData)
@@ -15,17 +15,15 @@
15
15
  import builtins
16
16
  import itertools
17
17
  from operator import attrgetter
18
- from typing import Callable, List
19
18
  from weakref import WeakKeyDictionary, WeakSet
20
19
 
21
20
  import numpy as np
22
21
 
23
22
  from ...serialization.serializables import BoolField, FieldTypes, TupleField
24
- from ...typing_ import OperatorType, TileableType
23
+ from ...typing_ import TileableType
25
24
  from ...utils import on_deserialize_shape, on_serialize_nsplits, on_serialize_shape
26
25
  from ..base import Base
27
26
  from ..mode import enter_mode
28
- from .chunks import Chunk
29
27
  from .core import Entity, EntityData
30
28
  from .executable import _ExecutableMixin
31
29
 
@@ -34,79 +32,6 @@ class NotSupportTile(Exception):
34
32
  pass
35
33
 
36
34
 
37
- class OperatorTilesHandler:
38
- _handlers = dict()
39
-
40
- @classmethod
41
- def _get_op_cls(cls, op: OperatorType):
42
- if isinstance(op, type):
43
- return op
44
- return type(op)
45
-
46
- @classmethod
47
- def register(
48
- cls, op: OperatorType, tile_handler: Callable[[OperatorType], TileableType]
49
- ):
50
- cls._handlers[cls._get_op_cls(op)] = tile_handler
51
-
52
- @classmethod
53
- def unregister(cls, op: OperatorType):
54
- del cls._handlers[cls._get_op_cls(op)]
55
-
56
- @classmethod
57
- def get_handler(
58
- cls, op: OperatorType
59
- ) -> Callable[[OperatorType], List[TileableType]]:
60
- op_cls = cls._get_op_cls(op)
61
- return cls._handlers.get(op_cls, op_cls.tile)
62
-
63
- @classmethod
64
- def _assign_to(
65
- cls,
66
- tile_after_tensor_datas: List["TileableData"],
67
- tile_before_tensor_datas: List["TileableData"],
68
- ):
69
- assert len(tile_after_tensor_datas) == len(tile_before_tensor_datas)
70
-
71
- for tile_after_tensor_data, tile_before_tensor_data in zip(
72
- tile_after_tensor_datas, tile_before_tensor_datas
73
- ):
74
- if tile_before_tensor_data is None:
75
- # garbage collected
76
- continue
77
- tile_after_tensor_data.copy_to(tile_before_tensor_data)
78
- tile_before_tensor_data.op.outputs = tile_before_tensor_datas
79
-
80
- @enter_mode(kernel=True)
81
- def dispatch(self, op: OperatorType):
82
- op_cls = self._get_op_cls(op)
83
- tiled = None
84
- cause = None
85
-
86
- if op_cls in self._handlers:
87
- tiled = self._handlers[op_cls](op)
88
- else:
89
- try:
90
- tiled = op_cls.tile(op)
91
- except NotImplementedError as ex:
92
- cause = ex
93
- for super_cls in op_cls.__mro__:
94
- if super_cls in self._handlers:
95
- h = self._handlers[op_cls] = self._handlers[super_cls]
96
- tiled = h(op)
97
- break
98
-
99
- if tiled is not None:
100
- return tiled if isinstance(tiled, list) else [tiled]
101
- else:
102
- raise NotImplementedError(f"{type(op)} does not support tile") from cause
103
-
104
-
105
- handler = OperatorTilesHandler()
106
- register = OperatorTilesHandler.register
107
- unregister = OperatorTilesHandler.unregister
108
-
109
-
110
35
  class _ChunksIndexer:
111
36
  __slots__ = ("_tileable",)
112
37
 
@@ -231,7 +156,7 @@ entity_view_handler = EntityDataModificationHandler()
231
156
 
232
157
 
233
158
  class TileableData(EntityData, _ExecutableMixin):
234
- __slots__ = "_cix", "_entities", "_executed_sessions"
159
+ __slots__ = "_chunks", "_cix", "_entities", "_executed_sessions"
235
160
  _no_copy_attrs_ = Base._no_copy_attrs_ | {"_cix"}
236
161
 
237
162
  # optional fields
@@ -245,6 +170,8 @@ class TileableData(EntityData, _ExecutableMixin):
245
170
  cache = BoolField("cache", default=False)
246
171
 
247
172
  def __init__(self: TileableType, *args, **kwargs):
173
+ if kwargs.get("chunks") is not None:
174
+ self._chunks = kwargs.pop("chunks")
248
175
  if kwargs.get("_nsplits", None) is not None:
249
176
  kwargs["_nsplits"] = tuple(tuple(s) for s in kwargs["_nsplits"])
250
177
 
@@ -270,7 +197,7 @@ class TileableData(EntityData, _ExecutableMixin):
270
197
  return tuple(map(len, self._nsplits))
271
198
 
272
199
  @property
273
- def chunks(self) -> List[Chunk]:
200
+ def chunks(self) -> list:
274
201
  return getattr(self, "_chunks", None)
275
202
 
276
203
  @property
@@ -12,6 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .builder import ChunkGraphBuilder, TileableGraphBuilder, TileContext, TileStatus
15
+ from .builder import TileableGraphBuilder
16
16
  from .core import DAG, DirectedGraph, GraphContainsCycleError
17
- from .entity import ChunkGraph, EntityGraph, GraphSerializer, TileableGraph
17
+ from .entity import EntityGraph, GraphSerializer, TileableGraph
@@ -12,5 +12,4 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .chunk import ChunkGraphBuilder, TileContext, TileStatus
16
15
  from .tileable import TileableGraphBuilder