mplang-nightly 0.1.dev252__py3-none-any.whl → 0.1.dev254__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mplang/v2/__init__.py CHANGED
@@ -52,15 +52,19 @@ from mplang.v2.edsl import (
52
52
  TracedFunction,
53
53
  Tracer,
54
54
  Value,
55
+ find_context,
56
+ find_context_with_state,
57
+ find_interpreter,
55
58
  format_graph,
56
59
  get_current_context,
57
60
  get_default_context,
58
- get_root_context,
61
+ is_tracing,
59
62
  jit,
60
63
  pop_context,
61
64
  primitive,
62
65
  push_context,
63
66
  register_default_context_factory,
67
+ set_root_context,
64
68
  trace,
65
69
  )
66
70
  from mplang.v2.edsl.registry import get_profiler
@@ -97,34 +101,6 @@ from mplang.v2.runtime.interpreter import Interpreter
97
101
  # =============================================================================
98
102
 
99
103
 
100
- def set_root_context(context: Interpreter, force: bool = False) -> None:
101
- """Set the global/root execution context.
102
-
103
- This explicitly sets the provided interpreter as the Root Context.
104
- All subsequent operations (compile, evaluate, device resolution) will
105
- use this context as the default environment.
106
-
107
- Args:
108
- context: Interpreter to use as the root context.
109
- force: If True, clears the existing context stack before setting.
110
- If False (default), pushes onto the stack.
111
- """
112
- from mplang.v2.edsl.context import _context_stack, get_current_context
113
-
114
- if force:
115
- _context_stack.clear()
116
- _context_stack.append(context)
117
- return
118
-
119
- if get_current_context() is not None:
120
- raise RuntimeError(
121
- "Cannot set root context: Context stack is not empty. "
122
- "Use force=True to overwrite the existing root context."
123
- )
124
-
125
- push_context(context)
126
-
127
-
128
104
  def _get_context(context: Interpreter | None) -> Interpreter:
129
105
  """Get context from parameter or context stack."""
130
106
  if context is not None:
@@ -204,6 +180,30 @@ def fetch(
204
180
  ) -> Any:
205
181
  """Fetch results from interpreter context to Python.
206
182
 
183
+ This is a meta-function that operates at execution boundaries, not a traced
184
+ dialect operation. It brings data from the distributed/MPC runtime back to
185
+ the Python host.
186
+
187
+ Behavior in different contexts:
188
+ - **Tracing (compile)**: Returns the input unchanged (identity). The graph
189
+ outputs are determined by the function's return statement, not fetch calls.
190
+ - **Execution (evaluate)**: Actually fetches data from workers/parties.
191
+
192
+ Design Note (A vs B tradeoff):
193
+ Two designs were considered for fetch behavior during tracing:
194
+
195
+ - **Design A (chosen)**: fetch = identity during tracing. Graph outputs
196
+ are determined solely by the return statement. This is simpler and
197
+ avoids ambiguity when fetch and return reference different values.
198
+
199
+ - **Design B (alternative)**: fetch marks output points in the graph.
200
+ This would allow fetch(a), fetch(b), return b to output both a and b.
201
+ However, it complicates the semantics and requires tracking fetch
202
+ points separately from return values.
203
+
204
+ Design A was chosen for simplicity. If a value needs to be an output,
205
+ it should be returned. fetch's role is purely for execution-time I/O.
206
+
207
207
  Args:
208
208
  result: Object(s) to fetch. Can be a single InterpObject, DriverVar,
209
209
  or nested structure containing them.
@@ -216,14 +216,22 @@ def fetch(
216
216
  Fetched Python values. For device objects with follow_device=True,
217
217
  returns single value from the device's rank(s). Otherwise returns
218
218
  list of values (one per party) or single value for world_size=1.
219
+ During tracing, returns the input unchanged.
219
220
  """
220
-
221
221
  from jax.tree_util import tree_map
222
222
 
223
223
  from mplang.v2.backends.simp_driver.values import DriverVar
224
+ from mplang.v2.edsl.context import is_tracing
224
225
  from mplang.v2.runtime.interpreter import InterpObject
225
226
  from mplang.v2.runtime.value import WrapValue
226
227
 
228
+ # Check if we are in tracing context - if so, return identity
229
+ if is_tracing():
230
+ # Design A: fetch = identity during tracing
231
+ # Graph outputs are determined by return statement, not fetch calls
232
+ return result
233
+
234
+ # Execution context - actually fetch data
227
235
  interp = _get_context(context)
