remoterf 0.0.7.41__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of remoterf might be problematic. Click here for more details.
- remoteRF/__init__.py +0 -0
- remoteRF/common/__init__.py +2 -0
- remoteRF/common/grpc/__init__.py +1 -0
- remoteRF/common/grpc/grpc_pb2.py +59 -0
- remoteRF/common/grpc/grpc_pb2_grpc.py +97 -0
- remoteRF/common/utils/__init__.py +4 -0
- remoteRF/common/utils/ansi_codes.py +120 -0
- remoteRF/common/utils/api_token.py +31 -0
- remoteRF/common/utils/list_string.py +5 -0
- remoteRF/common/utils/process_arg.py +80 -0
- remoteRF/core/__init__.py +2 -0
- remoteRF/core/acc_login.py +4 -0
- remoteRF/core/app.py +508 -0
- remoteRF/core/cert_fetcher.py +140 -0
- remoteRF/core/certs/__init__.py +0 -0
- remoteRF/core/certs/ca.crt +32 -0
- remoteRF/core/certs/ca.key +52 -0
- remoteRF/core/certs/cert.pem +19 -0
- remoteRF/core/certs/key.pem +28 -0
- remoteRF/core/certs/server.crt +19 -0
- remoteRF/core/certs/server.key +28 -0
- remoteRF/core/config.py +143 -0
- remoteRF/core/grpc_acc.py +52 -0
- remoteRF/core/grpc_client.py +100 -0
- remoteRF/core/version.py +8 -0
- remoteRF/drivers/__init__.py +0 -0
- remoteRF/drivers/adalm_pluto/__init__.py +1 -0
- remoteRF/drivers/adalm_pluto/pluto_remote.py +249 -0
- remoterf-0.0.7.41.dist-info/METADATA +158 -0
- remoterf-0.0.7.41.dist-info/RECORD +33 -0
- remoterf-0.0.7.41.dist-info/WHEEL +5 -0
- remoterf-0.0.7.41.dist-info/entry_points.txt +4 -0
- remoterf-0.0.7.41.dist-info/top_level.txt +1 -0
remoteRF/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from . import grpc_pb2, grpc_pb2_grpc
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
|
3
|
+
# NO CHECKED-IN PROTOBUF GENCODE
|
|
4
|
+
# source: grpc.proto
|
|
5
|
+
# Protobuf Python Version: 5.29.0
|
|
6
|
+
"""Generated protocol buffer code."""
|
|
7
|
+
from google.protobuf import descriptor as _descriptor
|
|
8
|
+
from google.protobuf import descriptor_pool as _descriptor_pool
|
|
9
|
+
from google.protobuf import runtime_version as _runtime_version
|
|
10
|
+
from google.protobuf import symbol_database as _symbol_database
|
|
11
|
+
from google.protobuf.internal import builder as _builder
|
|
12
|
+
_runtime_version.ValidateProtobufRuntimeVersion(
|
|
13
|
+
_runtime_version.Domain.PUBLIC,
|
|
14
|
+
5,
|
|
15
|
+
29,
|
|
16
|
+
0,
|
|
17
|
+
'',
|
|
18
|
+
'grpc.proto'
|
|
19
|
+
)
|
|
20
|
+
# @@protoc_insertion_point(imports)
|
|
21
|
+
|
|
22
|
+
_sym_db = _symbol_database.Default()
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\ngrpc.proto\x12\tremote_rf\"\xa2\x01\n\x11GenericRPCRequest\x12\x15\n\rfunction_name\x18\x01 \x01(\t\x12\x34\n\x04\x61rgs\x18\x02 \x03(\x0b\x32&.remote_rf.GenericRPCRequest.ArgsEntry\x1a@\n\tArgsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.remote_rf.Argument:\x02\x38\x01\"\x96\x01\n\x12GenericRPCResponse\x12;\n\x07results\x18\x01 \x03(\x0b\x32*.remote_rf.GenericRPCResponse.ResultsEntry\x1a\x43\n\x0cResultsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.remote_rf.Argument:\x02\x38\x01\"\x19\n\nArrayShape\x12\x0b\n\x03\x64im\x18\x01 \x03(\x05\"+\n\rComplexNumber\x12\x0c\n\x04real\x18\x01 \x01(\x02\x12\x0c\n\x04imag\x18\x02 \x01(\x02\"a\n\x11\x43omplexNumpyArray\x12$\n\x05shape\x18\x01 \x01(\x0b\x32\x15.remote_rf.ArrayShape\x12&\n\x04\x64\x61ta\x18\x02 \x03(\x0b\x32\x18.remote_rf.ComplexNumber\"D\n\x0eRealNumpyArray\x12$\n\x05shape\x18\x01 \x01(\x0b\x32\x15.remote_rf.ArrayShape\x12\x0c\n\x04\x64\x61ta\x18\x02 \x03(\x02\"\xd7\x01\n\x08\x41rgument\x12\x16\n\x0cstring_value\x18\x01 \x01(\tH\x00\x12\x15\n\x0bint64_value\x18\x02 \x01(\x03H\x00\x12\x15\n\x0b\x66loat_value\x18\x03 \x01(\x02H\x00\x12\x14\n\nbool_value\x18\x04 \x01(\x08H\x00\x12\x35\n\rcomplex_array\x18\x05 \x01(\x0b\x32\x1c.remote_rf.ComplexNumpyArrayH\x00\x12/\n\nreal_array\x18\x06 \x01(\x0b\x32\x19.remote_rf.RealNumpyArrayH\x00\x42\x07\n\x05value2Q\n\nGenericRPC\x12\x43\n\x04\x43\x61ll\x12\x1c.remote_rf.GenericRPCRequest\x1a\x1d.remote_rf.GenericRPCResponseB\x1e\n\x10\x63om.example.demoB\nDemoProtosb\x06proto3')
|
|
28
|
+
|
|
29
|
+
_globals = globals()
|
|
30
|
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
31
|
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'grpc_pb2', _globals)
|
|
32
|
+
if not _descriptor._USE_C_DESCRIPTORS:
|
|
33
|
+
_globals['DESCRIPTOR']._loaded_options = None
|
|
34
|
+
_globals['DESCRIPTOR']._serialized_options = b'\n\020com.example.demoB\nDemoProtos'
|
|
35
|
+
_globals['_GENERICRPCREQUEST_ARGSENTRY']._loaded_options = None
|
|
36
|
+
_globals['_GENERICRPCREQUEST_ARGSENTRY']._serialized_options = b'8\001'
|
|
37
|
+
_globals['_GENERICRPCRESPONSE_RESULTSENTRY']._loaded_options = None
|
|
38
|
+
_globals['_GENERICRPCRESPONSE_RESULTSENTRY']._serialized_options = b'8\001'
|
|
39
|
+
_globals['_GENERICRPCREQUEST']._serialized_start=26
|
|
40
|
+
_globals['_GENERICRPCREQUEST']._serialized_end=188
|
|
41
|
+
_globals['_GENERICRPCREQUEST_ARGSENTRY']._serialized_start=124
|
|
42
|
+
_globals['_GENERICRPCREQUEST_ARGSENTRY']._serialized_end=188
|
|
43
|
+
_globals['_GENERICRPCRESPONSE']._serialized_start=191
|
|
44
|
+
_globals['_GENERICRPCRESPONSE']._serialized_end=341
|
|
45
|
+
_globals['_GENERICRPCRESPONSE_RESULTSENTRY']._serialized_start=274
|
|
46
|
+
_globals['_GENERICRPCRESPONSE_RESULTSENTRY']._serialized_end=341
|
|
47
|
+
_globals['_ARRAYSHAPE']._serialized_start=343
|
|
48
|
+
_globals['_ARRAYSHAPE']._serialized_end=368
|
|
49
|
+
_globals['_COMPLEXNUMBER']._serialized_start=370
|
|
50
|
+
_globals['_COMPLEXNUMBER']._serialized_end=413
|
|
51
|
+
_globals['_COMPLEXNUMPYARRAY']._serialized_start=415
|
|
52
|
+
_globals['_COMPLEXNUMPYARRAY']._serialized_end=512
|
|
53
|
+
_globals['_REALNUMPYARRAY']._serialized_start=514
|
|
54
|
+
_globals['_REALNUMPYARRAY']._serialized_end=582
|
|
55
|
+
_globals['_ARGUMENT']._serialized_start=585
|
|
56
|
+
_globals['_ARGUMENT']._serialized_end=800
|
|
57
|
+
_globals['_GENERICRPC']._serialized_start=802
|
|
58
|
+
_globals['_GENERICRPC']._serialized_end=883
|
|
59
|
+
# @@protoc_insertion_point(module_scope)
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
|
2
|
+
"""Client and server classes corresponding to protobuf-defined services."""
|
|
3
|
+
import grpc
|
|
4
|
+
import warnings
|
|
5
|
+
|
|
6
|
+
from . import grpc_pb2 as grpc__pb2
|
|
7
|
+
|
|
8
|
+
GRPC_GENERATED_VERSION = '1.71.0'
|
|
9
|
+
GRPC_VERSION = grpc.__version__
|
|
10
|
+
_version_not_supported = False
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
from grpc._utilities import first_version_is_lower
|
|
14
|
+
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
|
|
15
|
+
except ImportError:
|
|
16
|
+
_version_not_supported = True
|
|
17
|
+
|
|
18
|
+
if _version_not_supported:
|
|
19
|
+
raise RuntimeError(
|
|
20
|
+
f'The grpc package installed is at version {GRPC_VERSION},'
|
|
21
|
+
+ f' but the generated code in grpc_pb2_grpc.py depends on'
|
|
22
|
+
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
|
|
23
|
+
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
|
|
24
|
+
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class GenericRPCStub(object):
|
|
29
|
+
"""Missing associated documentation comment in .proto file."""
|
|
30
|
+
|
|
31
|
+
def __init__(self, channel):
|
|
32
|
+
"""Constructor.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
channel: A grpc.Channel.
|
|
36
|
+
"""
|
|
37
|
+
self.Call = channel.unary_unary(
|
|
38
|
+
'/remote_rf.GenericRPC/Call',
|
|
39
|
+
request_serializer=grpc__pb2.GenericRPCRequest.SerializeToString,
|
|
40
|
+
response_deserializer=grpc__pb2.GenericRPCResponse.FromString,
|
|
41
|
+
_registered_method=True)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class GenericRPCServicer(object):
|
|
45
|
+
"""Missing associated documentation comment in .proto file."""
|
|
46
|
+
|
|
47
|
+
def Call(self, request, context):
|
|
48
|
+
"""Missing associated documentation comment in .proto file."""
|
|
49
|
+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
50
|
+
context.set_details('Method not implemented!')
|
|
51
|
+
raise NotImplementedError('Method not implemented!')
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def add_GenericRPCServicer_to_server(servicer, server):
|
|
55
|
+
rpc_method_handlers = {
|
|
56
|
+
'Call': grpc.unary_unary_rpc_method_handler(
|
|
57
|
+
servicer.Call,
|
|
58
|
+
request_deserializer=grpc__pb2.GenericRPCRequest.FromString,
|
|
59
|
+
response_serializer=grpc__pb2.GenericRPCResponse.SerializeToString,
|
|
60
|
+
),
|
|
61
|
+
}
|
|
62
|
+
generic_handler = grpc.method_handlers_generic_handler(
|
|
63
|
+
'remote_rf.GenericRPC', rpc_method_handlers)
|
|
64
|
+
server.add_generic_rpc_handlers((generic_handler,))
|
|
65
|
+
server.add_registered_method_handlers('remote_rf.GenericRPC', rpc_method_handlers)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
# This class is part of an EXPERIMENTAL API.
|
|
69
|
+
class GenericRPC(object):
|
|
70
|
+
"""Missing associated documentation comment in .proto file."""
|
|
71
|
+
|
|
72
|
+
@staticmethod
|
|
73
|
+
def Call(request,
|
|
74
|
+
target,
|
|
75
|
+
options=(),
|
|
76
|
+
channel_credentials=None,
|
|
77
|
+
call_credentials=None,
|
|
78
|
+
insecure=False,
|
|
79
|
+
compression=None,
|
|
80
|
+
wait_for_ready=None,
|
|
81
|
+
timeout=None,
|
|
82
|
+
metadata=None):
|
|
83
|
+
return grpc.experimental.unary_unary(
|
|
84
|
+
request,
|
|
85
|
+
target,
|
|
86
|
+
'/remote_rf.GenericRPC/Call',
|
|
87
|
+
grpc__pb2.GenericRPCRequest.SerializeToString,
|
|
88
|
+
grpc__pb2.GenericRPCResponse.FromString,
|
|
89
|
+
options,
|
|
90
|
+
channel_credentials,
|
|
91
|
+
insecure,
|
|
92
|
+
call_credentials,
|
|
93
|
+
compression,
|
|
94
|
+
wait_for_ready,
|
|
95
|
+
timeout,
|
|
96
|
+
metadata,
|
|
97
|
+
_registered_method=True)
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
from prompt_toolkit.styles import Style
|
|
2
|
+
from prompt_toolkit.formatted_text import FormattedText
|
|
3
|
+
from prompt_toolkit import print_formatted_text
|
|
4
|
+
from enum import Enum
|
|
5
|
+
|
|
6
|
+
class Sty(Enum):
|
|
7
|
+
# Basic colors
|
|
8
|
+
RED = 'red'
|
|
9
|
+
GREEN = 'green'
|
|
10
|
+
BLUE = 'blue'
|
|
11
|
+
YELLOW = 'yellow'
|
|
12
|
+
MAGENTA = 'magenta'
|
|
13
|
+
CYAN = 'cyan'
|
|
14
|
+
GRAY = 'gray'
|
|
15
|
+
|
|
16
|
+
# Background colors
|
|
17
|
+
BG_RED = 'bg-red'
|
|
18
|
+
BG_GREEN = 'bg-green'
|
|
19
|
+
BG_BLUE = 'bg-blue'
|
|
20
|
+
|
|
21
|
+
# Bright versions
|
|
22
|
+
BRIGHT_RED = 'bright-red'
|
|
23
|
+
BRIGHT_GREEN = 'bright-green'
|
|
24
|
+
BRIGHT_BLUE = 'bright-blue'
|
|
25
|
+
|
|
26
|
+
# Formatting
|
|
27
|
+
BOLD = 'bold'
|
|
28
|
+
ITALIC = 'italic'
|
|
29
|
+
UNDERLINE = 'underline'
|
|
30
|
+
BLINK = 'blink'
|
|
31
|
+
REVERSE = 'reverse'
|
|
32
|
+
|
|
33
|
+
# Combinations
|
|
34
|
+
ERROR = 'error'
|
|
35
|
+
WARNING = 'warning'
|
|
36
|
+
INFO = 'info'
|
|
37
|
+
|
|
38
|
+
# Special
|
|
39
|
+
SELECTED = 'selected'
|
|
40
|
+
DEFAULT = 'default'
|
|
41
|
+
|
|
42
|
+
# Define the styles based on ANSI codes
|
|
43
|
+
style = Style.from_dict({
|
|
44
|
+
# Basic colors
|
|
45
|
+
'red': 'fg:#110000',
|
|
46
|
+
'green': 'fg:#003300',
|
|
47
|
+
'blue': 'fg:#0000ff',
|
|
48
|
+
'yellow': 'fg:#ffff00',
|
|
49
|
+
'magenta': 'fg:#ff00ff',
|
|
50
|
+
'cyan': 'fg:#00ffff',
|
|
51
|
+
'gray': 'fg:#808080',
|
|
52
|
+
|
|
53
|
+
# Bright versions
|
|
54
|
+
'bright-red': 'fg:#ff5555',
|
|
55
|
+
'bright-green': 'fg:#00ff00',
|
|
56
|
+
'bright-blue': 'fg:#5555ff',
|
|
57
|
+
|
|
58
|
+
# Formatting
|
|
59
|
+
'bold': 'bold',
|
|
60
|
+
'italic': 'italic',
|
|
61
|
+
'underline': 'underline',
|
|
62
|
+
'reverse': 'reverse',
|
|
63
|
+
|
|
64
|
+
# Combinations
|
|
65
|
+
'error': 'bg:#ff0000 fg:#ffffff bold',
|
|
66
|
+
'warning': 'bg:#ffff00 fg:#000000 bold',
|
|
67
|
+
'info': 'bg:#0000ff fg:#ffffff italic underline',
|
|
68
|
+
|
|
69
|
+
# Special
|
|
70
|
+
'selected': 'bg:#ffffff #000000 reverse',
|
|
71
|
+
'default':''
|
|
72
|
+
})
|
|
73
|
+
|
|
74
|
+
def printf(*args) -> str:
|
|
75
|
+
if len(args) % 2 != 0:
|
|
76
|
+
raise ValueError('Arguments must be in pairs of two.')
|
|
77
|
+
|
|
78
|
+
# Create formatted text using the defined style
|
|
79
|
+
formatted_text = []
|
|
80
|
+
|
|
81
|
+
for i in range(0, len(args), 2):
|
|
82
|
+
message = args[i]
|
|
83
|
+
styles = args[i+1]
|
|
84
|
+
|
|
85
|
+
if not isinstance(styles, tuple):
|
|
86
|
+
styles = (styles,)
|
|
87
|
+
|
|
88
|
+
resolved_styles = (s.value if isinstance(s, Enum) else s for s in styles)
|
|
89
|
+
|
|
90
|
+
style_class = ' '.join(resolved_styles)
|
|
91
|
+
|
|
92
|
+
formatted_text.append(('class:' + style_class, message))
|
|
93
|
+
|
|
94
|
+
# Create FormattedText object from pairs
|
|
95
|
+
text = FormattedText(formatted_text)
|
|
96
|
+
|
|
97
|
+
print_formatted_text(text, style=style)
|
|
98
|
+
return text
|
|
99
|
+
|
|
100
|
+
def stylize(*args):
|
|
101
|
+
"""
|
|
102
|
+
Create a styled prompt text based on pairs of (text, (Sty, ...), ...).
|
|
103
|
+
"""
|
|
104
|
+
if len(args) % 2 != 0:
|
|
105
|
+
raise ValueError("Arguments must be in pairs of (text, style_class).")
|
|
106
|
+
|
|
107
|
+
styled_parts = []
|
|
108
|
+
for i in range(0, len(args), 2):
|
|
109
|
+
text = args[i]
|
|
110
|
+
styles = args[i + 1]
|
|
111
|
+
|
|
112
|
+
if not isinstance(styles, tuple):
|
|
113
|
+
styles = (styles,)
|
|
114
|
+
|
|
115
|
+
resolved_styles = (s.value if isinstance(s, Enum) else s for s in styles)
|
|
116
|
+
|
|
117
|
+
style_class = ' '.join(resolved_styles)
|
|
118
|
+
styled_parts.append(('class:' + style_class, text))
|
|
119
|
+
|
|
120
|
+
return FormattedText(styled_parts)
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import hashlib
|
|
3
|
+
import base64
|
|
4
|
+
import secrets
|
|
5
|
+
from dotenv import load_dotenv, find_dotenv
|
|
6
|
+
|
|
7
|
+
### API Token Management
|
|
8
|
+
# API Tokens are used to authenticate clients to the device
|
|
9
|
+
# API Tokens are stored locally on the device in its .env file
|
|
10
|
+
|
|
11
|
+
def generate_token(length=8) -> tuple[str, str, str]:
|
|
12
|
+
random_bytes = secrets.token_bytes(length) # Generate a random byte string
|
|
13
|
+
token = base64.urlsafe_b64encode(random_bytes).decode('utf-8').rstrip('=') # Encode the byte string in a URL-safe base64 format
|
|
14
|
+
salt = os.urandom(16).hex() # 16 bytes of random salt
|
|
15
|
+
hashed = hashlib.sha256(bytes.fromhex(salt) + token.encode()).hexdigest() # Hash to sha256 standard
|
|
16
|
+
return salt, hashed, token
|
|
17
|
+
|
|
18
|
+
def validate_token(salt, hash, token) -> bool:
|
|
19
|
+
new_hashed = hashlib.sha256(bytes.fromhex(salt) + token.encode()).hexdigest()
|
|
20
|
+
return new_hashed == hash
|
|
21
|
+
|
|
22
|
+
def hash_token(token: str) -> tuple[str, str]:
|
|
23
|
+
salt = os.urandom(16).hex()
|
|
24
|
+
hashed = hashlib.sha256(bytes.fromhex(salt) + token.encode()).hexdigest()
|
|
25
|
+
return salt, hashed
|
|
26
|
+
|
|
27
|
+
# Example Usage:
|
|
28
|
+
# if __name__ == '__main__':
|
|
29
|
+
# okay = generate_token()
|
|
30
|
+
# assert validate_token(okay[0], okay[1], okay[2] + "s"), "Token validation failed!"
|
|
31
|
+
# print("Token validation succeeded!")
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
from ..grpc import grpc_pb2, grpc_pb2_grpc
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
def unmap_arg(arg):
|
|
5
|
+
if arg.HasField('int64_value'):
|
|
6
|
+
return arg.int64_value
|
|
7
|
+
elif arg.HasField('float_value'):
|
|
8
|
+
return arg.float_value
|
|
9
|
+
elif arg.HasField('string_value'):
|
|
10
|
+
return arg.string_value
|
|
11
|
+
elif arg.HasField('bool_value'):
|
|
12
|
+
return arg.bool_value
|
|
13
|
+
elif arg.HasField('real_array'):
|
|
14
|
+
shape = tuple(arg.real_array.shape.dim)
|
|
15
|
+
return np.array(arg.real_array.data, dtype=np.float64).reshape(shape)
|
|
16
|
+
elif arg.HasField('complex_array'):
|
|
17
|
+
shape = tuple(arg.complex_array.shape.dim)
|
|
18
|
+
data = [complex(c.real, c.imag) for c in arg.complex_array.data]
|
|
19
|
+
return np.array(data, dtype=np.complex64).reshape(shape)
|
|
20
|
+
else:
|
|
21
|
+
raise ValueError(f"Unknown argument type during unmapping: {arg}")
|
|
22
|
+
|
|
23
|
+
def map_arg(value):
|
|
24
|
+
arg = grpc_pb2.Argument()
|
|
25
|
+
|
|
26
|
+
if isinstance(value, int):
|
|
27
|
+
arg.int64_value = value
|
|
28
|
+
elif isinstance(value, float):
|
|
29
|
+
arg.float_value = value
|
|
30
|
+
elif isinstance(value, str):
|
|
31
|
+
arg.string_value = value
|
|
32
|
+
elif isinstance(value, bool):
|
|
33
|
+
arg.bool_value = value
|
|
34
|
+
elif isinstance(value, np.ndarray):
|
|
35
|
+
if np.iscomplexobj(value):
|
|
36
|
+
complex_array = arg.complex_array
|
|
37
|
+
complex_array.shape.dim.extend(value.shape)
|
|
38
|
+
for num in value.ravel():
|
|
39
|
+
complex_num = complex_array.data.add()
|
|
40
|
+
complex_num.real = num.real
|
|
41
|
+
complex_num.imag = num.imag
|
|
42
|
+
else:
|
|
43
|
+
float_array = arg.real_array
|
|
44
|
+
float_array.shape.dim.extend(value.shape)
|
|
45
|
+
float_array.data.extend(value.ravel())
|
|
46
|
+
else:
|
|
47
|
+
raise ValueError(f"Unknown argument type during mapping: {value}")
|
|
48
|
+
return arg
|
|
49
|
+
|
|
50
|
+
def map_array_proto(np_array):
|
|
51
|
+
arg = grpc_pb2.Argument()
|
|
52
|
+
|
|
53
|
+
# Check if the array is complex
|
|
54
|
+
if np.iscomplexobj(np_array):
|
|
55
|
+
complex_array = grpc_pb2.ComplexArray()
|
|
56
|
+
for num in np_array.flat:
|
|
57
|
+
complex_number = complex_array.data.add()
|
|
58
|
+
complex_number.real = num.real
|
|
59
|
+
complex_number.imag = num.imag
|
|
60
|
+
arg.complex_array.CopyFrom(complex_array)
|
|
61
|
+
else:
|
|
62
|
+
# Handle as a regular float array
|
|
63
|
+
float_array = grpc_pb2.FloatArray()
|
|
64
|
+
float_array.data.extend(np_array.flat)
|
|
65
|
+
arg.float_array.CopyFrom(float_array)
|
|
66
|
+
|
|
67
|
+
return arg
|
|
68
|
+
|
|
69
|
+
def unmap_array_proto(arg):
|
|
70
|
+
# Check which type of array is available and convert appropriately
|
|
71
|
+
if arg.HasField('complex_array'):
|
|
72
|
+
# Convert ComplexArray to a numpy array of complex numbers
|
|
73
|
+
data = [complex(cn.real, cn.imag) for cn in arg.complex_array.data]
|
|
74
|
+
return np.array(data, dtype=np.complex64)
|
|
75
|
+
elif arg.HasField('float_array'):
|
|
76
|
+
# Convert FloatArray to a numpy array of floats
|
|
77
|
+
return np.array(arg.float_array.data, dtype=np.float32)
|
|
78
|
+
else:
|
|
79
|
+
raise ValueError("Argument does not contain a recognizable array.")
|
|
80
|
+
|