gnetclisdk 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gnetclisdk/auth.py ADDED
@@ -0,0 +1,41 @@
1
+ import abc
2
+
3
+
4
+ class ClientAuthentication(abc.ABC):
5
+ @abc.abstractmethod
6
+ def get_authentication_header_key(self) -> str:
7
+ """
8
+ Name of the header used for authentication.
9
+ :return: Header name.
10
+ """
11
+ raise NotImplementedError("abstract method get_authentication_header_key() not implemented")
12
+
13
+ @abc.abstractmethod
14
+ def create_authentication_header_value(self) -> str:
15
+ """
16
+ Creates value for authentication header.
17
+ :return: Authentication header value.
18
+ """
19
+ raise NotImplementedError("abstract method create_authentication_header_value() not implemented")
20
+
21
+
22
+ class OAuthClientAuthentication(ClientAuthentication):
23
+ def __init__(self, token: str):
24
+ self.__token = token
25
+
26
+ def get_authentication_header_key(self) -> str:
27
+ return "authorization"
28
+
29
+ def create_authentication_header_value(self) -> str:
30
+ return f"OAuth {self.__token}"
31
+
32
+
33
+ class BasicClientAuthentication(ClientAuthentication):
34
+ def __init__(self, token: str):
35
+ self.__token = token
36
+
37
+ def get_authentication_header_key(self) -> str:
38
+ return "authorization"
39
+
40
+ def create_authentication_header_value(self) -> str:
41
+ return f"Basic {self.__token}"
gnetclisdk/client.py ADDED
@@ -0,0 +1,463 @@
1
+ import asyncio
2
+ import logging
3
+ import os.path
4
+ import uuid
5
+ from abc import ABC, abstractmethod
6
+ from contextlib import asynccontextmanager
7
+ from dataclasses import dataclass, field
8
+ from functools import partial
9
+ from typing import Any, AsyncIterator, List, Optional, Tuple, Dict
10
+
11
+ import grpc
12
+ from google.protobuf.message import Message
13
+
14
+ from .proto import server_pb2, server_pb2_grpc
15
+ from .auth import BasicClientAuthentication, ClientAuthentication, OAuthClientAuthentication
16
+ from .exceptions import parse_grpc_error
17
+ from .interceptors import get_auth_client_interceptors
18
+
19
+ _logger = logging.getLogger(__name__)
20
+ HEADER_REQUEST_ID = "x-request-id"
21
+ HEADER_USER_AGENT = "user-agent"
22
+ DEFAULT_USER_AGENT = "Gnetcli SDK"
23
+ DEFAULT_SERVER = "localhost:50051"
24
+ SERVER_ENV = "GNETCLI_SERVER"
25
+ GRPC_MAX_MESSAGE_LENGTH = 130 * 1024**2
26
+
27
+ default_grpc_options: List[Tuple[str, Any]] = [
28
+ ("grpc.max_concurrent_streams", 900),
29
+ ("grpc.max_send_message_length", GRPC_MAX_MESSAGE_LENGTH),
30
+ ("grpc.max_receive_message_length", GRPC_MAX_MESSAGE_LENGTH),
31
+ ]
32
+
33
+
34
+ @dataclass
35
+ class QA:
36
+ question: str
37
+ answer: str
38
+
39
+
40
+ @dataclass
41
+ class Credentials:
42
+ login: str
43
+ password: str
44
+
45
+ def make_pb(self) -> Message:
46
+ pb = server_pb2.Credentials()
47
+ pb.login = self.login
48
+ pb.password = self.password
49
+ return pb
50
+
51
+
52
+ @dataclass
53
+ class File:
54
+ content: bytes
55
+ status: server_pb2.FileStatus
56
+
57
+
58
+ @dataclass
59
+ class HostParams:
60
+ device: str
61
+ port: Optional[int] = None
62
+ credentials: Optional[Credentials] = None
63
+
64
+
65
+ def make_auth(auth_token: str) -> ClientAuthentication:
66
+ if auth_token.lower().startswith("oauth"):
67
+ authentication = OAuthClientAuthentication(auth_token.split(" ")[1])
68
+ elif auth_token.lower().startswith("basic"):
69
+ authentication = BasicClientAuthentication(auth_token.split(" ")[1])
70
+ else:
71
+ raise Exception("unknown token type")
72
+ return authentication
73
+
74
+
75
+ class Gnetcli:
76
+ def __init__(
77
+ self,
78
+ auth_token: Optional[str] = None, # like 'Basic ...'
79
+ server: Optional[str] = None,
80
+ target_name_override: Optional[str] = None,
81
+ cert_file: Optional[str] = None,
82
+ user_agent: str = DEFAULT_USER_AGENT,
83
+ insecure_grpc: bool = False,
84
+ ):
85
+ if server is None:
86
+ self._server = os.getenv(SERVER_ENV, DEFAULT_SERVER)
87
+ else:
88
+ self._server = server
89
+ self._user_agent = user_agent
90
+
91
+ options: List[Tuple[str, Any]] = [
92
+ *default_grpc_options,
93
+ ("grpc.primary_user_agent", user_agent),
94
+ ]
95
+ if target_name_override:
96
+ _logger.warning("set target_name_override %s", target_name_override)
97
+ options.append(("grpc.ssl_target_name_override", target_name_override))
98
+ self._target_name_override = target_name_override
99
+ cert = get_cert(cert_file=cert_file)
100
+ channel_credentials = grpc.ssl_channel_credentials(root_certificates=cert)
101
+ interceptors = []
102
+ if auth_token:
103
+ authentication: ClientAuthentication
104
+ authentication = make_auth(auth_token)
105
+ interceptors = get_auth_client_interceptors(authentication)
106
+ grpc_channel_fn = partial(grpc.aio.secure_channel, credentials=channel_credentials, interceptors=interceptors)
107
+ if insecure_grpc:
108
+ grpc_channel_fn = partial(grpc.aio.insecure_channel, interceptors=interceptors)
109
+ self._grpc_channel_fn = grpc_channel_fn
110
+ self._options = options
111
+ self._channel: Optional[grpc.aio.Channel] = None
112
+ self._insecure_grpc: bool = insecure_grpc
113
+
114
+ async def cmd(
115
+ self,
116
+ hostname: str,
117
+ cmd: str,
118
+ trace: bool = False,
119
+ qa: Optional[List[QA]] = None,
120
+ read_timeout: float = 0.0,
121
+ cmd_timeout: float = 0.0,
122
+ ) -> Message:
123
+ pbcmd = make_cmd(
124
+ hostname=hostname,
125
+ cmd=cmd,
126
+ trace=trace,
127
+ qa=qa,
128
+ read_timeout=read_timeout,
129
+ cmd_timeout=cmd_timeout,
130
+ )
131
+ if self._channel is None:
132
+ _logger.debug("connect to %s", self._server)
133
+ self._channel = self._grpc_channel_fn(self._server, options=self._options)
134
+ stub = server_pb2_grpc.GnetcliStub(self._channel)
135
+ response = await grpc_call_wrapper(stub.Exec, pbcmd)
136
+ return response
137
+
138
+ async def add_device(
139
+ self,
140
+ name: str,
141
+ prompt_expression: str,
142
+ error_expression: Optional[str] = None,
143
+ pager_expression: Optional[str] = None,
144
+ ) -> Message:
145
+ pbdev = server_pb2.Device
146
+ pbdev.name = name
147
+ pbdev.prompt_expression = prompt_expression
148
+ if error_expression:
149
+ pbdev.error_expression = error_expression
150
+ if pager_expression:
151
+ pbdev.pager_expression = pager_expression
152
+ if self._channel is None:
153
+ _logger.debug("connect to %s", self._server)
154
+ self._channel = self._grpc_channel_fn(self._server, options=self._options)
155
+ stub = server_pb2_grpc.GnetcliStub(self._channel)
156
+ response = await grpc_call_wrapper(stub.AddDevice, pbdev)
157
+ return response
158
+
159
+ def connect(self) -> None:
160
+ # make connection here will pass it to session
161
+ if not self._channel:
162
+ _logger.debug("real connect to %s", self._server)
163
+ self._channel = self._grpc_channel_fn(self._server, options=self._options)
164
+
165
+ async def cmd_netconf(self, hostname: str, cmd: str, json: bool = False, trace: bool = False) -> Message:
166
+ pbcmd = server_pb2.CMDNetconf(host=hostname, cmd=cmd, json=json, trace=trace)
167
+ _logger.debug("connect to %s", self._server)
168
+ async with self._grpc_channel_fn(self._server, options=self._options) as channel:
169
+ stub = server_pb2_grpc.GnetcliStub(channel)
170
+ _logger.debug("executing netconf cmd: %r", pbcmd)
171
+ try:
172
+ response = await grpc_call_wrapper(stub.ExecNetconf, pbcmd)
173
+ except Exception as e:
174
+ _logger.error("error hostname=%s cmd=%r error=%s", hostname, repr(pbcmd), e)
175
+ raise
176
+ return response
177
+
178
+ @asynccontextmanager
179
+ async def cmd_session(self, hostname: str) -> AsyncIterator["GnetcliSessionCmd"]:
180
+ sess = GnetcliSessionCmd(
181
+ hostname,
182
+ server=self._server,
183
+ channel=self._channel,
184
+ target_name_override=self._target_name_override,
185
+ user_agent=self._user_agent,
186
+ insecure_grpc=self._insecure_grpc,
187
+ )
188
+ await sess.connect()
189
+ try:
190
+ yield sess
191
+ finally:
192
+ await sess.close()
193
+
194
+ @asynccontextmanager
195
+ async def netconf_session(self, hostname: str) -> AsyncIterator["GnetcliSessionNetconf"]:
196
+ sess = GnetcliSessionNetconf(
197
+ hostname,
198
+ # self._token,
199
+ server=self._server,
200
+ target_name_override=self._target_name_override,
201
+ user_agent=self._user_agent,
202
+ insecure_grpc=self._insecure_grpc,
203
+ )
204
+ await sess.connect()
205
+ try:
206
+ yield sess
207
+ finally:
208
+ await sess.close()
209
+
210
+ async def set_host_params(self, hostname: str, params: HostParams) -> None:
211
+ pbcmd = server_pb2.HostParams(
212
+ host=hostname,
213
+ port=params.port,
214
+ credentials=params.credentials.make_pb(),
215
+ device=params.device)
216
+ _logger.debug("connect to %s", self._server)
217
+ async with self._grpc_channel_fn(self._server, options=self._options) as channel:
218
+ _logger.debug("set params for %s", hostname)
219
+ stub = server_pb2_grpc.GnetcliStub(channel)
220
+ await grpc_call_wrapper(stub.SetupHostParams, pbcmd)
221
+ return
222
+
223
+ async def upload(self, hostname: str, files: Dict[str, File]) -> None:
224
+ pbcmd = server_pb2.FileUploadRequest(host=hostname, files=make_files_request(files))
225
+ _logger.debug("connect to %s", self._server)
226
+ async with self._grpc_channel_fn(self._server, options=self._options) as channel:
227
+ _logger.debug("upload %s to %s", files.keys(), hostname)
228
+ stub = server_pb2_grpc.GnetcliStub(channel)
229
+ response: Message = await grpc_call_wrapper(stub.Upload, pbcmd)
230
+ _logger.debug("upload res %s", response)
231
+ return
232
+
233
+ async def download(self, hostname: str, paths: List[str]) -> Dict[str, File]:
234
+ pbcmd = server_pb2.FileDownloadRequest(host=hostname, paths=paths)
235
+ _logger.debug("connect to %s", self._server)
236
+ async with self._grpc_channel_fn(self._server, options=self._options) as channel:
237
+ _logger.debug("download %s from %s", paths, hostname)
238
+ stub = server_pb2_grpc.GnetcliStub(channel)
239
+ response: server_pb2.FilesResult = await grpc_call_wrapper(stub.Download, pbcmd)
240
+ res: Dict[str, File] = {}
241
+ for file in response.files:
242
+ res[file.path] = File(content=file.data, status=file.status)
243
+ return res
244
+
245
+
246
+ class GnetcliSession(ABC):
247
+ def __init__(
248
+ self,
249
+ hostname: str,
250
+ token: str,
251
+ server: str = DEFAULT_SERVER,
252
+ target_name_override: Optional[str] = None,
253
+ cert_file: Optional[str] = None,
254
+ user_agent: str = DEFAULT_USER_AGENT,
255
+ insecure_grpc: bool = False,
256
+ channel: Optional[grpc.aio.Channel] = None,
257
+ credentials: Optional[Credentials] = None,
258
+ ):
259
+ self._hostname = hostname
260
+ self._credentials = credentials
261
+ self._server = server
262
+ self._channel: Optional[grpc.aio.Channel] = channel
263
+ self._stub: Optional[server_pb2_grpc.GnetcliStub] = None
264
+ self._stream: Optional[grpc.aio.StreamStreamCall] = None
265
+ self._user_agent = user_agent
266
+
267
+ options: List[Tuple[str, Any]] = [
268
+ ("grpc.max_concurrent_streams", 900),
269
+ ("grpc.max_send_message_length", GRPC_MAX_MESSAGE_LENGTH),
270
+ ("grpc.max_receive_message_length", GRPC_MAX_MESSAGE_LENGTH),
271
+ ]
272
+ if target_name_override:
273
+ options.append(("grpc.ssl_target_name_override", target_name_override))
274
+ cert = get_cert(cert_file=cert_file)
275
+ channel_credentials = grpc.ssl_channel_credentials(root_certificates=cert)
276
+ authentication: ClientAuthentication
277
+ if token.startswith("OAuth"):
278
+ authentication = OAuthClientAuthentication(token.split(" ")[1])
279
+ elif token.startswith("Basic"):
280
+ authentication = BasicClientAuthentication(token.split(" ")[1])
281
+ else:
282
+ raise Exception("unknown token type")
283
+ interceptors = get_auth_client_interceptors(authentication)
284
+ grpc_channel_fn = partial(grpc.aio.secure_channel, credentials=channel_credentials, interceptors=interceptors)
285
+ if insecure_grpc:
286
+ grpc_channel_fn = partial(grpc.aio.insecure_channel, interceptors=interceptors)
287
+ self._grpc_channel_fn = grpc_channel_fn
288
+ self._options = options
289
+ self._req_id: Optional[Any] = None
290
+
291
+ def _get_metadata(self) -> List[Tuple[str, str]]:
292
+ req_id = make_req_id()
293
+ metadata = [
294
+ (HEADER_REQUEST_ID, req_id),
295
+ (HEADER_USER_AGENT, self._user_agent),
296
+ ]
297
+ return metadata
298
+
299
+ @abstractmethod
300
+ async def connect(self) -> None:
301
+ if self._channel is None:
302
+ _logger.debug("connect to %s self._channel=%s", self._server, self._channel)
303
+ self._channel = self._grpc_channel_fn(self._server, options=self._options)
304
+ self._stub = server_pb2_grpc.GnetcliStub(self._channel)
305
+ if self._stub is None:
306
+ raise Exception("empty stub")
307
+
308
+ async def _cmd(self, cmdpb: Any) -> Message:
309
+ # TODO: add connect retry on first cmd
310
+ if not self._stream:
311
+ raise Exception("empty self._stream")
312
+ try:
313
+ _logger.debug("cmd %r on %r", str(cmdpb).replace("\n", ""), self._stream)
314
+ await self._stream.write(cmdpb)
315
+ response: Message = await self._stream.read()
316
+ except grpc.aio.AioRpcError as e:
317
+ _logger.debug("caught exception %s %s", e, parse_grpc_error(e))
318
+ gn_exc, verbose = parse_grpc_error(e)
319
+ last_exc = gn_exc(
320
+ message=f"{e.__class__.__name__} {e.details()}",
321
+ imetadata=e.initial_metadata(), # type: ignore
322
+ verbose=verbose,
323
+ )
324
+ last_exc.__cause__ = e
325
+ raise last_exc from None
326
+ _logger.debug("response %s", format_long_msg(str(response), 100))
327
+ return response
328
+
329
+ async def close(self) -> None:
330
+ _logger.debug("close stream %s", self._stream)
331
+ if self._stream:
332
+ await self._stream.done_writing()
333
+ self._stream.done()
334
+ self._stream = None
335
+
336
+
337
+ class GnetcliSessionCmd(GnetcliSession):
338
+ async def cmd(
339
+ self,
340
+ cmd: str,
341
+ trace: bool = False,
342
+ qa: Optional[List[QA]] = None,
343
+ cmd_timeout: float = 0.0,
344
+ read_timeout: float = 0.0,
345
+ ) -> Message:
346
+ _logger.debug("session cmd %r", cmd)
347
+ pbcmd = make_cmd(
348
+ hostname=self._hostname,
349
+ cmd=cmd,
350
+ trace=trace,
351
+ qa=qa,
352
+ read_timeout=read_timeout,
353
+ cmd_timeout=cmd_timeout,
354
+ )
355
+ return await self._cmd(pbcmd)
356
+
357
+ async def connect(self) -> None:
358
+ await super(GnetcliSessionCmd, self).connect()
359
+ if self._stub:
360
+ self._stream = self._stub.ExecChat(metadata=self._get_metadata())
361
+ else:
362
+ raise Exception()
363
+
364
+
365
+ class GnetcliSessionNetconf(GnetcliSession):
366
+ async def cmd(self, cmd: str, trace: bool = False, json: bool = False) -> Message:
367
+ _logger.debug("netconf session cmd %r", cmd)
368
+ cmdpb = server_pb2.CMDNetconf(host=self._hostname, cmd=cmd, json=json)
369
+ return await self._cmd(cmdpb)
370
+
371
+ async def connect(self) -> None:
372
+ await super(GnetcliSessionNetconf, self).connect()
373
+ if self._stub:
374
+ self._stream = self._stub.ExecNetconfChat(metadata=self._get_metadata())
375
+ else:
376
+ raise Exception()
377
+
378
+
379
+ async def grpc_call_wrapper(stub: grpc.UnaryUnaryMultiCallable, request: Any) -> Message:
380
+ last_exc: Optional[Exception] = None
381
+ response: Optional[Message] = None
382
+ for i in range(5):
383
+ req_id = make_req_id()
384
+ metadata = [
385
+ (HEADER_REQUEST_ID, req_id),
386
+ ]
387
+ _logger.debug("executing %s: %r, req_id=%s", type(request), repr(request), req_id)
388
+ await asyncio.sleep(i * 2)
389
+ try:
390
+ response = await stub(request=request, metadata=metadata)
391
+ except grpc.aio.AioRpcError as e:
392
+ _logger.debug("caught exception %s req_id=%s %s", e, req_id, parse_grpc_error(e))
393
+ gn_exc, verbose = parse_grpc_error(e)
394
+ last_exc = gn_exc(
395
+ message=f"{e.__class__.__name__} {e.details()}",
396
+ imetadata=e.initial_metadata(), # type: ignore
397
+ request_id=req_id,
398
+ verbose=verbose,
399
+ )
400
+ last_exc.__cause__ = e
401
+ raise last_exc from None
402
+ else:
403
+ last_exc = None
404
+ break
405
+
406
+ if last_exc is not None:
407
+ raise last_exc
408
+ if response is None:
409
+ raise Exception()
410
+ else:
411
+ return response
412
+
413
+
414
+ def make_req_id() -> str:
415
+ return str(uuid.uuid4())
416
+
417
+
418
+ def get_cert(cert_file: Optional[str]) -> Optional[bytes]:
419
+ cert: Optional[bytes] = None
420
+ if cert_file:
421
+ _logger.debug("open cert_file %s", cert_file)
422
+ with open(cert_file, "rb") as f:
423
+ cert = f.read()
424
+ return cert
425
+
426
+
427
+ def format_long_msg(msg: str, max_len: int) -> str:
428
+ if len(msg) <= max_len:
429
+ return msg
430
+ return "%s... and %s more" % (msg[:max_len], len(msg) - max_len)
431
+
432
+
433
+ def make_cmd(
434
+ hostname: str,
435
+ cmd: str,
436
+ trace: bool = False,
437
+ qa: Optional[List[QA]] = None,
438
+ read_timeout: float = 0.0,
439
+ cmd_timeout: float = 0.0,
440
+ ) -> Message:
441
+ qa_cmd: List[Message] = []
442
+ if qa:
443
+ for item in qa:
444
+ qaitem = server_pb2.QA()
445
+ qaitem.question = item.question
446
+ qaitem.answer = item.answer
447
+ qa_cmd.append(qaitem)
448
+ res = server_pb2.CMD(
449
+ host=hostname,
450
+ cmd=cmd,
451
+ trace=trace,
452
+ qa=qa_cmd,
453
+ read_timeout=read_timeout,
454
+ cmd_timeout=cmd_timeout,
455
+ )
456
+ return res # type: ignore
457
+
458
+
459
+ def make_files_request(files: Dict[str, File]) -> List[server_pb2.FileData]:
460
+ res: List[server_pb2.FileData] = []
461
+ for path, file in files.items():
462
+ res.append(server_pb2.FileData(path=path, data=file.content))
463
+ return res
@@ -0,0 +1,116 @@
1
+ from typing import Optional, Sequence, Tuple, Type, Union
2
+
3
+ import grpc.aio
4
+
5
+ MetadataType = Sequence[Tuple[str, Union[str, bytes]]]
6
+
7
+
8
+ def extract_metadata(m: MetadataType) -> dict:
9
+ # calling get from metadataType throws KeyError
10
+ metadata = {}
11
+ for k, v in m:
12
+ metadata[k] = v
13
+ return metadata
14
+
15
+
16
+ class GnetcliException(Exception):
17
+ def __init__(
18
+ self,
19
+ message: str = "",
20
+ imetadata: Optional[MetadataType] = None,
21
+ request_id: Optional[str] = None,
22
+ verbose: Optional[str] = "",
23
+ ):
24
+ self.message = message
25
+ if imetadata:
26
+ rs = extract_metadata(imetadata).get("real-server")
27
+ if rs:
28
+ self.message = f"{self.message} RS:{rs}"
29
+ if request_id:
30
+ self.message = f"{self.message} req_id:{request_id}"
31
+ if verbose:
32
+ self.message = f"{self.message} verbose:{verbose}"
33
+ super().__init__(self.message)
34
+
35
+
36
+ class DeviceConnectError(GnetcliException):
37
+ """
38
+ Problem with connection to a device.
39
+ """
40
+
41
+ pass
42
+
43
+
44
+ class UnknownDevice(GnetcliException):
45
+ """
46
+ Host is not found in inventory
47
+ """
48
+
49
+ pass
50
+
51
+
52
+ class DeviceAuthError(DeviceConnectError):
53
+ """
54
+ Unable to authenticate on a device.
55
+ """
56
+
57
+ pass
58
+
59
+
60
+ class ExecError(GnetcliException):
61
+ """
62
+ Error happened during execution.
63
+ """
64
+
65
+ pass
66
+
67
+
68
+ class NotReady(GnetcliException):
69
+ """
70
+ Server is not ready.
71
+ """
72
+
73
+ pass
74
+
75
+
76
+ class Unauthenticated(GnetcliException):
77
+ """
78
+ Unable to authenticate on Gnetcli server.
79
+ """
80
+
81
+ pass
82
+
83
+
84
+ class PermissionDenied(GnetcliException):
85
+ """
86
+ Permission denied.
87
+ """
88
+
89
+ pass
90
+
91
+
92
+ def parse_grpc_error(grpc_error: grpc.aio.AioRpcError) -> Tuple[Type[GnetcliException], str]:
93
+ code = grpc_error.code()
94
+ detail = ""
95
+ if grpc_error.details():
96
+ detail = grpc_error.details() # type: ignore
97
+ if code == grpc.StatusCode.UNAVAILABLE and detail == "not ready":
98
+ return NotReady, ""
99
+ if code == grpc.StatusCode.UNAUTHENTICATED:
100
+ return Unauthenticated, detail
101
+ if code == grpc.StatusCode.PERMISSION_DENIED:
102
+ return PermissionDenied, detail
103
+ if code == grpc.StatusCode.OUT_OF_RANGE:
104
+ return UnknownDevice, detail
105
+ if code == grpc.StatusCode.INTERNAL:
106
+ if detail == "auth_device_error":
107
+ verbose = ""
108
+ return DeviceAuthError, verbose
109
+ if detail in {"connection_error", "busy_error"}:
110
+ verbose = ""
111
+ return DeviceConnectError, verbose
112
+ elif detail in {"exec_error", "generic_error"}:
113
+ verbose = ""
114
+ return ExecError, verbose
115
+
116
+ return GnetcliException, ""