maxframe 0.1.0b5__cp39-cp39-win32.whl → 1.0.0__cp39-cp39-win32.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of maxframe might be problematic. Click here for more details.

Files changed (203) hide show
  1. maxframe/_utils.cp39-win32.pyd +0 -0
  2. maxframe/codegen.py +10 -4
  3. maxframe/config/config.py +68 -10
  4. maxframe/config/validators.py +42 -11
  5. maxframe/conftest.py +58 -14
  6. maxframe/core/__init__.py +2 -16
  7. maxframe/core/entity/__init__.py +1 -12
  8. maxframe/core/entity/executable.py +1 -1
  9. maxframe/core/entity/objects.py +46 -45
  10. maxframe/core/entity/output_types.py +0 -3
  11. maxframe/core/entity/tests/test_objects.py +43 -0
  12. maxframe/core/entity/tileables.py +5 -78
  13. maxframe/core/graph/__init__.py +2 -2
  14. maxframe/core/graph/builder/__init__.py +0 -1
  15. maxframe/core/graph/builder/base.py +5 -4
  16. maxframe/core/graph/builder/tileable.py +4 -4
  17. maxframe/core/graph/builder/utils.py +4 -8
  18. maxframe/core/graph/core.cp39-win32.pyd +0 -0
  19. maxframe/core/graph/core.pyx +4 -4
  20. maxframe/core/graph/entity.py +9 -33
  21. maxframe/core/operator/__init__.py +2 -9
  22. maxframe/core/operator/base.py +3 -5
  23. maxframe/core/operator/objects.py +0 -9
  24. maxframe/core/operator/utils.py +55 -0
  25. maxframe/dataframe/__init__.py +1 -1
  26. maxframe/dataframe/arithmetic/around.py +5 -17
  27. maxframe/dataframe/arithmetic/core.py +15 -7
  28. maxframe/dataframe/arithmetic/docstring.py +7 -33
  29. maxframe/dataframe/arithmetic/equal.py +4 -2
  30. maxframe/dataframe/arithmetic/greater.py +4 -2
  31. maxframe/dataframe/arithmetic/greater_equal.py +4 -2
  32. maxframe/dataframe/arithmetic/less.py +2 -2
  33. maxframe/dataframe/arithmetic/less_equal.py +4 -2
  34. maxframe/dataframe/arithmetic/not_equal.py +4 -2
  35. maxframe/dataframe/arithmetic/tests/test_arithmetic.py +39 -16
  36. maxframe/dataframe/core.py +31 -7
  37. maxframe/dataframe/datasource/date_range.py +2 -2
  38. maxframe/dataframe/datasource/read_odps_query.py +117 -23
  39. maxframe/dataframe/datasource/read_odps_table.py +6 -3
  40. maxframe/dataframe/datasource/tests/test_datasource.py +103 -8
  41. maxframe/dataframe/datastore/tests/test_to_odps.py +48 -0
  42. maxframe/dataframe/datastore/to_odps.py +28 -0
  43. maxframe/dataframe/extensions/__init__.py +5 -0
  44. maxframe/dataframe/extensions/flatjson.py +131 -0
  45. maxframe/dataframe/extensions/flatmap.py +317 -0
  46. maxframe/dataframe/extensions/reshuffle.py +1 -1
  47. maxframe/dataframe/extensions/tests/test_extensions.py +108 -3
  48. maxframe/dataframe/groupby/core.py +1 -1
  49. maxframe/dataframe/groupby/cum.py +0 -1
  50. maxframe/dataframe/groupby/fill.py +4 -1
  51. maxframe/dataframe/groupby/getitem.py +6 -0
  52. maxframe/dataframe/groupby/tests/test_groupby.py +5 -1
  53. maxframe/dataframe/groupby/transform.py +5 -1
  54. maxframe/dataframe/indexing/align.py +1 -1
  55. maxframe/dataframe/indexing/loc.py +6 -4
  56. maxframe/dataframe/indexing/rename.py +5 -28
  57. maxframe/dataframe/indexing/sample.py +0 -1
  58. maxframe/dataframe/indexing/set_index.py +68 -1
  59. maxframe/dataframe/initializer.py +11 -1
  60. maxframe/dataframe/merge/__init__.py +9 -1
  61. maxframe/dataframe/merge/concat.py +41 -31
  62. maxframe/dataframe/merge/merge.py +237 -3
  63. maxframe/dataframe/merge/tests/test_merge.py +126 -1
  64. maxframe/dataframe/misc/apply.py +5 -10
  65. maxframe/dataframe/misc/case_when.py +1 -1
  66. maxframe/dataframe/misc/describe.py +2 -2
  67. maxframe/dataframe/misc/drop_duplicates.py +8 -8
  68. maxframe/dataframe/misc/eval.py +4 -0
  69. maxframe/dataframe/misc/memory_usage.py +2 -2
  70. maxframe/dataframe/misc/pct_change.py +1 -83
  71. maxframe/dataframe/misc/tests/test_misc.py +33 -2
  72. maxframe/dataframe/misc/transform.py +1 -30
  73. maxframe/dataframe/misc/value_counts.py +4 -17
  74. maxframe/dataframe/missing/dropna.py +1 -1
  75. maxframe/dataframe/missing/fillna.py +5 -5
  76. maxframe/dataframe/operators.py +1 -17
  77. maxframe/dataframe/reduction/core.py +2 -2
  78. maxframe/dataframe/reduction/tests/test_reduction.py +2 -4
  79. maxframe/dataframe/sort/sort_values.py +1 -11
  80. maxframe/dataframe/statistics/corr.py +3 -3
  81. maxframe/dataframe/statistics/quantile.py +13 -19
  82. maxframe/dataframe/statistics/tests/test_statistics.py +4 -4
  83. maxframe/dataframe/tests/test_initializer.py +33 -2
  84. maxframe/dataframe/utils.py +26 -11
  85. maxframe/dataframe/window/expanding.py +5 -3
  86. maxframe/dataframe/window/tests/test_expanding.py +2 -2
  87. maxframe/errors.py +13 -0
  88. maxframe/extension.py +12 -0
  89. maxframe/io/__init__.py +13 -0
  90. maxframe/io/objects/__init__.py +24 -0
  91. maxframe/io/objects/core.py +140 -0
  92. maxframe/io/objects/tensor.py +76 -0
  93. maxframe/io/objects/tests/__init__.py +13 -0
  94. maxframe/io/objects/tests/test_object_io.py +97 -0
  95. maxframe/{odpsio → io/odpsio}/__init__.py +3 -1
  96. maxframe/{odpsio → io/odpsio}/arrow.py +42 -10
  97. maxframe/{odpsio → io/odpsio}/schema.py +38 -16
  98. maxframe/io/odpsio/tableio.py +719 -0
  99. maxframe/io/odpsio/tests/__init__.py +13 -0
  100. maxframe/{odpsio → io/odpsio}/tests/test_schema.py +59 -22
  101. maxframe/{odpsio → io/odpsio}/tests/test_tableio.py +50 -23
  102. maxframe/{odpsio → io/odpsio}/tests/test_volumeio.py +4 -6
  103. maxframe/io/odpsio/volumeio.py +63 -0
  104. maxframe/learn/contrib/__init__.py +3 -1
  105. maxframe/learn/contrib/graph/__init__.py +15 -0
  106. maxframe/learn/contrib/graph/connected_components.py +215 -0
  107. maxframe/learn/contrib/graph/tests/__init__.py +13 -0
  108. maxframe/learn/contrib/graph/tests/test_connected_components.py +53 -0
  109. maxframe/learn/contrib/llm/__init__.py +16 -0
  110. maxframe/learn/contrib/llm/core.py +54 -0
  111. maxframe/learn/contrib/llm/models/__init__.py +14 -0
  112. maxframe/learn/contrib/llm/models/dashscope.py +73 -0
  113. maxframe/learn/contrib/llm/multi_modal.py +42 -0
  114. maxframe/learn/contrib/llm/text.py +42 -0
  115. maxframe/learn/contrib/xgboost/classifier.py +26 -2
  116. maxframe/learn/contrib/xgboost/core.py +87 -2
  117. maxframe/learn/contrib/xgboost/dmatrix.py +3 -6
  118. maxframe/learn/contrib/xgboost/predict.py +29 -46
  119. maxframe/learn/contrib/xgboost/regressor.py +3 -10
  120. maxframe/learn/contrib/xgboost/train.py +29 -18
  121. maxframe/{core/operator/fuse.py → learn/core.py} +7 -10
  122. maxframe/lib/mmh3.cp39-win32.pyd +0 -0
  123. maxframe/lib/mmh3.pyi +43 -0
  124. maxframe/lib/sparse/tests/test_sparse.py +15 -15
  125. maxframe/lib/wrapped_pickle.py +2 -1
  126. maxframe/opcodes.py +8 -0
  127. maxframe/protocol.py +154 -27
  128. maxframe/remote/core.py +4 -8
  129. maxframe/serialization/__init__.py +1 -0
  130. maxframe/serialization/core.cp39-win32.pyd +0 -0
  131. maxframe/serialization/core.pxd +3 -0
  132. maxframe/serialization/core.pyi +3 -0
  133. maxframe/serialization/core.pyx +67 -26
  134. maxframe/serialization/exception.py +1 -1
  135. maxframe/serialization/pandas.py +52 -17
  136. maxframe/serialization/serializables/core.py +180 -15
  137. maxframe/serialization/serializables/field_type.py +4 -1
  138. maxframe/serialization/serializables/tests/test_serializable.py +54 -5
  139. maxframe/serialization/tests/test_serial.py +2 -1
  140. maxframe/session.py +9 -2
  141. maxframe/tensor/__init__.py +81 -2
  142. maxframe/tensor/arithmetic/isclose.py +1 -0
  143. maxframe/tensor/arithmetic/tests/test_arithmetic.py +22 -18
  144. maxframe/tensor/core.py +5 -136
  145. maxframe/tensor/datasource/array.py +3 -0
  146. maxframe/tensor/datasource/full.py +1 -1
  147. maxframe/tensor/datasource/tests/test_datasource.py +1 -1
  148. maxframe/tensor/indexing/flatnonzero.py +1 -1
  149. maxframe/tensor/indexing/getitem.py +2 -0
  150. maxframe/tensor/merge/__init__.py +2 -0
  151. maxframe/tensor/merge/concatenate.py +101 -0
  152. maxframe/tensor/merge/tests/test_merge.py +30 -1
  153. maxframe/tensor/merge/vstack.py +74 -0
  154. maxframe/tensor/{base → misc}/__init__.py +2 -0
  155. maxframe/tensor/{base → misc}/atleast_1d.py +1 -3
  156. maxframe/tensor/misc/atleast_2d.py +70 -0
  157. maxframe/tensor/misc/atleast_3d.py +85 -0
  158. maxframe/tensor/misc/tests/__init__.py +13 -0
  159. maxframe/tensor/{base → misc}/transpose.py +22 -18
  160. maxframe/tensor/{base → misc}/unique.py +3 -3
  161. maxframe/tensor/operators.py +1 -7
  162. maxframe/tensor/random/core.py +1 -1
  163. maxframe/tensor/reduction/count_nonzero.py +2 -1
  164. maxframe/tensor/reduction/mean.py +1 -0
  165. maxframe/tensor/reduction/nanmean.py +1 -0
  166. maxframe/tensor/reduction/nanvar.py +2 -0
  167. maxframe/tensor/reduction/tests/test_reduction.py +12 -1
  168. maxframe/tensor/reduction/var.py +2 -0
  169. maxframe/tensor/statistics/quantile.py +2 -2
  170. maxframe/tensor/utils.py +2 -22
  171. maxframe/tests/test_protocol.py +34 -0
  172. maxframe/tests/test_utils.py +0 -12
  173. maxframe/tests/utils.py +17 -2
  174. maxframe/typing_.py +4 -1
  175. maxframe/udf.py +8 -9
  176. maxframe/utils.py +106 -86
  177. {maxframe-0.1.0b5.dist-info → maxframe-1.0.0.dist-info}/METADATA +25 -25
  178. {maxframe-0.1.0b5.dist-info → maxframe-1.0.0.dist-info}/RECORD +197 -173
  179. {maxframe-0.1.0b5.dist-info → maxframe-1.0.0.dist-info}/WHEEL +1 -1
  180. maxframe_client/__init__.py +0 -1
  181. maxframe_client/clients/framedriver.py +4 -1
  182. maxframe_client/fetcher.py +81 -74
  183. maxframe_client/session/consts.py +3 -0
  184. maxframe_client/session/graph.py +8 -2
  185. maxframe_client/session/odps.py +194 -40
  186. maxframe_client/session/task.py +94 -39
  187. maxframe_client/tests/test_fetcher.py +21 -3
  188. maxframe_client/tests/test_session.py +109 -8
  189. maxframe/core/entity/chunks.py +0 -68
  190. maxframe/core/entity/fuse.py +0 -73
  191. maxframe/core/graph/builder/chunk.py +0 -430
  192. maxframe/odpsio/tableio.py +0 -322
  193. maxframe/odpsio/volumeio.py +0 -95
  194. maxframe_client/clients/spe.py +0 -104
  195. /maxframe/{odpsio → core/entity}/tests/__init__.py +0 -0
  196. /maxframe/{tensor/base → dataframe/datastore}/tests/__init__.py +0 -0
  197. /maxframe/{odpsio → io/odpsio}/tests/test_arrow.py +0 -0
  198. /maxframe/tensor/{base → misc}/astype.py +0 -0
  199. /maxframe/tensor/{base → misc}/broadcast_to.py +0 -0
  200. /maxframe/tensor/{base → misc}/ravel.py +0 -0
  201. /maxframe/tensor/{base/tests/test_base.py → misc/tests/test_misc.py} +0 -0
  202. /maxframe/tensor/{base → misc}/where.py +0 -0
  203. {maxframe-0.1.0b5.dist-info → maxframe-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,719 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import time
17
+ from abc import ABC, abstractmethod
18
+ from contextlib import contextmanager
19
+ from typing import Dict, List, Optional, Union
20
+
21
+ import numpy as np
22
+ import pyarrow as pa
23
+ from odps import ODPS
24
+ from odps import __version__ as pyodps_version
25
+ from odps.apis.storage_api import (
26
+ StorageApiArrowClient,
27
+ TableBatchScanResponse,
28
+ TableBatchWriteResponse,
29
+ )
30
+ from odps.tunnel import TableTunnel
31
+ from odps.types import OdpsSchema, PartitionSpec, timestamp_ntz
32
+ from odps.utils import call_with_retry
33
+
34
+ try:
35
+ import pyarrow.compute as pac
36
+ except ImportError:
37
+ pac = None
38
+
39
+ from ...config import options
40
+ from ...env import ODPS_STORAGE_API_ENDPOINT
41
+ from ...lib.version import Version
42
+ from ...utils import sync_pyodps_options
43
+ from .schema import odps_schema_to_arrow_schema
44
+
45
+ PartitionsType = Union[List[str], str, None]
46
+
47
+ _DEFAULT_ROW_BATCH_SIZE = 4096
48
+ _need_patch_batch = Version(pyodps_version) < Version("0.12.0")
49
+
50
+
51
+ class ODPSTableIO(ABC):
52
+ def __new__(cls, odps: ODPS):
53
+ if cls is ODPSTableIO:
54
+ if options.use_common_table or ODPS_STORAGE_API_ENDPOINT in os.environ:
55
+ return HaloTableIO(odps)
56
+ else:
57
+ return TunnelTableIO(odps)
58
+ return super().__new__(cls)
59
+
60
+ def __init__(self, odps: ODPS):
61
+ self._odps = odps
62
+
63
+ @classmethod
64
+ def _get_reader_schema(
65
+ cls,
66
+ table_schema: OdpsSchema,
67
+ columns: Optional[List[str]] = None,
68
+ partition_columns: Union[None, bool, List[str]] = None,
69
+ ) -> OdpsSchema:
70
+ final_cols = []
71
+
72
+ columns = columns or [col.name for col in table_schema.simple_columns]
73
+ if partition_columns is True:
74
+ partition_columns = [c.name for c in table_schema.partitions]
75
+ else:
76
+ partition_columns = partition_columns or []
77
+
78
+ for col_name in columns + partition_columns:
79
+ final_cols.append(table_schema[col_name])
80
+ return OdpsSchema(final_cols)
81
+
82
+ @abstractmethod
83
+ def open_reader(
84
+ self,
85
+ full_table_name: str,
86
+ partitions: PartitionsType = None,
87
+ columns: Optional[List[str]] = None,
88
+ partition_columns: Union[None, bool, List[str]] = None,
89
+ start: Optional[int] = None,
90
+ stop: Optional[int] = None,
91
+ reverse_range: bool = False,
92
+ row_batch_size: int = _DEFAULT_ROW_BATCH_SIZE,
93
+ ):
94
+ raise NotImplementedError
95
+
96
+ @abstractmethod
97
+ def open_writer(
98
+ self,
99
+ full_table_name: str,
100
+ partition: Optional[str] = None,
101
+ overwrite: bool = True,
102
+ ):
103
+ raise NotImplementedError
104
+
105
+
106
+ class TunnelMultiPartitionReader:
107
+ def __init__(
108
+ self,
109
+ odps_entry: ODPS,
110
+ table_name: str,
111
+ partitions: PartitionsType,
112
+ columns: Optional[List[str]] = None,
113
+ partition_columns: Optional[List[str]] = None,
114
+ start: Optional[int] = None,
115
+ count: Optional[int] = None,
116
+ partition_to_download_ids: Dict[str, str] = None,
117
+ ):
118
+ self._odps_entry = odps_entry
119
+ self._table = odps_entry.get_table(table_name)
120
+ self._columns = columns
121
+
122
+ odps_schema = ODPSTableIO._get_reader_schema(
123
+ self._table.table_schema, columns, partition_columns
124
+ )
125
+ self._schema = odps_schema_to_arrow_schema(odps_schema)
126
+
127
+ self._start = start or 0
128
+ self._count = count
129
+ self._row_left = count
130
+
131
+ self._cur_reader = None
132
+ self._reader_iter = None
133
+ self._cur_partition_id = -1
134
+ self._reader_start_pos = 0
135
+
136
+ if partitions is None:
137
+ if not self._table.table_schema.partitions:
138
+ self._partitions = [None]
139
+ else:
140
+ self._partitions = [str(pt) for pt in self._table.partitions]
141
+ elif isinstance(partitions, str):
142
+ self._partitions = [partitions]
143
+ else:
144
+ self._partitions = partitions
145
+
146
+ self._partition_cols = partition_columns
147
+ self._partition_to_download_ids = partition_to_download_ids or dict()
148
+
149
+ @property
150
+ def count(self) -> Optional[int]:
151
+ if len(self._partitions) > 1:
152
+ return None
153
+ return self._count
154
+
155
+ def _open_next_reader(self):
156
+ if self._cur_reader is not None:
157
+ self._reader_start_pos += self._cur_reader.count
158
+
159
+ if (
160
+ self._row_left is not None and self._row_left <= 0
161
+ ) or 1 + self._cur_partition_id >= len(self._partitions):
162
+ self._cur_reader = None
163
+ return
164
+
165
+ while 1 + self._cur_partition_id < len(self._partitions):
166
+ self._cur_partition_id += 1
167
+
168
+ part_str = self._partitions[self._cur_partition_id]
169
+
170
+ # todo make this more formal when PyODPS 0.12.0 is released
171
+ req_columns = self._columns
172
+ if not _need_patch_batch:
173
+ req_columns = self._schema.names
174
+ with sync_pyodps_options():
175
+ self._cur_reader = self._table.open_reader(
176
+ part_str,
177
+ columns=req_columns,
178
+ arrow=True,
179
+ download_id=self._partition_to_download_ids.get(part_str),
180
+ )
181
+ if self._cur_reader.count + self._reader_start_pos > self._start:
182
+ start = self._start - self._reader_start_pos
183
+ if self._count is None:
184
+ count = None
185
+ else:
186
+ count = min(self._count, self._cur_reader.count - start)
187
+
188
+ with sync_pyodps_options():
189
+ self._reader_iter = self._cur_reader.read(start, count)
190
+ break
191
+ self._reader_start_pos += self._cur_reader.count
192
+ else:
193
+ self._cur_reader = None
194
+
195
+ def _fill_batch_partition(self, batch: pa.RecordBatch) -> pa.RecordBatch:
196
+ pt_spec = PartitionSpec(self._partitions[self._cur_partition_id])
197
+
198
+ names = list(batch.schema.names)
199
+ arrays = []
200
+ for idx in range(batch.num_columns):
201
+ col = batch.column(idx)
202
+ if isinstance(col.type, pa.TimestampType):
203
+ if col.type.tz is not None:
204
+ target_type = pa.timestamp(
205
+ self._schema.types[idx].unit, col.type.tz
206
+ )
207
+ arrays.append(col.cast(target_type))
208
+ else:
209
+ target_type = pa.timestamp(
210
+ self._schema.types[idx].unit, options.local_timezone
211
+ )
212
+ pd_col = col.to_pandas().dt.tz_localize(options.local_timezone)
213
+ arrays.append(pa.Array.from_pandas(pd_col).cast(target_type))
214
+ else:
215
+ arrays.append(batch.column(idx))
216
+
217
+ for part_col in self._partition_cols or []:
218
+ names.append(part_col)
219
+ col_type = self._schema.field_by_name(part_col).type
220
+ pt_col = np.repeat([pt_spec[part_col]], batch.num_rows)
221
+ arrays.append(pa.array(pt_col).cast(col_type))
222
+ return pa.RecordBatch.from_arrays(arrays, names)
223
+
224
+ def read(self):
225
+ with sync_pyodps_options():
226
+ if self._cur_reader is None:
227
+ self._open_next_reader()
228
+ if self._cur_reader is None:
229
+ return None
230
+ while self._cur_reader is not None:
231
+ try:
232
+ batch = next(self._reader_iter)
233
+ if batch is not None:
234
+ if self._row_left is not None:
235
+ self._row_left -= batch.num_rows
236
+ if _need_patch_batch:
237
+ return self._fill_batch_partition(batch)
238
+ else:
239
+ return batch
240
+ except StopIteration:
241
+ self._open_next_reader()
242
+ return None
243
+
244
+ def read_all(self) -> pa.Table:
245
+ batches = []
246
+ while True:
247
+ batch = self.read()
248
+ if batch is None:
249
+ break
250
+ batches.append(batch)
251
+ if not batches:
252
+ return self._schema.empty_table()
253
+ return pa.Table.from_batches(batches)
254
+
255
+
256
+ class TunnelWrappedWriter:
257
+ def __init__(self, nested_writer):
258
+ self._writer = nested_writer
259
+
260
+ def write(self, data: Union[pa.RecordBatch, pa.Table]):
261
+ if not any(isinstance(tp, pa.TimestampType) for tp in data.schema.types):
262
+ self._writer.write(data)
263
+ return
264
+ pa_type = type(data)
265
+ arrays = []
266
+ for idx in range(data.num_columns):
267
+ name = data.schema.names[idx]
268
+ col = data.column(idx)
269
+ if not isinstance(col.type, pa.TimestampType):
270
+ arrays.append(col)
271
+ continue
272
+ if self._writer.schema[name].type == timestamp_ntz:
273
+ col = HaloTableArrowWriter._localize_timezone(col, "UTC")
274
+ else:
275
+ col = HaloTableArrowWriter._localize_timezone(col)
276
+ arrays.append(col)
277
+ data = pa_type.from_arrays(arrays, names=data.schema.names)
278
+ self._writer.write(data)
279
+
280
+ def __getattr__(self, item):
281
+ return getattr(self._writer, item)
282
+
283
+
284
+ class TunnelTableIO(ODPSTableIO):
285
+ @contextmanager
286
+ def open_reader(
287
+ self,
288
+ full_table_name: str,
289
+ partitions: PartitionsType = None,
290
+ columns: Optional[List[str]] = None,
291
+ partition_columns: Union[None, bool, List[str]] = None,
292
+ start: Optional[int] = None,
293
+ stop: Optional[int] = None,
294
+ reverse_range: bool = False,
295
+ row_batch_size: int = _DEFAULT_ROW_BATCH_SIZE,
296
+ ):
297
+ with sync_pyodps_options():
298
+ table = self._odps.get_table(full_table_name)
299
+
300
+ if partition_columns is True:
301
+ partition_columns = [c.name for c in table.table_schema.partitions]
302
+
303
+ total_records = None
304
+ part_to_down_id = None
305
+ if (
306
+ (start is not None and start < 0)
307
+ or (stop is not None and stop < 0)
308
+ or (reverse_range and start is None)
309
+ ):
310
+ with sync_pyodps_options():
311
+ table = self._odps.get_table(full_table_name)
312
+ tunnel = TableTunnel(self._odps)
313
+ parts = (
314
+ [partitions]
315
+ if partitions is None or isinstance(partitions, str)
316
+ else partitions
317
+ )
318
+ part_to_down_id = dict()
319
+ total_records = 0
320
+ for part in parts:
321
+ down_session = tunnel.create_download_session(
322
+ table, async_mode=True, partition_spec=part
323
+ )
324
+ part_to_down_id[part] = down_session.id
325
+ total_records += down_session.count
326
+
327
+ count = None
328
+ if start is not None or stop is not None:
329
+ if reverse_range:
330
+ start = start if start is not None else total_records - 1
331
+ stop = stop if stop is not None else -1
332
+ else:
333
+ start = start if start is not None else 0
334
+ stop = stop if stop is not None else None
335
+ start = start if start >= 0 else total_records + start
336
+ stop = stop if stop is None or stop >= 0 else total_records + stop
337
+ if reverse_range:
338
+ count = start - stop
339
+ start = stop + 1
340
+ else:
341
+ count = stop - start if stop is not None and start is not None else None
342
+
343
+ yield TunnelMultiPartitionReader(
344
+ self._odps,
345
+ full_table_name,
346
+ partitions=partitions,
347
+ columns=columns,
348
+ partition_columns=partition_columns,
349
+ start=start,
350
+ count=count,
351
+ partition_to_download_ids=part_to_down_id,
352
+ )
353
+
354
+ @contextmanager
355
+ def open_writer(
356
+ self,
357
+ full_table_name: str,
358
+ partition: Optional[str] = None,
359
+ overwrite: bool = True,
360
+ ):
361
+ table = self._odps.get_table(full_table_name)
362
+ with sync_pyodps_options():
363
+ with table.open_writer(
364
+ partition=partition,
365
+ arrow=True,
366
+ create_partition=partition is not None,
367
+ overwrite=overwrite,
368
+ ) as writer:
369
+ # fixme should yield writer directly once pyodps fixes
370
+ # related arrow timestamp bug when provided schema and
371
+ # table schema is identical.
372
+ if _need_patch_batch:
373
+ yield TunnelWrappedWriter(writer)
374
+ else:
375
+ yield writer
376
+
377
+
378
+ class HaloTableArrowReader:
379
+ def __init__(
380
+ self,
381
+ client: StorageApiArrowClient,
382
+ scan_info: TableBatchScanResponse,
383
+ odps_schema: OdpsSchema,
384
+ start: Optional[int] = None,
385
+ count: Optional[int] = None,
386
+ row_batch_size: Optional[int] = None,
387
+ ):
388
+ self._client = client
389
+ self._scan_info = scan_info
390
+
391
+ self._cur_split_id = -1
392
+ self._cur_reader = None
393
+
394
+ self._odps_schema = odps_schema
395
+ self._arrow_schema = odps_schema_to_arrow_schema(odps_schema)
396
+
397
+ self._start = start
398
+ self._count = count
399
+ self._cursor = 0
400
+ self._row_batch_size = row_batch_size
401
+
402
+ @property
403
+ def count(self) -> int:
404
+ return self._count
405
+
406
+ def _open_next_reader(self):
407
+ from odps.apis.storage_api import ReadRowsRequest
408
+
409
+ if 0 <= self._scan_info.split_count <= self._cur_split_id + 1:
410
+ # scan by split
411
+ self._cur_reader = None
412
+ return
413
+ elif self._count is not None and self._cursor >= self._count:
414
+ # scan by range
415
+ self._cur_reader = None
416
+ return
417
+
418
+ read_rows_kw = {}
419
+ if self._start is not None:
420
+ read_rows_kw["row_index"] = self._start + self._cursor
421
+ read_rows_kw["row_count"] = min(
422
+ self._row_batch_size, self._count - self._cursor
423
+ )
424
+ self._cursor = min(self._count, self._cursor + self._row_batch_size)
425
+
426
+ req = ReadRowsRequest(
427
+ session_id=self._scan_info.session_id,
428
+ split_index=self._cur_split_id + 1,
429
+ **read_rows_kw,
430
+ )
431
+ self._cur_reader = call_with_retry(self._client.read_rows_arrow, req)
432
+ self._cur_split_id += 1
433
+
434
+ def _convert_timezone(self, batch: pa.RecordBatch) -> pa.RecordBatch:
435
+ timezone = options.local_timezone
436
+ if not any(isinstance(tp, pa.TimestampType) for tp in batch.schema.types):
437
+ return batch
438
+
439
+ cols = []
440
+ for idx in range(batch.num_columns):
441
+ col = batch.column(idx)
442
+ name = batch.schema.names[idx]
443
+ if not isinstance(col.type, pa.TimestampType):
444
+ cols.append(col)
445
+ continue
446
+ if self._odps_schema[name].type == timestamp_ntz:
447
+ col = col.cast(pa.timestamp(col.type.unit))
448
+ cols.append(col)
449
+ continue
450
+
451
+ if hasattr(pac, "local_timestamp"):
452
+ col = col.cast(pa.timestamp(col.type.unit, timezone))
453
+ else:
454
+ pd_col = col.to_pandas().dt.tz_convert(timezone)
455
+ col = pa.Array.from_pandas(pd_col).cast(
456
+ pa.timestamp(col.type.unit, timezone)
457
+ )
458
+ cols.append(col)
459
+
460
+ return pa.RecordBatch.from_arrays(cols, names=batch.schema.names)
461
+
462
+ def read(self):
463
+ if self._cur_reader is None:
464
+ self._open_next_reader()
465
+ if self._cur_reader is None:
466
+ return None
467
+ while self._cur_reader is not None:
468
+ batch = self._cur_reader.read()
469
+ if batch is not None:
470
+ return self._convert_timezone(batch)
471
+ self._open_next_reader()
472
+ return None
473
+
474
+ def read_all(self) -> pa.Table:
475
+ batches = []
476
+ while True:
477
+ batch = self.read()
478
+ if batch is None:
479
+ break
480
+ batches.append(batch)
481
+ if not batches:
482
+ return self._arrow_schema.empty_table()
483
+ return pa.Table.from_batches(batches)
484
+
485
+
486
+ class HaloTableArrowWriter:
487
+ def __init__(
488
+ self,
489
+ client: StorageApiArrowClient,
490
+ write_info: TableBatchWriteResponse,
491
+ odps_schema: OdpsSchema,
492
+ ):
493
+ self._client = client
494
+ self._write_info = write_info
495
+ self._odps_schema = odps_schema
496
+ self._arrow_schema = odps_schema_to_arrow_schema(odps_schema)
497
+
498
+ self._writer = None
499
+
500
+ def open(self):
501
+ from odps.apis.storage_api import WriteRowsRequest
502
+
503
+ self._writer = call_with_retry(
504
+ self._client.write_rows_arrow,
505
+ WriteRowsRequest(self._write_info.session_id),
506
+ )
507
+
508
+ @classmethod
509
+ def _localize_timezone(cls, col, tz=None):
510
+ from odps.lib import tzlocal
511
+
512
+ if tz is None:
513
+ if options.local_timezone is None:
514
+ tz = str(tzlocal.get_localzone())
515
+ else:
516
+ tz = str(options.local_timezone)
517
+
518
+ if col.type.tz is not None:
519
+ return col
520
+ if hasattr(pac, "assume_timezone"):
521
+ col = pac.assume_timezone(col, tz)
522
+ return col
523
+ else:
524
+ col = col.to_pandas()
525
+ return pa.Array.from_pandas(col.dt.tz_localize(tz))
526
+
527
+ def _convert_schema(self, batch: pa.RecordBatch):
528
+ if batch.schema == self._arrow_schema and not any(
529
+ isinstance(tp, pa.TimestampType) for tp in self._arrow_schema.types
530
+ ):
531
+ return batch
532
+ cols = []
533
+ for idx in range(batch.num_columns):
534
+ col = batch.column(idx)
535
+ name = batch.schema.names[idx]
536
+
537
+ if isinstance(col.type, pa.TimestampType):
538
+ if self._odps_schema[name].type == timestamp_ntz:
539
+ col = self._localize_timezone(col, "UTC")
540
+ else:
541
+ col = self._localize_timezone(col)
542
+
543
+ if col.type != self._arrow_schema.types[idx]:
544
+ col = col.cast(self._arrow_schema.types[idx])
545
+ cols.append(col)
546
+ return pa.RecordBatch.from_arrays(cols, names=batch.schema.names)
547
+
548
+ def write(self, batch):
549
+ if isinstance(batch, pa.Table):
550
+ for b in batch.to_batches():
551
+ self._writer.write(self._convert_schema(b))
552
+ else:
553
+ self._writer.write(self._convert_schema(batch))
554
+
555
+ def close(self):
556
+ commit_msg, is_success = self._writer.finish()
557
+ if not is_success:
558
+ raise IOError(commit_msg)
559
+ return commit_msg
560
+
561
+
562
+ class HaloTableIO(ODPSTableIO):
563
+ _storage_api_endpoint = os.getenv(ODPS_STORAGE_API_ENDPOINT)
564
+
565
+ @staticmethod
566
+ def _convert_partitions(partitions: PartitionsType) -> Optional[List[str]]:
567
+ if partitions is None:
568
+ return []
569
+ elif isinstance(partitions, (str, PartitionSpec)):
570
+ partitions = [partitions]
571
+ return [
572
+ "/".join(f"{k}={v}" for k, v in PartitionSpec(pt).items())
573
+ for pt in partitions
574
+ ]
575
+
576
+ def get_table_record_count(
577
+ self, full_table_name: str, partitions: PartitionsType = None
578
+ ):
579
+ from odps.apis.storage_api import SplitOptions, TableBatchScanRequest
580
+
581
+ table = self._odps.get_table(full_table_name)
582
+ client = StorageApiArrowClient(
583
+ self._odps, table, rest_endpoint=self._storage_api_endpoint
584
+ )
585
+
586
+ split_option = SplitOptions.SplitMode.SIZE
587
+
588
+ scan_kw = {
589
+ "required_partitions": self._convert_partitions(partitions),
590
+ "split_options": SplitOptions.get_default_options(split_option),
591
+ }
592
+
593
+ # todo add more options for partition column handling
594
+ req = TableBatchScanRequest(**scan_kw)
595
+ resp = client.create_read_session(req)
596
+ return resp.record_count
597
+
598
+ @contextmanager
599
+ def open_reader(
600
+ self,
601
+ full_table_name: str,
602
+ partitions: PartitionsType = None,
603
+ columns: Optional[List[str]] = None,
604
+ partition_columns: Union[None, bool, List[str]] = None,
605
+ start: Optional[int] = None,
606
+ stop: Optional[int] = None,
607
+ reverse_range: bool = False,
608
+ row_batch_size: int = _DEFAULT_ROW_BATCH_SIZE,
609
+ ):
610
+ from odps.apis.storage_api import (
611
+ SessionRequest,
612
+ SessionStatus,
613
+ SplitOptions,
614
+ TableBatchScanRequest,
615
+ )
616
+
617
+ table = self._odps.get_table(full_table_name)
618
+ client = StorageApiArrowClient(
619
+ self._odps, table, rest_endpoint=self._storage_api_endpoint
620
+ )
621
+
622
+ split_option = SplitOptions.SplitMode.SIZE
623
+ if start is not None or stop is not None:
624
+ split_option = SplitOptions.SplitMode.ROW_OFFSET
625
+
626
+ scan_kw = {
627
+ "required_partitions": self._convert_partitions(partitions),
628
+ "split_options": SplitOptions.get_default_options(split_option),
629
+ }
630
+ columns = columns or [c.name for c in table.table_schema.simple_columns]
631
+ scan_kw["required_data_columns"] = columns
632
+ if partition_columns is True:
633
+ scan_kw["required_partition_columns"] = [
634
+ c.name for c in table.table_schema.partitions
635
+ ]
636
+ else:
637
+ scan_kw["required_partition_columns"] = partition_columns
638
+
639
+ # todo add more options for partition column handling
640
+ req = TableBatchScanRequest(**scan_kw)
641
+ resp = call_with_retry(client.create_read_session, req)
642
+
643
+ session_id = resp.session_id
644
+ status = resp.session_status
645
+ while status == SessionStatus.INIT:
646
+ resp = call_with_retry(client.get_read_session, SessionRequest(session_id))
647
+ status = resp.session_status
648
+ time.sleep(1.0)
649
+
650
+ assert status == SessionStatus.NORMAL
651
+
652
+ count = None
653
+ if start is not None or stop is not None:
654
+ if reverse_range:
655
+ start = start if start is not None else resp.record_count - 1
656
+ stop = stop if stop is not None else -1
657
+ else:
658
+ start = start if start is not None else 0
659
+ stop = stop if stop is not None else resp.record_count
660
+ start = start if start >= 0 else resp.record_count + start
661
+ stop = stop if stop >= 0 else resp.record_count + stop
662
+ if reverse_range:
663
+ count = start - stop
664
+ start = stop + 1
665
+ else:
666
+ count = stop - start
667
+
668
+ reader_schema = self._get_reader_schema(
669
+ table.table_schema, columns, partition_columns
670
+ )
671
+ yield HaloTableArrowReader(
672
+ client,
673
+ resp,
674
+ odps_schema=reader_schema,
675
+ start=start,
676
+ count=count,
677
+ row_batch_size=row_batch_size,
678
+ )
679
+
680
+ @contextmanager
681
+ def open_writer(
682
+ self,
683
+ full_table_name: str,
684
+ partition: Optional[str] = None,
685
+ overwrite: bool = True,
686
+ ):
687
+ from odps.apis.storage_api import (
688
+ SessionRequest,
689
+ SessionStatus,
690
+ TableBatchWriteRequest,
691
+ )
692
+
693
+ table = self._odps.get_table(full_table_name)
694
+ client = StorageApiArrowClient(
695
+ self._odps, table, rest_endpoint=self._storage_api_endpoint
696
+ )
697
+
698
+ part_strs = self._convert_partitions(partition)
699
+ part_str = part_strs[0] if part_strs else None
700
+ req = TableBatchWriteRequest(partition_spec=part_str, overwrite=overwrite)
701
+ resp = call_with_retry(client.create_write_session, req)
702
+
703
+ session_id = resp.session_id
704
+ writer = HaloTableArrowWriter(client, resp, table.table_schema)
705
+ writer.open()
706
+
707
+ yield writer
708
+
709
+ commit_msg = writer.close()
710
+ resp = call_with_retry(
711
+ client.commit_write_session,
712
+ SessionRequest(session_id=session_id),
713
+ [commit_msg],
714
+ )
715
+ while resp.session_status == SessionStatus.COMMITTING:
716
+ resp = call_with_retry(
717
+ client.get_write_session, SessionRequest(session_id=session_id)
718
+ )
719
+ assert resp.session_status == SessionStatus.COMMITTED