maxframe 1.0.0rc1__cp37-cp37m-win_amd64.whl → 1.0.0rc3__cp37-cp37m-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of maxframe might be problematic. Click here for more details.

Files changed (138) hide show
  1. maxframe/_utils.cp37-win_amd64.pyd +0 -0
  2. maxframe/codegen.py +3 -6
  3. maxframe/config/config.py +49 -10
  4. maxframe/config/validators.py +42 -11
  5. maxframe/conftest.py +15 -2
  6. maxframe/core/__init__.py +2 -13
  7. maxframe/core/entity/__init__.py +0 -4
  8. maxframe/core/entity/objects.py +46 -3
  9. maxframe/core/entity/output_types.py +0 -3
  10. maxframe/core/entity/tests/test_objects.py +43 -0
  11. maxframe/core/entity/tileables.py +5 -78
  12. maxframe/core/graph/__init__.py +2 -2
  13. maxframe/core/graph/builder/__init__.py +0 -1
  14. maxframe/core/graph/builder/base.py +5 -4
  15. maxframe/core/graph/builder/tileable.py +4 -4
  16. maxframe/core/graph/builder/utils.py +4 -8
  17. maxframe/core/graph/core.cp37-win_amd64.pyd +0 -0
  18. maxframe/core/graph/entity.py +9 -33
  19. maxframe/core/operator/__init__.py +2 -9
  20. maxframe/core/operator/base.py +3 -5
  21. maxframe/core/operator/objects.py +0 -9
  22. maxframe/core/operator/utils.py +55 -0
  23. maxframe/dataframe/__init__.py +1 -1
  24. maxframe/dataframe/arithmetic/around.py +5 -17
  25. maxframe/dataframe/arithmetic/core.py +15 -7
  26. maxframe/dataframe/arithmetic/docstring.py +5 -55
  27. maxframe/dataframe/arithmetic/tests/test_arithmetic.py +22 -0
  28. maxframe/dataframe/core.py +5 -5
  29. maxframe/dataframe/datasource/date_range.py +2 -2
  30. maxframe/dataframe/datasource/read_odps_query.py +7 -1
  31. maxframe/dataframe/datasource/read_odps_table.py +3 -2
  32. maxframe/dataframe/datasource/tests/test_datasource.py +14 -0
  33. maxframe/dataframe/datastore/to_odps.py +1 -1
  34. maxframe/dataframe/groupby/cum.py +0 -1
  35. maxframe/dataframe/groupby/tests/test_groupby.py +4 -0
  36. maxframe/dataframe/indexing/add_prefix_suffix.py +1 -1
  37. maxframe/dataframe/indexing/rename.py +3 -37
  38. maxframe/dataframe/indexing/sample.py +0 -1
  39. maxframe/dataframe/indexing/set_index.py +68 -1
  40. maxframe/dataframe/merge/merge.py +236 -2
  41. maxframe/dataframe/merge/tests/test_merge.py +123 -0
  42. maxframe/dataframe/misc/apply.py +3 -10
  43. maxframe/dataframe/misc/case_when.py +1 -1
  44. maxframe/dataframe/misc/describe.py +2 -2
  45. maxframe/dataframe/misc/drop_duplicates.py +4 -25
  46. maxframe/dataframe/misc/eval.py +4 -0
  47. maxframe/dataframe/misc/pct_change.py +1 -83
  48. maxframe/dataframe/misc/transform.py +1 -30
  49. maxframe/dataframe/misc/value_counts.py +4 -17
  50. maxframe/dataframe/missing/dropna.py +1 -1
  51. maxframe/dataframe/missing/fillna.py +5 -5
  52. maxframe/dataframe/operators.py +1 -17
  53. maxframe/dataframe/reduction/core.py +2 -2
  54. maxframe/dataframe/sort/sort_values.py +1 -11
  55. maxframe/dataframe/statistics/quantile.py +5 -17
  56. maxframe/dataframe/utils.py +4 -7
  57. maxframe/io/objects/__init__.py +24 -0
  58. maxframe/io/objects/core.py +140 -0
  59. maxframe/io/objects/tensor.py +76 -0
  60. maxframe/io/objects/tests/__init__.py +13 -0
  61. maxframe/io/objects/tests/test_object_io.py +97 -0
  62. maxframe/{odpsio → io/odpsio}/__init__.py +3 -1
  63. maxframe/{odpsio → io/odpsio}/arrow.py +12 -8
  64. maxframe/{odpsio → io/odpsio}/schema.py +15 -12
  65. maxframe/io/odpsio/tableio.py +702 -0
  66. maxframe/io/odpsio/tests/__init__.py +13 -0
  67. maxframe/{odpsio → io/odpsio}/tests/test_schema.py +19 -18
  68. maxframe/{odpsio → io/odpsio}/tests/test_tableio.py +50 -23
  69. maxframe/{odpsio → io/odpsio}/tests/test_volumeio.py +4 -6
  70. maxframe/io/odpsio/volumeio.py +57 -0
  71. maxframe/learn/contrib/xgboost/classifier.py +26 -2
  72. maxframe/learn/contrib/xgboost/core.py +87 -2
  73. maxframe/learn/contrib/xgboost/dmatrix.py +3 -6
  74. maxframe/learn/contrib/xgboost/predict.py +21 -7
  75. maxframe/learn/contrib/xgboost/regressor.py +3 -10
  76. maxframe/learn/contrib/xgboost/train.py +27 -17
  77. maxframe/{core/operator/fuse.py → learn/core.py} +7 -10
  78. maxframe/lib/mmh3.cp37-win_amd64.pyd +0 -0
  79. maxframe/protocol.py +41 -17
  80. maxframe/remote/core.py +4 -8
  81. maxframe/serialization/__init__.py +1 -0
  82. maxframe/serialization/core.cp37-win_amd64.pyd +0 -0
  83. maxframe/serialization/serializables/core.py +48 -9
  84. maxframe/tensor/__init__.py +69 -2
  85. maxframe/tensor/arithmetic/isclose.py +1 -0
  86. maxframe/tensor/arithmetic/tests/test_arithmetic.py +21 -17
  87. maxframe/tensor/core.py +5 -136
  88. maxframe/tensor/datasource/array.py +3 -0
  89. maxframe/tensor/datasource/full.py +1 -1
  90. maxframe/tensor/datasource/tests/test_datasource.py +1 -1
  91. maxframe/tensor/indexing/flatnonzero.py +1 -1
  92. maxframe/tensor/merge/__init__.py +2 -0
  93. maxframe/tensor/merge/concatenate.py +98 -0
  94. maxframe/tensor/merge/tests/test_merge.py +30 -1
  95. maxframe/tensor/merge/vstack.py +70 -0
  96. maxframe/tensor/{base → misc}/__init__.py +2 -0
  97. maxframe/tensor/{base → misc}/atleast_1d.py +0 -2
  98. maxframe/tensor/misc/atleast_2d.py +70 -0
  99. maxframe/tensor/misc/atleast_3d.py +85 -0
  100. maxframe/tensor/misc/tests/__init__.py +13 -0
  101. maxframe/tensor/{base → misc}/transpose.py +22 -18
  102. maxframe/tensor/{base → misc}/unique.py +2 -2
  103. maxframe/tensor/operators.py +1 -7
  104. maxframe/tensor/random/core.py +1 -1
  105. maxframe/tensor/reduction/count_nonzero.py +1 -0
  106. maxframe/tensor/reduction/mean.py +1 -0
  107. maxframe/tensor/reduction/nanmean.py +1 -0
  108. maxframe/tensor/reduction/nanvar.py +2 -0
  109. maxframe/tensor/reduction/tests/test_reduction.py +12 -1
  110. maxframe/tensor/reduction/var.py +2 -0
  111. maxframe/tensor/statistics/quantile.py +2 -2
  112. maxframe/tensor/utils.py +2 -22
  113. maxframe/tests/utils.py +11 -2
  114. maxframe/typing_.py +4 -1
  115. maxframe/udf.py +8 -9
  116. maxframe/utils.py +32 -70
  117. {maxframe-1.0.0rc1.dist-info → maxframe-1.0.0rc3.dist-info}/METADATA +2 -2
  118. {maxframe-1.0.0rc1.dist-info → maxframe-1.0.0rc3.dist-info}/RECORD +133 -123
  119. maxframe_client/fetcher.py +60 -68
  120. maxframe_client/session/graph.py +8 -2
  121. maxframe_client/session/odps.py +58 -22
  122. maxframe_client/tests/test_fetcher.py +21 -3
  123. maxframe_client/tests/test_session.py +27 -4
  124. maxframe/core/entity/chunks.py +0 -68
  125. maxframe/core/entity/fuse.py +0 -73
  126. maxframe/core/graph/builder/chunk.py +0 -430
  127. maxframe/odpsio/tableio.py +0 -322
  128. maxframe/odpsio/volumeio.py +0 -95
  129. /maxframe/{odpsio → core/entity}/tests/__init__.py +0 -0
  130. /maxframe/{tensor/base/tests → io}/__init__.py +0 -0
  131. /maxframe/{odpsio → io/odpsio}/tests/test_arrow.py +0 -0
  132. /maxframe/tensor/{base → misc}/astype.py +0 -0
  133. /maxframe/tensor/{base → misc}/broadcast_to.py +0 -0
  134. /maxframe/tensor/{base → misc}/ravel.py +0 -0
  135. /maxframe/tensor/{base/tests/test_base.py → misc/tests/test_misc.py} +0 -0
  136. /maxframe/tensor/{base → misc}/where.py +0 -0
  137. {maxframe-1.0.0rc1.dist-info → maxframe-1.0.0rc3.dist-info}/WHEEL +0 -0
  138. {maxframe-1.0.0rc1.dist-info → maxframe-1.0.0rc3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,702 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import time
17
+ from abc import ABC, abstractmethod
18
+ from contextlib import contextmanager
19
+ from typing import Dict, List, Optional, Union
20
+
21
+ import pyarrow as pa
22
+ from odps import ODPS
23
+ from odps import __version__ as pyodps_version
24
+ from odps.apis.storage_api import (
25
+ StorageApiArrowClient,
26
+ TableBatchScanResponse,
27
+ TableBatchWriteResponse,
28
+ )
29
+ from odps.config import option_context as pyodps_option_context
30
+ from odps.tunnel import TableTunnel
31
+ from odps.types import OdpsSchema, PartitionSpec, timestamp_ntz
32
+
33
+ try:
34
+ import pyarrow.compute as pac
35
+ except ImportError:
36
+ pac = None
37
+
38
+ from ...config import options
39
+ from ...env import ODPS_STORAGE_API_ENDPOINT
40
+ from ...lib.version import Version
41
+ from .schema import odps_schema_to_arrow_schema
42
+
43
+ PartitionsType = Union[List[str], str, None]
44
+
45
+ _DEFAULT_ROW_BATCH_SIZE = 4096
46
+ _need_convert_timezone = Version(pyodps_version) < Version("0.11.7")
47
+
48
+
49
+ @contextmanager
50
+ def _sync_pyodps_timezone():
51
+ with pyodps_option_context() as cfg:
52
+ cfg.local_timezone = options.local_timezone
53
+ yield
54
+
55
+
56
+ class ODPSTableIO(ABC):
57
+ def __new__(cls, odps: ODPS):
58
+ if cls is ODPSTableIO:
59
+ if options.use_common_table:
60
+ return HaloTableIO(odps)
61
+ else:
62
+ return TunnelTableIO(odps)
63
+ return super().__new__(cls)
64
+
65
+ def __init__(self, odps: ODPS):
66
+ self._odps = odps
67
+
68
+ @classmethod
69
+ def _get_reader_schema(
70
+ cls,
71
+ table_schema: OdpsSchema,
72
+ columns: Optional[List[str]] = None,
73
+ partition_columns: Union[None, bool, List[str]] = None,
74
+ ) -> OdpsSchema:
75
+ final_cols = []
76
+
77
+ columns = columns or [col.name for col in table_schema.simple_columns]
78
+ if partition_columns is True:
79
+ partition_columns = [c.name for c in table_schema.partitions]
80
+ else:
81
+ partition_columns = partition_columns or []
82
+
83
+ for col_name in columns + partition_columns:
84
+ final_cols.append(table_schema[col_name])
85
+ return OdpsSchema(final_cols)
86
+
87
+ @abstractmethod
88
+ def open_reader(
89
+ self,
90
+ full_table_name: str,
91
+ partitions: PartitionsType = None,
92
+ columns: Optional[List[str]] = None,
93
+ partition_columns: Union[None, bool, List[str]] = None,
94
+ start: Optional[int] = None,
95
+ stop: Optional[int] = None,
96
+ reverse_range: bool = False,
97
+ row_batch_size: int = _DEFAULT_ROW_BATCH_SIZE,
98
+ ):
99
+ raise NotImplementedError
100
+
101
+ @abstractmethod
102
+ def open_writer(
103
+ self,
104
+ full_table_name: str,
105
+ partition: Optional[str] = None,
106
+ overwrite: bool = True,
107
+ ):
108
+ raise NotImplementedError
109
+
110
+
111
+ class TunnelMultiPartitionReader:
112
+ def __init__(
113
+ self,
114
+ odps_entry: ODPS,
115
+ table_name: str,
116
+ partitions: PartitionsType,
117
+ columns: Optional[List[str]] = None,
118
+ partition_columns: Optional[List[str]] = None,
119
+ start: Optional[int] = None,
120
+ count: Optional[int] = None,
121
+ partition_to_download_ids: Dict[str, str] = None,
122
+ ):
123
+ self._odps_entry = odps_entry
124
+ self._table = odps_entry.get_table(table_name)
125
+ self._columns = columns
126
+
127
+ odps_schema = ODPSTableIO._get_reader_schema(
128
+ self._table.table_schema, columns, partition_columns
129
+ )
130
+ self._schema = odps_schema_to_arrow_schema(odps_schema)
131
+
132
+ self._start = start or 0
133
+ self._count = count
134
+ self._row_left = count
135
+
136
+ self._cur_reader = None
137
+ self._reader_iter = None
138
+ self._cur_partition_id = -1
139
+ self._reader_start_pos = 0
140
+
141
+ if partitions is None or isinstance(partitions, str):
142
+ self._partitions = [partitions]
143
+ else:
144
+ self._partitions = partitions
145
+
146
+ self._partition_cols = partition_columns
147
+ self._partition_to_download_ids = partition_to_download_ids or dict()
148
+
149
+ @property
150
+ def count(self) -> Optional[int]:
151
+ if len(self._partitions) > 1:
152
+ return None
153
+ return self._count
154
+
155
+ def _open_next_reader(self):
156
+ if self._cur_reader is not None:
157
+ self._reader_start_pos += self._cur_reader.count
158
+
159
+ if (
160
+ self._row_left is not None and self._row_left <= 0
161
+ ) or 1 + self._cur_partition_id >= len(self._partitions):
162
+ self._cur_reader = None
163
+ return
164
+
165
+ while 1 + self._cur_partition_id < len(self._partitions):
166
+ self._cur_partition_id += 1
167
+
168
+ part_str = self._partitions[self._cur_partition_id]
169
+ with _sync_pyodps_timezone():
170
+ self._cur_reader = self._table.open_reader(
171
+ part_str,
172
+ columns=self._columns,
173
+ arrow=True,
174
+ download_id=self._partition_to_download_ids.get(part_str),
175
+ )
176
+ if self._cur_reader.count + self._reader_start_pos > self._start:
177
+ start = self._start - self._reader_start_pos
178
+ if self._count is None:
179
+ count = None
180
+ else:
181
+ count = min(self._count, self._cur_reader.count - start)
182
+
183
+ with _sync_pyodps_timezone():
184
+ self._reader_iter = self._cur_reader.read(start, count)
185
+ break
186
+ self._reader_start_pos += self._cur_reader.count
187
+ else:
188
+ self._cur_reader = None
189
+
190
+ def _fill_batch_partition(self, batch: pa.RecordBatch) -> pa.RecordBatch:
191
+ pt_spec = PartitionSpec(self._partitions[self._cur_partition_id])
192
+
193
+ names = list(batch.schema.names)
194
+ arrays = []
195
+ for idx in range(batch.num_columns):
196
+ col = batch.column(idx)
197
+ if _need_convert_timezone and isinstance(col.type, pa.TimestampType):
198
+ if col.type.tz is not None:
199
+ target_type = pa.timestamp(
200
+ self._schema.types[idx].unit, col.type.tz
201
+ )
202
+ arrays.append(col.cast(target_type))
203
+ else:
204
+ target_type = pa.timestamp(
205
+ self._schema.types[idx].unit, options.local_timezone
206
+ )
207
+ pd_col = col.to_pandas().dt.tz_localize(options.local_timezone)
208
+ arrays.append(pa.Array.from_pandas(pd_col).cast(target_type))
209
+ else:
210
+ arrays.append(batch.column(idx))
211
+
212
+ for part_col in self._partition_cols or []:
213
+ names.append(part_col)
214
+ col_type = self._schema.field_by_name(part_col).type
215
+ arrays.append(pa.array([pt_spec[part_col]] * batch.num_rows).cast(col_type))
216
+ return pa.RecordBatch.from_arrays(arrays, names)
217
+
218
+ def read(self):
219
+ with _sync_pyodps_timezone():
220
+ if self._cur_reader is None:
221
+ self._open_next_reader()
222
+ if self._cur_reader is None:
223
+ return None
224
+ while self._cur_reader is not None:
225
+ try:
226
+ batch = next(self._reader_iter)
227
+ if batch is not None:
228
+ if self._row_left is not None:
229
+ self._row_left -= batch.num_rows
230
+ return self._fill_batch_partition(batch)
231
+ except StopIteration:
232
+ self._open_next_reader()
233
+ return None
234
+
235
+ def read_all(self) -> pa.Table:
236
+ batches = []
237
+ while True:
238
+ batch = self.read()
239
+ if batch is None:
240
+ break
241
+ batches.append(batch)
242
+ if not batches:
243
+ return self._schema.empty_table()
244
+ return pa.Table.from_batches(batches)
245
+
246
+
247
+ class TunnelWrappedWriter:
248
+ def __init__(self, nested_writer):
249
+ self._writer = nested_writer
250
+
251
+ def write(self, data: Union[pa.RecordBatch, pa.Table]):
252
+ if not any(isinstance(tp, pa.TimestampType) for tp in data.schema.types):
253
+ self._writer.write(data)
254
+ return
255
+ pa_type = type(data)
256
+ arrays = []
257
+ for idx in range(data.num_columns):
258
+ name = data.schema.names[idx]
259
+ col = data.column(idx)
260
+ if not isinstance(col.type, pa.TimestampType):
261
+ arrays.append(col)
262
+ continue
263
+ if self._writer.schema[name].type == timestamp_ntz:
264
+ col = HaloTableArrowWriter._localize_timezone(col, "UTC")
265
+ else:
266
+ col = HaloTableArrowWriter._localize_timezone(col)
267
+ arrays.append(col)
268
+ data = pa_type.from_arrays(arrays, names=data.schema.names)
269
+ self._writer.write(data)
270
+
271
+ def __getattr__(self, item):
272
+ return getattr(self._writer, item)
273
+
274
+
275
+ class TunnelTableIO(ODPSTableIO):
276
+ @contextmanager
277
+ def open_reader(
278
+ self,
279
+ full_table_name: str,
280
+ partitions: PartitionsType = None,
281
+ columns: Optional[List[str]] = None,
282
+ partition_columns: Union[None, bool, List[str]] = None,
283
+ start: Optional[int] = None,
284
+ stop: Optional[int] = None,
285
+ reverse_range: bool = False,
286
+ row_batch_size: int = _DEFAULT_ROW_BATCH_SIZE,
287
+ ):
288
+ table = self._odps.get_table(full_table_name)
289
+ if partition_columns is True:
290
+ partition_columns = [c.name for c in table.table_schema.partitions]
291
+
292
+ total_records = None
293
+ part_to_down_id = None
294
+ if (
295
+ (start is not None and start < 0)
296
+ or (stop is not None and stop < 0)
297
+ or (reverse_range and start is None)
298
+ ):
299
+ table = self._odps.get_table(full_table_name)
300
+ tunnel = TableTunnel(self._odps)
301
+ parts = (
302
+ [partitions]
303
+ if partitions is None or isinstance(partitions, str)
304
+ else partitions
305
+ )
306
+ part_to_down_id = dict()
307
+ total_records = 0
308
+ for part in parts:
309
+ down_session = tunnel.create_download_session(
310
+ table, async_mode=True, partition_spec=part
311
+ )
312
+ part_to_down_id[part] = down_session.id
313
+ total_records += down_session.count
314
+
315
+ count = None
316
+ if start is not None or stop is not None:
317
+ if reverse_range:
318
+ start = start if start is not None else total_records - 1
319
+ stop = stop if stop is not None else -1
320
+ else:
321
+ start = start if start is not None else 0
322
+ stop = stop if stop is not None else None
323
+ start = start if start >= 0 else total_records + start
324
+ stop = stop if stop is None or stop >= 0 else total_records + stop
325
+ if reverse_range:
326
+ count = start - stop
327
+ start = stop + 1
328
+ else:
329
+ count = stop - start if stop is not None and start is not None else None
330
+
331
+ yield TunnelMultiPartitionReader(
332
+ self._odps,
333
+ full_table_name,
334
+ partitions=partitions,
335
+ columns=columns,
336
+ partition_columns=partition_columns,
337
+ start=start,
338
+ count=count,
339
+ partition_to_download_ids=part_to_down_id,
340
+ )
341
+
342
+ @contextmanager
343
+ def open_writer(
344
+ self,
345
+ full_table_name: str,
346
+ partition: Optional[str] = None,
347
+ overwrite: bool = True,
348
+ ):
349
+ table = self._odps.get_table(full_table_name)
350
+ with _sync_pyodps_timezone():
351
+ with table.open_writer(
352
+ partition=partition,
353
+ arrow=True,
354
+ create_partition=partition is not None,
355
+ overwrite=overwrite,
356
+ ) as writer:
357
+ # fixme should yield writer directly once pyodps fixes
358
+ # related arrow timestamp bug when provided schema and
359
+ # table schema is identical.
360
+ if _need_convert_timezone:
361
+ yield TunnelWrappedWriter(writer)
362
+ else:
363
+ yield writer
364
+
365
+
366
+ class HaloTableArrowReader:
367
+ def __init__(
368
+ self,
369
+ client: StorageApiArrowClient,
370
+ scan_info: TableBatchScanResponse,
371
+ odps_schema: OdpsSchema,
372
+ start: Optional[int] = None,
373
+ count: Optional[int] = None,
374
+ row_batch_size: Optional[int] = None,
375
+ ):
376
+ self._client = client
377
+ self._scan_info = scan_info
378
+
379
+ self._cur_split_id = -1
380
+ self._cur_reader = None
381
+
382
+ self._odps_schema = odps_schema
383
+ self._arrow_schema = odps_schema_to_arrow_schema(odps_schema)
384
+
385
+ self._start = start
386
+ self._count = count
387
+ self._cursor = 0
388
+ self._row_batch_size = row_batch_size
389
+
390
+ @property
391
+ def count(self) -> int:
392
+ return self._count
393
+
394
+ def _open_next_reader(self):
395
+ from odps.apis.storage_api import ReadRowsRequest
396
+
397
+ if 0 <= self._scan_info.split_count <= self._cur_split_id + 1:
398
+ # scan by split
399
+ self._cur_reader = None
400
+ return
401
+ elif self._count is not None and self._cursor >= self._count:
402
+ # scan by range
403
+ self._cur_reader = None
404
+ return
405
+
406
+ read_rows_kw = {}
407
+ if self._start is not None:
408
+ read_rows_kw["row_index"] = self._start + self._cursor
409
+ read_rows_kw["row_count"] = min(
410
+ self._row_batch_size, self._count - self._cursor
411
+ )
412
+ self._cursor = min(self._count, self._cursor + self._row_batch_size)
413
+
414
+ req = ReadRowsRequest(
415
+ session_id=self._scan_info.session_id,
416
+ split_index=self._cur_split_id + 1,
417
+ **read_rows_kw,
418
+ )
419
+ self._cur_reader = self._client.read_rows_arrow(req)
420
+ self._cur_split_id += 1
421
+
422
+ def _convert_timezone(self, batch: pa.RecordBatch) -> pa.RecordBatch:
423
+ timezone = options.local_timezone
424
+ if not any(isinstance(tp, pa.TimestampType) for tp in batch.schema.types):
425
+ return batch
426
+
427
+ cols = []
428
+ for idx in range(batch.num_columns):
429
+ col = batch.column(idx)
430
+ name = batch.schema.names[idx]
431
+ if not isinstance(col.type, pa.TimestampType):
432
+ cols.append(col)
433
+ continue
434
+ if self._odps_schema[name].type == timestamp_ntz:
435
+ col = col.cast(pa.timestamp(col.type.unit))
436
+ cols.append(col)
437
+ continue
438
+
439
+ if hasattr(pac, "local_timestamp"):
440
+ col = col.cast(pa.timestamp(col.type.unit, timezone))
441
+ else:
442
+ pd_col = col.to_pandas().dt.tz_convert(timezone)
443
+ col = pa.Array.from_pandas(pd_col).cast(
444
+ pa.timestamp(col.type.unit, timezone)
445
+ )
446
+ cols.append(col)
447
+
448
+ return pa.RecordBatch.from_arrays(cols, names=batch.schema.names)
449
+
450
+ def read(self):
451
+ if self._cur_reader is None:
452
+ self._open_next_reader()
453
+ if self._cur_reader is None:
454
+ return None
455
+ while self._cur_reader is not None:
456
+ batch = self._cur_reader.read()
457
+ if batch is not None:
458
+ return self._convert_timezone(batch)
459
+ self._open_next_reader()
460
+ return None
461
+
462
+ def read_all(self) -> pa.Table:
463
+ batches = []
464
+ while True:
465
+ batch = self.read()
466
+ if batch is None:
467
+ break
468
+ batches.append(batch)
469
+ if not batches:
470
+ return self._arrow_schema.empty_table()
471
+ return pa.Table.from_batches(batches)
472
+
473
+
474
+ class HaloTableArrowWriter:
475
+ def __init__(
476
+ self,
477
+ client: StorageApiArrowClient,
478
+ write_info: TableBatchWriteResponse,
479
+ odps_schema: OdpsSchema,
480
+ ):
481
+ self._client = client
482
+ self._write_info = write_info
483
+ self._odps_schema = odps_schema
484
+ self._arrow_schema = odps_schema_to_arrow_schema(odps_schema)
485
+
486
+ self._writer = None
487
+
488
+ def open(self):
489
+ from odps.apis.storage_api import WriteRowsRequest
490
+
491
+ self._writer = self._client.write_rows_arrow(
492
+ WriteRowsRequest(self._write_info.session_id)
493
+ )
494
+
495
+ @classmethod
496
+ def _localize_timezone(cls, col, tz=None):
497
+ from odps.lib import tzlocal
498
+
499
+ if tz is None:
500
+ if options.local_timezone is None:
501
+ tz = str(tzlocal.get_localzone())
502
+ else:
503
+ tz = str(options.local_timezone)
504
+
505
+ if col.type.tz is not None:
506
+ return col
507
+ if hasattr(pac, "assume_timezone"):
508
+ col = pac.assume_timezone(col, tz)
509
+ return col
510
+ else:
511
+ col = col.to_pandas()
512
+ return pa.Array.from_pandas(col.dt.tz_localize(tz))
513
+
514
+ def _convert_schema(self, batch: pa.RecordBatch):
515
+ if batch.schema == self._arrow_schema and not any(
516
+ isinstance(tp, pa.TimestampType) for tp in self._arrow_schema.types
517
+ ):
518
+ return batch
519
+ cols = []
520
+ for idx in range(batch.num_columns):
521
+ col = batch.column(idx)
522
+ name = batch.schema.names[idx]
523
+
524
+ if isinstance(col.type, pa.TimestampType):
525
+ if self._odps_schema[name].type == timestamp_ntz:
526
+ col = self._localize_timezone(col, "UTC")
527
+ else:
528
+ col = self._localize_timezone(col)
529
+
530
+ if col.type != self._arrow_schema.types[idx]:
531
+ col = col.cast(self._arrow_schema.types[idx])
532
+ cols.append(col)
533
+ return pa.RecordBatch.from_arrays(cols, names=batch.schema.names)
534
+
535
+ def write(self, batch):
536
+ if isinstance(batch, pa.Table):
537
+ for b in batch.to_batches():
538
+ self._writer.write(self._convert_schema(b))
539
+ else:
540
+ self._writer.write(self._convert_schema(batch))
541
+
542
+ def close(self):
543
+ commit_msg, is_success = self._writer.finish()
544
+ if not is_success:
545
+ raise IOError(commit_msg)
546
+ return commit_msg
547
+
548
+
549
+ class HaloTableIO(ODPSTableIO):
550
+ _storage_api_endpoint = os.getenv(ODPS_STORAGE_API_ENDPOINT)
551
+
552
+ @staticmethod
553
+ def _convert_partitions(partitions: PartitionsType) -> Optional[List[str]]:
554
+ if partitions is None:
555
+ return []
556
+ elif isinstance(partitions, (str, PartitionSpec)):
557
+ partitions = [partitions]
558
+ return [
559
+ "/".join(f"{k}={v}" for k, v in PartitionSpec(pt).items())
560
+ for pt in partitions
561
+ ]
562
+
563
+ def get_table_record_count(
564
+ self, full_table_name: str, partitions: PartitionsType = None
565
+ ):
566
+ from odps.apis.storage_api import SplitOptions, TableBatchScanRequest
567
+
568
+ table = self._odps.get_table(full_table_name)
569
+ client = StorageApiArrowClient(
570
+ self._odps, table, rest_endpoint=self._storage_api_endpoint
571
+ )
572
+
573
+ split_option = SplitOptions.SplitMode.SIZE
574
+
575
+ scan_kw = {
576
+ "required_partitions": self._convert_partitions(partitions),
577
+ "split_options": SplitOptions.get_default_options(split_option),
578
+ }
579
+
580
+ # todo add more options for partition column handling
581
+ req = TableBatchScanRequest(**scan_kw)
582
+ resp = client.create_read_session(req)
583
+ return resp.record_count
584
+
585
+ @contextmanager
586
+ def open_reader(
587
+ self,
588
+ full_table_name: str,
589
+ partitions: PartitionsType = None,
590
+ columns: Optional[List[str]] = None,
591
+ partition_columns: Union[None, bool, List[str]] = None,
592
+ start: Optional[int] = None,
593
+ stop: Optional[int] = None,
594
+ reverse_range: bool = False,
595
+ row_batch_size: int = _DEFAULT_ROW_BATCH_SIZE,
596
+ ):
597
+ from odps.apis.storage_api import (
598
+ SessionRequest,
599
+ SplitOptions,
600
+ Status,
601
+ TableBatchScanRequest,
602
+ )
603
+
604
+ table = self._odps.get_table(full_table_name)
605
+ client = StorageApiArrowClient(
606
+ self._odps, table, rest_endpoint=self._storage_api_endpoint
607
+ )
608
+
609
+ split_option = SplitOptions.SplitMode.SIZE
610
+ if start is not None or stop is not None:
611
+ split_option = SplitOptions.SplitMode.ROW_OFFSET
612
+
613
+ scan_kw = {
614
+ "required_partitions": self._convert_partitions(partitions),
615
+ "split_options": SplitOptions.get_default_options(split_option),
616
+ }
617
+ columns = columns or [c.name for c in table.table_schema.simple_columns]
618
+ scan_kw["required_data_columns"] = columns
619
+ if partition_columns is True:
620
+ scan_kw["required_partition_columns"] = [
621
+ c.name for c in table.table_schema.partitions
622
+ ]
623
+ else:
624
+ scan_kw["required_partition_columns"] = partition_columns
625
+
626
+ # todo add more options for partition column handling
627
+ req = TableBatchScanRequest(**scan_kw)
628
+ resp = client.create_read_session(req)
629
+
630
+ session_id = resp.session_id
631
+ status = resp.status
632
+ while status == Status.WAIT:
633
+ resp = client.get_read_session(SessionRequest(session_id))
634
+ status = resp.status
635
+ time.sleep(1.0)
636
+
637
+ assert status == Status.OK
638
+
639
+ count = None
640
+ if start is not None or stop is not None:
641
+ if reverse_range:
642
+ start = start if start is not None else resp.record_count - 1
643
+ stop = stop if stop is not None else -1
644
+ else:
645
+ start = start if start is not None else 0
646
+ stop = stop if stop is not None else resp.record_count
647
+ start = start if start >= 0 else resp.record_count + start
648
+ stop = stop if stop >= 0 else resp.record_count + stop
649
+ if reverse_range:
650
+ count = start - stop
651
+ start = stop + 1
652
+ else:
653
+ count = stop - start
654
+
655
+ reader_schema = self._get_reader_schema(
656
+ table.table_schema, columns, partition_columns
657
+ )
658
+ yield HaloTableArrowReader(
659
+ client,
660
+ resp,
661
+ odps_schema=reader_schema,
662
+ start=start,
663
+ count=count,
664
+ row_batch_size=row_batch_size,
665
+ )
666
+
667
+ @contextmanager
668
+ def open_writer(
669
+ self,
670
+ full_table_name: str,
671
+ partition: Optional[str] = None,
672
+ overwrite: bool = True,
673
+ ):
674
+ from odps.apis.storage_api import (
675
+ SessionRequest,
676
+ SessionStatus,
677
+ TableBatchWriteRequest,
678
+ )
679
+
680
+ table = self._odps.get_table(full_table_name)
681
+ client = StorageApiArrowClient(
682
+ self._odps, table, rest_endpoint=self._storage_api_endpoint
683
+ )
684
+
685
+ part_strs = self._convert_partitions(partition)
686
+ part_str = part_strs[0] if part_strs else None
687
+ req = TableBatchWriteRequest(partition_spec=part_str, overwrite=overwrite)
688
+ resp = client.create_write_session(req)
689
+
690
+ session_id = resp.session_id
691
+ writer = HaloTableArrowWriter(client, resp, table.table_schema)
692
+ writer.open()
693
+
694
+ yield writer
695
+
696
+ commit_msg = writer.close()
697
+ resp = client.commit_write_session(
698
+ SessionRequest(session_id=session_id), [commit_msg]
699
+ )
700
+ while resp.session_status == SessionStatus.COMMITTING:
701
+ resp = client.get_write_session(SessionRequest(session_id=session_id))
702
+ assert resp.session_status == SessionStatus.COMMITTED
@@ -0,0 +1,13 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.