gnetclisdk 1.0.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
gnetclisdk/auth.py ADDED
@@ -0,0 +1,41 @@
1
+ import abc
2
+
3
+
4
+ class ClientAuthentication(abc.ABC):
5
+ @abc.abstractmethod
6
+ def get_authentication_header_key(self) -> str:
7
+ """
8
+ Name of the header used for authentication.
9
+ :return: Header name.
10
+ """
11
+ raise NotImplementedError("abstract method get_authentication_header_key() not implemented")
12
+
13
+ @abc.abstractmethod
14
+ def create_authentication_header_value(self) -> str:
15
+ """
16
+ Creates value for authentication header.
17
+ :return: Authentication header value.
18
+ """
19
+ raise NotImplementedError("abstract method create_authentication_header_value() not implemented")
20
+
21
+
22
+ class OAuthClientAuthentication(ClientAuthentication):
23
+ def __init__(self, token: str):
24
+ self.__token = token
25
+
26
+ def get_authentication_header_key(self) -> str:
27
+ return "authorization"
28
+
29
+ def create_authentication_header_value(self) -> str:
30
+ return f"OAuth {self.__token}"
31
+
32
+
33
+ class BasicClientAuthentication(ClientAuthentication):
34
+ def __init__(self, token: str):
35
+ self.__token = token
36
+
37
+ def get_authentication_header_key(self) -> str:
38
+ return "authorization"
39
+
40
+ def create_authentication_header_value(self) -> str:
41
+ return f"Basic {self.__token}"
gnetclisdk/client.py ADDED
@@ -0,0 +1,463 @@
1
+ import asyncio
2
+ import logging
3
+ import os.path
4
+ import uuid
5
+ from abc import ABC, abstractmethod
6
+ from contextlib import asynccontextmanager
7
+ from dataclasses import dataclass, field
8
+ from functools import partial
9
+ from typing import Any, AsyncIterator, List, Optional, Tuple, Dict
10
+
11
+ import grpc
12
+ from google.protobuf.message import Message
13
+
14
+ from .proto import server_pb2, server_pb2_grpc
15
+ from .auth import BasicClientAuthentication, ClientAuthentication, OAuthClientAuthentication
16
+ from .exceptions import parse_grpc_error
17
+ from .interceptors import get_auth_client_interceptors
18
+
19
+ _logger = logging.getLogger(__name__)
20
+ HEADER_REQUEST_ID = "x-request-id"
21
+ HEADER_USER_AGENT = "user-agent"
22
+ DEFAULT_USER_AGENT = "Gnetcli SDK"
23
+ DEFAULT_SERVER = "localhost:50051"
24
+ SERVER_ENV = "GNETCLI_SERVER"
25
+ GRPC_MAX_MESSAGE_LENGTH = 130 * 1024**2
26
+
27
+ default_grpc_options: List[Tuple[str, Any]] = [
28
+ ("grpc.max_concurrent_streams", 900),
29
+ ("grpc.max_send_message_length", GRPC_MAX_MESSAGE_LENGTH),
30
+ ("grpc.max_receive_message_length", GRPC_MAX_MESSAGE_LENGTH),
31
+ ]
32
+
33
+
34
+ @dataclass
35
+ class QA:
36
+ question: str
37
+ answer: str
38
+
39
+
40
+ @dataclass
41
+ class Credentials:
42
+ login: str
43
+ password: str
44
+
45
+ def make_pb(self) -> Message:
46
+ pb = server_pb2.Credentials()
47
+ pb.login = self.login
48
+ pb.password = self.password
49
+ return pb
50
+
51
+
52
+ @dataclass
53
+ class File:
54
+ content: bytes
55
+ status: server_pb2.FileStatus
56
+
57
+
58
+ @dataclass
59
+ class HostParams:
60
+ device: str
61
+ port: Optional[int] = None
62
+ credentials: Optional[Credentials] = None
63
+
64
+
65
+ def make_auth(auth_token: str) -> ClientAuthentication:
66
+ if auth_token.lower().startswith("oauth"):
67
+ authentication = OAuthClientAuthentication(auth_token.split(" ")[1])
68
+ elif auth_token.lower().startswith("basic"):
69
+ authentication = BasicClientAuthentication(auth_token.split(" ")[1])
70
+ else:
71
+ raise Exception("unknown token type")
72
+ return authentication
73
+
74
+
75
+ class Gnetcli:
76
+ def __init__(
77
+ self,
78
+ auth_token: Optional[str] = None, # like 'Basic ...'
79
+ server: Optional[str] = None,
80
+ target_name_override: Optional[str] = None,
81
+ cert_file: Optional[str] = None,
82
+ user_agent: str = DEFAULT_USER_AGENT,
83
+ insecure_grpc: bool = False,
84
+ ):
85
+ if server is None:
86
+ self._server = os.getenv(SERVER_ENV, DEFAULT_SERVER)
87
+ else:
88
+ self._server = server
89
+ self._user_agent = user_agent
90
+
91
+ options: List[Tuple[str, Any]] = [
92
+ *default_grpc_options,
93
+ ("grpc.primary_user_agent", user_agent),
94
+ ]
95
+ if target_name_override:
96
+ _logger.warning("set target_name_override %s", target_name_override)
97
+ options.append(("grpc.ssl_target_name_override", target_name_override))
98
+ self._target_name_override = target_name_override
99
+ cert = get_cert(cert_file=cert_file)
100
+ channel_credentials = grpc.ssl_channel_credentials(root_certificates=cert)
101
+ interceptors = []
102
+ if auth_token:
103
+ authentication: ClientAuthentication
104
+ authentication = make_auth(auth_token)
105
+ interceptors = get_auth_client_interceptors(authentication)
106
+ grpc_channel_fn = partial(grpc.aio.secure_channel, credentials=channel_credentials, interceptors=interceptors)
107
+ if insecure_grpc:
108
+ grpc_channel_fn = partial(grpc.aio.insecure_channel, interceptors=interceptors)
109
+ self._grpc_channel_fn = grpc_channel_fn
110
+ self._options = options
111
+ self._channel: Optional[grpc.aio.Channel] = None
112
+ self._insecure_grpc: bool = insecure_grpc
113
+
114
+ async def cmd(
115
+ self,
116
+ hostname: str,
117
+ cmd: str,
118
+ trace: bool = False,
119
+ qa: Optional[List[QA]] = None,
120
+ read_timeout: float = 0.0,
121
+ cmd_timeout: float = 0.0,
122
+ ) -> Message:
123
+ pbcmd = make_cmd(
124
+ hostname=hostname,
125
+ cmd=cmd,
126
+ trace=trace,
127
+ qa=qa,
128
+ read_timeout=read_timeout,
129
+ cmd_timeout=cmd_timeout,
130
+ )
131
+ if self._channel is None:
132
+ _logger.debug("connect to %s", self._server)
133
+ self._channel = self._grpc_channel_fn(self._server, options=self._options)
134
+ stub = server_pb2_grpc.GnetcliStub(self._channel)
135
+ response = await grpc_call_wrapper(stub.Exec, pbcmd)
136
+ return response
137
+
138
+ async def add_device(
139
+ self,
140
+ name: str,
141
+ prompt_expression: str,
142
+ error_expression: Optional[str] = None,
143
+ pager_expression: Optional[str] = None,
144
+ ) -> Message:
145
+ pbdev = server_pb2.Device
146
+ pbdev.name = name
147
+ pbdev.prompt_expression = prompt_expression
148
+ if error_expression:
149
+ pbdev.error_expression = error_expression
150
+ if pager_expression:
151
+ pbdev.pager_expression = pager_expression
152
+ if self._channel is None:
153
+ _logger.debug("connect to %s", self._server)
154
+ self._channel = self._grpc_channel_fn(self._server, options=self._options)
155
+ stub = server_pb2_grpc.GnetcliStub(self._channel)
156
+ response = await grpc_call_wrapper(stub.AddDevice, pbdev)
157
+ return response
158
+
159
+ def connect(self) -> None:
160
+ # make connection here will pass it to session
161
+ if not self._channel:
162
+ _logger.debug("real connect to %s", self._server)
163
+ self._channel = self._grpc_channel_fn(self._server, options=self._options)
164
+
165
+ async def cmd_netconf(self, hostname: str, cmd: str, json: bool = False, trace: bool = False) -> Message:
166
+ pbcmd = server_pb2.CMDNetconf(host=hostname, cmd=cmd, json=json, trace=trace)
167
+ _logger.debug("connect to %s", self._server)
168
+ async with self._grpc_channel_fn(self._server, options=self._options) as channel:
169
+ stub = server_pb2_grpc.GnetcliStub(channel)
170
+ _logger.debug("executing netconf cmd: %r", pbcmd)
171
+ try:
172
+ response = await grpc_call_wrapper(stub.ExecNetconf, pbcmd)
173
+ except Exception as e:
174
+ _logger.error("error hostname=%s cmd=%r error=%s", hostname, repr(pbcmd), e)
175
+ raise
176
+ return response
177
+
178
+ @asynccontextmanager
179
+ async def cmd_session(self, hostname: str) -> AsyncIterator["GnetcliSessionCmd"]:
180
+ sess = GnetcliSessionCmd(
181
+ hostname,
182
+ server=self._server,
183
+ channel=self._channel,
184
+ target_name_override=self._target_name_override,
185
+ user_agent=self._user_agent,
186
+ insecure_grpc=self._insecure_grpc,
187
+ )
188
+ await sess.connect()
189
+ try:
190
+ yield sess
191
+ finally:
192
+ await sess.close()
193
+
194
+ @asynccontextmanager
195
+ async def netconf_session(self, hostname: str) -> AsyncIterator["GnetcliSessionNetconf"]:
196
+ sess = GnetcliSessionNetconf(
197
+ hostname,
198
+ # self._token,
199
+ server=self._server,
200
+ target_name_override=self._target_name_override,
201
+ user_agent=self._user_agent,
202
+ insecure_grpc=self._insecure_grpc,
203
+ )
204
+ await sess.connect()
205
+ try:
206
+ yield sess
207
+ finally:
208
+ await sess.close()
209
+
210
+ async def set_host_params(self, hostname: str, params: HostParams) -> None:
211
+ pbcmd = server_pb2.HostParams(
212
+ host=hostname,
213
+ port=params.port,
214
+ credentials=params.credentials.make_pb(),
215
+ device=params.device)
216
+ _logger.debug("connect to %s", self._server)
217
+ async with self._grpc_channel_fn(self._server, options=self._options) as channel:
218
+ _logger.debug("set params for %s", hostname)
219
+ stub = server_pb2_grpc.GnetcliStub(channel)
220
+ await grpc_call_wrapper(stub.SetupHostParams, pbcmd)
221
+ return
222
+
223
+ async def upload(self, hostname: str, files: Dict[str, File]) -> None:
224
+ pbcmd = server_pb2.FileUploadRequest(host=hostname, files=make_files_request(files))
225
+ _logger.debug("connect to %s", self._server)
226
+ async with self._grpc_channel_fn(self._server, options=self._options) as channel:
227
+ _logger.debug("upload %s to %s", files.keys(), hostname)
228
+ stub = server_pb2_grpc.GnetcliStub(channel)
229
+ response: Message = await grpc_call_wrapper(stub.Upload, pbcmd)
230
+ _logger.debug("upload res %s", response)
231
+ return
232
+
233
+ async def download(self, hostname: str, paths: List[str]) -> Dict[str, File]:
234
+ pbcmd = server_pb2.FileDownloadRequest(host=hostname, paths=paths)
235
+ _logger.debug("connect to %s", self._server)
236
+ async with self._grpc_channel_fn(self._server, options=self._options) as channel:
237
+ _logger.debug("download %s from %s", paths, hostname)
238
+ stub = server_pb2_grpc.GnetcliStub(channel)
239
+ response: server_pb2.FilesResult = await grpc_call_wrapper(stub.Download, pbcmd)
240
+ res: Dict[str, File] = {}
241
+ for file in response.files:
242
+ res[file.path] = File(content=file.data, status=file.status)
243
+ return res
244
+
245
+
246
+ class GnetcliSession(ABC):
247
+ def __init__(
248
+ self,
249
+ hostname: str,
250
+ token: str,
251
+ server: str = DEFAULT_SERVER,
252
+ target_name_override: Optional[str] = None,
253
+ cert_file: Optional[str] = None,
254
+ user_agent: str = DEFAULT_USER_AGENT,
255
+ insecure_grpc: bool = False,
256
+ channel: Optional[grpc.aio.Channel] = None,
257
+ credentials: Optional[Credentials] = None,
258
+ ):
259
+ self._hostname = hostname
260
+ self._credentials = credentials
261
+ self._server = server
262
+ self._channel: Optional[grpc.aio.Channel] = channel
263
+ self._stub: Optional[server_pb2_grpc.GnetcliStub] = None
264
+ self._stream: Optional[grpc.aio.StreamStreamCall] = None
265
+ self._user_agent = user_agent
266
+
267
+ options: List[Tuple[str, Any]] = [
268
+ ("grpc.max_concurrent_streams", 900),
269
+ ("grpc.max_send_message_length", GRPC_MAX_MESSAGE_LENGTH),
270
+ ("grpc.max_receive_message_length", GRPC_MAX_MESSAGE_LENGTH),
271
+ ]
272
+ if target_name_override:
273
+ options.append(("grpc.ssl_target_name_override", target_name_override))
274
+ cert = get_cert(cert_file=cert_file)
275
+ channel_credentials = grpc.ssl_channel_credentials(root_certificates=cert)
276
+ authentication: ClientAuthentication
277
+ if token.startswith("OAuth"):
278
+ authentication = OAuthClientAuthentication(token.split(" ")[1])
279
+ elif token.startswith("Basic"):
280
+ authentication = BasicClientAuthentication(token.split(" ")[1])
281
+ else:
282
+ raise Exception("unknown token type")
283
+ interceptors = get_auth_client_interceptors(authentication)
284
+ grpc_channel_fn = partial(grpc.aio.secure_channel, credentials=channel_credentials, interceptors=interceptors)
285
+ if insecure_grpc:
286
+ grpc_channel_fn = partial(grpc.aio.insecure_channel, interceptors=interceptors)
287
+ self._grpc_channel_fn = grpc_channel_fn
288
+ self._options = options
289
+ self._req_id: Optional[Any] = None
290
+
291
+ def _get_metadata(self) -> List[Tuple[str, str]]:
292
+ req_id = make_req_id()
293
+ metadata = [
294
+ (HEADER_REQUEST_ID, req_id),
295
+ (HEADER_USER_AGENT, self._user_agent),
296
+ ]
297
+ return metadata
298
+
299
+ @abstractmethod
300
+ async def connect(self) -> None:
301
+ if self._channel is None:
302
+ _logger.debug("connect to %s self._channel=%s", self._server, self._channel)
303
+ self._channel = self._grpc_channel_fn(self._server, options=self._options)
304
+ self._stub = server_pb2_grpc.GnetcliStub(self._channel)
305
+ if self._stub is None:
306
+ raise Exception("empty stub")
307
+
308
+ async def _cmd(self, cmdpb: Any) -> Message:
309
+ # TODO: add connect retry on first cmd
310
+ if not self._stream:
311
+ raise Exception("empty self._stream")
312
+ try:
313
+ _logger.debug("cmd %r on %r", str(cmdpb).replace("\n", ""), self._stream)
314
+ await self._stream.write(cmdpb)
315
+ response: Message = await self._stream.read()
316
+ except grpc.aio.AioRpcError as e:
317
+ _logger.debug("caught exception %s %s", e, parse_grpc_error(e))
318
+ gn_exc, verbose = parse_grpc_error(e)
319
+ last_exc = gn_exc(
320
+ message=f"{e.__class__.__name__} {e.details()}",
321
+ imetadata=e.initial_metadata(), # type: ignore
322
+ verbose=verbose,
323
+ )
324
+ last_exc.__cause__ = e
325
+ raise last_exc from None
326
+ _logger.debug("response %s", format_long_msg(str(response), 100))
327
+ return response
328
+
329
+ async def close(self) -> None:
330
+ _logger.debug("close stream %s", self._stream)
331
+ if self._stream:
332
+ await self._stream.done_writing()
333
+ self._stream.done()
334
+ self._stream = None
335
+
336
+
337
+ class GnetcliSessionCmd(GnetcliSession):
338
+ async def cmd(
339
+ self,
340
+ cmd: str,
341
+ trace: bool = False,
342
+ qa: Optional[List[QA]] = None,
343
+ cmd_timeout: float = 0.0,
344
+ read_timeout: float = 0.0,
345
+ ) -> Message:
346
+ _logger.debug("session cmd %r", cmd)
347
+ pbcmd = make_cmd(
348
+ hostname=self._hostname,
349
+ cmd=cmd,
350
+ trace=trace,
351
+ qa=qa,
352
+ read_timeout=read_timeout,
353
+ cmd_timeout=cmd_timeout,
354
+ )
355
+ return await self._cmd(pbcmd)
356
+
357
+ async def connect(self) -> None:
358
+ await super(GnetcliSessionCmd, self).connect()
359
+ if self._stub:
360
+ self._stream = self._stub.ExecChat(metadata=self._get_metadata())
361
+ else:
362
+ raise Exception()
363
+
364
+
365
+ class GnetcliSessionNetconf(GnetcliSession):
366
+ async def cmd(self, cmd: str, trace: bool = False, json: bool = False) -> Message:
367
+ _logger.debug("netconf session cmd %r", cmd)
368
+ cmdpb = server_pb2.CMDNetconf(host=self._hostname, cmd=cmd, json=json)
369
+ return await self._cmd(cmdpb)
370
+
371
+ async def connect(self) -> None:
372
+ await super(GnetcliSessionNetconf, self).connect()
373
+ if self._stub:
374
+ self._stream = self._stub.ExecNetconfChat(metadata=self._get_metadata())
375
+ else:
376
+ raise Exception()
377
+
378
+
379
+ async def grpc_call_wrapper(stub: grpc.UnaryUnaryMultiCallable, request: Any) -> Message:
380
+ last_exc: Optional[Exception] = None
381
+ response: Optional[Message] = None
382
+ for i in range(5):
383
+ req_id = make_req_id()
384
+ metadata = [
385
+ (HEADER_REQUEST_ID, req_id),
386
+ ]
387
+ _logger.debug("executing %s: %r, req_id=%s", type(request), repr(request), req_id)
388
+ await asyncio.sleep(i * 2)
389
+ try:
390
+ response = await stub(request=request, metadata=metadata)
391
+ except grpc.aio.AioRpcError as e:
392
+ _logger.debug("caught exception %s req_id=%s %s", e, req_id, parse_grpc_error(e))
393
+ gn_exc, verbose = parse_grpc_error(e)
394
+ last_exc = gn_exc(
395
+ message=f"{e.__class__.__name__} {e.details()}",
396
+ imetadata=e.initial_metadata(), # type: ignore
397
+ request_id=req_id,
398
+ verbose=verbose,
399
+ )
400
+ last_exc.__cause__ = e
401
+ raise last_exc from None
402
+ else:
403
+ last_exc = None
404
+ break
405
+
406
+ if last_exc is not None:
407
+ raise last_exc
408
+ if response is None:
409
+ raise Exception()
410
+ else:
411
+ return response
412
+
413
+
414
+ def make_req_id() -> str:
415
+ return str(uuid.uuid4())
416
+
417
+
418
+ def get_cert(cert_file: Optional[str]) -> Optional[bytes]:
419
+ cert: Optional[bytes] = None
420
+ if cert_file:
421
+ _logger.debug("open cert_file %s", cert_file)
422
+ with open(cert_file, "rb") as f:
423
+ cert = f.read()
424
+ return cert
425
+
426
+
427
+ def format_long_msg(msg: str, max_len: int) -> str:
428
+ if len(msg) <= max_len:
429
+ return msg
430
+ return "%s... and %s more" % (msg[:max_len], len(msg) - max_len)
431
+
432
+
433
+ def make_cmd(
434
+ hostname: str,
435
+ cmd: str,
436
+ trace: bool = False,
437
+ qa: Optional[List[QA]] = None,
438
+ read_timeout: float = 0.0,
439
+ cmd_timeout: float = 0.0,
440
+ ) -> Message:
441
+ qa_cmd: List[Message] = []
442
+ if qa:
443
+ for item in qa:
444
+ qaitem = server_pb2.QA()
445
+ qaitem.question = item.question
446
+ qaitem.answer = item.answer
447
+ qa_cmd.append(qaitem)
448
+ res = server_pb2.CMD(
449
+ host=hostname,
450
+ cmd=cmd,
451
+ trace=trace,
452
+ qa=qa_cmd,
453
+ read_timeout=read_timeout,
454
+ cmd_timeout=cmd_timeout,
455
+ )
456
+ return res # type: ignore
457
+
458
+
459
+ def make_files_request(files: Dict[str, File]) -> List[server_pb2.FileData]:
460
+ res: List[server_pb2.FileData] = []
461
+ for path, file in files.items():
462
+ res.append(server_pb2.FileData(path=path, data=file.content))
463
+ return res
@@ -0,0 +1,116 @@
1
+ from typing import Optional, Sequence, Tuple, Type, Union
2
+
3
+ import grpc.aio
4
+
5
+ MetadataType = Sequence[Tuple[str, Union[str, bytes]]]
6
+
7
+
8
+ def extract_metadata(m: MetadataType) -> dict:
9
+ # calling get from metadataType throws KeyError
10
+ metadata = {}
11
+ for k, v in m:
12
+ metadata[k] = v
13
+ return metadata
14
+
15
+
16
+ class GnetcliException(Exception):
17
+ def __init__(
18
+ self,
19
+ message: str = "",
20
+ imetadata: Optional[MetadataType] = None,
21
+ request_id: Optional[str] = None,
22
+ verbose: Optional[str] = "",
23
+ ):
24
+ self.message = message
25
+ if imetadata:
26
+ rs = extract_metadata(imetadata).get("real-server")
27
+ if rs:
28
+ self.message = f"{self.message} RS:{rs}"
29
+ if request_id:
30
+ self.message = f"{self.message} req_id:{request_id}"
31
+ if verbose:
32
+ self.message = f"{self.message} verbose:{verbose}"
33
+ super().__init__(self.message)
34
+
35
+
36
+ class DeviceConnectError(GnetcliException):
37
+ """
38
+ Problem with connection to a device.
39
+ """
40
+
41
+ pass
42
+
43
+
44
+ class UnknownDevice(GnetcliException):
45
+ """
46
+ Host is not found in inventory
47
+ """
48
+
49
+ pass
50
+
51
+
52
+ class DeviceAuthError(DeviceConnectError):
53
+ """
54
+ Unable to authenticate on a device.
55
+ """
56
+
57
+ pass
58
+
59
+
60
+ class ExecError(GnetcliException):
61
+ """
62
+ Error happened during execution.
63
+ """
64
+
65
+ pass
66
+
67
+
68
+ class NotReady(GnetcliException):
69
+ """
70
+ Server is not ready.
71
+ """
72
+
73
+ pass
74
+
75
+
76
+ class Unauthenticated(GnetcliException):
77
+ """
78
+ Unable to authenticate on Gnetcli server.
79
+ """
80
+
81
+ pass
82
+
83
+
84
+ class PermissionDenied(GnetcliException):
85
+ """
86
+ Permission denied.
87
+ """
88
+
89
+ pass
90
+
91
+
92
+ def parse_grpc_error(grpc_error: grpc.aio.AioRpcError) -> Tuple[Type[GnetcliException], str]:
93
+ code = grpc_error.code()
94
+ detail = ""
95
+ if grpc_error.details():
96
+ detail = grpc_error.details() # type: ignore
97
+ if code == grpc.StatusCode.UNAVAILABLE and detail == "not ready":
98
+ return NotReady, ""
99
+ if code == grpc.StatusCode.UNAUTHENTICATED:
100
+ return Unauthenticated, detail
101
+ if code == grpc.StatusCode.PERMISSION_DENIED:
102
+ return PermissionDenied, detail
103
+ if code == grpc.StatusCode.OUT_OF_RANGE:
104
+ return UnknownDevice, detail
105
+ if code == grpc.StatusCode.INTERNAL:
106
+ if detail == "auth_device_error":
107
+ verbose = ""
108
+ return DeviceAuthError, verbose
109
+ if detail in {"connection_error", "busy_error"}:
110
+ verbose = ""
111
+ return DeviceConnectError, verbose
112
+ elif detail in {"exec_error", "generic_error"}:
113
+ verbose = ""
114
+ return ExecError, verbose
115
+
116
+ return GnetcliException, ""