maxframe 1.0.0rc4__cp37-cp37m-win32.whl → 1.1.0__cp37-cp37m-win32.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of maxframe might be problematic. Click here for more details.

Files changed (83) hide show
  1. maxframe/_utils.cp37-win32.pyd +0 -0
  2. maxframe/config/config.py +3 -0
  3. maxframe/conftest.py +9 -2
  4. maxframe/core/graph/core.cp37-win32.pyd +0 -0
  5. maxframe/core/operator/base.py +2 -0
  6. maxframe/dataframe/arithmetic/tests/test_arithmetic.py +17 -16
  7. maxframe/dataframe/core.py +24 -2
  8. maxframe/dataframe/datasource/read_odps_query.py +63 -34
  9. maxframe/dataframe/datasource/tests/test_datasource.py +59 -7
  10. maxframe/dataframe/extensions/__init__.py +5 -0
  11. maxframe/dataframe/extensions/apply_chunk.py +649 -0
  12. maxframe/dataframe/extensions/flatjson.py +131 -0
  13. maxframe/dataframe/extensions/flatmap.py +28 -40
  14. maxframe/dataframe/extensions/reshuffle.py +1 -1
  15. maxframe/dataframe/extensions/tests/test_apply_chunk.py +186 -0
  16. maxframe/dataframe/extensions/tests/test_extensions.py +46 -2
  17. maxframe/dataframe/groupby/__init__.py +1 -0
  18. maxframe/dataframe/groupby/aggregation.py +1 -0
  19. maxframe/dataframe/groupby/apply.py +9 -1
  20. maxframe/dataframe/groupby/core.py +1 -1
  21. maxframe/dataframe/groupby/fill.py +4 -1
  22. maxframe/dataframe/groupby/getitem.py +6 -0
  23. maxframe/dataframe/groupby/tests/test_groupby.py +1 -1
  24. maxframe/dataframe/groupby/transform.py +8 -2
  25. maxframe/dataframe/indexing/loc.py +6 -4
  26. maxframe/dataframe/merge/__init__.py +9 -1
  27. maxframe/dataframe/merge/concat.py +41 -31
  28. maxframe/dataframe/merge/merge.py +1 -1
  29. maxframe/dataframe/merge/tests/test_merge.py +3 -1
  30. maxframe/dataframe/misc/apply.py +3 -0
  31. maxframe/dataframe/misc/drop_duplicates.py +5 -1
  32. maxframe/dataframe/misc/map.py +3 -1
  33. maxframe/dataframe/misc/tests/test_misc.py +24 -2
  34. maxframe/dataframe/misc/transform.py +22 -13
  35. maxframe/dataframe/reduction/__init__.py +3 -0
  36. maxframe/dataframe/reduction/aggregation.py +1 -0
  37. maxframe/dataframe/reduction/median.py +56 -0
  38. maxframe/dataframe/reduction/tests/test_reduction.py +17 -7
  39. maxframe/dataframe/statistics/quantile.py +8 -2
  40. maxframe/dataframe/statistics/tests/test_statistics.py +4 -4
  41. maxframe/dataframe/tests/test_utils.py +60 -0
  42. maxframe/dataframe/utils.py +110 -7
  43. maxframe/dataframe/window/expanding.py +5 -3
  44. maxframe/dataframe/window/tests/test_expanding.py +2 -2
  45. maxframe/io/objects/tests/test_object_io.py +39 -12
  46. maxframe/io/odpsio/arrow.py +30 -2
  47. maxframe/io/odpsio/schema.py +23 -5
  48. maxframe/io/odpsio/tableio.py +26 -110
  49. maxframe/io/odpsio/tests/test_schema.py +40 -0
  50. maxframe/io/odpsio/tests/test_tableio.py +5 -5
  51. maxframe/io/odpsio/tests/test_volumeio.py +35 -11
  52. maxframe/io/odpsio/volumeio.py +27 -3
  53. maxframe/learn/contrib/__init__.py +3 -2
  54. maxframe/learn/contrib/llm/__init__.py +16 -0
  55. maxframe/learn/contrib/llm/core.py +54 -0
  56. maxframe/learn/contrib/llm/models/__init__.py +14 -0
  57. maxframe/learn/contrib/llm/models/dashscope.py +73 -0
  58. maxframe/learn/contrib/llm/multi_modal.py +42 -0
  59. maxframe/learn/contrib/llm/text.py +42 -0
  60. maxframe/lib/mmh3.cp37-win32.pyd +0 -0
  61. maxframe/lib/sparse/tests/test_sparse.py +15 -15
  62. maxframe/opcodes.py +7 -1
  63. maxframe/serialization/core.cp37-win32.pyd +0 -0
  64. maxframe/serialization/core.pyx +13 -1
  65. maxframe/serialization/pandas.py +50 -20
  66. maxframe/serialization/serializables/core.py +24 -5
  67. maxframe/serialization/serializables/field_type.py +4 -1
  68. maxframe/serialization/serializables/tests/test_serializable.py +8 -1
  69. maxframe/serialization/tests/test_serial.py +2 -1
  70. maxframe/tensor/__init__.py +19 -7
  71. maxframe/tests/utils.py +16 -0
  72. maxframe/udf.py +27 -0
  73. maxframe/utils.py +36 -8
  74. {maxframe-1.0.0rc4.dist-info → maxframe-1.1.0.dist-info}/METADATA +2 -2
  75. {maxframe-1.0.0rc4.dist-info → maxframe-1.1.0.dist-info}/RECORD +83 -72
  76. maxframe_client/clients/framedriver.py +4 -1
  77. maxframe_client/fetcher.py +18 -2
  78. maxframe_client/session/odps.py +23 -10
  79. maxframe_client/session/task.py +2 -24
  80. maxframe_client/session/tests/test_task.py +0 -4
  81. maxframe_client/tests/test_session.py +30 -10
  82. {maxframe-1.0.0rc4.dist-info → maxframe-1.1.0.dist-info}/WHEEL +0 -0
  83. {maxframe-1.0.0rc4.dist-info → maxframe-1.1.0.dist-info}/top_level.txt +0 -0
