mplang-nightly 0.1.dev268__py3-none-any.whl → 0.1.dev270__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +391 -17
- mplang/{v2/backends → backends}/__init__.py +9 -7
- mplang/{v2/backends → backends}/bfv_impl.py +6 -6
- mplang/{v2/backends → backends}/crypto_impl.py +6 -6
- mplang/{v2/backends → backends}/field_impl.py +5 -5
- mplang/{v2/backends → backends}/func_impl.py +4 -4
- mplang/{v2/backends → backends}/phe_impl.py +3 -3
- mplang/{v2/backends → backends}/simp_design.md +1 -1
- mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
- mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
- mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
- mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
- mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
- mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
- mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
- mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
- mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
- mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
- mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
- mplang/{v2/backends → backends}/spu_impl.py +8 -8
- mplang/{v2/backends → backends}/spu_state.py +4 -4
- mplang/{v2/backends → backends}/store_impl.py +3 -3
- mplang/{v2/backends → backends}/table_impl.py +8 -8
- mplang/{v2/backends → backends}/tee_impl.py +6 -6
- mplang/{v2/backends → backends}/tensor_impl.py +6 -6
- mplang/{v2/cli.py → cli.py} +9 -9
- mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
- mplang/{v2/dialects → dialects}/__init__.py +5 -5
- mplang/{v2/dialects → dialects}/bfv.py +6 -6
- mplang/{v2/dialects → dialects}/crypto.py +5 -5
- mplang/{v2/dialects → dialects}/dtypes.py +2 -2
- mplang/{v2/dialects → dialects}/field.py +3 -3
- mplang/{v2/dialects → dialects}/func.py +2 -2
- mplang/{v2/dialects → dialects}/phe.py +6 -6
- mplang/{v2/dialects → dialects}/simp.py +6 -6
- mplang/{v2/dialects → dialects}/spu.py +7 -7
- mplang/{v2/dialects → dialects}/store.py +2 -2
- mplang/{v2/dialects → dialects}/table.py +3 -3
- mplang/{v2/dialects → dialects}/tee.py +6 -6
- mplang/{v2/dialects → dialects}/tensor.py +5 -5
- mplang/{v2/edsl → edsl}/__init__.py +3 -3
- mplang/{v2/edsl → edsl}/context.py +6 -6
- mplang/{v2/edsl → edsl}/graph.py +5 -5
- mplang/{v2/edsl → edsl}/jit.py +2 -2
- mplang/{v2/edsl → edsl}/object.py +1 -1
- mplang/{v2/edsl → edsl}/primitive.py +5 -5
- mplang/{v2/edsl → edsl}/printer.py +1 -1
- mplang/{v2/edsl → edsl}/serde.py +1 -1
- mplang/{v2/edsl → edsl}/tracer.py +7 -7
- mplang/{v2/edsl → edsl}/typing.py +1 -1
- mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
- mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
- mplang/{v2/kernels → kernels}/okvs_opt.cpp +46 -31
- mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
- mplang/{v2/libs → libs}/collective.py +5 -5
- mplang/{v2/libs → libs}/device/__init__.py +1 -1
- mplang/{v2/libs → libs}/device/api.py +12 -12
- mplang/{v2/libs → libs}/ml/__init__.py +1 -1
- mplang/{v2/libs → libs}/ml/sgb.py +4 -4
- mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
- mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
- mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
- mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
- mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
- mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +19 -13
- mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
- mplang/libs/mpc/psi/rr22.py +303 -0
- mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
- mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
- mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
- mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
- mplang/{v2/runtime → runtime}/interpreter.py +11 -11
- mplang/{v2/runtime → runtime}/value.py +2 -2
- mplang/{v1/runtime → utils}/__init__.py +18 -15
- mplang/{v1/utils → utils}/func_utils.py +1 -1
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/METADATA +2 -2
- mplang_nightly-0.1.dev270.dist-info/RECORD +102 -0
- mplang/v1/__init__.py +0 -157
- mplang/v1/_device.py +0 -602
- mplang/v1/analysis/__init__.py +0 -37
- mplang/v1/analysis/diagram.py +0 -567
- mplang/v1/core/__init__.py +0 -157
- mplang/v1/core/cluster.py +0 -343
- mplang/v1/core/comm.py +0 -281
- mplang/v1/core/context_mgr.py +0 -50
- mplang/v1/core/dtypes.py +0 -335
- mplang/v1/core/expr/__init__.py +0 -80
- mplang/v1/core/expr/ast.py +0 -542
- mplang/v1/core/expr/evaluator.py +0 -581
- mplang/v1/core/expr/printer.py +0 -285
- mplang/v1/core/expr/transformer.py +0 -141
- mplang/v1/core/expr/utils.py +0 -78
- mplang/v1/core/expr/visitor.py +0 -85
- mplang/v1/core/expr/walk.py +0 -387
- mplang/v1/core/interp.py +0 -160
- mplang/v1/core/mask.py +0 -325
- mplang/v1/core/mpir.py +0 -965
- mplang/v1/core/mpobject.py +0 -117
- mplang/v1/core/mptype.py +0 -407
- mplang/v1/core/pfunc.py +0 -130
- mplang/v1/core/primitive.py +0 -877
- mplang/v1/core/table.py +0 -218
- mplang/v1/core/tensor.py +0 -75
- mplang/v1/core/tracer.py +0 -383
- mplang/v1/host.py +0 -130
- mplang/v1/kernels/__init__.py +0 -41
- mplang/v1/kernels/base.py +0 -125
- mplang/v1/kernels/basic.py +0 -240
- mplang/v1/kernels/context.py +0 -369
- mplang/v1/kernels/crypto.py +0 -122
- mplang/v1/kernels/fhe.py +0 -858
- mplang/v1/kernels/mock_tee.py +0 -72
- mplang/v1/kernels/phe.py +0 -1864
- mplang/v1/kernels/spu.py +0 -341
- mplang/v1/kernels/sql_duckdb.py +0 -44
- mplang/v1/kernels/stablehlo.py +0 -90
- mplang/v1/kernels/value.py +0 -626
- mplang/v1/ops/__init__.py +0 -35
- mplang/v1/ops/base.py +0 -424
- mplang/v1/ops/basic.py +0 -294
- mplang/v1/ops/crypto.py +0 -262
- mplang/v1/ops/fhe.py +0 -272
- mplang/v1/ops/jax_cc.py +0 -147
- mplang/v1/ops/nnx_cc.py +0 -168
- mplang/v1/ops/phe.py +0 -216
- mplang/v1/ops/spu.py +0 -151
- mplang/v1/ops/sql_cc.py +0 -303
- mplang/v1/ops/tee.py +0 -36
- mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
- mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
- mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
- mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
- mplang/v1/runtime/channel.py +0 -230
- mplang/v1/runtime/cli.py +0 -451
- mplang/v1/runtime/client.py +0 -456
- mplang/v1/runtime/communicator.py +0 -131
- mplang/v1/runtime/data_providers.py +0 -303
- mplang/v1/runtime/driver.py +0 -324
- mplang/v1/runtime/exceptions.py +0 -27
- mplang/v1/runtime/http_api.md +0 -56
- mplang/v1/runtime/link_comm.py +0 -196
- mplang/v1/runtime/server.py +0 -501
- mplang/v1/runtime/session.py +0 -270
- mplang/v1/runtime/simulation.py +0 -324
- mplang/v1/simp/__init__.py +0 -13
- mplang/v1/simp/api.py +0 -353
- mplang/v1/simp/mpi.py +0 -131
- mplang/v1/simp/party.py +0 -225
- mplang/v1/simp/random.py +0 -120
- mplang/v1/simp/smpc.py +0 -238
- mplang/v1/utils/__init__.py +0 -13
- mplang/v1/utils/crypto.py +0 -32
- mplang/v1/utils/spu_utils.py +0 -130
- mplang/v1/utils/table_utils.py +0 -185
- mplang/v2/__init__.py +0 -424
- mplang/v2/libs/mpc/psi/rr22.py +0 -344
- mplang_nightly-0.1.dev268.dist-info/RECORD +0 -180
- /mplang/{v2/backends → backends}/channel.py +0 -0
- /mplang/{v2/edsl → edsl}/README.md +0 -0
- /mplang/{v2/edsl → edsl}/registry.py +0 -0
- /mplang/{v2/kernels → kernels}/Makefile +0 -0
- /mplang/{v2/kernels → kernels}/__init__.py +0 -0
- /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
- /mplang/{v2/libs → libs}/device/cluster.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
- /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
- /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
- /mplang/{v2/runtime → runtime}/object_store.py +0 -0
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/licenses/LICENSE +0 -0
mplang/v1/utils/spu_utils.py
DELETED
|
@@ -1,130 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
"""SPU-related utility functions for protocol and field type conversion."""
|
|
16
|
-
|
|
17
|
-
import spu.libspu as libspu
|
|
18
|
-
|
|
19
|
-
# Global mappings for SPU protocol and field type conversion
|
|
20
|
-
SPU_PROTOCOL_MAPPING = {
|
|
21
|
-
"REF2K": libspu.ProtocolKind.REF2K,
|
|
22
|
-
"SEMI2K": libspu.ProtocolKind.SEMI2K,
|
|
23
|
-
"ABY3": libspu.ProtocolKind.ABY3,
|
|
24
|
-
"CHEETAH": libspu.ProtocolKind.CHEETAH,
|
|
25
|
-
"SECURENN": libspu.ProtocolKind.SECURENN,
|
|
26
|
-
}
|
|
27
|
-
|
|
28
|
-
SPU_FIELD_MAPPING = {
|
|
29
|
-
"FM32": libspu.FieldType.FM32,
|
|
30
|
-
"FM64": libspu.FieldType.FM64,
|
|
31
|
-
"FM128": libspu.FieldType.FM128,
|
|
32
|
-
}
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
def parse_protocol(protocol: str | int) -> libspu.ProtocolKind:
|
|
36
|
-
"""Parse SPU protocol from string or integer to ProtocolKind enum.
|
|
37
|
-
|
|
38
|
-
Args:
|
|
39
|
-
protocol: Protocol specification as string (e.g., "SEMI2K") or integer.
|
|
40
|
-
|
|
41
|
-
Returns:
|
|
42
|
-
libspu.ProtocolKind: The corresponding protocol enum.
|
|
43
|
-
|
|
44
|
-
Raises:
|
|
45
|
-
ValueError: If the protocol is invalid.
|
|
46
|
-
|
|
47
|
-
Examples:
|
|
48
|
-
>>> parse_spu_protocol("SEMI2K")
|
|
49
|
-
ProtocolKind.SEMI2K
|
|
50
|
-
>>> parse_spu_protocol(2)
|
|
51
|
-
ProtocolKind.SEMI2K
|
|
52
|
-
"""
|
|
53
|
-
if isinstance(protocol, str):
|
|
54
|
-
if protocol not in SPU_PROTOCOL_MAPPING:
|
|
55
|
-
raise ValueError(
|
|
56
|
-
f"Invalid SPU protocol: {protocol}. "
|
|
57
|
-
f"Valid protocols are: {list(SPU_PROTOCOL_MAPPING.keys())}"
|
|
58
|
-
)
|
|
59
|
-
return SPU_PROTOCOL_MAPPING[protocol]
|
|
60
|
-
else:
|
|
61
|
-
# Assume it's an integer, validate it
|
|
62
|
-
try:
|
|
63
|
-
spu_protocol = libspu.ProtocolKind(protocol)
|
|
64
|
-
# Check if it's a valid enum value (not ???)
|
|
65
|
-
if spu_protocol.name == "???":
|
|
66
|
-
raise ValueError(f"Invalid SPU protocol value: {protocol}")
|
|
67
|
-
return spu_protocol
|
|
68
|
-
except TypeError as exc:
|
|
69
|
-
raise ValueError(
|
|
70
|
-
f"Invalid SPU protocol: {protocol}. "
|
|
71
|
-
f"Must be a valid protocol string or integer."
|
|
72
|
-
) from exc
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
def parse_field(field: str | int) -> libspu.FieldType:
|
|
76
|
-
"""Parse SPU field type from string or integer to FieldType enum.
|
|
77
|
-
|
|
78
|
-
Args:
|
|
79
|
-
field: Field type specification as string (e.g., "FM64") or integer.
|
|
80
|
-
|
|
81
|
-
Returns:
|
|
82
|
-
libspu.FieldType: The corresponding field type enum.
|
|
83
|
-
|
|
84
|
-
Raises:
|
|
85
|
-
ValueError: If the field type is invalid.
|
|
86
|
-
|
|
87
|
-
Examples:
|
|
88
|
-
>>> parse_spu_field("FM64")
|
|
89
|
-
FieldType.FM64
|
|
90
|
-
>>> parse_spu_field(2)
|
|
91
|
-
FieldType.FM64
|
|
92
|
-
"""
|
|
93
|
-
if isinstance(field, str):
|
|
94
|
-
if field not in SPU_FIELD_MAPPING:
|
|
95
|
-
raise ValueError(
|
|
96
|
-
f"Invalid SPU field type: {field}. "
|
|
97
|
-
f"Valid field types are: {list(SPU_FIELD_MAPPING.keys())}"
|
|
98
|
-
)
|
|
99
|
-
return SPU_FIELD_MAPPING[field]
|
|
100
|
-
else:
|
|
101
|
-
# Assume it's an integer, validate it
|
|
102
|
-
try:
|
|
103
|
-
spu_field = libspu.FieldType(field)
|
|
104
|
-
# Check if it's a valid enum value
|
|
105
|
-
if spu_field.name == "???":
|
|
106
|
-
raise ValueError(f"Invalid SPU field type value: {field}")
|
|
107
|
-
return spu_field
|
|
108
|
-
except TypeError as exc:
|
|
109
|
-
raise ValueError(
|
|
110
|
-
f"Invalid SPU field type: {field}. "
|
|
111
|
-
f"Must be a valid field type string or integer."
|
|
112
|
-
) from exc
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
def list_protocols() -> list[str]:
|
|
116
|
-
"""Get list of valid SPU protocol names.
|
|
117
|
-
|
|
118
|
-
Returns:
|
|
119
|
-
List of valid protocol names as strings.
|
|
120
|
-
"""
|
|
121
|
-
return list(SPU_PROTOCOL_MAPPING.keys())
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
def list_fields() -> list[str]:
|
|
125
|
-
"""Get list of valid SPU field type names.
|
|
126
|
-
|
|
127
|
-
Returns:
|
|
128
|
-
List of valid field type names as strings.
|
|
129
|
-
"""
|
|
130
|
-
return list(SPU_FIELD_MAPPING.keys())
|
mplang/v1/utils/table_utils.py
DELETED
|
@@ -1,185 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
from __future__ import annotations
|
|
16
|
-
|
|
17
|
-
import io
|
|
18
|
-
from typing import Any
|
|
19
|
-
|
|
20
|
-
import pyarrow as pa
|
|
21
|
-
import pyarrow.csv as pa_csv
|
|
22
|
-
import pyarrow.orc as pa_orc
|
|
23
|
-
import pyarrow.parquet as pa_pq
|
|
24
|
-
|
|
25
|
-
from mplang.v1.core.table import TableLike
|
|
26
|
-
|
|
27
|
-
__all__ = ["decode_table", "encode_table", "read_table", "write_table"]
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
def _parse_kwargs(kwargs: dict[str, Any], keys: list[str]) -> dict[str, Any] | None:
|
|
31
|
-
if not kwargs:
|
|
32
|
-
return None
|
|
33
|
-
|
|
34
|
-
return {key: kwargs[key] for key in keys if key in kwargs}
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
_csv_read_option_keys = [
|
|
38
|
-
"skip_rows",
|
|
39
|
-
"skip_rows_after_names",
|
|
40
|
-
"column_names",
|
|
41
|
-
"autogenerate_column_names",
|
|
42
|
-
"encoding",
|
|
43
|
-
]
|
|
44
|
-
_csv_parse_option_keys = [
|
|
45
|
-
"delimiter",
|
|
46
|
-
"quote_char",
|
|
47
|
-
"double_quote",
|
|
48
|
-
"escape_char",
|
|
49
|
-
"newlines_in_values",
|
|
50
|
-
"ignore_empty_lines",
|
|
51
|
-
]
|
|
52
|
-
_csv_convert_option_keys = [
|
|
53
|
-
"check_utf8",
|
|
54
|
-
"column_types",
|
|
55
|
-
"null_values",
|
|
56
|
-
"true_values",
|
|
57
|
-
"false_values",
|
|
58
|
-
"decimal_point",
|
|
59
|
-
"strings_can_be_null",
|
|
60
|
-
"quoted_strings_can_be_null",
|
|
61
|
-
"include_columns",
|
|
62
|
-
"include_missing_columns",
|
|
63
|
-
"auto_dict_encode",
|
|
64
|
-
"auto_dict_max_cardinality",
|
|
65
|
-
"timestamp_parsers",
|
|
66
|
-
]
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
def read_table(
|
|
70
|
-
source: Any,
|
|
71
|
-
format: str = "parquet",
|
|
72
|
-
columns: list[str] | None = None,
|
|
73
|
-
**kwargs: Any,
|
|
74
|
-
) -> pa.Table:
|
|
75
|
-
"""Read data from a file and return a PyArrow table.
|
|
76
|
-
|
|
77
|
-
Args:
|
|
78
|
-
source: The source to read data from (file path, file-like object, etc.)
|
|
79
|
-
format: The format of the data source ("parquet", "csv", or "orc")
|
|
80
|
-
columns: List of column names to read (None means all columns)
|
|
81
|
-
**kwargs: Additional keyword arguments passed to the underlying reader
|
|
82
|
-
|
|
83
|
-
Returns:
|
|
84
|
-
A PyArrow Table containing the data from the source
|
|
85
|
-
|
|
86
|
-
Raises:
|
|
87
|
-
ValueError: If an unsupported format is specified
|
|
88
|
-
"""
|
|
89
|
-
match format:
|
|
90
|
-
case "csv":
|
|
91
|
-
if columns:
|
|
92
|
-
kwargs["include_columns"] = columns
|
|
93
|
-
read_args = _parse_kwargs(kwargs, _csv_read_option_keys)
|
|
94
|
-
parse_args = _parse_kwargs(kwargs, _csv_parse_option_keys)
|
|
95
|
-
convert_args = _parse_kwargs(kwargs, _csv_convert_option_keys)
|
|
96
|
-
|
|
97
|
-
read_opts = pa_csv.ReadOptions(**read_args) if read_args else None
|
|
98
|
-
parse_opts = pa_csv.ParseOptions(**parse_args) if parse_args else None
|
|
99
|
-
conv_opts = pa_csv.ConvertOptions(**convert_args) if convert_args else None
|
|
100
|
-
return pa_csv.read_csv(
|
|
101
|
-
source,
|
|
102
|
-
read_options=read_opts,
|
|
103
|
-
parse_options=parse_opts,
|
|
104
|
-
convert_options=conv_opts,
|
|
105
|
-
)
|
|
106
|
-
case "orc":
|
|
107
|
-
return pa_orc.read_table(source, columns=columns, **kwargs)
|
|
108
|
-
case "parquet":
|
|
109
|
-
return pa_pq.read_table(source, columns=columns, **kwargs)
|
|
110
|
-
case _:
|
|
111
|
-
raise ValueError(f"unsupported data format. {format}")
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
def write_table(
|
|
115
|
-
data: TableLike,
|
|
116
|
-
where: Any,
|
|
117
|
-
format: str = "parquet",
|
|
118
|
-
**kwargs: Any,
|
|
119
|
-
) -> None:
|
|
120
|
-
"""Write a table-like object to a file in the specified format.
|
|
121
|
-
|
|
122
|
-
Args:
|
|
123
|
-
data: The table-like object to write (PyArrow Table or other compatible format)
|
|
124
|
-
where: The destination to write to (file path, file-like object, etc.)
|
|
125
|
-
format: The format to write the data in ("parquet", "csv", or "orc")
|
|
126
|
-
**kwargs: Additional keyword arguments passed to the underlying writer
|
|
127
|
-
|
|
128
|
-
Returns:
|
|
129
|
-
None
|
|
130
|
-
|
|
131
|
-
Raises:
|
|
132
|
-
ValueError: If the table has no columns or an unsupported format is specified
|
|
133
|
-
"""
|
|
134
|
-
# Convert data to PyArrow Table if needed
|
|
135
|
-
table = data if isinstance(data, pa.Table) else pa.table(data)
|
|
136
|
-
if len(table.column_names) == 0:
|
|
137
|
-
raise ValueError("Cannot convert Table with no columns.")
|
|
138
|
-
|
|
139
|
-
match format:
|
|
140
|
-
case "csv":
|
|
141
|
-
options = pa_csv.WriteOptions(**kwargs) if kwargs else None
|
|
142
|
-
pa_csv.write_csv(table, where, write_options=options)
|
|
143
|
-
case "orc":
|
|
144
|
-
pa_orc.write_table(table, where, **kwargs)
|
|
145
|
-
case "parquet":
|
|
146
|
-
pa_pq.write_table(table, where, **kwargs)
|
|
147
|
-
case _:
|
|
148
|
-
raise ValueError(f"unsupported data format. {format}")
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
def decode_table(
|
|
152
|
-
data: bytes,
|
|
153
|
-
format: str = "parquet",
|
|
154
|
-
columns: list[str] | None = None,
|
|
155
|
-
**kwargs: Any,
|
|
156
|
-
) -> pa.Table:
|
|
157
|
-
"""Decode a bytes object into a PyArrow table.
|
|
158
|
-
|
|
159
|
-
Args:
|
|
160
|
-
data: The bytes object containing the encoded table data
|
|
161
|
-
format: The format of the encoded data ("parquet", "csv", or "orc")
|
|
162
|
-
columns: List of column names to decode (None means all columns)
|
|
163
|
-
**kwargs: Additional keyword arguments passed to the underlying reader
|
|
164
|
-
|
|
165
|
-
Returns:
|
|
166
|
-
A PyArrow Table decoded from the bytes data
|
|
167
|
-
"""
|
|
168
|
-
buffer = io.BytesIO(data)
|
|
169
|
-
return read_table(buffer, format=format, columns=columns, **kwargs)
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
def encode_table(data: TableLike, format: str = "parquet", **kwargs: Any) -> bytes:
|
|
173
|
-
"""Encode a table-like object into bytes.
|
|
174
|
-
|
|
175
|
-
Args:
|
|
176
|
-
data: The table-like object to encode (PyArrow Table or other compatible format)
|
|
177
|
-
format: The format to encode the data in ("parquet", "csv", or "orc")
|
|
178
|
-
**kwargs: Additional keyword arguments passed to the underlying writer
|
|
179
|
-
|
|
180
|
-
Returns:
|
|
181
|
-
Bytes object containing the encoded table data
|
|
182
|
-
"""
|
|
183
|
-
buffer = io.BytesIO()
|
|
184
|
-
write_table(data, buffer, format, **kwargs)
|
|
185
|
-
return buffer.getvalue()
|