pyinfra 3.4__py2.py3-none-any.whl → 3.5__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pyinfra/api/arguments.py CHANGED
@@ -72,6 +72,11 @@ class ConnectorArguments(TypedDict, total=False):
72
72
  _get_pty: bool
73
73
  _stdin: Union[str, Iterable[str]]
74
74
 
75
+ # Retry arguments
76
+ _retries: int
77
+ _retry_delay: Union[int, float]
78
+ _retry_until: Optional[Callable[[dict], bool]]
79
+
75
80
 
76
81
  def generate_env(config: "Config", value: dict) -> dict:
77
82
  env = config.ENV.copy()
@@ -232,11 +237,28 @@ def all_global_arguments() -> List[tuple[str, Type]]:
232
237
  return list(get_type_hints(AllArguments).items())
233
238
 
234
239
 
240
+ # Create a dictionary for retry arguments
241
+ retry_argument_meta: dict[str, ArgumentMeta] = {
242
+ "_retries": ArgumentMeta(
243
+ "Number of times to retry failed operations.",
244
+ default=lambda config: config.RETRY,
245
+ ),
246
+ "_retry_delay": ArgumentMeta(
247
+ "Delay in seconds between retry attempts.",
248
+ default=lambda config: config.RETRY_DELAY,
249
+ ),
250
+ "_retry_until": ArgumentMeta(
251
+ "Callable taking output data that returns True to continue retrying.",
252
+ default=lambda config: None,
253
+ ),
254
+ }
255
+
235
256
  all_argument_meta: dict[str, ArgumentMeta] = {
236
257
  **auth_argument_meta,
237
258
  **shell_argument_meta,
238
259
  **meta_argument_meta,
239
260
  **execution_argument_meta,
261
+ **retry_argument_meta, # Add retry arguments
240
262
  }
241
263
 
242
264
  EXECUTION_KWARG_KEYS = list(ExecutionArguments.__annotations__.keys())
@@ -286,6 +308,45 @@ __argument_docs__ = {
286
308
  ),
287
309
  "Operation meta & callbacks": (meta_argument_meta, "", ""),
288
310
  "Execution strategy": (execution_argument_meta, "", ""),
311
+ "Retry behavior": (
312
+ retry_argument_meta,
313
+ """
314
+ Retry arguments allow you to automatically retry operations that fail. You can specify
315
+ how many times to retry, the delay between retries, and optionally a condition
316
+ function to determine when to stop retrying.
317
+ """,
318
+ """
319
+ .. code:: python
320
+
321
+ # Retry a command up to 3 times with the default 5 second delay
322
+ server.shell(
323
+ name="Run flaky command with retries",
324
+ commands=["flaky_command"],
325
+ _retries=3,
326
+ )
327
+ # Retry with a custom delay
328
+ server.shell(
329
+ name="Run flaky command with custom delay",
330
+ commands=["flaky_command"],
331
+ _retries=2,
332
+ _retry_delay=10, # 10 second delay between retries
333
+ )
334
+ # Retry with a custom condition
335
+ def retry_on_specific_error(output_data):
336
+ # Retry if stderr contains "temporary failure"
337
+ for line in output_data["stderr_lines"]:
338
+ if "temporary failure" in line.lower():
339
+ return True
340
+ return False
341
+
342
+ server.shell(
343
+ name="Run command with conditional retry",
344
+ commands=["flaky_command"],
345
+ _retries=5,
346
+ _retry_until=retry_on_specific_error,
347
+ )
348
+ """,
349
+ ),
289
350
  }
290
351
 
291
352
 