@@ -18,10 +18,8 @@ from abc import ABC, abstractmethod
18
18
  from contextlib import contextmanager
19
19
  from typing import Dict, List, Optional, Union
20
20
 
21
- import numpy as np
22
21
  import pyarrow as pa
23
22
  from odps import ODPS
24
- from odps import __version__ as pyodps_version
25
23
  from odps.apis.storage_api import (
26
24
  StorageApiArrowClient,
27
25
  TableBatchScanResponse,
@@ -29,6 +27,7 @@ from odps.apis.storage_api import (
29
27
  )
30
28
  from odps.tunnel import TableTunnel
31
29
  from odps.types import OdpsSchema, PartitionSpec, timestamp_ntz
30
+ from odps.utils import call_with_retry
32
31
 
33
32
  try:
34
33
  import pyarrow.compute as pac
@@ -37,20 +36,18 @@ except ImportError:
37
36
 
38
37
  from ...config import options
39
38
  from ...env import ODPS_STORAGE_API_ENDPOINT
40
- from ...lib.version import Version
41
39
  from ...utils import sync_pyodps_options
42
40
  from .schema import odps_schema_to_arrow_schema
43
41
 
44
42
  PartitionsType = Union[List[str], str, None]
45
43
 
46
44
  _DEFAULT_ROW_BATCH_SIZE = 4096
47
- _need_patch_batch = Version(pyodps_version) < Version("0.12.0")
48
45
 
49
46
 
50
47
  class ODPSTableIO(ABC):
51
48
  def __new__(cls, odps: ODPS):
52
49
  if cls is ODPSTableIO:
53
- if options.use_common_table:
50
+ if options.use_common_table or ODPS_STORAGE_API_ENDPOINT in os.environ:
54
51
  return HaloTableIO(odps)
55
52
  else:
56
53
  return TunnelTableIO(odps)
@@ -132,7 +129,12 @@ class TunnelMultiPartitionReader:
132
129
  self._cur_partition_id = -1
133
130
  self._reader_start_pos = 0
134
131
 
135
- if partitions is None or isinstance(partitions, str):
132
+ if partitions is None:
133
+ if not self._table.table_schema.partitions:
134
+ self._partitions = [None]
135
+ else:
136
+ self._partitions = [str(pt) for pt in self._table.partitions]
137
+ elif isinstance(partitions, str):
136
138
  self._partitions = [partitions]
137
139
  else:
138
140
  self._partitions = partitions
@@ -160,17 +162,14 @@ class TunnelMultiPartitionReader:
160
162
  self._cur_partition_id += 1
161
163
 
162
164
  part_str = self._partitions[self._cur_partition_id]
163
-
164
- # todo make this more formal when PyODPS 0.12.0 is released
165
- req_columns = self._columns
166
- if not _need_patch_batch:
167
- req_columns = self._schema.names
165
+ req_columns = self._schema.names
168
166
  with sync_pyodps_options():
169
167
  self._cur_reader = self._table.open_reader(
170
168
  part_str,
171
169
  columns=req_columns,
172
170
  arrow=True,
173
171
  download_id=self._partition_to_download_ids.get(part_str),
172
+ append_partitions=True,
174
173
  )
175
174
  if self._cur_reader.count + self._reader_start_pos > self._start:
176
175
  start = self._start - self._reader_start_pos
@@ -186,35 +185,6 @@ class TunnelMultiPartitionReader:
186
185
  else:
187
186
  self._cur_reader = None
188
187
 
189
- def _fill_batch_partition(self, batch: pa.RecordBatch) -> pa.RecordBatch:
190
- pt_spec = PartitionSpec(self._partitions[self._cur_partition_id])
191
-
192
- names = list(batch.schema.names)
193
- arrays = []
194
- for idx in range(batch.num_columns):
195
- col = batch.column(idx)
196
- if isinstance(col.type, pa.TimestampType):
197
- if col.type.tz is not None:
198
- target_type = pa.timestamp(
199
- self._schema.types[idx].unit, col.type.tz
200
- )
201
- arrays.append(col.cast(target_type))
202
- else:
203
- target_type = pa.timestamp(
204
- self._schema.types[idx].unit, options.local_timezone
205
- )
206
- pd_col = col.to_pandas().dt.tz_localize(options.local_timezone)
207
- arrays.append(pa.Array.from_pandas(pd_col).cast(target_type))
208
- else:
209
- arrays.append(batch.column(idx))
210
-
211
- for part_col in self._partition_cols or []:
212
- names.append(part_col)
213
- col_type = self._schema.field_by_name(part_col).type
214
- pt_col = np.repeat([pt_spec[part_col]], batch.num_rows)
215
- arrays.append(pa.array(pt_col).cast(col_type))
216
- return pa.RecordBatch.from_arrays(arrays, names)
217
-
218
188
  def read(self):
219
189
  with sync_pyodps_options():
220
190
  if self._cur_reader is None:
@@ -227,10 +197,7 @@ class TunnelMultiPartitionReader:
227
197
  if batch is not None:
228
198
  if self._row_left is not None:
229
199
  self._row_left -= batch.num_rows
230
- if _need_patch_batch:
231
- return self._fill_batch_partition(batch)
232
- else:
233
- return batch
200
+ return batch
234
201
  except StopIteration:
235
202
  self._open_next_reader()
236
203
  return None
@@ -247,34 +214,6 @@ class TunnelMultiPartitionReader:
247
214
  return pa.Table.from_batches(batches)
248
215
 
249
216
 
250
- class TunnelWrappedWriter:
251
- def __init__(self, nested_writer):
252
- self._writer = nested_writer
253
-
254
- def write(self, data: Union[pa.RecordBatch, pa.Table]):
255
- if not any(isinstance(tp, pa.TimestampType) for tp in data.schema.types):
256
- self._writer.write(data)
257
- return
258
- pa_type = type(data)
259
- arrays = []
260
- for idx in range(data.num_columns):
261
- name = data.schema.names[idx]
262
- col = data.column(idx)
263
- if not isinstance(col.type, pa.TimestampType):
264
- arrays.append(col)
265
- continue
266
- if self._writer.schema[name].type == timestamp_ntz:
267
- col = HaloTableArrowWriter._localize_timezone(col, "UTC")
268
- else:
269
- col = HaloTableArrowWriter._localize_timezone(col)
270
- arrays.append(col)
271
- data = pa_type.from_arrays(arrays, names=data.schema.names)
272
- self._writer.write(data)
273
-
274
- def __getattr__(self, item):
275
- return getattr(self._writer, item)
276
-
277
-
278
217
  class TunnelTableIO(ODPSTableIO):
279
218
  @contextmanager
280
219
  def open_reader(
@@ -360,13 +299,7 @@ class TunnelTableIO(ODPSTableIO):
360
299
  create_partition=partition is not None,
361
300
  overwrite=overwrite,
362
301
  ) as writer:
363
- # fixme should yield writer directly once pyodps fixes
364
- # related arrow timestamp bug when provided schema and
365
- # table schema is identical.
366
- if _need_patch_batch:
367
- yield TunnelWrappedWriter(writer)
368
- else:
369
- yield writer
302
+ yield writer
370
303
 
371
304
 
372
305
  class HaloTableArrowReader:
@@ -422,7 +355,7 @@ class HaloTableArrowReader:
422
355
  split_index=self._cur_split_id + 1,
423
356
  **read_rows_kw,
424
357
  )
