maxframe 1.0.0rc2__cp310-cp310-win_amd64.whl → 1.0.0rc3__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of maxframe might be problematic. Click here for more details.

Files changed (106) hide show
  1. maxframe/_utils.cp310-win_amd64.pyd +0 -0
  2. maxframe/codegen.py +3 -2
  3. maxframe/config/config.py +16 -9
  4. maxframe/config/validators.py +42 -12
  5. maxframe/conftest.py +13 -2
  6. maxframe/core/__init__.py +2 -13
  7. maxframe/core/entity/__init__.py +0 -4
  8. maxframe/core/entity/objects.py +45 -2
  9. maxframe/core/entity/output_types.py +0 -3
  10. maxframe/core/entity/tests/test_objects.py +43 -0
  11. maxframe/core/entity/tileables.py +5 -78
  12. maxframe/core/graph/__init__.py +2 -2
  13. maxframe/core/graph/builder/__init__.py +0 -1
  14. maxframe/core/graph/builder/base.py +5 -4
  15. maxframe/core/graph/builder/tileable.py +4 -4
  16. maxframe/core/graph/builder/utils.py +4 -8
  17. maxframe/core/graph/core.cp310-win_amd64.pyd +0 -0
  18. maxframe/core/graph/entity.py +9 -33
  19. maxframe/core/operator/__init__.py +2 -9
  20. maxframe/core/operator/base.py +3 -5
  21. maxframe/core/operator/objects.py +0 -9
  22. maxframe/core/operator/utils.py +55 -0
  23. maxframe/dataframe/datasource/read_odps_query.py +1 -1
  24. maxframe/dataframe/datasource/read_odps_table.py +1 -1
  25. maxframe/dataframe/datastore/to_odps.py +1 -1
  26. maxframe/dataframe/operators.py +1 -17
  27. maxframe/dataframe/reduction/core.py +2 -2
  28. maxframe/io/objects/__init__.py +24 -0
  29. maxframe/io/objects/core.py +140 -0
  30. maxframe/io/objects/tensor.py +76 -0
  31. maxframe/io/objects/tests/__init__.py +13 -0
  32. maxframe/io/objects/tests/test_object_io.py +97 -0
  33. maxframe/{odpsio → io/odpsio}/__init__.py +2 -0
  34. maxframe/{odpsio → io/odpsio}/arrow.py +4 -4
  35. maxframe/{odpsio → io/odpsio}/schema.py +5 -5
  36. maxframe/{odpsio → io/odpsio}/tableio.py +10 -4
  37. maxframe/io/odpsio/tests/__init__.py +13 -0
  38. maxframe/{odpsio → io/odpsio}/tests/test_schema.py +3 -3
  39. maxframe/{odpsio → io/odpsio}/tests/test_tableio.py +3 -3
  40. maxframe/{odpsio → io/odpsio}/tests/test_volumeio.py +4 -6
  41. maxframe/io/odpsio/volumeio.py +57 -0
  42. maxframe/learn/contrib/xgboost/classifier.py +26 -2
  43. maxframe/learn/contrib/xgboost/core.py +87 -2
  44. maxframe/learn/contrib/xgboost/dmatrix.py +1 -4
  45. maxframe/learn/contrib/xgboost/predict.py +19 -5
  46. maxframe/learn/contrib/xgboost/regressor.py +3 -10
  47. maxframe/learn/contrib/xgboost/train.py +25 -15
  48. maxframe/{core/operator/fuse.py → learn/core.py} +7 -10
  49. maxframe/lib/mmh3.cp310-win_amd64.pyd +0 -0
  50. maxframe/protocol.py +1 -15
  51. maxframe/remote/core.py +4 -8
  52. maxframe/serialization/__init__.py +1 -0
  53. maxframe/serialization/core.cp310-win_amd64.pyd +0 -0
  54. maxframe/tensor/__init__.py +10 -2
  55. maxframe/tensor/arithmetic/isclose.py +1 -0
  56. maxframe/tensor/arithmetic/tests/test_arithmetic.py +21 -17
  57. maxframe/tensor/core.py +5 -136
  58. maxframe/tensor/datasource/array.py +3 -0
  59. maxframe/tensor/datasource/full.py +1 -1
  60. maxframe/tensor/datasource/tests/test_datasource.py +1 -1
  61. maxframe/tensor/indexing/flatnonzero.py +1 -1
  62. maxframe/tensor/merge/__init__.py +2 -0
  63. maxframe/tensor/merge/concatenate.py +98 -0
  64. maxframe/tensor/merge/tests/test_merge.py +30 -1
  65. maxframe/tensor/merge/vstack.py +70 -0
  66. maxframe/tensor/{base → misc}/__init__.py +2 -0
  67. maxframe/tensor/{base → misc}/atleast_1d.py +0 -2
  68. maxframe/tensor/misc/atleast_2d.py +70 -0
  69. maxframe/tensor/misc/atleast_3d.py +85 -0
  70. maxframe/tensor/misc/tests/__init__.py +13 -0
  71. maxframe/tensor/{base → misc}/transpose.py +22 -18
  72. maxframe/tensor/operators.py +1 -7
  73. maxframe/tensor/random/core.py +1 -1
  74. maxframe/tensor/reduction/count_nonzero.py +1 -0
  75. maxframe/tensor/reduction/mean.py +1 -0
  76. maxframe/tensor/reduction/nanmean.py +1 -0
  77. maxframe/tensor/reduction/nanvar.py +2 -0
  78. maxframe/tensor/reduction/tests/test_reduction.py +12 -1
  79. maxframe/tensor/reduction/var.py +2 -0
  80. maxframe/tensor/utils.py +2 -22
  81. maxframe/typing_.py +4 -1
  82. maxframe/udf.py +8 -9
  83. maxframe/utils.py +15 -61
  84. maxframe-1.0.0rc3.dist-info/METADATA +104 -0
  85. {maxframe-1.0.0rc2.dist-info → maxframe-1.0.0rc3.dist-info}/RECORD +101 -91
  86. {maxframe-1.0.0rc2.dist-info → maxframe-1.0.0rc3.dist-info}/WHEEL +1 -1
  87. maxframe_client/fetcher.py +23 -42
  88. maxframe_client/session/graph.py +8 -2
  89. maxframe_client/session/odps.py +54 -18
  90. maxframe_client/tests/test_fetcher.py +1 -1
  91. maxframe_client/tests/test_session.py +14 -2
  92. maxframe/core/entity/chunks.py +0 -68
  93. maxframe/core/entity/fuse.py +0 -73
  94. maxframe/core/graph/builder/chunk.py +0 -430
  95. maxframe/odpsio/volumeio.py +0 -95
  96. maxframe-1.0.0rc2.dist-info/METADATA +0 -177
  97. /maxframe/{odpsio → core/entity}/tests/__init__.py +0 -0
  98. /maxframe/{tensor/base/tests → io}/__init__.py +0 -0
  99. /maxframe/{odpsio → io/odpsio}/tests/test_arrow.py +0 -0
  100. /maxframe/tensor/{base → misc}/astype.py +0 -0
  101. /maxframe/tensor/{base → misc}/broadcast_to.py +0 -0
  102. /maxframe/tensor/{base → misc}/ravel.py +0 -0
  103. /maxframe/tensor/{base/tests/test_base.py → misc/tests/test_misc.py} +0 -0
  104. /maxframe/tensor/{base → misc}/unique.py +0 -0
  105. /maxframe/tensor/{base → misc}/where.py +0 -0
  106. {maxframe-1.0.0rc2.dist-info → maxframe-1.0.0rc3.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,7 @@ from odps import ODPS
22
22
 
23
23
  import maxframe.dataframe as md
24
24
  from maxframe.config import options
25
- from maxframe.odpsio import ODPSTableIO
25
+ from maxframe.io.odpsio import ODPSTableIO
26
26
  from maxframe.protocol import ODPSTableResultInfo, ResultType
27
27
  from maxframe.tests.utils import tn
28
28
 
@@ -247,7 +247,19 @@ def test_run_and_fetch_series(start_mock_session):
247
247
  )
