mplang-nightly 0.1.dev164__py3-none-any.whl → 0.1.dev165__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -14,16 +14,14 @@
14
14
 
15
15
  from __future__ import annotations
16
16
 
17
- from typing import Any
18
-
19
17
  from mplang.core.pfunc import PFunction
20
18
  from mplang.kernels.base import kernel_def
19
+ from mplang.kernels.value import TableValue
21
20
 
22
21
 
23
22
  @kernel_def("duckdb.run_sql")
24
- def _duckdb_sql(pfunc: PFunction, *args: Any) -> Any:
23
+ def _duckdb_sql(pfunc: PFunction, *args: TableValue) -> TableValue:
25
24
  import duckdb
26
- import pandas as pd
27
25
 
28
26
  # TODO: maybe we could translate the sql to duckdb dialect
29
27
  # instead of raising an exception
@@ -36,12 +34,9 @@ def _duckdb_sql(pfunc: PFunction, *args: Any) -> Any:
36
34
  if in_names is None:
37
35
  raise ValueError("duckdb sql missing in_names attr")
38
36
  for arg, name in zip(args, in_names, strict=True):
39
- if isinstance(arg, pd.DataFrame):
40
- df = arg
41
- elif isinstance(arg, list): # const list-of-dict for tests
42
- df = pd.DataFrame.from_records(arg)
43
- else:
44
- raise ValueError(f"unsupported duckdb input type {type(arg)}")
45
- conn.register(name, df)
46
- res_df = conn.execute(pfunc.fn_text).fetchdf()
47
- return res_df
37
+ # Use Arrow directly for zero-copy data transfer
38
+ arrow_table = arg.to_arrow()
39
+ conn.register(name, arrow_table)
40
+ # Fetch result as Arrow table for consistency
41
+ res_arrow = conn.execute(pfunc.fn_text).fetch_arrow_table()
42
+ return TableValue(res_arrow)
@@ -18,11 +18,13 @@ from typing import Any
18
18
 
19
19
  import jax
20
20
  import jax.numpy as jnp
21
+ import numpy as np
21
22
  from jax._src import xla_bridge
22
23
  from jax.lib import xla_client as xc
23
24
 
24
25
  from mplang.core.pfunc import PFunction
25
26
  from mplang.kernels.base import cur_kctx, kernel_def
27
+ from mplang.kernels.value import TensorValue
26
28
 
27
29
 
28
30
  @kernel_def("mlir.stablehlo")
@@ -61,13 +63,17 @@ def _stablehlo_exec(pfunc: PFunction, *args: Any) -> Any:
61
63
  # Filter out arguments that were eliminated by JAX during compilation
62
64
  runtime_args = tuple(args[i] for i in keep_indices)
63
65
 
64
- jax_args = []
65
- for arg in runtime_args:
66
- if hasattr(arg, "numpy"):
67
- jax_arg = jnp.array(arg.numpy()) # type: ignore
68
- else:
69
- jax_arg = jnp.array(arg)
70
- jax_args.append(jax.device_put(jax_arg))
66
+ tensor_args: list[TensorValue] = []
67
+ for idx, arg in enumerate(runtime_args):
68
+ if not isinstance(arg, TensorValue):
69
+ raise TypeError(
70
+ f"StableHLO kernel expects TensorValue inputs, got {type(arg).__name__} at position {idx}"
71
+ )
72
+ tensor_args.append(arg)
73
+
74
+ jax_args = [
75
+ jax.device_put(jnp.asarray(tensor.to_numpy())) for tensor in tensor_args
76
+ ]
71
77
 
72
78
  try:
73
79
  result = compiled.execute_sharded(jax_args)
@@ -75,9 +81,9 @@ def _stablehlo_exec(pfunc: PFunction, *args: Any) -> Any:
75
81
  flat: list[Any] = []
76
82
  for lst in arrays:
77
83
  if isinstance(lst, list) and len(lst) == 1:
78
- flat.append(jnp.array(lst[0]))
84
+ flat.append(TensorValue(np.asarray(lst[0])))
79
85
  else:
80
- flat.extend([jnp.array(a) for a in lst])
86
+ flat.extend(TensorValue(np.asarray(a)) for a in lst)
81
87
  return tuple(flat)
82
88
  except Exception as e: # pragma: no cover
83
89
  raise RuntimeError(f"StableHLO execute failed: {e}") from e