@@ -305,7 +366,8 @@ def pop_global_arguments(
305
366
 
306
367
  config = state.config
307
368
  if ctx_config.isset():
308
- config = config
369
+ config = ctx_config.get()
370
+ assert config is not None
309
371
 
310
372
  cdkwargs = host.current_deploy_kwargs
311
373
  meta_kwargs: dict[str, Any] = cdkwargs or {} # type: ignore[assignment]
pyinfra/api/config.py CHANGED
@@ -53,6 +53,12 @@ class ConfigDefaults:
53
53
  IGNORE_ERRORS: bool = False
54
54
  # Shell to use to execute commands
55
55
  SHELL: str = "sh"
56
+ # Whether to display full diffs for files
57
+ DIFF: bool = False
58
+ # Number of times to retry failed operations
59
+ RETRY: int = 0
60
+ # Delay in seconds between retry attempts
61
+ RETRY_DELAY: int = 5
56
62
 
57
63
 
58
64
  config_defaults = {key: value for key, value in ConfigDefaults.__dict__.items() if key.isupper()}
pyinfra/api/connect.py CHANGED
@@ -46,5 +46,22 @@ def connect_all(state: "State"):
46
46
 
47
47
 
48
48
  def disconnect_all(state: "State"):
49
- for host in state.activated_hosts: # only hosts we connected to please!
50
- host.disconnect() # normally a noop
49
+ """
50
+ Disconnect from all of the configured servers in parallel. Reads/writes state.inventory.
51
+
52
+ Args:
53
+ state (``pyinfra.api.State`` obj): the state containing an inventory to connect to
54
+ """
55
+ greenlet_to_host = {
56
+ state.pool.spawn(host.disconnect): host
57
+ for host in state.activated_hosts # only hosts we connected to please!
58
+ }
59
+
60
+ with progress_spinner(greenlet_to_host.values()) as progress:
61
+ for greenlet in gevent.iwait(greenlet_to_host.keys()):
62
+ host = greenlet_to_host[greenlet]
63
+ progress(host)
64
+
65
+ for greenlet, host in greenlet_to_host.items():
66
+ # Raise any unexpected exception
67
+ greenlet.get()
pyinfra/api/operation.py CHANGED
@@ -47,6 +47,9 @@ class OperationMeta:
47
47
  _commands: Optional[list[Any]] = None
48
48
  _maybe_is_change: Optional[bool] = None
49
49
  _success: Optional[bool] = None
50
+ _retry_attempts: int = 0
51
+ _max_retries: int = 0
52
+ _retry_succeeded: Optional[bool] = None
50
53
 
51
54
  def __init__(self, hash, is_change: Optional[bool]):
52
55
  self._hash = hash
@@ -59,9 +62,17 @@ class OperationMeta:
59
62
  """
60
63
 
61
64
  if self._commands is not None:
65
+ retry_info = ""
66
+ if self._retry_attempts > 0:
67
+ retry_result = "succeeded" if self._retry_succeeded else "failed"
68
+ retry_info = (
69
+ f", retries={self._retry_attempts}/{self._max_retries} ({retry_result})"
70
+ )
71
+
62
72
  return (
63
73
  "OperationMeta(executed=True, "
64
- f"success={self.did_succeed()}, hash={self._hash}, commands={len(self._commands)})"
74
+ f"success={self.did_succeed()}, hash={self._hash}, "
75
+ f"commands={len(self._commands)}{retry_info})"
65
76
  )
66
77
  return (
67
78
  "OperationMeta(executed=False, "
@@ -74,12 +85,20 @@ class OperationMeta:
74
85
  success: bool,
75
86
  commands: list[Any],
76
87
  combined_output: "CommandOutput",
88
+ retry_attempts: int = 0,
89
+ max_retries: int = 0,
77
90
  ) -> None:
78
91
  if self.is_complete():
79
92
  raise RuntimeError("Cannot complete an already complete operation")
80
93
  self._success = success
81
94
  self._commands = commands
82
95
  self._combined_output = combined_output
96
+ self._retry_attempts = retry_attempts
97
+ self._max_retries = max_retries
98
+
99
+ # Determine if operation succeeded after retries
100
+ if retry_attempts > 0:
101
+ self._retry_succeeded = success
83
102
 
84
103
  def is_complete(self) -> bool:
85
104
  return self._success is not None
@@ -150,6 +169,40 @@ class OperationMeta:
150
169
  def stderr(self) -> str:
151
170
  return "\n".join(self.stderr_lines)
152
171
 
172
+ @property
173
+ def retry_attempts(self) -> int:
174
+ return self._retry_attempts
175
+
176
+ @property
177
+ def max_retries(self) -> int:
178
+ return self._max_retries
179
+
180
+ @property
181
+ def was_retried(self) -> bool:
182
+ """
183
+ Returns whether this operation was retried at least once.
184
+ """
185
+ return self._retry_attempts > 0
186
+
187
+ @property
188
+ def retry_succeeded(self) -> Optional[bool]:
189
+ """
190
+ Returns whether this operation succeeded after retries.
191
+ Returns None if the operation was not retried.
192
+ """
193
+ return self._retry_succeeded
194
+
195
+ def get_retry_info(self) -> dict[str, Any]:
196
+ """
197
+ Returns a dictionary with all retry-related information.
198
+ """
199
+ return {
200
+ "retry_attempts": self._retry_attempts,
201
+ "max_retries": self._max_retries,
202
+ "was_retried": self.was_retried,
203
+ "retry_succeeded": self._retry_succeeded,
204
+ }
205
+
153
206
 
154
207
  def add_op(state: State, op_func, *args, **kwargs):
155
208
  """
pyinfra/api/operations.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import time
3
4
  import traceback
4
5
  from itertools import product
5
6
  from socket import error as socket_error, timeout as timeout_error
@@ -66,6 +67,11 @@ def _run_host_op(state: "State", host: "Host", op_hash: str) -> Optional[bool]:
66
67
  continue_on_error = global_arguments["_continue_on_error"]
67
68
  timeout = global_arguments.get("_timeout", 0)
68
69
 
70
+ # Extract retry arguments
71
+ retries = global_arguments.get("_retries", 0)
72
+ retry_delay = global_arguments.get("_retry_delay", 5)
73
+ retry_until = global_arguments.get("_retry_until", None)
74
+
69
75
  executor_kwarg_keys = CONNECTOR_ARGUMENT_KEYS
70
76
  # See: https://github.com/python/mypy/issues/10371
71
77
  base_connector_arguments: ConnectorArguments = cast(
@@ -73,67 +79,114 @@ def _run_host_op(state: "State", host: "Host", op_hash: str) -> Optional[bool]:
73
79
  {key: global_arguments[key] for key in executor_kwarg_keys if key in global_arguments}, # type: ignore[literal-required] # noqa
74
80
  )
75
81
 
82
+ retry_attempt = 0
76
83
  did_error = False
77
84
  executed_commands = 0
78
- commands = []
85
+ commands: list[PyinfraCommand] = []
79
86
  all_output_lines: list[OutputLine] = []
80
87
 
81
- for command in op_data.command_generator():
82
- commands.append(command)
83
-
84
- status = False
85
-
86
- connector_arguments = base_connector_arguments.copy()
87
- connector_arguments.update(command.connector_arguments)
88
-
89
- if not isinstance(command, PyinfraCommand):
90
- raise TypeError("{0} is an invalid pyinfra command!".format(command))
91
-
92
- if isinstance(command, FunctionCommand):
93
- try:
94
- status = command.execute(state, host, connector_arguments)
95
- except Exception as e:
96
- # Custom functions could do anything, so expect anything!
97
- logger.warning(traceback.format_exc())
98
- host.log_styled(
99
- f"Unexpected error in Python callback: {format_exception(e)}",
100
- fg="red",
101
- log_func=logger.warning,
102
- )
103
-
104
- elif isinstance(command, StringCommand):
105
- output_lines = CommandOutput([])
106
- try:
107
- status, output_lines = command.execute(
108
- state,
109
- host,
110
- connector_arguments,
111
- )
112
- except (timeout_error, socket_error, SSHException) as e:
113
- log_host_command_error(host, e, timeout=timeout)
114
- all_output_lines.extend(output_lines)
115
- # If we failed and have not already printed the stderr, print it
116
- if status is False and not state.print_output:
117
- print_host_combined_output(host, output_lines)
118
-
119
- else:
120
- try:
121
- status = command.execute(state, host, connector_arguments)
122
- except (timeout_error, socket_error, SSHException, IOError) as e:
123
- log_host_command_error(host, e, timeout=timeout)
124
-
125
- # Break the loop to trigger a failure
126
- if status is False:
127
- did_error = True
128
- if continue_on_error is True:
129
- continue
130
- break
88
+ # Retry loop
89
+ while retry_attempt <= retries:
90
+ did_error = False
91
+ executed_commands = 0
92
+ commands = []
93
+ all_output_lines = []
94
+
95
+ for command in op_data.command_generator():
96
+ commands.append(command)
97
+ status = False
98
+ connector_arguments = base_connector_arguments.copy()
99
+ connector_arguments.update(command.connector_arguments)
100
+
101
+ if not isinstance(command, PyinfraCommand):
102
+ raise TypeError("{0} is an invalid pyinfra command!".format(command))
103
+
104
+ if isinstance(command, FunctionCommand):
105
+ try:
106
+ status = command.execute(state, host, connector_arguments)
107
+ except Exception as e:
108
+ # Custom functions could do anything, so expect anything!
109
+ logger.warning(traceback.format_exc())
110
+ host.log_styled(
111
+ f"Unexpected error in Python callback: {format_exception(e)}",
112
+ fg="red",
113
+ log_func=logger.warning,
114
+ )
115
+
116
+ elif isinstance(command, StringCommand):
117
+ output_lines = CommandOutput([])
118
+ try:
119
+ status, output_lines = command.execute(
120
+ state,
121
+ host,
122
+ connector_arguments,
123
+ )
124
+ except (timeout_error, socket_error, SSHException) as e:
125
+ log_host_command_error(host, e, timeout=timeout)
126
+ all_output_lines.extend(output_lines)
127
+ # If we failed and have not already printed the stderr, print it
128
+ if status is False and not state.print_output:
129
+ print_host_combined_output(host, output_lines)
130
+
131
+ else:
132
+ try:
133
+ status = command.execute(state, host, connector_arguments)
134
+ except (timeout_error, socket_error, SSHException, IOError) as e:
135
+ log_host_command_error(host, e, timeout=timeout)
136
+
137
+ # Break the loop to trigger a failure
138
+ if status is False:
139
+ did_error = True
140
+ if continue_on_error is True:
141
+ continue
142
+ break
143
+
144
+ executed_commands += 1
145
+
146
+ # Check if we should retry
147
+ should_retry = False
148
+ if retry_attempt < retries:
149
+ # Retry on error
150
+ if did_error:
151
+ should_retry = True
152
+ # Retry on condition if no error
153
+ elif retry_until and not did_error:
154
+ try:
155
+ output_data = {
156
+ "stdout_lines": [
157
+ line.line for line in all_output_lines if line.buffer_name == "stdout"
158
+ ],
159
+ "stderr_lines": [
160
+ line.line for line in all_output_lines if line.buffer_name == "stderr"
161
+ ],
162
+ "commands": [str(command) for command in commands],
163
+ "executed_commands": executed_commands,
164
+ "host": host.name,
165
+ "operation": ", ".join(state.get_op_meta(op_hash).names) or "Operation",
166
+ }
167
+ should_retry = retry_until(output_data)
168
+ except Exception as e:
169
+ host.log_styled(
170
+ f"Error in retry_until function: {format_exception(e)}",
171
+ fg="red",
172
+ log_func=logger.warning,
173
+ )
174
+
175
+ if should_retry:
176
+ retry_attempt += 1
177
+ state.trigger_callbacks("operation_host_retry", host, op_hash, retry_attempt, retries)
178
+ op_name = ", ".join(state.get_op_meta(op_hash).names) or "Operation"
179
+ host.log_styled(
180
+ f"Retrying {op_name} (attempt {retry_attempt}/{retries}) after {retry_delay}s...",
181
+ fg="yellow",
182
+ log_func=logger.info,
183
+ )
184
+ time.sleep(retry_delay)
185
+ continue
131
186
 
132
- executed_commands += 1
187
+ break
133
188
 
134
189
  # Handle results
135
- #
136
-
137
190
  op_success = return_status = not did_error
138
191
  host_results = state.get_results_for_host(host)
139
192
 
@@ -142,10 +195,13 @@ def _run_host_op(state: "State", host: "Host", op_hash: str) -> Optional[bool]:
142
195
  host_results.success_ops += 1
143
196
 
144
197
  _status_log = "Success" if executed_commands > 0 else "No changes"
198
+ if retry_attempt > 0:
199
+ _status_log = f"{_status_log} on retry {retry_attempt}"
200
+
145
201
  _click_log_status = click.style(_status_log, "green")
146
202
  logger.info("{0}{1}".format(host.print_prefix, _click_log_status))
147
203
 
148
- state.trigger_callbacks("operation_host_success", host, op_hash)
204
+ state.trigger_callbacks("operation_host_success", host, op_hash, retry_attempt)
149
205
  else:
150
206
  if ignore_errors:
151
207
  host_results.ignored_error_ops += 1
@@ -156,6 +212,11 @@ def _run_host_op(state: "State", host: "Host", op_hash: str) -> Optional[bool]:
156
212
  host_results.partial_ops += 1
157
213
 
158
214
  _command_description = f"executed {executed_commands} commands"
215
+ if retry_attempt > 0:
216
+ _command_description = (
217
+ f"{_command_description} (failed after {retry_attempt}/{retries} retries)"
218
+ )
219
+
159
220
  log_error_or_warning(host, ignore_errors, _command_description, continue_on_error)
160
221
 
161
222
  # Ignored, op "completes" w/ ignored error
@@ -164,12 +225,14 @@ def _run_host_op(state: "State", host: "Host", op_hash: str) -> Optional[bool]:
164
225
  return_status = True
165
226
 
166
227
  # Unignored error -> False
167
- state.trigger_callbacks("operation_host_error", host, op_hash)
228
+ state.trigger_callbacks("operation_host_error", host, op_hash, retry_attempt, retries)
168
229
 
169
230
  op_data.operation_meta.set_complete(
170
231
  op_success,
171
232
  commands,
172
233
  CommandOutput(all_output_lines),
234
+ retry_attempts=retry_attempt,
235
+ max_retries=retries,
173
236
  )
174
237
 
175
238
  return return_status
pyinfra/api/state.py CHANGED
@@ -70,11 +70,19 @@ class BaseStateCallback:
70
70
  pass
71
71
 
72
72
  @staticmethod
73
- def operation_host_success(state: "State", host: "Host", op_hash):
73
+ def operation_host_success(state: "State", host: "Host", op_hash, retry_count: int = 0):
74
74
  pass
75
75
 
76
76
  @staticmethod
77
- def operation_host_error(state: "State", host: "Host", op_hash):
77
+ def operation_host_error(
78
+ state: "State", host: "Host", op_hash, retry_count: int = 0, max_retries: int = 0
79
+ ):
80
+ pass
81
+
82
+ @staticmethod
83
+ def operation_host_retry(
84
+ state: "State", host: "Host", op_hash, retry_num: int, max_retries: int
85
+ ):
78
86
  pass
79
87
 
80
88
  @staticmethod
@@ -0,0 +1 @@
1
+ from .client import SCPClient # noqa: F401
@@ -0,0 +1,204 @@
1
+ from __future__ import annotations
2
+
3
+ import ntpath
4
+ import os
5
+ from pathlib import PurePath
6
+ from shlex import quote
7
+ from socket import timeout as SocketTimeoutError
8
+ from typing import IO, AnyStr
9
+
10
+ from paramiko import Channel
11
+ from paramiko.transport import Transport
12
+
13
+ SCP_COMMAND = b"scp"
14
+
15
+
16
+ # Unicode conversion functions; assume UTF-8
17
+ def asbytes(s: bytes | str | PurePath) -> bytes:
18
+ """Turns unicode into bytes, if needed.
19
+
20
+ Assumes UTF-8.
21
+ """
22
+ if isinstance(s, bytes):
23
+ return s
24
+ elif isinstance(s, PurePath):
25
+ return bytes(s)
26
+ else:
27
+ return s.encode("utf-8")
28
+
29
+
30
+ def asunicode(s: bytes | str) -> str:
31
+ """Turns bytes into unicode, if needed.
32
+
33
+ Uses UTF-8.
34
+ """
35
+ if isinstance(s, bytes):
36
+ return s.decode("utf-8", "replace")
37
+ else:
38
+ return s
39
+
40
+
41
+ class SCPClient:
42
+ """
43
+ An scp1 implementation, compatible with openssh scp.
44
+ Raises SCPException for all transport related errors. Local filesystem
45
+ and OS errors pass through.
46
+
47
+ Main public methods are .putfo and .getfo
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ transport: Transport,
53
+ buff_size: int = 16384,
54
+ socket_timeout: float = 10.0,
55
+ ):
56
+ self.transport = transport
57
+ self.buff_size = buff_size
58
+ self.socket_timeout = socket_timeout
59
+ self._channel: Channel | None = None
60
+ self.scp_command = SCP_COMMAND
61
+
62
+ @property
63
+ def channel(self) -> Channel:
64
+ """Return an open Channel, (re)opening if needed."""
65
+ if self._channel is None or self._channel.closed:
66
+ self._channel = self.transport.open_session()
67
+ return self._channel
68
+
69
+ def __enter__(self):
70
+ _ = self.channel # triggers opening if not already open
71
+ return self
72
+
73
+ def __exit__(self, type, value, traceback):
74
+ self.close()
75
+
76
+ def putfo(
77
+ self,
78
+ fl: IO[AnyStr],
79
+ remote_path: str | bytes,
80
+ mode: str | bytes = "0644",
81
+ size: int | None = None,
82
+ ) -> None:
83
+ if size is None:
84
+ pos = fl.tell()
85
+ fl.seek(0, os.SEEK_END) # Seek to end
86
+ size = fl.tell() - pos
87
+ fl.seek(pos, os.SEEK_SET) # Seek back
88
+
89
+ self.channel.settimeout(self.socket_timeout)
90
+ self.channel.exec_command(
91
+ self.scp_command + b" -t " + asbytes(quote(asunicode(remote_path)))
92
+ )
93
+ self._recv_confirm()
94
+ self._send_file(fl, remote_path, mode, size=size)
95
+ self.close()
96
+
97
+ def getfo(self, remote_path: str, fl: IO):
98
+ remote_path_sanitized = quote(remote_path)
99
+ if os.name == "nt":
100
+ remote_file_name = ntpath.basename(remote_path_sanitized)
101
+ else:
102
+ remote_file_name = os.path.basename(remote_path_sanitized)
103
+ self.channel.settimeout(self.socket_timeout)
104
+ self.channel.exec_command(self.scp_command + b" -f " + asbytes(remote_path_sanitized))
105
+ self._recv_all(fl, remote_file_name)
106
+ self.close()
107
+ return fl
108
+
109
+ def close(self):
110
+ """close scp channel"""
111
+ if self._channel is not None:
112
+ self._channel.close()
113
+ self._channel = None
114
+
115
+ def _send_file(self, fl, name, mode, size):
116
+ basename = asbytes(os.path.basename(name))
117
+ # The protocol can't handle \n in the filename.
118
+ # Quote them as the control sequence \^J for now,
119
+ # which is how openssh handles it.
120
+ self.channel.sendall(
121
+ ("C%s %d " % (mode, size)).encode("ascii") + basename.replace(b"\n", b"\\^J") + b"\n"
122
+ )
123
+ self._recv_confirm()
124
+ file_pos = 0
125
+ buff_size = self.buff_size
126
+ chan = self.channel
127
+ while file_pos < size:
128
+ chan.sendall(fl.read(buff_size))
129
+ file_pos = fl.tell()
130
+ chan.sendall(b"\x00")
131
+ self._recv_confirm()
132
+
133
+ def _recv_confirm(self):
134
+ # read scp response
135
+ msg = b""
136
+ try:
137
+ msg = self.channel.recv(512)
138
+ except SocketTimeoutError:
139
+ raise SCPException("Timeout waiting for scp response")
140
+ # slice off the first byte, so this compare will work in py2 and py3
141
+ if msg and msg[0:1] == b"\x00":
142
+ return
143
+ elif msg and msg[0:1] == b"\x01":
144
+ raise SCPException(asunicode(msg[1:]))
145
+ elif self.channel.recv_stderr_ready():
146
+ msg = self.channel.recv_stderr(512)
147
+ raise SCPException(asunicode(msg))
148
+ elif not msg:
149
+ raise SCPException("No response from server")
150
+ else:
151
+ raise SCPException("Invalid response from server", msg)
152
+
153
+ def _recv_all(self, fh: IO, remote_file_name: str) -> None:
154
+ # loop over scp commands, and receive as necessary
155
+ commands = (b"C",)
156
+ while not self.channel.closed:
157
+ # wait for command as long as we're open
158
+ self.channel.sendall(b"\x00")
159
+ msg = self.channel.recv(1024)
160
+ if not msg: # chan closed while receiving
161
+ break
162
+ assert msg[-1:] == b"\n"
163
+ msg = msg[:-1]
164
+ code = msg[0:1]
165
+ if code not in commands:
166
+ raise SCPException(asunicode(msg[1:]))
167
+ self._recv_file(msg[1:], fh, remote_file_name)
168
+
169
+ def _recv_file(self, cmd: bytes, fh: IO, remote_file_name: str) -> None:
170
+ chan = self.channel
171
+ parts = cmd.strip().split(b" ", 2)
172
+
173
+ try:
174
+ size = int(parts[1])
175
+ except (ValueError, IndexError):
176
+ chan.send(b"\x01")
177
+ chan.close()
178
+ raise SCPException("Bad file format")
179
+
180
+ buff_size = self.buff_size
181
+ pos = 0
182
+ chan.send(b"\x00")
183
+ try:
184
+ while pos < size:
185
+ # we have to make sure we don't read the final byte
186
+ if size - pos <= buff_size:
187
+ buff_size = size - pos
188
+ data = chan.recv(buff_size)
189
+ if not data:
190
+ raise SCPException("Underlying channel was closed")
191
+ fh.write(data)
192
+ pos = fh.tell()
193
+ msg = chan.recv(512)
194
+ if msg and msg[0:1] != b"\x00":
195
+ raise SCPException(asunicode(msg[1:]))
196
+ except SocketTimeoutError:
197
+ chan.close()
198
+ raise SCPException("Error receiving, socket.timeout")
199
+
200
+
201
+ class SCPException(Exception):
202
+ """SCP exception class"""
203
+
204
+ pass