248
248
 
249
249
 
250
- def test_run_remote_success(start_mock_session):
250
+ def test_execute_with_tensor(oss_config, start_mock_session):
251
+ pd_df = pd.DataFrame(
252
+ {"angles": [0, 3, 4], "degrees": [360, 180, 360]},
253
+ index=["circle", "triangle", "rectangle"],
254
+ )
255
+ df = md.DataFrame(pd_df)
256
+
257
+ result = (df - [1, 2]).execute().fetch()
258
+ expected = pd_df - [1, 2]
259
+ pd.testing.assert_frame_equal(result, expected)
260
+
261
+
262
+ def test_run_remote_success(oss_config, start_mock_session):
251
263
  def func(a, b):
252
264
  return a + b
253
265
 
@@ -258,7 +270,7 @@ def test_run_remote_success(start_mock_session):
258
270
  assert result == 21
259
271
 
260
272
 
261
- def test_run_remote_error(start_mock_session):
273
+ def test_run_remote_error(oss_config, start_mock_session):
262
274
  def func():
263
275
  raise ValueError
264
276
 
@@ -1,68 +0,0 @@
1
- # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from ...serialization.serializables import BoolField, FieldTypes, TupleField
16
- from ...utils import tokenize
17
- from .core import Entity, EntityData
18
-
19
-
20
- class ChunkData(EntityData):
21
- __slots__ = ()
22
-
23
- is_broadcaster = BoolField("is_broadcaster", default=False)
24
- # If the operator is a shuffle mapper, this flag indicates whether the current chunk is mapper chunk when
25
- # the operator produce multiple chunks such as TensorUnique.
26
- is_mapper = BoolField("is_mapper", default=None)
27
- # optional fields
28
- _index = TupleField("index", FieldTypes.uint32)
29
-
30
- def __repr__(self):
31
- if self.op.stage is None:
32
- return (
33
- f"{type(self).__name__} <op={type(self.op).__name__}, "
34
- f"key={self.key}>"
35
- )
36
- else:
37
- return (
38
- f"{type(self).__name__} <op={type(self.op).__name__}, "
39
- f"stage={self.op.stage.name}, key={self.key}>"
40
- )
41
-
42
- @property
43
- def index(self):
44
- return getattr(self, "_index", None)
45
-
46
- @property
47
- def device(self):
48
- return self.op.device
49
-
50
- def _update_key(self):
51
- object.__setattr__(
52
- self,
53
- "_key",
54
- tokenize(
55
- type(self).__name__,
56
- *(getattr(self, k, None) for k in self._keys_ if k != "_index"),
57
- ),
58
- )
59
-
60
-
61
- class Chunk(Entity):
62
- _allow_data_type_ = (ChunkData,)
63
-
64
- def __repr__(self):
65
- return f"{type(self).__name__}({self._data.__repr__()})"
66
-
67
-
68
- CHUNK_TYPE = (Chunk, ChunkData)
@@ -1,73 +0,0 @@
1
- # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import numpy as np
16
-
17
- from ...serialization.serializables import ReferenceField
18
- from .chunks import CHUNK_TYPE, Chunk, ChunkData
19
-
20
-
21
- class FuseChunkData(ChunkData):
22
- __slots__ = ("_inited",)
23
-
24
- _chunk = ReferenceField(
25
- "chunk", CHUNK_TYPE, on_serialize=lambda x: x.data if hasattr(x, "data") else x
26
- )
27
-
28
- def __init__(self, *args, **kwargs):
29
- self._inited = False
30
- super().__init__(*args, **kwargs)
31
- self._extra_params = {}
32
- self._inited = True
33
-
34
- @property
35
- def chunk(self):
36
- return self._chunk
37
-
38
- @property
39
- def composed(self):
40
- # for compatibility, just return the topological ordering,
41
- # once we apply optimization on the subgraph,
42
- # `composed` is not needed any more and should be removed then.
43
- assert getattr(self._op, "fuse_graph", None) is not None
44
- fuse_graph = self._op.fuse_graph
45
- return list(fuse_graph.topological_iter())
46
-
47
- def __getattr__(self, attr):
48
- if not self._inited:
49
- return object.__getattribute__(self, attr)
50
- if attr in self._extra_params:
51
- return self._extra_params[attr]
52
- try:
53
- return getattr(self._chunk, attr)
54
- except AttributeError:
55
- return object.__getattribute__(self, attr)
56
-
57
- def __setattr__(self, attr, value):
58
- if attr == "params":
59
- self._chunk.params = value
60
- else:
61
- super().__setattr__(attr, value)
62
-
63
- @property
64
- def nbytes(self):
65
- return np.prod(self.shape) * self.dtype.itemsize
66
-
67
-
68
- class FuseChunk(Chunk):
69
- __slots__ = ()
70
- _allow_data_type_ = (FuseChunkData,)
71
-
72
-
73
- FUSE_CHUNK_TYPE = (FuseChunkData, FuseChunk)
@@ -1,430 +0,0 @@
1
- # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import dataclasses
16
- import functools
17
- from typing import Callable, Dict, Generator, Iterable, List, Optional, Set, Type, Union
18
-
19
- from ....core import CHUNK_TYPE, FUSE_CHUNK_TYPE, TILEABLE_TYPE
20
- from ....typing_ import ChunkType, EntityType, TileableType
21
- from ....utils import build_fetch, copy_tileables
22
- from ...entity.tileables import handler
23
- from ...mode import enter_mode
24
- from ..entity import ChunkGraph, TileableGraph
25
- from .base import AbstractGraphBuilder
26
-
27
- tile_gen_type = Generator[List[ChunkType], List[ChunkType], List[TileableType]]
28
- DEFAULT_UPDATED_PROGRESS = 0.4
29
-
30
-
31
- @dataclasses.dataclass
32
- class _TileableHandler:
33
- tileable: TileableType
34
- handler: tile_gen_type
35
- last_need_processes: List[EntityType] = None
36
-
37
-
38
- @dataclasses.dataclass
39
- class _TileableTileInfo:
40
- curr_iter: int
41
- # incremental progress for this iteration
42
- tile_progress: float
43
- # newly generated chunks by a tileable in this iteration
44
- generated_chunks: List[ChunkType] = dataclasses.field(default_factory=list)
45
-
46
-
47
- class TileContext(Dict[TileableType, TileableType]):
48
- _tileables = Set[TileableType]
49
- _tileable_to_progress: Dict[TileableType, float]
50
- _tileable_to_tile_infos: Dict[TileableType, List[_TileableTileInfo]]
51
-
52
- def __init__(self, *args, **kw):
53
- super().__init__(*args, **kw)
54
- self._tileables = None
55
- self._tileable_to_progress = dict()
56
- self._tileable_to_tile_infos = dict()
57
-
58
- def set_tileables(self, tileables: Set[TileableType]):
59
- self._tileables = tileables
60
-
61
- def __setitem__(self, key, value):
62
- self._tileable_to_progress.pop(key, None)
63
- return super().__setitem__(key, value)
64
-
65
- def set_progress(self, tileable: TileableType, progress: float):
66
- assert 0.0 <= progress <= 1.0
67
- last_progress = self._tileable_to_progress.get(tileable, 0.0)
68
- self._tileable_to_progress[tileable] = max(progress, last_progress)
69
-
70
- def get_progress(self, tileable: TileableType) -> float:
71
- if tileable in self:
72
- return 1.0
73
- else:
74
- return self._tileable_to_progress.get(tileable, 0.0)
75
-
76
- def get_all_progress(self) -> float:
77
- return sum(self.get_progress(t) for t in self._tileables) / len(self._tileables)
78
-
79
- def record_tileable_tile_info(
80
- self, tileable: TileableType, curr_iter: int, generated_chunks: List[ChunkType]
81
- ):
82
- if tileable not in self._tileable_to_tile_infos:
83
- self._tileable_to_tile_infos[tileable] = []
84
- prev_progress = sum(
85
- info.tile_progress for info in self._tileable_to_tile_infos[tileable]
86
- )
87
- curr_progress = self.get_progress(tileable)
88
- infos = self._tileable_to_tile_infos[tileable]
89
- infos.append(
90
- _TileableTileInfo(
91
- curr_iter=curr_iter,
92
- tile_progress=curr_progress - prev_progress,
93
- generated_chunks=generated_chunks,
94
- )
95
- )
96
-
97
- def get_tileable_tile_infos(self) -> Dict[TileableType, List[_TileableTileInfo]]:
98
- return {t: self._tileable_to_tile_infos.get(t, list()) for t in self._tileables}
99
-
100
-
101
- @dataclasses.dataclass
102
- class TileStatus:
103
- entities: List[EntityType] = None
104
- progress: float = None
105
-
106
-
107
- class Tiler:
108
- _cur_iter: int
109
- _cur_chunk_graph: Optional[ChunkGraph]
110
- _tileable_handlers: Iterable[_TileableHandler]
111
-
112
- def __init__(
113
- self,
114
- tileable_graph: TileableGraph,
115
- tile_context: TileContext,
116
- processed_chunks: Set[str],
117
- chunk_to_fetch: Dict[ChunkType, ChunkType],
118
- add_nodes: Callable,
119
- ):
120
- self._tileable_graph = tileable_graph
121
- self._tile_context = tile_context
122
- self._processed_chunks = processed_chunks
123
- self._chunk_to_fetch = chunk_to_fetch
124
- self._add_nodes = self._wrap_add_nodes(add_nodes)
125
- self._curr_iter = 0
126
- self._cur_chunk_graph = None
127
- self._tileable_handlers = (
128
- _TileableHandler(tileable, self._tile_handler(tileable))
129
- for tileable in tileable_graph.topological_iter()
130
- )
131
-
132
- def _wrap_add_nodes(self, add_nodes: Callable):
133
- @functools.wraps(add_nodes)
134
- def inner(
135
- chunk_graph: ChunkGraph,
136
- chunks: List[ChunkType],
137
- visited: Set[ChunkType],
138
- tileable: TileableType,
139
- ):
140
- prev_chunks = set(chunk_graph)
141
- add_nodes(chunk_graph, chunks, visited)
142
- new_chunks = set(chunk_graph)
143
- self._tile_context.record_tileable_tile_info(
144
- tileable, self._curr_iter, list(new_chunks - prev_chunks)
145
- )
146
-
147
- return inner
148
-
149
- @staticmethod
150
- def _get_data(entity: EntityType):
151
- return entity.data if hasattr(entity, "data") else entity
152
-
153
- def _tile_handler(self, tileable: TileableType) -> tile_gen_type:
154
- from ....core.operator import Fetch
155
-
156
- tileable = self._get_data(tileable)
157
-
158
- if isinstance(tileable.op, Fetch) and not tileable.is_coarse():
159
- return [tileable]
160
-
161
- assert tileable.is_coarse()
162
-
163
- # copy tileable
164
- tiled_tileables = copy_tileables(
165
- tileable.op.outputs,
166
- inputs=[self._tile_context[inp] for inp in tileable.inputs],
167
- copy_key=True,
168
- copy_id=False,
169
- )
170
- tiled_tileables = [self._get_data(t) for t in tiled_tileables]
171
- # start to tile
172
- tiled_tileables = yield from handler.tile(tiled_tileables)
173
- return tiled_tileables
174
-
175
- def _gen_tileable_handlers(self, next_tileable_handlers: List[_TileableHandler]):
176
- for tile_handler in self._tileable_handlers:
177
- tileable, handler = tile_handler.tileable, tile_handler.handler
178
- if tileable in self._tile_context:
179
- continue
180
- if any(
181
- inp not in self._tile_context
182
- for inp in self._tileable_graph.predecessors(tileable)
183
- ):
184
- # predecessors not finished yet
185
- next_tileable_handlers.append(_TileableHandler(tileable, handler))
186
- continue
187
-
188
- yield _TileableHandler(tileable, handler)
189
-
190
- def _tile(
191
- self,
192
- chunk_graph: ChunkGraph,
193
- tileable: TileableType,
194
- tile_handler: tile_gen_type,
195
- next_tileable_handlers: List[_TileableHandler],
196
- to_update_tileables: List[TileableType],
197
- visited: Set[EntityType],
198
- ):
199
- try:
200
- need_process = next(tile_handler)
201
-
202
- if isinstance(need_process, TileStatus):
203
- # process tile that returns progress
204
- self._tile_context.set_progress(tileable, need_process.progress)
205
- need_process = need_process.entities
206
- else:
207
- # if progress not specified, we just update 0.4 * rest progress
208
- progress = self._tile_context.get_progress(tileable)
209
- new_progress = progress + (1.0 - progress) * DEFAULT_UPDATED_PROGRESS
210
- self._tile_context.set_progress(tileable, new_progress)
211
-
212
- chunks = []
213
- if need_process is not None:
214
- for t in need_process:
215
- if isinstance(t, CHUNK_TYPE):
216
- chunks.append(self._get_data(t))
217
- elif isinstance(t, TILEABLE_TYPE):
218
- to_update_tileables.append(self._get_data(t))
219
- # not finished yet
220
- self._add_nodes(chunk_graph, chunks.copy(), visited, tileable)
221
- next_tileable_handlers.append(
222
- _TileableHandler(tileable, tile_handler, need_process)
223
- )
224
- # add intermediate chunks into result chunks
225
- # to prevent them being pruned
226
- chunk_graph.result_chunks.extend(c for c in chunks if c in chunk_graph)
227
- except StopIteration as e:
228
- # tile done
229
- tiled_tileables = e.value
230
- for out, tiled_tileable in zip(tileable.op.outputs, tiled_tileables):
231
- out = self._get_data(out)
232
- tiled_tileable = self._get_data(tiled_tileable)
233
-
234
- chunks = tiled_tileable.chunks
235
- if chunks is None: # pragma: no cover
236
- raise ValueError(f"tileable({out}) is still coarse after tile")
237
- chunks = [self._get_data(c) for c in chunks]
238
- self._tile_context[out] = tiled_tileable
239
- self._add_nodes(chunk_graph, chunks, visited, tileable)
240
-
241
- def _gen_result_chunks(
242
- self,
243
- chunk_graph: ChunkGraph,
244
- next_tileable_handlers: List[_TileableHandler],
245
- ):
246
- result_chunks = chunk_graph.result_chunks
247
- tileable_graph = self._tileable_graph
248
- result_chunk_set = set(result_chunks)
249
-
250
- def _add_result_chunk(c):
251
- if c not in result_chunk_set:
252
- result_chunks.append(c)
253
- result_chunk_set.add(c)
254
-
255
- if next_tileable_handlers:
256
- for tileable_handler in next_tileable_handlers:
257
- tileable = tileable_handler.tileable
258
- # tileable that tile not completed, scan their inputs
259
- for inp_tileable in tileable_graph.iter_predecessors(tileable):
260
- if (
261
- tileable_handler.last_need_processes is None
262
- or tileable_graph.count_successors(inp_tileable) > 1
263
- ):
264
- # if nothing yielded inside its tile,
265
- # or the input has more than 1 successors,
266
- # make sure their chunks in result,
267
- # so that they will not be executed repeatedly
268
- if inp_tileable in self._tile_context:
269
- for chunk in self._tile_context[inp_tileable].chunks:
270
- chunk = self._get_data(chunk)
271
- if chunk in chunk_graph:
272
- _add_result_chunk(chunk)
273
- for tileable in tileable_graph.result_tileables:
274
- if tileable in self._tile_context:
275
- for chunk in self._tile_context[tileable].chunks:
276
- chunk = self._get_data(chunk)
277
- if chunk in chunk_graph:
278
- _add_result_chunk(chunk)
279
- if (
280
- chunk in self._chunk_to_fetch
281
- and self._chunk_to_fetch[chunk] in chunk_graph
282
- ):
283
- _add_result_chunk(self._chunk_to_fetch[chunk])
284
-
285
- def _iter(self):
286
- chunk_graph = self._cur_chunk_graph
287
-
288
- to_update_tileables = []
289
- visited = set()
290
-
291
- if chunk_graph is not None:
292
- # last tiled chunks, add them to processed
293
- # so that fetch chunk can be generated.
294
- # Use chunk key as the key to make sure the copied chunk can be build to a fetch.
295
- processed_chunks = (
296
- c.chunk.key if isinstance(c, FUSE_CHUNK_TYPE) else c.key
297
- for c in chunk_graph.result_chunks
298
- )
299
- self._processed_chunks.update(processed_chunks)
300
-
301
- result_chunks = []
302
- chunk_graph = self._cur_chunk_graph = ChunkGraph(result_chunks)
303
-
304
- next_tileable_handlers = []
305
- # tile
306
- for tile_handler in self._gen_tileable_handlers(next_tileable_handlers):
307
- self._tile(
308
- chunk_graph,
309
- tile_handler.tileable,
310
- tile_handler.handler,
311
- next_tileable_handlers,
312
- to_update_tileables,
313
- visited,
314
- )
315
- self._tileable_handlers = next_tileable_handlers
316
- # gen result chunks
317
- self._gen_result_chunks(chunk_graph, next_tileable_handlers)
318
- # prune unused chunks
319
- prune_chunk_graph(chunk_graph)
320
-
321
- self._curr_iter += 1
322
-
323
- return to_update_tileables
324
-
325
- def __iter__(self):
326
- while self._tileable_handlers:
327
- to_update_tileables = self._iter()
328
- yield self._cur_chunk_graph
329
- for t in to_update_tileables:
330
- t.refresh_params()
331
-
332
-
333
- def prune_chunk_graph(chunk_graph: ChunkGraph):
334
- from ....core.operator import Fetch, ShuffleProxy, VirtualOperator
335
-
336
- result_set = set(chunk_graph.result_chunks)
337
- stack = list(chunk_graph.result_chunks)
338
- used = set()
339
- while stack:
340
- n = stack.pop()
341
- if n in used:
342
- continue
343
- used.add(n)
344
- stack.extend(chunk_graph.predecessors(n))
345
- if isinstance(n.op, ShuffleProxy):
346
- stack.extend(
347
- succ for succ in chunk_graph.iter_successors(n) if succ not in used
348
- )
349
-
350
- unused = {n for n in chunk_graph if n not in used}
351
- for n in unused:
352
- # for pruned chunks, we assume we will use them later,
353
- # so we add the inputs of them into result chunks,
354
- # to prevent from duplicated submission
355
- for inp in chunk_graph.iter_predecessors(n):
356
- if (
357
- inp in used
358
- and inp not in result_set
359
- and not isinstance(inp.op, (Fetch, VirtualOperator))
360
- ):
361
- chunk_graph.result_chunks.append(inp)
362
- result_set.add(inp)
363
- # prune chunk
364
- chunk_graph.remove_node(n)
365
-
366
-
367
- class ChunkGraphBuilder(AbstractGraphBuilder):
368
- _graph: TileableGraph
369
-
370
- def __init__(
371
- self,
372
- graph: TileableGraph,
373
- fuse_enabled: bool = True,
374
- tile_context: TileContext = None,
375
- tiler_cls: Union[Type[Tiler], Callable] = None,
376
- ):
377
- super().__init__(graph)
378
- self.fuse_enabled = fuse_enabled
379
- self.tile_context = TileContext() if tile_context is None else tile_context
380
- self.tile_context.set_tileables(set(graph))
381
-
382
- self._processed_chunks: Set[str] = set()
383
- self._chunk_to_fetch: Dict[ChunkType, ChunkType] = dict()
384
-
385
- tiler_cls = Tiler if tiler_cls is None else tiler_cls
386
- self.tiler = tiler_cls(
387
- self._graph,
388
- self.tile_context,
389
- self._processed_chunks,
390
- self._chunk_to_fetch,
391
- self._add_nodes,
392
- )
393
-
394
- def _process_node(self, entity: EntityType):
395
- if entity.key in self._processed_chunks:
396
- if entity not in self._chunk_to_fetch:
397
- # gen fetch
398
- fetch_chunk = build_fetch(entity).data
399
- self._chunk_to_fetch[entity] = fetch_chunk
400
- return self._chunk_to_fetch[entity]
401
- return entity
402
-
403
- def _select_inputs(self, inputs: List[ChunkType]):
404
- new_inputs = []
405
- for inp in inputs:
406
- if inp.key in self._processed_chunks:
407
- # gen fetch
408
- if inp not in self._chunk_to_fetch:
409
- fetch_chunk = build_fetch(inp).data
410
- self._chunk_to_fetch[inp] = fetch_chunk
411
- new_inputs.append(self._chunk_to_fetch[inp])
412
- else:
413
- new_inputs.append(inp)
414
- return new_inputs
415
-
416
- def _if_add_node(self, node: EntityType, visited: Set):
417
- return node not in visited and node.key not in self._processed_chunks
418
-
419
- def _build(self) -> Iterable[Union[TileableGraph, ChunkGraph]]:
420
- tile_iterator = iter(self.tiler)
421
- while True:
422
- try:
423
- with enter_mode(build=True, kernel=True):
424
- graph = next(tile_iterator)
425
- yield graph
426
- except StopIteration:
427
- break
428
-
429
- def build(self) -> Generator[Union[TileableGraph, ChunkGraph], None, None]:
430
- yield from self._build()