aster-cli 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,430 @@
1
+ """
2
+ aster_cli.shell.invoker -- Dynamic RPC invocation from the shell.
3
+
4
+ Invokes service methods by name, handling:
5
+ - Argument building from key=value pairs or interactive prompting
6
+ - Unary, server-stream, client-stream, and bidi-stream patterns
7
+ - Rich result formatting with timing
8
+ - Error display
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import asyncio
14
+ import json
15
+ import time
16
+ from typing import Any, AsyncIterator
17
+
18
+ from aster_cli.shell.hooks import FieldSchema, MethodSchema, get_hook_registry
19
+ from aster_cli.shell.plugin import CommandContext
20
+ from aster_cli.shell.vfs import NodeKind, resolve_path
21
+
22
+
23
+ def _build_method_schema(
24
+ service_name: str, method_name: str, method_meta: dict[str, Any]
25
+ ) -> MethodSchema:
26
+ """Build a MethodSchema from VFS metadata for hook consumption."""
27
+ fields = method_meta.get("fields", [])
28
+ request_fields = [
29
+ FieldSchema(
30
+ name=f.get("name", ""),
31
+ type_name=f.get("kind", "") or f.get("type", "str"),
32
+ required=f.get("required", True),
33
+ default=f.get("default_value") if f.get("default_kind") == "value" else f.get("default"),
34
+ description=f.get("description", ""),
35
+ )
36
+ for f in fields
37
+ ] if fields else None
38
+
39
+ return MethodSchema(
40
+ service_name=service_name,
41
+ method_name=method_name,
42
+ pattern=method_meta.get("pattern", "unary"),
43
+ request_type=method_meta.get("request_type", ""),
44
+ response_type=method_meta.get("response_type", ""),
45
+ request_fields=request_fields,
46
+ timeout=method_meta.get("timeout"),
47
+ )
48
+
49
+
50
+ async def invoke_method(
51
+ ctx: CommandContext,
52
+ service_name: str,
53
+ method_name: str,
54
+ payload: dict[str, Any],
55
+ ) -> None:
56
+ """Invoke a service method and display the result.
57
+
58
+ Args:
59
+ ctx: Command context with connection and display.
60
+ service_name: Name of the service.
61
+ method_name: Name of the method.
62
+ payload: Arguments dict (from _parse_call_args).
63
+ """
64
+ display = ctx.display
65
+ hooks = get_hook_registry()
66
+
67
+ # Look up method metadata from VFS
68
+ services_node = ctx.vfs_root.child("services")
69
+ method_meta: dict[str, Any] = {}
70
+ if services_node:
71
+ svc_node = services_node.child(service_name)
72
+ if svc_node:
73
+ m_node = svc_node.child(method_name)
74
+ if m_node:
75
+ method_meta = m_node.metadata
76
+
77
+ pattern = method_meta.get("pattern", "unary")
78
+ schema = _build_method_schema(service_name, method_name, method_meta)
79
+
80
+ # ── Input building (hook point) ───────────────────────────────────────
81
+ # If no payload and method has parameters, use input builder hook
82
+ if not payload and ctx.interactive and method_meta.get("fields"):
83
+ builder = hooks.input_builder
84
+ if builder:
85
+ async def _ask(prompt: str) -> str | None:
86
+ loop = asyncio.get_event_loop()
87
+ try:
88
+ return await loop.run_in_executor(None, lambda: input(prompt))
89
+ except (KeyboardInterrupt, EOFError):
90
+ return None
91
+
92
+ payload = await builder.build_payload(schema, payload, _ask)
93
+ else:
94
+ payload = await _prompt_for_args(ctx, method_meta)
95
+
96
+ if payload is None:
97
+ return # user cancelled
98
+
99
+ display.info(f"-> {service_name}.{method_name}({_summarize_args(payload)})")
100
+
101
+ # Fire guide event
102
+ if hasattr(ctx, "guide") and ctx.guide:
103
+ ctx.guide.fire("invoke")
104
+
105
+ try:
106
+ t0 = time.monotonic()
107
+
108
+ # When the user is inside a `session <Service>` subshell, route
109
+ # every call through the persistent SessionProxyClient instead of
110
+ # opening a new bidi stream per invocation. This is the only way
111
+ # to call methods on session-scoped services from the shell --
112
+ # the per-call ``connection.invoke`` path explicitly rejects them.
113
+ active_session = getattr(ctx, "session", None)
114
+ if active_session is not None:
115
+ result = await active_session.call(method_name, payload)
116
+ elapsed = (time.monotonic() - t0) * 1000
117
+ renderer = hooks.output_renderer
118
+ rendered = False
119
+ if renderer:
120
+ rendered = await renderer.render_response(
121
+ schema, _to_serializable(result), display
122
+ )
123
+ if not rendered:
124
+ display.rpc_result(_to_serializable(result), elapsed_ms=elapsed)
125
+ return
126
+
127
+ if pattern == "unary":
128
+ result = await ctx.connection.invoke(service_name, method_name, payload)
129
+ elapsed = (time.monotonic() - t0) * 1000
130
+
131
+ # ── Output rendering (hook point) ─────────────────────────────
132
+ renderer = hooks.output_renderer
133
+ rendered = False
134
+ if renderer:
135
+ rendered = await renderer.render_response(
136
+ schema, _to_serializable(result), display
137
+ )
138
+ if not rendered:
139
+ display.rpc_result(_to_serializable(result), elapsed_ms=elapsed)
140
+
141
+ elif pattern == "server_stream":
142
+ stream = await ctx.connection.server_stream(service_name, method_name, payload)
143
+ await _display_stream(ctx, stream, t0)
144
+
145
+ elif pattern == "client_stream":
146
+ # Client stream: read values from user until empty line
147
+ display.info("Enter values (JSON, one per line). Empty line to send:")
148
+ values = await _read_stream_input(ctx)
149
+ result = await ctx.connection.client_stream(service_name, method_name, values)
150
+ elapsed = (time.monotonic() - t0) * 1000
151
+
152
+ renderer = hooks.output_renderer
153
+ rendered = False
154
+ if renderer:
155
+ rendered = await renderer.render_response(
156
+ schema, _to_serializable(result), display
157
+ )
158
+ if not rendered:
159
+ display.rpc_result(_to_serializable(result), elapsed_ms=elapsed)
160
+
161
+ elif pattern == "bidi_stream":
162
+ await _handle_bidi(ctx, service_name, method_name, payload, t0)
163
+
164
+ else:
165
+ display.error(f"unknown pattern: {pattern}")
166
+
167
+ except KeyboardInterrupt:
168
+ display.info("(cancelled)")
169
+ except Exception as e:
170
+ msg = str(e)
171
+ # Provide actionable hints for common errors
172
+ if "Expected" in msg and "got dict" in msg:
173
+ display.error(
174
+ f"RPC failed: the server expected a typed object but received a dict.\n"
175
+ f" This usually means the shell is sending JSON to a Fory-only service.\n"
176
+ f" Try: aster call <addr> {service_name}.{method_name} '<json>'"
177
+ )
178
+ elif "FAILED_PRECONDITION" in msg or "scope mismatch" in msg.lower():
179
+ display.error(
180
+ f"RPC failed: '{service_name}' is session-scoped.\n"
181
+ f" Try: cd /services && session {service_name}"
182
+ )
183
+ elif "PERMISSION_DENIED" in msg:
184
+ display.error(
185
+ f"RPC failed: permission denied for {service_name}.{method_name}.\n"
186
+ f" Check your credential has the required role."
187
+ )
188
+ elif "DEADLINE_EXCEEDED" in msg:
189
+ display.error(
190
+ f"RPC failed: request timed out ({service_name}.{method_name})."
191
+ )
192
+ elif "UNAVAILABLE" in msg:
193
+ display.error(
194
+ f"RPC failed: service unavailable. The connection may have dropped.\n"
195
+ f" Try: refresh"
196
+ )
197
+ else:
198
+ display.error(f"RPC failed: {e}")
199
+
200
+
201
+ async def _display_stream(
202
+ ctx: CommandContext,
203
+ stream: AsyncIterator[Any],
204
+ t0: float,
205
+ ) -> None:
206
+ """Display a server-streaming response."""
207
+ count = 0
208
+ try:
209
+ async for value in stream:
210
+ ctx.display.streaming_value(count, _to_serializable(value))
211
+ count += 1
212
+ except KeyboardInterrupt:
213
+ pass
214
+
215
+ elapsed = (time.monotonic() - t0) * 1000
216
+ ctx.display.info(f"({count} items, {elapsed:.0f}ms)")
217
+
218
+
219
+ async def _handle_bidi(
220
+ ctx: CommandContext,
221
+ service_name: str,
222
+ method_name: str,
223
+ initial_payload: dict[str, Any],
224
+ t0: float,
225
+ ) -> None:
226
+ """Handle a bidirectional streaming call.
227
+
228
+ Runs two concurrent tasks:
229
+ - Reader: displays incoming values from the server
230
+ - Writer: reads user input and sends to the server
231
+
232
+ Type a JSON value and press Enter to send. Ctrl+D or empty line to stop sending.
233
+ """
234
+ display = ctx.display
235
+
236
+ display.info("Bidi stream open. Type JSON values to send, Ctrl+D or empty line to close input.")
237
+ display.info("─" * 40)
238
+
239
+ send_queue: asyncio.Queue[Any] = asyncio.Queue()
240
+
241
+ # Seed with initial payload if non-empty
242
+ if initial_payload:
243
+ await send_queue.put(initial_payload)
244
+
245
+ # Convert each outgoing dict to a typed Fory dataclass when the
246
+ # transport uses Fory (Python servers). For TS servers + JSON
247
+ # codec, this is a no-op. The bidi case can't pre-convert like
248
+ # the other streaming patterns because values arrive lazily on a
249
+ # separate task; we convert them at the queue boundary instead.
250
+ def _maybe_typed(value: Any) -> Any:
251
+ builder = getattr(ctx.connection, "build_typed_request_for_bidi", None)
252
+ if builder is None:
253
+ return value
254
+ return builder(service_name, method_name, value)
255
+
256
+ async def input_producer() -> AsyncIterator[Any]:
257
+ """Yield values from the send queue until sentinel."""
258
+ while True:
259
+ value = await send_queue.get()
260
+ if value is _SENTINEL:
261
+ return
262
+ yield _maybe_typed(value)
263
+
264
+ async def read_user_input() -> None:
265
+ """Read lines from stdin and push to send queue."""
266
+ loop = asyncio.get_event_loop()
267
+ try:
268
+ while True:
269
+ line = await loop.run_in_executor(None, _read_line_sync)
270
+ if line is None or line.strip() == "":
271
+ break
272
+ try:
273
+ value = json.loads(line)
274
+ except json.JSONDecodeError:
275
+ # Try key=value
276
+ if "=" in line:
277
+ parts = line.strip().split("=", 1)
278
+ value = {parts[0]: parts[1]}
279
+ else:
280
+ display.error("invalid JSON -- enter a JSON value or key=value")
281
+ continue
282
+ await send_queue.put(value)
283
+ finally:
284
+ await send_queue.put(_SENTINEL)
285
+
286
+ async def display_responses(stream: AsyncIterator[Any]) -> None:
287
+ """Display incoming stream values."""
288
+ count = 0
289
+ async for value in stream:
290
+ display.streaming_value(count, _to_serializable(value))
291
+ count += 1
292
+
293
+ try:
294
+ stream = ctx.connection.bidi_stream(
295
+ service_name, method_name, input_producer()
296
+ )
297
+ await asyncio.gather(
298
+ read_user_input(),
299
+ display_responses(stream),
300
+ )
301
+ except KeyboardInterrupt:
302
+ pass
303
+
304
+ elapsed = (time.monotonic() - t0) * 1000
305
+ display.info(f"(bidi stream closed, {elapsed:.0f}ms)")
306
+
307
+
308
+ async def _prompt_for_args(
309
+ ctx: CommandContext,
310
+ method_meta: dict[str, Any],
311
+ ) -> dict[str, Any] | None:
312
+ """Interactively prompt for method arguments.
313
+
314
+ Uses method metadata (fields list) to prompt for each parameter.
315
+ """
316
+ fields = method_meta.get("fields", [])
317
+ if not fields:
318
+ return {}
319
+
320
+ result: dict[str, Any] = {}
321
+ ctx.display.print("[dim]Enter arguments (Ctrl+C to cancel):[/dim]")
322
+
323
+ try:
324
+ loop = asyncio.get_event_loop()
325
+ for f in fields:
326
+ name = f.get("name", f"arg{len(result)}")
327
+ ftype = f.get("type", "str")
328
+ required = f.get("required", True)
329
+ default = f.get("default")
330
+
331
+ prompt = f" ▸ {name}"
332
+ if ftype:
333
+ prompt += f" ({ftype})"
334
+ if default is not None:
335
+ prompt += f" [{default}]"
336
+ prompt += ": "
337
+
338
+ value = await loop.run_in_executor(None, lambda p=prompt: input(p))
339
+ value = value.strip()
340
+
341
+ if not value and default is not None:
342
+ result[name] = default
343
+ continue
344
+ elif not value and not required:
345
+ continue
346
+ elif not value:
347
+ ctx.display.error(f"{name} is required")
348
+ return None
349
+
350
+ # Try to parse as JSON
351
+ try:
352
+ result[name] = json.loads(value)
353
+ except (json.JSONDecodeError, ValueError):
354
+ result[name] = value
355
+
356
+ except (KeyboardInterrupt, EOFError):
357
+ ctx.display.info("(cancelled)")
358
+ return None
359
+
360
+ return result
361
+
362
+
363
+ async def _read_stream_input(ctx: CommandContext) -> list[Any]:
364
+ """Read a sequence of JSON values from stdin for client streaming."""
365
+ values: list[Any] = []
366
+ loop = asyncio.get_event_loop()
367
+
368
+ try:
369
+ while True:
370
+ line = await loop.run_in_executor(None, lambda: input(" > "))
371
+ line = line.strip()
372
+ if not line:
373
+ break
374
+ try:
375
+ values.append(json.loads(line))
376
+ except json.JSONDecodeError:
377
+ ctx.display.error(f"invalid JSON: {line}")
378
+ except (KeyboardInterrupt, EOFError):
379
+ pass
380
+
381
+ return values
382
+
383
+
384
+ def _read_line_sync() -> str | None:
385
+ """Read a line synchronously, returning None on EOF."""
386
+ try:
387
+ return input(" ⇢ ")
388
+ except EOFError:
389
+ return None
390
+
391
+
392
+ class _SentinelType:
393
+ pass
394
+
395
+
396
+ _SENTINEL = _SentinelType()
397
+
398
+
399
+ def _to_serializable(obj: Any) -> Any:
400
+ """Convert an object to a JSON-serializable form."""
401
+ if obj is None or isinstance(obj, (str, int, float, bool)):
402
+ return obj
403
+ if isinstance(obj, bytes):
404
+ try:
405
+ return obj.decode("utf-8")
406
+ except UnicodeDecodeError:
407
+ return obj.hex()
408
+ if isinstance(obj, dict):
409
+ return {str(k): _to_serializable(v) for k, v in obj.items()}
410
+ if isinstance(obj, (list, tuple)):
411
+ return [_to_serializable(v) for v in obj]
412
+ if hasattr(obj, "__dict__"):
413
+ return {k: _to_serializable(v) for k, v in obj.__dict__.items()
414
+ if not k.startswith("_")}
415
+ return str(obj)
416
+
417
+
418
+ def _summarize_args(payload: dict[str, Any]) -> str:
419
+ """Short summary of call arguments for display."""
420
+ if not payload:
421
+ return ""
422
+ items = []
423
+ for k, v in list(payload.items())[:3]:
424
+ if isinstance(v, str) and len(v) > 20:
425
+ v = v[:17] + "…"
426
+ items.append(f"{k}={v!r}")
427
+ s = ", ".join(items)
428
+ if len(payload) > 3:
429
+ s += f", … +{len(payload) - 3} more"
430
+ return s
@@ -0,0 +1,185 @@
1
+ """
2
+ aster_cli.shell.plugin -- Plugin system for shell commands.
3
+
4
+ Every interactive shell command is a plugin that can also be invoked
5
+ as a CLI subcommand (e.g., ``ls`` at /blobs ↔ ``aster blob ls``).
6
+
7
+ Plugins self-register via the @register decorator, declaring:
8
+ - name: the command name (e.g., "ls", "cat", "describe")
9
+ - context: glob pattern for where the command is valid (e.g., "/blobs", "/services/*")
10
+ - cli_noun_verb: optional (noun, verb) tuple for CLI mapping (e.g., ("blob", "ls"))
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import argparse
16
+ import fnmatch
17
+ from abc import ABC, abstractmethod
18
+ from dataclasses import dataclass, field
19
+ from typing import TYPE_CHECKING, Any
20
+
21
+ if TYPE_CHECKING:
22
+ from aster_cli.shell.vfs import VfsNode
23
+
24
+
25
+ @dataclass
26
+ class Argument:
27
+ """Describes a command argument for both autocomplete and CLI argparse."""
28
+
29
+ name: str
30
+ description: str = ""
31
+ required: bool = False
32
+ positional: bool = False
33
+ choices: list[str] | None = None
34
+ default: Any = None
35
+ type: type = str
36
+
37
+
38
+ class ShellCommand(ABC):
39
+ """Base class for all shell commands (plugins).
40
+
41
+ Subclass this and use @register to make it available in the shell
42
+ and optionally as a CLI subcommand.
43
+ """
44
+
45
+ # Command metadata -- set by subclasses
46
+ name: str = ""
47
+ description: str = ""
48
+ contexts: list[str] = [] # glob patterns for valid VFS paths
49
+ cli_noun_verb: tuple[str, str] | None = None # ("blob", "ls") → aster blob ls
50
+ hidden: bool = False # hide from help/autocomplete
51
+
52
+ @abstractmethod
53
+ async def execute(
54
+ self,
55
+ args: list[str],
56
+ ctx: CommandContext,
57
+ ) -> None:
58
+ """Execute the command.
59
+
60
+ Args:
61
+ args: Parsed argument tokens after the command name.
62
+ ctx: Execution context with VFS, display, connection.
63
+ """
64
+
65
+ def get_arguments(self) -> list[Argument]:
66
+ """Return argument definitions for autocomplete and CLI registration."""
67
+ return []
68
+
69
+ def get_completions(self, ctx: CommandContext, partial: str) -> list[str]:
70
+ """Return completion suggestions for the current partial input.
71
+
72
+ Override for dynamic completions (e.g., blob hashes, method params).
73
+ """
74
+ return []
75
+
76
+ def is_valid_at(self, path: str) -> bool:
77
+ """Check if this command is valid at the given VFS path."""
78
+ if not self.contexts:
79
+ return True # global command
80
+ return any(fnmatch.fnmatch(path, pat) for pat in self.contexts)
81
+
82
+
83
+ @dataclass
84
+ class CommandContext:
85
+ """Runtime context passed to every command execution."""
86
+
87
+ vfs_cwd: str # current VFS path
88
+ vfs_root: VfsNode # root of the VFS tree
89
+ connection: Any # AsterClient or peer connection
90
+ display: Any # Display instance for rich output
91
+ peer_name: str = "" # display name of connected peer
92
+ interactive: bool = True # False when called from CLI
93
+ raw_output: bool = False # True for pipe-friendly JSON output
94
+ guide: Any = None # GuideManager instance (optional)
95
+ # Active session (set inside `session <ServiceName>` subshells). When
96
+ # set, method invocations are routed through the persistent session
97
+ # bidi stream instead of opening a new stream per call -- the only
98
+ # way to call methods on session-scoped services from the shell.
99
+ session: Any = None
100
+
101
+
102
+ # ── Plugin registry ───────────────────────────────────────────────────────────
103
+
104
+ _registry: dict[str, ShellCommand] = {}
105
+
106
+
107
+ def register(cmd_class: type[ShellCommand]) -> type[ShellCommand]:
108
+ """Class decorator to register a shell command plugin.
109
+
110
+ Usage::
111
+
112
+ @register
113
+ class LsCommand(ShellCommand):
114
+ name = "ls"
115
+ description = "List contents"
116
+ contexts = ["/", "/blobs", "/services", "/services/*"]
117
+ """
118
+ instance = cmd_class()
119
+ _registry[instance.name] = instance
120
+ return cmd_class
121
+
122
+
123
+ def get_command(name: str) -> ShellCommand | None:
124
+ """Look up a registered command by name."""
125
+ return _registry.get(name)
126
+
127
+
128
+ def get_commands_for_path(path: str) -> list[ShellCommand]:
129
+ """Get all commands valid at the given VFS path."""
130
+ return [cmd for cmd in _registry.values() if cmd.is_valid_at(path) and not cmd.hidden]
131
+
132
+
133
+ def get_all_commands() -> dict[str, ShellCommand]:
134
+ """Get all registered commands."""
135
+ return dict(_registry)
136
+
137
+
138
+ def register_cli_subcommands(subparsers: argparse._SubParsersAction) -> None:
139
+ """Register all plugin CLI noun-verb subcommands with argparse.
140
+
141
+ This creates subcommands like ``aster blob ls``, ``aster service describe``, etc.
142
+ """
143
+ # Group commands by CLI noun
144
+ noun_groups: dict[str, list[ShellCommand]] = {}
145
+ for cmd in _registry.values():
146
+ if cmd.cli_noun_verb:
147
+ noun, _verb = cmd.cli_noun_verb
148
+ noun_groups.setdefault(noun, []).append(cmd)
149
+
150
+ for noun, commands in sorted(noun_groups.items()):
151
+ noun_parser = subparsers.add_parser(noun, help=f"{noun.title()} commands")
152
+ noun_subs = noun_parser.add_subparsers(dest=f"{noun}_command")
153
+
154
+ for cmd in commands:
155
+ _noun, verb = cmd.cli_noun_verb # type: ignore[misc]
156
+ verb_parser = noun_subs.add_parser(verb, help=cmd.description)
157
+
158
+ # Add common args
159
+ verb_parser.add_argument(
160
+ "peer", help="Peer address to connect to"
161
+ )
162
+ verb_parser.add_argument(
163
+ "--rcan", default=None, help="Path to RCAN credential"
164
+ )
165
+ verb_parser.add_argument(
166
+ "--json", action="store_true", dest="raw_json",
167
+ help="Output raw JSON (for scripting)",
168
+ )
169
+
170
+ # Add command-specific args
171
+ for arg in cmd.get_arguments():
172
+ if arg.positional:
173
+ verb_parser.add_argument(
174
+ arg.name,
175
+ nargs="?" if not arg.required else None,
176
+ default=arg.default,
177
+ help=arg.description,
178
+ )
179
+ else:
180
+ verb_parser.add_argument(
181
+ f"--{arg.name}",
182
+ required=arg.required,
183
+ default=arg.default,
184
+ help=arg.description,
185
+ )