remoteRF-server-testing 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. remoteRF_server/__init__.py +0 -0
  2. remoteRF_server/common/__init__.py +0 -0
  3. remoteRF_server/common/grpc/__init__.py +1 -0
  4. remoteRF_server/common/grpc/grpc_host_pb2.py +63 -0
  5. remoteRF_server/common/grpc/grpc_host_pb2_grpc.py +97 -0
  6. remoteRF_server/common/grpc/grpc_pb2.py +59 -0
  7. remoteRF_server/common/grpc/grpc_pb2_grpc.py +97 -0
  8. remoteRF_server/common/idl/__init__.py +1 -0
  9. remoteRF_server/common/idl/device_schema.py +39 -0
  10. remoteRF_server/common/idl/pluto_schema.py +174 -0
  11. remoteRF_server/common/idl/schema.py +358 -0
  12. remoteRF_server/common/utils/__init__.py +6 -0
  13. remoteRF_server/common/utils/ansi_codes.py +120 -0
  14. remoteRF_server/common/utils/api_token.py +21 -0
  15. remoteRF_server/common/utils/db_connection.py +35 -0
  16. remoteRF_server/common/utils/db_location.py +24 -0
  17. remoteRF_server/common/utils/list_string.py +5 -0
  18. remoteRF_server/common/utils/process_arg.py +80 -0
  19. remoteRF_server/drivers/__init__.py +0 -0
  20. remoteRF_server/drivers/adalm_pluto/__init__.py +0 -0
  21. remoteRF_server/drivers/adalm_pluto/pluto_remote_server.py +105 -0
  22. remoteRF_server/host/__init__.py +0 -0
  23. remoteRF_server/host/host_auth_token.py +292 -0
  24. remoteRF_server/host/host_directory_store.py +142 -0
  25. remoteRF_server/host/host_tunnel_server.py +1388 -0
  26. remoteRF_server/server/__init__.py +0 -0
  27. remoteRF_server/server/acc_perms.py +317 -0
  28. remoteRF_server/server/cert_provider.py +184 -0
  29. remoteRF_server/server/device_manager.py +688 -0
  30. remoteRF_server/server/grpc_server.py +1023 -0
  31. remoteRF_server/server/reservation.py +811 -0
  32. remoteRF_server/server/rpc_manager.py +104 -0
  33. remoteRF_server/server/user_group_cli.py +723 -0
  34. remoteRF_server/server/user_group_handler.py +1120 -0
  35. remoteRF_server/serverrf_cli.py +1377 -0
  36. remoteRF_server/tools/__init__.py +191 -0
  37. remoteRF_server/tools/gen_certs.py +274 -0
  38. remoteRF_server/tools/gist_status.py +139 -0
  39. remoteRF_server/tools/gist_status_testing.py +67 -0
  40. remoterf_server_testing-0.0.0.dist-info/METADATA +612 -0
  41. remoterf_server_testing-0.0.0.dist-info/RECORD +44 -0
  42. remoterf_server_testing-0.0.0.dist-info/WHEEL +5 -0
  43. remoterf_server_testing-0.0.0.dist-info/entry_points.txt +2 -0
  44. remoterf_server_testing-0.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,358 @@