425
- self._cur_reader = self._client.read_rows_arrow(req)
358
+ self._cur_reader = call_with_retry(self._client.read_rows_arrow, req)
426
359
  self._cur_split_id += 1
427
360
 
428
361
  def _convert_timezone(self, batch: pa.RecordBatch) -> pa.RecordBatch:
@@ -494,8 +427,9 @@ class HaloTableArrowWriter:
494
427
  def open(self):
495
428
  from odps.apis.storage_api import WriteRowsRequest
496
429
 
497
- self._writer = self._client.write_rows_arrow(
498
- WriteRowsRequest(self._write_info.session_id)
430
+ self._writer = call_with_retry(
431
+ self._client.write_rows_arrow,
432
+ WriteRowsRequest(self._write_info.session_id),
499
433
  )
500
434
 
501
435
  @classmethod
@@ -566,28 +500,6 @@ class HaloTableIO(ODPSTableIO):
566
500
  for pt in partitions
567
501
  ]
568
502
 
569
- def get_table_record_count(
570
- self, full_table_name: str, partitions: PartitionsType = None
571
- ):
572
- from odps.apis.storage_api import SplitOptions, TableBatchScanRequest
573
-
574
- table = self._odps.get_table(full_table_name)
575
- client = StorageApiArrowClient(
576
- self._odps, table, rest_endpoint=self._storage_api_endpoint
577
- )
578
-
579
- split_option = SplitOptions.SplitMode.SIZE
580
-
581
- scan_kw = {
582
- "required_partitions": self._convert_partitions(partitions),
583
- "split_options": SplitOptions.get_default_options(split_option),
584
- }
585
-
586
- # todo add more options for partition column handling
587
- req = TableBatchScanRequest(**scan_kw)
588
- resp = client.create_read_session(req)
589
- return resp.record_count
590
-
591
503
  @contextmanager