228
236
 
229
237
  def _fetch_single(var: Any) -> Any:
@@ -375,10 +383,14 @@ __all__ = [ # noqa: RUF022
375
383
  "compile",
376
384
  "evaluate",
377
385
  "fetch",
386
+ "find_context",
387
+ "find_context_with_state",
388
+ "find_interpreter",
378
389
  "format_graph",
379
390
  "function",
380
391
  "get_current_context",
381
392
  "get_default_context",
393
+ "is_tracing",
382
394
  "jit",
383
395
  "mplang",
384
396
  "pop_context",
@@ -404,7 +416,6 @@ __all__ = [ # noqa: RUF022
404
416
  # Dialects
405
417
  "dialects",
406
418
  "register_default_context_factory",
407
- "get_root_context",
408
419
  "get_profiler",
409
420
  ]
410
421
 
@@ -154,6 +154,10 @@ class SimpMemDriver(SimpDriver):
154
154
  self._workers = workers
155
155
  self._mesh = mesh
156
156
 
157
+ def shutdown(self) -> None:
158
+ """Shutdown the local memory driver and its mesh."""
159
+ self._mesh.shutdown()
160
+
157
161
  @property
158
162
  def world_size(self) -> int:
159
163
  return self._world_size
mplang/v2/cli.py CHANGED
@@ -62,7 +62,7 @@ def run_worker(
62
62
  signal.signal(signal.SIGINT, signal.SIG_DFL)
63
63
  signal.signal(signal.SIGTERM, signal.SIG_DFL)
64
64
 
65
- from mplang.v2.backends.simp_http_worker import create_worker_app
65
+ from mplang.v2.backends.simp_worker.http import create_worker_app
66
66
 
67
67
  app = create_worker_app(rank, world_size, endpoints, spu_endpoints)
68
68
 
@@ -33,12 +33,16 @@ from . import typing as typing
33
33
  # Context management
34
34
  from .context import (
35
35
  Context,
36
+ find_context,
37
+ find_context_with_state,
38
+ find_interpreter,
36
39
  get_current_context,
37
40
  get_default_context,
38
- get_root_context,
41
+ is_tracing,
39
42
  pop_context,
40
43
  push_context,
41
44
  register_default_context_factory,
45
+ set_root_context,
42
46
  )
43
47
 
44
48
  # Graph IR
@@ -77,15 +81,19 @@ __all__ = [
77
81
  "Tracer",
78
82
  "Value",
79
83
  "VectorObject",
84
+ "find_context",
85
+ "find_context_with_state",
86
+ "find_interpreter",
80
87
  "format_graph",
81
88
  "get_current_context",
82
89
  "get_default_context",
83
- "get_root_context",
90
+ "is_tracing",
84
91
  "jit",
85
92
  "pop_context",
86
93
  "primitive",
87
94
  "push_context",
88
95
  "register_default_context_factory",
96
+ "set_root_context",
89
97
  "trace",
90
98
  "typing",
91
99
  ]
mplang/v2/edsl/context.py CHANGED
@@ -27,6 +27,16 @@ Contexts can be used directly with Python's 'with' statement:
27
27
  with tracer:
28
28
  # Operations run under tracer context
29
29
  result = primitive.bind(x, y)
30
+
31
+ State Management:
32
+ Contexts can carry arbitrary named state via set_state/get_state.
33
+ This allows different layers (device, ml, analytics) to attach their
34
+ own state without the EDSL layer knowing about specific state types.
35
+
36
+ State key conventions:
37
+ - "dialect.{name}": Dialect runtime state (e.g., "dialect.simp")
38
+ - "device.cluster": Device/cluster configuration
39
+ - "ml.{component}": ML pipeline components
30
40
  """
31
41
 
32
42
  from __future__ import annotations
@@ -42,14 +52,65 @@ if TYPE_CHECKING:
42
52
 
43
53
 
44
54
  class Context(ABC):
45
- """Base class for EDSL execution contexts.
55
+ """Base class for EDSL execution contexts with extensible state slots.
46
56
 
47
57
  A Context represents an environment where primitives are executed.
48
58
  There are two types of contexts:
49
59
  - Tracer: Records operations to Graph IR (compile-time)
50
60
  - Interpreter: Execution context (executes operations immediately)
61
+
62
+ State Management:
63
+ Contexts can carry arbitrary named state. Different layers can attach
64
+ their own state without the EDSL layer knowing specifics:
65
+
66
+ >>> ctx.set_state("device.cluster", cluster_spec)
67
+ >>> ctx.set_state("dialect.simp", simp_driver)
68
+ >>> cluster = ctx.get_state("device.cluster")
51
69
  """
52
70
 
71
+ def __init__(self) -> None:
72
+ self._states: dict[str, Any] = {}
73
+
74
+ # =========================================================================
75
+ # State Management
76
+ # =========================================================================
77
+
78
+ def set_state(self, key: str, value: Any) -> None:
79
+ """Attach state to this context.
80
+
81
+ Args:
82
+ key: State key (e.g., "dialect.simp", "device.cluster")
83
+ value: State value
84
+ """
85
+ self._states[key] = value
86
+
87
+ def get_state(self, key: str, default: Any = None) -> Any:
88
+ """Get attached state by key.
89
+
90
+ Args:
91
+ key: State key
92
+ default: Default value if key not found
93
+
94
+ Returns:
95
+ State value or default
96
+ """
97
+ return self._states.get(key, default)
98
+
99
+ def has_state(self, key: str) -> bool:
100
+ """Check if state exists.
101
+
102
+ Args:
103
+ key: State key
104
+
105
+ Returns:
106
+ True if state exists
107
+ """
108
+ return key in self._states
109
+
110
+ # =========================================================================
111
+ # Abstract Methods
112
+ # =========================================================================
113
+
53
114
  @abstractmethod
54
115
  def bind_primitive(
55
116
  self, primitive: Primitive, args: tuple[Any, ...], kwargs: dict[str, Any]
@@ -80,6 +141,10 @@ class Context(ABC):
80
141
  Object in the context's native type (TraceObject or InterpObject)
81
142
  """
82
143
 
144
+ # =========================================================================
145
+ # Context Manager
146
+ # =========================================================================
147
+
83
148
  def __enter__(self) -> Self:
84
149
  """Enter context manager (push context onto stack)."""
85
150
  push_context(self)
@@ -90,9 +155,31 @@ class Context(ABC):
90
155
  pop_context()
91
156
 
92
157
 
93
- # ============================================================================
158
+ # =============================================================================
159
+ # Abstract Interpreter Interface
160
+ # =============================================================================
161
+
162
+
163
+ class AbstractInterpreter(Context):
164
+ """Abstract interface for Interpreters.
165
+
166
+ This allows EDSL components (like JIT) to depend on the Interpreter interface
167
+ without depending on the concrete Runtime implementation (which may depend on
168
+ ObjectStore, Backends, etc.).
169
+ """
170
+
171
+ @abstractmethod
172
+ def evaluate_graph(self, graph: Graph, inputs: list[Any]) -> Any:
173
+ """Execute a Graph IR with given inputs."""
174
+
175
+ @abstractmethod
176
+ def lift(self, obj: Any) -> Any:
177
+ """Lift a python object to an interpreter object."""
178
+
179
+
180
+ # =============================================================================
94
181
  # Global Context Stack Management
95
- # ============================================================================
182
+ # =============================================================================
96
183
 
97
184
  _context_stack: list[Context] = []
98
185
  _default_context: Context | None = None
@@ -100,20 +187,25 @@ _default_context_factory: Callable[[], Context] | None = None
100
187
 
101
188
 
102
189
  def get_current_context() -> Context | None:
103
- """Get the current active context.
190
+ """Get the current active context (top of stack).
104
191
 
105
- Returns None if no context is active (will use default context).
192
+ Returns None if no context is active.
106
193
  """
107
-
108
194
  return _context_stack[-1] if _context_stack else None
109
195
 
110
196
 
111
- def get_root_context() -> Context | None:
112
- """Get the root context (bottom of the stack).
197
+ def push_context(context: Context) -> None:
198
+ """Push a context onto the stack (enter context)."""
199
+ _context_stack.append(context)
200
+
201
+
202
+ def pop_context() -> Context | None:
203
+ """Pop a context from the stack (exit context).
113
204
 
114
- This context typically holds the global environment state (e.g. ClusterSpec).
205
+ Returns:
206
+ The popped context, or None if stack was empty.
115
207
  """
116
- return _context_stack[0] if _context_stack else None
208
+ return _context_stack.pop() if _context_stack else None
117
209
 
118
210
 
119
211
  def find_context(predicate: Callable[[Context], bool]) -> Context | None:
@@ -122,9 +214,6 @@ def find_context(predicate: Callable[[Context], bool]) -> Context | None:
122
214
  Traverses from top (most recent) to bottom of the context stack,
123
215
  returning the first context for which predicate(ctx) returns True.
124
216
 
125
- This is a general-purpose utility for finding contexts with specific
126
- attributes or capabilities without hardcoding business logic here.
127
-
128
217
  Args:
129
218
  predicate: A callable that takes a Context and returns True if it matches.
130
219
 
@@ -133,13 +222,7 @@ def find_context(predicate: Callable[[Context], bool]) -> Context | None:
133
222
 
134
223
  Example:
135
224
  >>> # Find context with simp dialect state
136
- >>> ctx = find_context(
137
- ... lambda c: hasattr(c, "get_dialect_state")
138
- ... and c.get_dialect_state("simp") is not None
139
- ... )
140
- >>>
141
- >>> # Find context with cluster spec
142
- >>> ctx = find_context(lambda c: getattr(c, "_cluster_spec", None) is not None)
225
+ >>> ctx = find_context(lambda c: c.has_state("dialect.simp"))
143
226
  """
144
227
  for ctx in reversed(_context_stack):
145
228
  if predicate(ctx):
@@ -147,15 +230,41 @@ def find_context(predicate: Callable[[Context], bool]) -> Context | None:
147
230
  return None
148
231
 
149
232
 
150
- def push_context(context: Context) -> None:
151
- """Push a context onto the stack (enter context)."""
152
- _context_stack.append(context)
233
+ def find_context_with_state(key: str) -> Context | None:
234
+ """Find first context that has the specified state.
235
+
236
+ Args:
237
+ key: State key to look for
238
+
239
+ Returns:
240
+ First context with the state, or None
241
+ """
242
+ return find_context(lambda c: c.has_state(key))
243
+
244
+
245
+ def find_interpreter() -> Context | None:
246
+ """Find first Interpreter in the context stack.
247
+
248
+ Returns:
249
+ First Interpreter context, or None if not found.
250
+ """
251
+ return find_context(lambda c: isinstance(c, AbstractInterpreter))
153
252
 
154
253
 
155
- def pop_context() -> None:
156
- """Pop a context from the stack (exit context)."""
157
- if _context_stack:
158
- _context_stack.pop()
254
+ def is_tracing() -> bool:
255
+ """Check if current context is a Tracer.
256
+
257
+ Returns:
258
+ True if the top of the context stack is a Tracer.
259
+ """
260
+ from mplang.v2.edsl.tracer import Tracer
261
+
262
+ return isinstance(get_current_context(), Tracer)
263
+
264
+
265
+ # =============================================================================
266
+ # Default Context Management
267
+ # =============================================================================
159
268
 
160
269
 
161
270
  def register_default_context_factory(factory: Callable[[], Context]) -> None:
@@ -177,18 +286,26 @@ def get_default_context() -> Context:
177
286
  return _default_context
178
287
 
179
288
 
180
- class AbstractInterpreter(Context):
181
- """Abstract interface for Interpreters.
289
+ def set_root_context(context: Context, force: bool = False) -> None:
290
+ """Set the root/default execution context.
182
291
 
183
- This allows EDSL components (like JIT) to depend on the Interpreter interface
184
- without depending on the concrete Runtime implementation (which may depend on
185
- ObjectStore, Backends, etc.).
186
- """
187
-
188
- @abstractmethod
189
- def evaluate_graph(self, graph: Graph, inputs: list[Any]) -> Any:
190
- """Execute a Graph IR with given inputs."""
292
+ This sets the provided context as the base of the context stack.
293
+ All subsequent operations will use this context as the default environment.
191
294
 
192
- @abstractmethod
193
- def lift(self, obj: Any) -> Any:
194
- """Lift a python object to an interpreter object."""
295
+ Args:
296
+ context: Context to set as root.
297
+ force: If True, clears the existing context stack before setting.
298
+ If False (default), raises error if stack is not empty.
299
+ """
300
+ if force:
301
+ _context_stack.clear()
302
+ _context_stack.append(context)
303
+ return
304
+
305
+ if get_current_context() is not None:
306
+ raise RuntimeError(
307
+ "Cannot set root context: Context stack is not empty. "
308
+ "Use force=True to overwrite the existing context."
309
+ )
310
+
311
+ push_context(context)
@@ -732,25 +732,23 @@ def fetch(obj: Object) -> Any:
732
732
  """
733
733
  from mplang.v2.backends.simp_driver.state import SimpDriver
734
734
  from mplang.v2.backends.simp_driver.values import DriverVar
735
- from mplang.v2.backends.table_impl import TableValue
736
- from mplang.v2.backends.tensor_impl import TensorValue
737
735
  from mplang.v2.edsl.context import get_current_context
738
736
  from mplang.v2.runtime.interpreter import InterpObject, Interpreter
737
+ from mplang.v2.runtime.value import WrapValue
739
738
 
740
739
  def _unwrap_value(val: Any) -> Any:
741
- """Unwrap Value types to get the underlying data."""
742
- if isinstance(val, TensorValue):
743
- return val.data
744
- elif isinstance(val, TableValue):
740
+ """Unwrap WrapValue to get the underlying data."""
741
+ if isinstance(val, WrapValue):
745
742
  return val.data
746
743
  return val
747
744
 
748
- # Get device info
745
+ # 1. Ensure is object and is device obj
749
746
  if not is_device_obj(obj):
750
747
  raise DeviceError(
751
748
  "Object does not have device attribute. Use mp.fetch() directly."
752
749
  )
753
750
 
751
+ # 2. Get device information according to device id
754
752
  dev_id = get_dev_attr(obj)
755
753
  cluster = _resolve_cluster()
756
754
  dev_info = cluster.devices[dev_id]
@@ -760,9 +758,10 @@ def fetch(obj: Object) -> Any:
760
758
  if not isinstance(ctx, Interpreter):
761
759
  raise RuntimeError("No interpreter context available for fetch")
762
760
 
763
- simp_state = cast(SimpDriver | None, ctx.get_dialect_state("simp"))
761
+ simp_state = ctx.get_dialect_state("simp")
762
+ assert isinstance(simp_state, SimpDriver), "DriverVar requires simp state"
764
763
 
765
- # Unwrap InterpObject
764
+ # Unwrap InterpObject to get runtime value
766
765
  assert isinstance(obj, InterpObject), f"Expected InterpObject, got {type(obj)}"
767
766
  runtime_obj = obj.runtime_obj
768
767
 
@@ -770,19 +769,26 @@ def fetch(obj: Object) -> Any:
770
769
  """Fetch value from a rank (DriverVar values are always URIs)."""
771
770
  uri = runtime_obj.values[rank]
772
771
  assert isinstance(uri, str) and "://" in uri, f"Expected URI, got {uri}"
773
- assert simp_state is not None, "No simp state for fetch"
774
- return _unwrap_value(simp_state.fetch(rank, uri).result())
772
+ return simp_state.fetch(rank, uri).result()
775
773
 
776
- # Handle DriverVar
774
+ # 3. Match device type and do corresponding fetch action
777
775
  if isinstance(runtime_obj, DriverVar):
778
- # For PPU/TEE: single member
776
+ # 3.1 PPU/TEE: single member, fetch directly
779
777
  if dev_info.kind.upper() in ("PPU", "TEE"):
780
778
  assert len(dev_info.members) == 1
781
- return _fetch_from_rank(dev_info.members[0].rank)
779
+ result = _fetch_from_rank(dev_info.members[0].rank)
780
+ # 4. Unwrap if WrapValue
781
+ return _unwrap_value(result)
782
782
 
783
- # For SPU: fetch from first member (should be revealed first)
783
+ # 3.2 SPU: fetch from all ranks and reconstruct
784
784
  elif dev_info.kind.upper() == "SPU":
785
- return _fetch_from_rank(dev_info.members[0].rank)
786
-
787
- # Direct value
785
+ # Fetch shares from all SPU members
786
+ shares = [_fetch_from_rank(m.rank) for m in dev_info.members]
787
+ # For now, just return the first share (TODO: implement spu.reconstruct)
788
+ # In practice, SPU values should be revealed to a PPU first
789
+ result = shares[0] if shares else None
790
+ # 4. Unwrap if WrapValue
791
+ return _unwrap_value(result)
792
+
793
+ # Direct value (not DriverVar)
788
794
  return _unwrap_value(runtime_obj)
@@ -0,0 +1,23 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Machine Learning algorithms for secure multi-party computation."""
16
+
17
+ from mplang.v2.libs.ml.sgb import SecureBoost, Tree, TreeEnsemble
18
+
19
+ __all__ = [
20
+ "SecureBoost",
21
+ "Tree",
22
+ "TreeEnsemble",
23
+ ]