1
+ # src/remoteRF_server/common/idl/schema.py
2
+ """
3
+ IDL Schema framework.
4
+
5
+ Admin writes one .py per device TYPE, drops it in ~/.config/remoterf/drivers/,
6
+ and the server auto-discovers it on startup via load_drivers().
7
+
8
+ ~/.config/remoterf/
9
+ devices.env ← "I have 3 plutos at these serials"
10
+ drivers/
11
+ pluto_schema.py ← "here's what a pluto IS"
12
+ hackrf_schema.py
13
+ webcam_schema.py
14
+
15
+ The admin's schema file:
16
+ 1. Imports from this module
17
+ 2. Subclasses DeviceSchema
18
+ 3. Decorates with @idl_register("pluto")
19
+ 4. Implements make_device(**kwargs) — factory from manifest config
20
+ 5. Decorates exposed methods with @idl_expose(kind="get"|"set"|"call")
21
+
22
+ That's it. The server never needs to be modified.
23
+ """
24
+ from __future__ import annotations
25
+
26
+ import hashlib
27
+ import importlib.util
28
+ import json
29
+ import sys
30
+ from pathlib import Path
31
+ from typing import Any, Dict, Optional
32
+
33
+ from ..utils import *
34
+
35
+ # ───────────────────────────────────────────────────────────────────
36
+ # @idl_expose decorator
37
+ # ───────────────────────────────────────────────────────────────────
38
+
39
+ _EXPOSE_KEY = "__idl__"
40
+
41
+ def idl_expose(_fn=None, *, kind: str = "call", doc: str = ""):
42
+ """
43
+ Mark a method as exposed over IDL.
44
+
45
+ @idl_expose(kind="get")
46
+ def get_tx_lo(self):
47
+ return self.device.tx_lo
48
+
49
+ @idl_expose(kind="set")
50
+ def set_tx_lo(self, value):
51
+ self.device.tx_lo = value
52
+
53
+ @idl_expose(kind="call")
54
+ def call_rx(self):
55
+ return self.device.rx()
56
+
57
+ @idl_expose # bare — defaults to kind="call"
58
+ def some_method(self, x):
59
+ ...
60
+
61
+ kind:
62
+ "get" — property getter (no args beyond self)
63
+ "set" — property setter (one arg: value)
64
+ "call" — general method (any args)
65
+ """
66
+ def _apply(fn):
67
+ meta = {
68
+ "kind": kind,
69
+ "doc": doc or (fn.__doc__ or "").strip(),
70
+ }
71
+ setattr(fn, _EXPOSE_KEY, meta)
72
+ return fn
73
+
74
+ if _fn is not None:
75
+ return _apply(_fn)
76
+ return _apply
77
+
78
+
79
+ # ───────────────────────────────────────────────────────────────────
80
+ # Driver registry
81
+ # ───────────────────────────────────────────────────────────────────
82
+
83
+ _DRIVER_REGISTRY: Dict[str, type] = {}
84
+
85
+ def register_driver(device_type: str):
86
+ """
87
+ Class decorator — registers a schema so the server can find it by type string.
88
+
89
+ @register_driver("pluto")
90
+ class PlutoSchema(DeviceSchema):
91
+ ...
92
+ """
93
+ def decorator(cls):
94
+ key = device_type.strip().lower()
95
+ _DRIVER_REGISTRY[key] = cls
96
+ return cls
97
+ return decorator
98
+
99
+ idl_register = register_driver
100
+
101
+ def get_driver_class(device_type: str) -> type:
102
+ cls = _DRIVER_REGISTRY.get(device_type.strip().lower())
103
+ if cls is None:
104
+ raise KeyError(
105
+ f"no driver for type '{device_type}'. "
106
+ f"registered: {list(_DRIVER_REGISTRY.keys()) or ['(none)']}"
107
+ )
108
+ return cls
109
+
110
+ def list_registered_drivers() -> Dict[str, str]:
111
+ return {
112
+ k: f"{cls.device_type} v{cls.driver_version}"
113
+ for k, cls in _DRIVER_REGISTRY.items()
114
+ }
115
+
116
+ # ───────────────────────────────────────────────────────────────────
117
+ # Plugin loader — auto-discover driver schemas from a folder
118
+ # ───────────────────────────────────────────────────────────────────
119
+
120
+ def load_drivers() -> Dict[str, str]:
121
+ """
122
+ Import every .py in the drivers config dir. Each file's @idl_register
123
+ decorator self-registers into the global registry.
124
+
125
+ Returns {filename: status} for logging.
126
+
127
+ Called once at server startup:
128
+ load_drivers()
129
+ """
130
+ drivers_dir = get_drivers_dir()
131
+ results: Dict[str, str] = {}
132
+
133
+ if not drivers_dir.is_dir():
134
+ return {"__error__": f"{drivers_dir} is not a directory"}
135
+
136
+ for f in sorted(drivers_dir.glob("*.py")):
137
+ if f.name.startswith("_"):
138
+ continue
139
+ try:
140
+ spec = importlib.util.spec_from_file_location(f.stem, str(f))
141
+ if spec is None or spec.loader is None:
142
+ results[f.name] = "skip: not a valid module"
143
+ continue
144
+ mod = importlib.util.module_from_spec(spec)
145
+ # Make our module importable from inside the driver file
146
+ # so `from schema import ...` works even when loaded as a plugin
147
+ sys.modules[f.stem] = mod
148
+ spec.loader.exec_module(mod)
149
+ results[f.name] = "ok"
150
+ except Exception as e:
151
+ results[f.name] = f"error: {e}"
152
+
153
+ return results
154
+
155
+
156
+ # ───────────────────────────────────────────────────────────────────
157
+ # DeviceSchema base class
158
+ # ───────────────────────────────────────────────────────────────────
159
+
160
+ class DeviceSchema:
161
+ """
162
+ Base for all device schemas. Admin subclasses this once per device type.
163
+
164
+ Class attributes (admin sets these):
165
+ device_type str — must match the TYPE in devices.env
166
+ driver_version str — semver
167
+
168
+ Admin implements:
169
+ make_device(**kwargs) — factory: manifest config → live device handle
170
+ returns the device object, or None on failure
171
+
172
+ Admin decorates exposed methods with @idl_expose(kind=...).
173
+
174
+ The server calls:
175
+ schema = PlutoSchema()
176
+ dev = schema.make_device(serial="00000001")
177
+ schema.device = dev # or schema.bind(dev)
178
+ schema.dispatch("get_tx_lo", {}) # → calls get_tx_lo()
179
+ """
180
+
181
+ device_type: str = ""
182
+ driver_version: str = ""
183
+
184
+ def __init__(self):
185
+ self.device = None # set after make_device()
186
+ self._exposed: Dict[str, _Exposed] = {}
187
+ self._schema_cache: dict | None = None
188
+ self._hash_cache: str | None = None
189
+ self._introspect()
190
+
191
+ # ── Factory (admin overrides) ─────────────────────────────────
192
+
193
+ @staticmethod
194
+ def make_device(**kwargs):
195
+ """
196
+ Create and return a live device handle using kwargs from the manifest.
197
+ Admin MUST override this.
198
+ """
199
+ raise NotImplementedError("make_device() must be implemented in the schema")
200
+
201
+ # ── Bind device to schema instance ────────────────────────────
202
+
203
+ def bind(self, device) -> None:
204
+ """Attach a live device handle so exposed methods can use self.device."""
205
+ self.device = device
206
+
207
+ def is_connected(self) -> bool:
208
+ return self.device is not None
209
+
210
+ # ── Introspection ─────────────────────────────────────────────
211
+
212
+ def _introspect(self):
213
+ seen: set = set()
214
+ for klass in type(self).__mro__:
215
+ for name, obj in vars(klass).items():
216
+ if name in seen or name.startswith("_"):
217
+ continue
218
+ meta = getattr(obj, _EXPOSE_KEY, None)
219
+ if meta is None:
220
+ continue
221
+ seen.add(name)
222
+ self._exposed[name] = _Exposed(
223
+ name=name, fn=obj, kind=meta["kind"], doc=meta["doc"],
224
+ )
225
+
226
+ # ── IDL generation ────────────────────────────────────────────
227
+
228
+ def get_idl(self) -> dict:
229
+ """Full wire-ready schema dict with content-addressed hash."""
230
+ if self._schema_cache is not None:
231
+ return self._schema_cache
232
+
233
+ body = self._build_body()
234
+ canon = json.dumps(body, sort_keys=True, separators=(",", ":"))
235
+ h = f"sha256:{hashlib.sha256(canon.encode()).hexdigest()}"
236
+
237
+ self._hash_cache = h
238
+ self._schema_cache = {**body, "schema_hash": h}
239
+ return self._schema_cache
240
+
241
+ def get_idl_hash(self) -> str:
242
+ if self._hash_cache is None:
243
+ self.get_idl()
244
+ return self._hash_cache # type: ignore
245
+
246
+ def get_idl_json(self, *, pretty: bool = False) -> str:
247
+ d = self.get_idl()
248
+ if pretty:
249
+ return json.dumps(d, indent=2, sort_keys=True)
250
+ return json.dumps(d, sort_keys=True, separators=(",", ":"))
251
+
252
+ def _build_body(self) -> dict:
253
+ getters = {}
254
+ setters = {}
255
+ calls = {}
256
+
257
+ for name, ex in sorted(self._exposed.items()):
258
+ entry = {"doc": ex.doc} if ex.doc else {}
259
+ if ex.kind == "get":
260
+ getters[name] = entry
261
+ elif ex.kind == "set":
262
+ setters[name] = entry
263
+ elif ex.kind == "call":
264
+ calls[name] = entry
265
+
266
+ return {
267
+ "schema_version": "1.0",
268
+ "device_type": self.device_type,
269
+ "driver_version": self.driver_version,
270
+ "getters": getters,
271
+ "setters": setters,
272
+ "calls": calls,
273
+ }
274
+
275
+ # ── Dispatch (server RPC layer calls this) ────────────────────
276
+
277
+ def dispatch(self, method_name: str, args: dict) -> Any:
278
+ """
279
+ Called by the RPC layer. Finds the exposed method and calls it.
280
+
281
+ dispatch("get_tx_lo", {}) → self.get_tx_lo()
282
+ dispatch("set_tx_lo", {"value": 2.4e9}) → self.set_tx_lo(2.4e9)
283
+ dispatch("call_rx", {}) → self.call_rx()
284
+ dispatch("call_tx", {"value": samples}) → self.call_tx(samples)
285
+ """
286
+ ex = self._exposed.get(method_name)
287
+ if ex is None:
288
+ raise KeyError(
289
+ f"'{method_name}' is not exposed on {self.device_type}. "
290
+ f"Available: {list(self._exposed.keys())}"
291
+ )
292
+
293
+ if self.device is None:
294
+ raise RuntimeError(f"no device bound to {self.device_type} schema")
295
+
296
+ if ex.kind == "get":
297
+ return ex.fn(self)
298
+
299
+ elif ex.kind == "set":
300
+ value = args.get("value")
301
+ if value is None and "value" not in args:
302
+ # Also check first positional-style arg key
303
+ for k, v in args.items():
304
+ value = v
305
+ break
306
+ return ex.fn(self, value)
307
+
308
+ elif ex.kind == "call":
309
+ # Calls can have 0 or more args
310
+ # Simple convention: if the function takes 'value', pass it
311
+ import inspect
312
+ sig = inspect.signature(ex.fn)
313
+ params = [p for p in sig.parameters if p != "self"]
314
+
315
+ if len(params) == 0:
316
+ return ex.fn(self)
317
+ elif len(params) == 1:
318
+ # Single arg — pull from args dict by param name or "value"
319
+ pname = params[0]
320
+ val = args.get(pname, args.get("value"))
321
+ return ex.fn(self, val)
322
+ else:
323
+ # Multi-arg — match by name
324
+ call_args = {p: args.get(p) for p in params if p in args}
325
+ return ex.fn(self, **call_args)
326
+ else:
327
+ raise ValueError(f"unknown kind '{ex.kind}' on {method_name}")
328
+
329
+ # ── Debug ─────────────────────────────────────────────────────
330
+
331
+ def list_exposed(self) -> dict:
332
+ by_kind: Dict[str, list] = {"get": [], "set": [], "call": []}
333
+ for name, ex in sorted(self._exposed.items()):
334
+ by_kind.get(ex.kind, []).append(name)
335
+ return {
336
+ "device_type": self.device_type,
337
+ "driver_version": self.driver_version,
338
+ "connected": self.is_connected(),
339
+ "getters": by_kind["get"],
340
+ "setters": by_kind["set"],
341
+ "calls": by_kind["call"],
342
+ }
343
+
344
+ def __repr__(self) -> str:
345
+ return f"<{type(self).__name__} type={self.device_type!r} connected={self.is_connected()}>"
346
+
347
+
348
+ # ───────────────────────────────────────────────────────────────────
349
+ # Internal
350
+ # ───────────────────────────────────────────────────────────────────
351
+
352
+ class _Exposed:
353
+ __slots__ = ("name", "fn", "kind", "doc")
354
+ def __init__(self, *, name, fn, kind, doc):
355
+ self.name = name
356
+ self.fn = fn
357
+ self.kind = kind
358
+ self.doc = doc
@@ -0,0 +1,6 @@
1
+ from .api_token import validate_token, generate_token, hash_token
2
+ from .process_arg import unmap_arg, map_arg
3
+ from .ansi_codes import printf, stylize, Sty
4
+ from .list_string import list_to_str, str_to_list
5
+ from .db_connection import db_connection
6
+ from .db_location import get_db_dir, get_certs_dir, get_drivers_dir, get_remoterf_root
@@ -0,0 +1,120 @@
1
+ from prompt_toolkit.styles import Style
2
+ from prompt_toolkit.formatted_text import FormattedText
3
+ from prompt_toolkit import print_formatted_text
4
+ from enum import Enum
5
+
6
+ class Sty(Enum):
7
+ # Basic colors
8
+ RED = 'red'
9
+ GREEN = 'green'
10
+ BLUE = 'blue'
11
+ YELLOW = 'yellow'
12
+ MAGENTA = 'magenta'
13
+ CYAN = 'cyan'
14
+ GRAY = 'gray'
15
+
16
+ # Background colors
17
+ BG_RED = 'bg-red'
18
+ BG_GREEN = 'bg-green'
19
+ BG_BLUE = 'bg-blue'
20
+
21
+ # Bright versions
22
+ BRIGHT_RED = 'bright-red'
23
+ BRIGHT_GREEN = 'bright-green'
24
+ BRIGHT_BLUE = 'bright-blue'
25
+
26
+ # Formatting
27
+ BOLD = 'bold'
28
+ ITALIC = 'italic'
29
+ UNDERLINE = 'underline'
30
+ BLINK = 'blink'
31
+ REVERSE = 'reverse'
32
+
33
+ # Combinations
34
+ ERROR = 'error'
35
+ WARNING = 'warning'
36
+ INFO = 'info'
37
+
38
+ # Special
39
+ SELECTED = 'selected'
40
+ DEFAULT = 'default'
41
+
42
+ # Define the styles based on ANSI codes
43
+ style = Style.from_dict({
44
+ # Basic colors
45
+ 'red': 'fg:#110000',
46
+ 'green': 'fg:#003300',
47
+ 'blue': 'fg:#0000ff',
48
+ 'yellow': 'fg:#ffff00',
49
+ 'magenta': 'fg:#ff00ff',
50
+ 'cyan': 'fg:#00ffff',
51
+ 'gray': 'fg:#808080',
52
+
53
+ # Bright versions
54
+ 'bright-red': 'fg:#ff5555',
55
+ 'bright-green': 'fg:#00ff00',
56
+ 'bright-blue': 'fg:#5555ff',
57
+
58
+ # Formatting
59
+ 'bold': 'bold',
60
+ 'italic': 'italic',
61
+ 'underline': 'underline',
62
+ 'reverse': 'reverse',
63
+
64
+ # Combinations
65
+ 'error': 'bg:#ff0000 fg:#ffffff bold',
66
+ 'warning': 'bg:#ffff00 fg:#000000 bold',
67
+ 'info': 'bg:#0000ff fg:#ffffff italic underline',
68
+
69
+ # Special
70
+ 'selected': 'bg:#ffffff #000000 reverse',
71
+ 'default':''
72
+ })
73
+
74
+ def printf(*args) -> str:
75
+ if len(args) % 2 != 0:
76
+ raise ValueError('Arguments must be in pairs of two.')
77
+
78
+ # Create formatted text using the defined style
79
+ formatted_text = []
80
+
81
+ for i in range(0, len(args), 2):
82
+ message = args[i]
83
+ styles = args[i+1]
84
+
85
+ if not isinstance(styles, tuple):
86
+ styles = (styles,)
87
+
88
+ resolved_styles = (s.value if isinstance(s, Enum) else s for s in styles)
89
+
90
+ style_class = ' '.join(resolved_styles)
91
+
92
+ formatted_text.append(('class:' + style_class, message))
93
+
94
+ # Create FormattedText object from pairs
95
+ text = FormattedText(formatted_text)
96
+
97
+ print_formatted_text(text, style=style)
98
+ return text
99
+
100
+ def stylize(*args):
101
+ """
102
+ Create a styled prompt text based on pairs of (text, (Sty, ...), ...).
103
+ """
104
+ if len(args) % 2 != 0:
105
+ raise ValueError("Arguments must be in pairs of (text, style_class).")
106
+
107
+ styled_parts = []
108
+ for i in range(0, len(args), 2):
109
+ text = args[i]
110
+ styles = args[i + 1]
111
+
112
+ if not isinstance(styles, tuple):
113
+ styles = (styles,)
114
+
115
+ resolved_styles = (s.value if isinstance(s, Enum) else s for s in styles)
116
+
117
+ style_class = ' '.join(resolved_styles)
118
+ styled_parts.append(('class:' + style_class, text))
119
+
120
+ return FormattedText(styled_parts)
@@ -0,0 +1,21 @@
1
+ import os
2
+ import hashlib
3
+ import base64
4
+ import secrets
5
+ from dotenv import load_dotenv, find_dotenv
6
+
7
+ def generate_token(length=8) -> tuple[str, str, str]:
8
+ random_bytes = secrets.token_bytes(length)
9
+ token = base64.urlsafe_b64encode(random_bytes).decode('utf-8').rstrip('=') # Encode in base64 format
10
+ salt = os.urandom(16).hex() # 16 bytes of random salt
11
+ hashed = hashlib.sha256(bytes.fromhex(salt) + token.encode()).hexdigest() # Hash to sha256 standard
12
+ return salt, hashed, token
13
+
14
+ def validate_token(salt, hash, token) -> bool:
15
+ new_hashed = hashlib.sha256(bytes.fromhex(salt) + token.encode()).hexdigest()
16
+ return new_hashed == hash
17
+
18
+ def hash_token(token: str) -> tuple[str, str]:
19
+ salt = os.urandom(16).hex()
20
+ hashed = hashlib.sha256(bytes.fromhex(salt) + token.encode()).hexdigest()
21
+ return salt, hashed
@@ -0,0 +1,35 @@
1
+ import sqlite3
2
+ from functools import wraps
3
+
4
+ def db_connection(func):
5
+ @wraps(func)
6
+ def wrapper(self, *args, **kwargs):
7
+ with self.lock:
8
+ old_db = getattr(self, "db", None)
9
+ old_cursor = getattr(self, "cursor", None)
10
+
11
+ conn = sqlite3.connect(str(self.filepath), timeout=5.0, check_same_thread=False)
12
+ try:
13
+ conn.execute("PRAGMA busy_timeout=5000;")
14
+ conn.execute("PRAGMA journal_mode=WAL;")
15
+
16
+ cur = conn.cursor()
17
+ self.db = conn
18
+ self.cursor = cur
19
+
20
+ try:
21
+ result = func(self, *args, **kwargs)
22
+ conn.commit()
23
+ return result
24
+ except Exception:
25
+ conn.rollback()
26
+ raise
27
+ finally:
28
+ try: cur.close()
29
+ except Exception: pass
30
+ finally:
31
+ try: conn.close()
32
+ except Exception: pass
33
+ self.db = old_db
34
+ self.cursor = old_cursor
35
+ return wrapper
@@ -0,0 +1,24 @@
1
+ import os
2
+ from pathlib import Path
3
+
4
+ APP_NAME = "remoterf"
5
+
6
+ def get_remoterf_root() -> Path:
7
+ # Linux-only: XDG_CONFIG_HOME or ~/.config
8
+ root = Path(os.environ.get("XDG_CONFIG_HOME", Path.home() / ".config")) / APP_NAME
9
+ root.mkdir(parents=True, exist_ok=True)
10
+
11
+ # ensure sibling folders exist
12
+ (root / "certs").mkdir(parents=True, exist_ok=True)
13
+ (root / "db").mkdir(parents=True, exist_ok=True)
14
+
15
+ return root
16
+
17
+ def get_certs_dir() -> Path:
18
+ return get_remoterf_root() / "certs"
19
+
20
+ def get_db_dir() -> Path:
21
+ return get_remoterf_root() / "db"
22
+
23
+ def get_drivers_dir() -> Path:
24
+ return get_remoterf_root() / "drivers"
@@ -0,0 +1,5 @@
1
+ def list_to_str(list:list) -> str:
2
+ return ','.join(str(x) for x in list)
3
+
4
+ def str_to_list(s:str) -> list[int]:
5
+ return [int(x) for x in s.split(',')]
@@ -0,0 +1,80 @@
1
+ from ..grpc import grpc_pb2, grpc_pb2_grpc
2
+ import numpy as np
3
+
4
+ def unmap_arg(arg):
5
+ if arg.HasField('int64_value'):
6
+ return arg.int64_value
7
+ elif arg.HasField('float_value'):
8
+ return arg.float_value
9
+ elif arg.HasField('string_value'):
10
+ return arg.string_value
11
+ elif arg.HasField('bool_value'):
12
+ return arg.bool_value
13
+ elif arg.HasField('real_array'):
14
+ shape = tuple(arg.real_array.shape.dim)
15
+ return np.array(arg.real_array.data, dtype=np.float64).reshape(shape)
16
+ elif arg.HasField('complex_array'):
17
+ shape = tuple(arg.complex_array.shape.dim)
18
+ data = [complex(c.real, c.imag) for c in arg.complex_array.data]
19
+ return np.array(data, dtype=np.complex64).reshape(shape)
20
+ else:
21
+ raise ValueError(f"Unknown argument type during unmapping: {arg}")
22
+
23
+ def map_arg(value):
24
+ arg = grpc_pb2.Argument()
25
+
26
+ if isinstance(value, int):
27
+ arg.int64_value = value
28
+ elif isinstance(value, float):
29
+ arg.float_value = value
30
+ elif isinstance(value, str):
31
+ arg.string_value = value
32
+ elif isinstance(value, bool):
33
+ arg.bool_value = value
34
+ elif isinstance(value, np.ndarray):
35
+ if np.iscomplexobj(value):
36
+ complex_array = arg.complex_array
37
+ complex_array.shape.dim.extend(value.shape)
38
+ for num in value.ravel():
39
+ complex_num = complex_array.data.add()
40
+ complex_num.real = num.real
41
+ complex_num.imag = num.imag
42
+ else:
43
+ float_array = arg.real_array
44
+ float_array.shape.dim.extend(value.shape)
45
+ float_array.data.extend(value.ravel())
46
+ else:
47
+ raise ValueError(f"Unknown argument type during mapping: {value}")
48
+ return arg
49
+
50
+ def map_array_proto(np_array):
51
+ arg = grpc_pb2.Argument()
52
+
53
+ # Check if the array is complex
54
+ if np.iscomplexobj(np_array):
55
+ complex_array = grpc_pb2.ComplexArray()
56
+ for num in np_array.flat:
57
+ complex_number = complex_array.data.add()
58
+ complex_number.real = num.real
59
+ complex_number.imag = num.imag
60
+ arg.complex_array.CopyFrom(complex_array)
61
+ else:
62
+ # Handle as a regular float array
63
+ float_array = grpc_pb2.FloatArray()
64
+ float_array.data.extend(np_array.flat)
65
+ arg.float_array.CopyFrom(float_array)
66
+
67
+ return arg
68
+
69
+ def unmap_array_proto(arg):
70
+ # Check which type of array is available and convert appropriately
71
+ if arg.HasField('complex_array'):
72
+ # Convert ComplexArray to a numpy array of complex numbers
73
+ data = [complex(cn.real, cn.imag) for cn in arg.complex_array.data]
74
+ return np.array(data, dtype=np.complex64)
75
+ elif arg.HasField('float_array'):
76
+ # Convert FloatArray to a numpy array of floats
77
+ return np.array(arg.float_array.data, dtype=np.float32)
78
+ else:
79
+ raise ValueError("Argument does not contain a recognizable array.")
80
+
File without changes
File without changes