592
504
  def open_reader(
593
505
  self,
@@ -631,12 +543,12 @@ class HaloTableIO(ODPSTableIO):
631
543
 
632
544
  # todo add more options for partition column handling
633
545
  req = TableBatchScanRequest(**scan_kw)
634
- resp = client.create_read_session(req)
546
+ resp = call_with_retry(client.create_read_session, req)
635
547
 
636
548
  session_id = resp.session_id
637
549
  status = resp.session_status
638
550
  while status == SessionStatus.INIT:
639
- resp = client.get_read_session(SessionRequest(session_id))
551
+ resp = call_with_retry(client.get_read_session, SessionRequest(session_id))
640
552
  status = resp.session_status
641
553
  time.sleep(1.0)
642
554
 
@@ -691,7 +603,7 @@ class HaloTableIO(ODPSTableIO):
691
603
  part_strs = self._convert_partitions(partition)
692
604
  part_str = part_strs[0] if part_strs else None
693
605
  req = TableBatchWriteRequest(partition_spec=part_str, overwrite=overwrite)
694
- resp = client.create_write_session(req)
606
+ resp = call_with_retry(client.create_write_session, req)
695
607
 
696
608
  session_id = resp.session_id
697
609
  writer = HaloTableArrowWriter(client, resp, table.table_schema)
@@ -700,9 +612,13 @@ class HaloTableIO(ODPSTableIO):
700
612
  yield writer
701
613
 
702
614
  commit_msg = writer.close()
703
- resp = client.commit_write_session(
704
- SessionRequest(session_id=session_id), [commit_msg]
615
+ resp = call_with_retry(
616
+ client.commit_write_session,
617
+ SessionRequest(session_id=session_id),
618
+ [commit_msg],
705
619
  )
706
620
  while resp.session_status == SessionStatus.COMMITTING:
707
- resp = client.get_write_session(SessionRequest(session_id=session_id))
621
+ resp = call_with_retry(
622
+ client.get_write_session, SessionRequest(session_id=session_id)
623
+ )
708
624
  assert resp.session_status == SessionStatus.COMMITTED
@@ -21,6 +21,7 @@ from odps import types as odps_types
21
21
  from .... import dataframe as md
22
22
  from .... import tensor as mt
23
23
  from ....core import OutputType
24
+ from ....utils import pd_release_version
24
25
  from ..schema import (
25
26
  arrow_schema_to_odps_schema,
26
27
  build_dataframe_table_meta,
@@ -292,3 +293,42 @@ def test_build_table_meta(wrap_obj):
292
293
  table_meta = build_dataframe_table_meta(test_df)
293
294
  expected_cols = ["a_2", "a_3", "a_0", "a_1_0", "a_1_1", "b", "c"]
294
295
  assert table_meta.table_column_names == expected_cols
296
+
297
+
298
+ @pytest.mark.skipif(
299
+ pd_release_version[0] < 2, reason="only run under pandas 2.0 or greater"
300
+ )
301
+ def test_table_meta_with_datetime():
302
+ raw_df = pd.DataFrame(
303
+ [
304
+ [1, "abc", "2024-10-01 11:23:12"],
305
+ [3, "uvw", "2024-10-02 22:55:13"],
306
+ ],
307
+ columns=["col1", "col2", "col3"],
308
+ )
309
+ df = md.DataFrame(raw_df).astype({"col3": "datetime64[ms]"})
310
+ schema, _ = pandas_to_odps_schema(df, unknown_as_string=True)
311
+ assert schema.columns[3].type == odps_types.datetime
312
+
313
+ raw_series = pd.Series(
314
+ ["2024-10-01 11:23:12", "2024-10-02 22:55:13"], dtype="datetime64[ms]"
315
+ )
316
+ s = md.Series(raw_series)
317
+ schema, _ = pandas_to_odps_schema(s, unknown_as_string=True)
318
+ assert schema.columns[1].type == odps_types.datetime
319
+
320
+ raw_index = pd.Index(
321
+ ["2024-10-01 11:23:12", "2024-10-02 22:55:13"], dtype="datetime64[ms]"
322
+ )
323
+ idx = md.Index(raw_index)
324
+ schema, _ = pandas_to_odps_schema(idx, unknown_as_string=True)
325
+ assert schema.columns[0].type == odps_types.datetime
326
+
327
+ src_df = pd.DataFrame(
328
+ [[1, "2024-10-01 11:23:12"], [3, "2024-10-02 22:55:13"]],
329
+ columns=["A", "B"],
330
+ ).astype({"B": "datetime64[ms]"})
331
+ raw_multiindex = pd.MultiIndex.from_frame(src_df)
332
+ multiidx = md.Index(raw_multiindex)
333
+ schema, _ = pandas_to_odps_schema(multiidx, unknown_as_string=True)
334
+ assert schema.columns[1].type == odps_types.datetime
@@ -31,7 +31,7 @@ def switch_table_io(request):
31
31
  old_use_common_table = options.use_common_table
32
32
  try:
33
33
  options.use_common_table = request.param
34
- yield
34
+ yield request.param
35
35
  finally:
36
36
  options.use_common_table = old_use_common_table
37
37
 
@@ -45,7 +45,7 @@ def test_empty_table_io(switch_table_io):
45
45
  table_io = ODPSTableIO(o)
46
46
 
47
47
  # test read from empty table
48
- empty_table_name = tn("test_empty_table_halo_read")
48
+ empty_table_name = tn("test_empty_table_halo_read_" + str(switch_table_io).lower())
49
49
  o.delete_table(empty_table_name, if_exists=True)
50
50
  tb = o.create_table(empty_table_name, "col1 string", lifecycle=1)
51
51
 
@@ -65,7 +65,7 @@ def test_table_io_without_parts(switch_table_io):
65
65
  table_io = ODPSTableIO(o)
66
66
 
67
67
  # test read and write tables without partition
68
- no_part_table_name = tn("test_no_part_halo_write")
68
+ no_part_table_name = tn("test_no_part_halo_write_" + str(switch_table_io).lower())
69
69
  o.delete_table(no_part_table_name, if_exists=True)
70
70
  col_desc = ",".join(f"{c} double" for c in "abcde") + ", f datetime"
71
71
  tb = o.create_table(no_part_table_name, col_desc, lifecycle=1)
@@ -99,7 +99,7 @@ def test_table_io_with_range_reader(switch_table_io):
99
99
  table_io = ODPSTableIO(o)
100
100
 
101
101
  # test read and write tables without partition
102
- no_part_table_name = tn("test_no_part_halo_write")
102
+ no_part_table_name = tn("test_halo_write_range_" + str(switch_table_io).lower())
103
103
  o.delete_table(no_part_table_name, if_exists=True)
104
104
  tb = o.create_table(
105
105
  no_part_table_name, ",".join(f"{c} double" for c in "abcde"), lifecycle=1
@@ -139,7 +139,7 @@ def test_table_io_with_parts(switch_table_io):
139
139
  table_io = ODPSTableIO(o)
140
140
 
141
141
  # test read and write tables with partition
142
- parted_table_name = tn("test_parted_halo_write")
142
+ parted_table_name = tn("test_parted_halo_write_" + str(switch_table_io).lower())
143
143
  o.delete_table(parted_table_name, if_exists=True)
144
144
  tb = o.create_table(
145
145
  parted_table_name,
@@ -42,15 +42,33 @@ def create_volume(request, oss_config):
42
42
  oss_bucket_name,
43
43
  oss_endpoint,
44
44
  ) = oss_config.oss_config
45
- test_location = "oss://%s:%s@%s/%s/%s" % (
46
- oss_access_id,
47
- oss_secret_access_key,
48
- oss_endpoint,
49
- oss_bucket_name,
50
- oss_test_dir_name,
51
- )
45
+
46
+ if "test" in oss_endpoint:
47
+ # offline config
48
+ test_location = "oss://%s:%s@%s/%s/%s" % (
49
+ oss_access_id,
50
+ oss_secret_access_key,
51
+ oss_endpoint,
52
+ oss_bucket_name,
53
+ oss_test_dir_name,
54
+ )
55
+ rolearn = None
56
+ else:
57
+ # online config
58
+ endpoint_parts = oss_endpoint.split(".", 1)
59
+ if "-internal" not in endpoint_parts[0]:
60
+ endpoint_parts[0] += "-internal"
61
+ test_location = "oss://%s/%s/%s" % (
62
+ ".".join(endpoint_parts),
63
+ oss_bucket_name,
64
+ oss_test_dir_name,
65
+ )
66
+ rolearn = oss_config.oss_rolearn
67
+
52
68
  oss_config.oss_bucket.put_object(oss_test_dir_name + "/", b"")
53
- odps_entry.create_external_volume(test_vol_name, location=test_location)
69
+ odps_entry.create_external_volume(
70
+ test_vol_name, location=test_location, rolearn=rolearn
71
+ )
54
72
  try:
55
73
  yield test_vol_name
56
74
  finally:
@@ -75,13 +93,19 @@ def test_read_write_volume(create_volume):
75
93
 
76
94
  odps_entry = ODPS.from_environments()
77
95
 
78
- writer = ODPSVolumeWriter(odps_entry, create_volume, test_vol_dir)
96
+ writer = ODPSVolumeWriter(
97
+ odps_entry, create_volume, test_vol_dir, replace_internal_host=True
98
+ )
79
99
 
80
- writer = ODPSVolumeWriter(odps_entry, create_volume, test_vol_dir)
100
+ writer = ODPSVolumeWriter(
101
+ odps_entry, create_volume, test_vol_dir, replace_internal_host=True
102
+ )
81
103
  writer.write_file("file1", b"content1")
82
104
  writer.write_file("file2", b"content2")
83
105
 
84
- reader = ODPSVolumeReader(odps_entry, create_volume, test_vol_dir)
106
+ reader = ODPSVolumeReader(
107
+ odps_entry, create_volume, test_vol_dir, replace_internal_host=True
108
+ )
85
109
  assert reader.read_file("file1") == b"content1"
86
110
  assert reader.read_file("file2") == b"content2"
87
111
 
@@ -16,13 +16,25 @@ import inspect
16
16
  from typing import Iterator, List, Optional, Union
17
17
 
18
18
  from odps import ODPS
19
+ from odps import __version__ as pyodps_version
20
+
21
+ from ...lib.version import Version
22
+
23
+ _has_replace_internal_host = Version(pyodps_version) >= Version("0.12.0")
19
24
 
20
25
 
21
26
  class ODPSVolumeReader:
22
- def __init__(self, odps_entry: ODPS, volume_name: str, volume_dir: str):
27
+ def __init__(
28
+ self,
29
+ odps_entry: ODPS,
30
+ volume_name: str,
31
+ volume_dir: str,
32
+ replace_internal_host: bool = False,
33
+ ):
23
34
  self._odps_entry = odps_entry
24
35
  self._volume = odps_entry.get_volume(volume_name)
25
36
  self._volume_dir = volume_dir
37
+ self._replace_internal_host = replace_internal_host
26
38
 
27
39
  def list_files(self) -> List[str]:
28
40
  def _get_file_name(vol_file):
@@ -38,7 +50,12 @@ class ODPSVolumeReader:
38
50
  ]
39
51
 
40
52
  def read_file(self, file_name: str) -> bytes:
41
- with self._volume.open_reader(self._volume_dir + "/" + file_name) as reader:
53
+ kw = {}
54
+ if _has_replace_internal_host and self._replace_internal_host:
55
+ kw = {"replace_internal_host": self._replace_internal_host}
56
+ with self._volume.open_reader(
57
+ self._volume_dir + "/" + file_name, **kw
58
+ ) as reader:
42
59
  return reader.read()
43
60
 
44
61
 
@@ -49,13 +66,20 @@ class ODPSVolumeWriter:
49
66
  volume_name: str,
50
67
  volume_dir: str,
51
68
  schema_name: Optional[str] = None,
69
+ replace_internal_host: bool = False,
52
70
  ):
53
71
  self._odps_entry = odps_entry
54
72
  self._volume = odps_entry.get_volume(volume_name, schema=schema_name)
55
73
  self._volume_dir = volume_dir
74
+ self._replace_internal_host = replace_internal_host
56
75
 
57
76
  def write_file(self, file_name: str, data: Union[bytes, Iterator[bytes]]):
58
- with self._volume.open_writer(self._volume_dir + "/" + file_name) as writer:
77
+ kw = {}
78
+ if _has_replace_internal_host and self._replace_internal_host:
79
+ kw = {"replace_internal_host": self._replace_internal_host}
80
+ with self._volume.open_writer(
81
+ self._volume_dir + "/" + file_name, **kw
82
+ ) as writer:
59
83
  if not inspect.isgenerator(data):
60
84
  writer.write(data)
61
85
  else:
@@ -12,7 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from . import graph, pytorch
15
+ from . import graph, llm, pytorch
16
16
 
17
- del pytorch
18
17
  del graph
18
+ del llm
19
+ del pytorch
@@ -0,0 +1,16 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from . import models, multi_modal, text
15
+
16
+ del models
@@ -0,0 +1,54 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict
15
+
16
+ import numpy as np
17
+ import pandas as pd
18
+
19
+ from ....core.entity.output_types import OutputType
20
+ from ....core.operator.base import Operator
21
+ from ....core.operator.core import TileableOperatorMixin
22
+ from ....dataframe.utils import parse_index
23
+ from ....serialization.serializables.core import Serializable
24
+ from ....serialization.serializables.field import AnyField, DictField, StringField
25
+
26
+
27
+ class LLM(Serializable):
28
+ name = StringField("name", default=None)
29
+
30
+ def validate_params(self, params: Dict[str, Any]):
31
+ pass
32
+
33
+
34
+ class LLMOperator(Operator, TileableOperatorMixin):
35
+ model = AnyField("model", default=None)
36
+ prompt_template = AnyField("prompt_template", default=None)
37
+ params = DictField("params", default=None)
38
+
39
+ def __init__(self, output_types=None, **kw):
40
+ if output_types is None:
41
+ output_types = [OutputType.dataframe]
42
+ super().__init__(_output_types=output_types, **kw)
43
+
44
+ def __call__(self, data):
45
+ col_names = ["response", "success"]
46
+ columns = parse_index(pd.Index(col_names), store_data=True)
47
+ out_dtypes = pd.Series([np.dtype("O"), np.dtype("bool")], index=col_names)
48
+ return self.new_tileable(
49
+ inputs=[data],
50
+ dtypes=out_dtypes,
51
+ shape=(data.shape[0], len(col_names)),
52
+ index_value=data.index_value,
53
+ columns_value=columns,
54
+ )
@@ -0,0 +1,14 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from .dashscope import DashScopeMultiModalLLM, DashScopeTextLLM
@@ -0,0 +1,73 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict
15
+
16
+ from ..... import opcodes
17
+ from .....serialization.serializables.core import Serializable
18
+ from .....serialization.serializables.field import StringField
19
+ from ..core import LLMOperator
20
+ from ..multi_modal import MultiModalLLM
21
+ from ..text import TextLLM
22
+
23
+
24
+ class DashScopeLLMMixin(Serializable):
25
+ __slots__ = ()
26
+
27
+ _not_supported_params = {"stream", "incremental_output"}
28
+
29
+ def validate_params(self, params: Dict[str, Any]):
30
+ for k in params.keys():
31
+ if k in self._not_supported_params:
32
+ raise ValueError(f"{k} is not supported")
33
+
34
+
35
+ class DashScopeTextLLM(TextLLM, DashScopeLLMMixin):
36
+ api_key_resource = StringField("api_key_resource", default=None)
37
+
38
+ def generate(
39
+ self,
40
+ data,
41
+ prompt_template: Dict[str, Any],
42
+ params: Dict[str, Any] = None,
43
+ ):
44
+ return DashScopeTextGenerationOperator(
45
+ model=self,
46
+ prompt_template=prompt_template,
47
+ params=params,
48
+ )(data)
49
+
50
+
51
+ class DashScopeMultiModalLLM(MultiModalLLM, DashScopeLLMMixin):
52
+ api_key_resource = StringField("api_key_resource", default=None)
53
+
54
+ def generate(
55
+ self,
56
+ data,
57
+ prompt_template: Dict[str, Any],
58
+ params: Dict[str, Any] = None,
59
+ ):
60
+ # TODO add precheck here
61
+ return DashScopeMultiModalGenerationOperator(
62
+ model=self,
63
+ prompt_template=prompt_template,
64
+ params=params,
65
+ )(data)
66
+
67
+
68
+ class DashScopeTextGenerationOperator(LLMOperator):
69
+ _op_type_ = opcodes.DASHSCOPE_TEXT_GENERATION
70
+
71
+
72
+ class DashScopeMultiModalGenerationOperator(LLMOperator):
73
+ _op_type_ = opcodes.DASHSCOPE_MULTI_MODAL_GENERATION