PyAlgoEngine 0.7.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- PyAlgoEngine-0.7.4.dist-info/LICENSE +21 -0
- PyAlgoEngine-0.7.4.dist-info/METADATA +27 -0
- PyAlgoEngine-0.7.4.dist-info/RECORD +43 -0
- PyAlgoEngine-0.7.4.dist-info/WHEEL +5 -0
- PyAlgoEngine-0.7.4.dist-info/top_level.txt +1 -0
- algo_engine/__init__.py +41 -0
- algo_engine/apps/__init__.py +17 -0
- algo_engine/apps/backtest/__init__.py +20 -0
- algo_engine/apps/backtest/doc_server.py +331 -0
- algo_engine/apps/backtest/tester.py +254 -0
- algo_engine/apps/backtest/web_app.py +127 -0
- algo_engine/apps/bokeh_server.py +205 -0
- algo_engine/apps/demo/__init__.py +0 -0
- algo_engine/apps/demo/test.py +39 -0
- algo_engine/backtest/__init__.py +19 -0
- algo_engine/backtest/__main__.py +51 -0
- algo_engine/backtest/metrics.py +179 -0
- algo_engine/backtest/replay.py +261 -0
- algo_engine/backtest/sim_match.py +295 -0
- algo_engine/base/__init__.py +40 -0
- algo_engine/base/console_utils.py +1070 -0
- algo_engine/base/finance_decimal.py +258 -0
- algo_engine/base/market_buffer.py +571 -0
- algo_engine/base/market_utils.py +3092 -0
- algo_engine/base/market_utils_nt.py +188 -0
- algo_engine/base/market_utils_posix.py +3004 -0
- algo_engine/base/technical_analysis.py +406 -0
- algo_engine/base/telemetrics.py +78 -0
- algo_engine/base/trade_utils.py +709 -0
- algo_engine/engine/__init__.py +28 -0
- algo_engine/engine/algo_engine.py +901 -0
- algo_engine/engine/event_engine.py +53 -0
- algo_engine/engine/market_engine.py +370 -0
- algo_engine/engine/trade_engine.py +2037 -0
- algo_engine/monitor/__init__.py +15 -0
- algo_engine/monitor/advanced_data_interface.py +239 -0
- algo_engine/profile/__init__.py +121 -0
- algo_engine/profile/cn.py +175 -0
- algo_engine/strategy/__init__.py +44 -0
- algo_engine/strategy/strategy_engine.py +440 -0
- algo_engine/utils/__init__.py +3 -0
- algo_engine/utils/commit_regularizer.py +49 -0
- algo_engine/utils/data_utils.py +251 -0
|
@@ -0,0 +1,1070 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import hashlib
|
|
3
|
+
import io
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
import pathlib
|
|
7
|
+
import re
|
|
8
|
+
import shutil
|
|
9
|
+
import subprocess
|
|
10
|
+
import sys
|
|
11
|
+
import threading
|
|
12
|
+
import time
|
|
13
|
+
import uuid
|
|
14
|
+
from enum import Enum
|
|
15
|
+
from typing import Iterable, Sized
|
|
16
|
+
|
|
17
|
+
from . import LOGGER
|
|
18
|
+
|
|
19
|
+
LOGGER = LOGGER.getChild('Console')
|
|
20
|
+
__all__ = ['Progress', 'GetInput', 'GetArgs', 'count_ordinal', 'TerminalStyle', 'InteractiveShell', 'ShellTransfer']
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# noinspection SpellCheckingInspection
|
|
24
|
+
class TerminalStyle(Enum):
|
|
25
|
+
CEND = '\33[0m'
|
|
26
|
+
CBOLD = '\33[1m'
|
|
27
|
+
CITALIC = '\33[3m'
|
|
28
|
+
CURL = '\33[4m'
|
|
29
|
+
CBLINK = '\33[5m'
|
|
30
|
+
CBLINK2 = '\33[6m'
|
|
31
|
+
CSELECTED = '\33[7m'
|
|
32
|
+
|
|
33
|
+
CBLACK = '\33[30m'
|
|
34
|
+
CRED = '\33[31m'
|
|
35
|
+
CGREEN = '\33[32m'
|
|
36
|
+
CYELLOW = '\33[33m'
|
|
37
|
+
CBLUE = '\33[34m'
|
|
38
|
+
CVIOLET = '\33[35m'
|
|
39
|
+
CBEIGE = '\33[36m'
|
|
40
|
+
CWHITE = '\33[37m'
|
|
41
|
+
|
|
42
|
+
CBLACKBG = '\33[40m'
|
|
43
|
+
CREDBG = '\33[41m'
|
|
44
|
+
CGREENBG = '\33[42m'
|
|
45
|
+
CYELLOWBG = '\33[43m'
|
|
46
|
+
CBLUEBG = '\33[44m'
|
|
47
|
+
CVIOLETBG = '\33[45m'
|
|
48
|
+
CBEIGEBG = '\33[46m'
|
|
49
|
+
CWHITEBG = '\33[47m'
|
|
50
|
+
|
|
51
|
+
CGREY = '\33[90m'
|
|
52
|
+
CRED2 = '\33[91m'
|
|
53
|
+
CGREEN2 = '\33[92m'
|
|
54
|
+
CYELLOW2 = '\33[93m'
|
|
55
|
+
CBLUE2 = '\33[94m'
|
|
56
|
+
CVIOLET2 = '\33[95m'
|
|
57
|
+
CBEIGE2 = '\33[96m'
|
|
58
|
+
CWHITE2 = '\33[97m'
|
|
59
|
+
|
|
60
|
+
CGREYBG = '\33[100m'
|
|
61
|
+
CREDBG2 = '\33[101m'
|
|
62
|
+
CGREENBG2 = '\33[102m'
|
|
63
|
+
CYELLOWBG2 = '\33[103m'
|
|
64
|
+
CBLUEBG2 = '\33[104m'
|
|
65
|
+
CVIOLETBG2 = '\33[105m'
|
|
66
|
+
CBEIGEBG2 = '\33[106m'
|
|
67
|
+
CWHITEBG2 = '\33[107m'
|
|
68
|
+
|
|
69
|
+
@staticmethod
|
|
70
|
+
def color_table():
|
|
71
|
+
"""
|
|
72
|
+
prints table of formatted text format options
|
|
73
|
+
"""
|
|
74
|
+
for style in range(8):
|
|
75
|
+
for fg in range(30, 38):
|
|
76
|
+
s1 = ''
|
|
77
|
+
for bg in range(40, 48):
|
|
78
|
+
_format = ';'.join([str(style), str(fg), str(bg)])
|
|
79
|
+
s1 += '\x1b[%sm %s \x1b[0m' % (_format, _format)
|
|
80
|
+
print(s1)
|
|
81
|
+
print('\n')
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class Progress(object):
|
|
85
|
+
DEFAULT = '{prompt} [{bar}] {progress:>7.2%} {eta}{done}'
|
|
86
|
+
MINI = '{prompt} {progress:.2%}'
|
|
87
|
+
FULL = '{prompt} [{bar}] {done_tasks}/{total_tasks} {progress:>7.2%}, {remaining} to go {eta}{done}'
|
|
88
|
+
|
|
89
|
+
def __init__(self, tasks: int | Iterable, prompt: str = 'Progress:', format_spec: str = DEFAULT, **kwargs):
|
|
90
|
+
self.prompt = prompt
|
|
91
|
+
self.format_spec = format_spec
|
|
92
|
+
self._width = kwargs.pop('width', None)
|
|
93
|
+
self.tick_size = kwargs.pop('tick_size', 0.0001)
|
|
94
|
+
self.progress_symbol = kwargs.pop('progress_symbol', '=')
|
|
95
|
+
self.blank_symbol = kwargs.pop('blank_symbol', ' ')
|
|
96
|
+
|
|
97
|
+
if isinstance(tasks, int):
|
|
98
|
+
self.total_tasks = tasks
|
|
99
|
+
self.tasks = range(self.total_tasks)
|
|
100
|
+
elif isinstance(tasks, (Sized, Iterable)):
|
|
101
|
+
self.total_tasks = len(tasks)
|
|
102
|
+
self.tasks = tasks
|
|
103
|
+
|
|
104
|
+
if 'outputs' not in kwargs:
|
|
105
|
+
self.outputs = [sys.stdout]
|
|
106
|
+
else:
|
|
107
|
+
outputs = kwargs.pop('outputs')
|
|
108
|
+
if outputs is None:
|
|
109
|
+
self.outputs = []
|
|
110
|
+
elif isinstance(outputs, Iterable):
|
|
111
|
+
self.outputs = outputs
|
|
112
|
+
else:
|
|
113
|
+
self.outputs = [outputs]
|
|
114
|
+
|
|
115
|
+
self.start_time = time.time()
|
|
116
|
+
self.done_tasks = 0
|
|
117
|
+
self.done_time = None
|
|
118
|
+
self.iter_task = None
|
|
119
|
+
self.last_output = -1
|
|
120
|
+
|
|
121
|
+
@property
|
|
122
|
+
def eta(self):
|
|
123
|
+
remaining = self.total_tasks - self.done_tasks
|
|
124
|
+
time_cost = time.time() - self.start_time
|
|
125
|
+
|
|
126
|
+
if self.done_tasks == 0:
|
|
127
|
+
eta = float('inf')
|
|
128
|
+
else:
|
|
129
|
+
eta = time_cost / self.done_tasks * remaining
|
|
130
|
+
|
|
131
|
+
return eta
|
|
132
|
+
|
|
133
|
+
@property
|
|
134
|
+
def work_time(self):
|
|
135
|
+
if self.done_time:
|
|
136
|
+
work_time = self.done_time - self.start_time
|
|
137
|
+
else:
|
|
138
|
+
work_time = time.time() - self.start_time
|
|
139
|
+
|
|
140
|
+
return work_time
|
|
141
|
+
|
|
142
|
+
@property
|
|
143
|
+
def is_done(self):
|
|
144
|
+
return self.done_tasks == self.total_tasks
|
|
145
|
+
|
|
146
|
+
@property
|
|
147
|
+
def progress(self):
|
|
148
|
+
if self.total_tasks:
|
|
149
|
+
return self.done_tasks / self.total_tasks
|
|
150
|
+
else:
|
|
151
|
+
return 1.
|
|
152
|
+
|
|
153
|
+
@property
|
|
154
|
+
def remaining(self):
|
|
155
|
+
return self.total_tasks - self.done_tasks
|
|
156
|
+
|
|
157
|
+
@property
|
|
158
|
+
def width(self):
|
|
159
|
+
if self._width:
|
|
160
|
+
width = self._width
|
|
161
|
+
else:
|
|
162
|
+
width = shutil.get_terminal_size().columns
|
|
163
|
+
|
|
164
|
+
return width
|
|
165
|
+
|
|
166
|
+
def format_progress(self):
|
|
167
|
+
|
|
168
|
+
if self.is_done:
|
|
169
|
+
eta = ''
|
|
170
|
+
done = f'All done in {self.work_time:,.2f} seconds'
|
|
171
|
+
else:
|
|
172
|
+
eta = f'ETA: {self.eta:,.2f} seconds'
|
|
173
|
+
done = ''
|
|
174
|
+
|
|
175
|
+
args = {
|
|
176
|
+
'total_tasks': self.total_tasks,
|
|
177
|
+
'done_tasks': self.done_tasks,
|
|
178
|
+
'progress': self.progress,
|
|
179
|
+
'remaining': self.remaining,
|
|
180
|
+
'work_time': self.work_time,
|
|
181
|
+
'eta': eta,
|
|
182
|
+
'done': done,
|
|
183
|
+
'prompt': self.prompt,
|
|
184
|
+
'bar': '',
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
bar_size = max(10, self.width - len(self.format_spec.format_map(args)))
|
|
188
|
+
progress_size = round(bar_size * self.progress)
|
|
189
|
+
args['bar'] = self.progress_symbol * progress_size + self.blank_symbol * (bar_size - progress_size)
|
|
190
|
+
progress_str = self.format_spec.format_map(args)
|
|
191
|
+
|
|
192
|
+
if self.is_done:
|
|
193
|
+
progress_str += '\n'
|
|
194
|
+
|
|
195
|
+
return progress_str
|
|
196
|
+
|
|
197
|
+
def reset(self):
|
|
198
|
+
self.done_tasks = 0
|
|
199
|
+
self.done_time = None
|
|
200
|
+
self.last_output = -1
|
|
201
|
+
|
|
202
|
+
def output(self):
|
|
203
|
+
progress_str = self.format_progress()
|
|
204
|
+
self.last_output = self.progress
|
|
205
|
+
|
|
206
|
+
for output in self.outputs:
|
|
207
|
+
if callable(output):
|
|
208
|
+
output(progress_str)
|
|
209
|
+
elif isinstance(output, (io.TextIOBase, logging.Logger)):
|
|
210
|
+
print('\r' + progress_str, file=output, end='')
|
|
211
|
+
else:
|
|
212
|
+
pass
|
|
213
|
+
|
|
214
|
+
def __call__(self, *args, **kwargs):
|
|
215
|
+
return self.format_progress()
|
|
216
|
+
|
|
217
|
+
def __next__(self):
|
|
218
|
+
try:
|
|
219
|
+
if (not self.tick_size) or self.progress >= self.tick_size + self.last_output:
|
|
220
|
+
self.output()
|
|
221
|
+
self.done_tasks += 1
|
|
222
|
+
return self.iter_task.__next__()
|
|
223
|
+
except StopIteration:
|
|
224
|
+
self.done_tasks = self.total_tasks
|
|
225
|
+
self.output()
|
|
226
|
+
raise StopIteration()
|
|
227
|
+
|
|
228
|
+
def __iter__(self):
|
|
229
|
+
self.reset()
|
|
230
|
+
self.start_time = time.time()
|
|
231
|
+
self.iter_task = self.tasks.__iter__()
|
|
232
|
+
return self
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
class GetInput(object):
|
|
236
|
+
def __init__(self, timeout=5, prompt_message: str = None, default_value: str = None):
|
|
237
|
+
|
|
238
|
+
if prompt_message is None:
|
|
239
|
+
prompt_message = f'Please respond in {timeout} seconds: '
|
|
240
|
+
|
|
241
|
+
self.timeout = timeout
|
|
242
|
+
self.default_value = default_value
|
|
243
|
+
self.prompt_message = prompt_message
|
|
244
|
+
self._input = None
|
|
245
|
+
self.input_thread: threading.Thread | None = None
|
|
246
|
+
self.show()
|
|
247
|
+
|
|
248
|
+
def show(self):
|
|
249
|
+
self.input_thread = threading.Thread(target=self.get_input)
|
|
250
|
+
self.input_thread.daemon = True
|
|
251
|
+
self.input_thread.start()
|
|
252
|
+
self.input_thread.join(timeout=self.timeout)
|
|
253
|
+
# input_thread.terminate()
|
|
254
|
+
|
|
255
|
+
if self._input is None:
|
|
256
|
+
print(f"No input was given within {self.timeout} seconds. Use {self.default_value} as default value.")
|
|
257
|
+
self._input = self.default_value
|
|
258
|
+
|
|
259
|
+
def get_input(self):
|
|
260
|
+
self._input = None
|
|
261
|
+
self._input = input(self.prompt_message)
|
|
262
|
+
return
|
|
263
|
+
|
|
264
|
+
@property
|
|
265
|
+
def input(self):
|
|
266
|
+
return self._input
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
class GetArgs(object):
|
|
270
|
+
class ExpectedArgument(object):
|
|
271
|
+
def __init__(self, name: str, **kwargs):
|
|
272
|
+
self.name = name
|
|
273
|
+
self.kwargs = kwargs
|
|
274
|
+
|
|
275
|
+
self.kwargs['dest'] = name
|
|
276
|
+
|
|
277
|
+
def __init__(self, parser: argparse.ArgumentParser = None, required_args: list[ExpectedArgument] = None, optional_args: list[ExpectedArgument] = None, identifier="--"):
|
|
278
|
+
self.parser = argparse.ArgumentParser() if parser is None else parser
|
|
279
|
+
self.identifier = identifier
|
|
280
|
+
|
|
281
|
+
self.required_args: dict[str, GetArgs.ExpectedArgument] = {}
|
|
282
|
+
self.optional_args: dict[str, GetArgs.ExpectedArgument] = {}
|
|
283
|
+
|
|
284
|
+
if required_args:
|
|
285
|
+
for argument in required_args:
|
|
286
|
+
self.add_argument(argument, optional=False)
|
|
287
|
+
|
|
288
|
+
if optional_args:
|
|
289
|
+
for argument in optional_args:
|
|
290
|
+
self.add_argument(argument, optional=True)
|
|
291
|
+
|
|
292
|
+
def add_flag(self, name: str, flag_value=True):
|
|
293
|
+
if flag_value:
|
|
294
|
+
action = 'store_true'
|
|
295
|
+
else:
|
|
296
|
+
action = 'store_false'
|
|
297
|
+
|
|
298
|
+
self.add_argument(argument=self.ExpectedArgument(name=name, action=action), optional=False)
|
|
299
|
+
|
|
300
|
+
def add_name(self, name: str, optional=False, **kwargs):
|
|
301
|
+
self.add_argument(argument=self.ExpectedArgument(name=name, **kwargs), optional=optional)
|
|
302
|
+
|
|
303
|
+
def add_argument(self, argument: ExpectedArgument, optional=False):
|
|
304
|
+
name = argument.name.lstrip(self.identifier)
|
|
305
|
+
|
|
306
|
+
if optional:
|
|
307
|
+
self.optional_args[name] = argument
|
|
308
|
+
else:
|
|
309
|
+
self.required_args[name] = argument
|
|
310
|
+
|
|
311
|
+
def parse(self):
|
|
312
|
+
for name in self.required_args:
|
|
313
|
+
self.parser.add_argument(f'{self.identifier}{name}', **self.required_args[name].kwargs)
|
|
314
|
+
|
|
315
|
+
parsed, unknown = self.parser.parse_known_args()
|
|
316
|
+
|
|
317
|
+
for arg_str in unknown:
|
|
318
|
+
if arg_str.startswith(self.identifier):
|
|
319
|
+
arg = arg_str.split('=')[0]
|
|
320
|
+
name = arg.strip(self.identifier)
|
|
321
|
+
|
|
322
|
+
if name in self.optional_args:
|
|
323
|
+
self.parser.add_argument(arg, **self.optional_args[name].kwargs)
|
|
324
|
+
else:
|
|
325
|
+
self.parser.add_argument(arg, dest=name)
|
|
326
|
+
|
|
327
|
+
args = self.parser.parse_args()
|
|
328
|
+
return args
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
class InteractiveShell(object):
|
|
332
|
+
def __init__(self, **kwargs):
|
|
333
|
+
self.encoding = kwargs.pop('encoding', 'utf-8')
|
|
334
|
+
self.logger = kwargs.pop('logger', LOGGER)
|
|
335
|
+
self.strip_ansi = kwargs.pop('strip_ansi', True)
|
|
336
|
+
self.external_fg = kwargs.pop('external_fg', True)
|
|
337
|
+
self.mode = kwargs.pop('mode', 'posix' if os.name == 'posix' else 'cmd')
|
|
338
|
+
self.cols = kwargs.pop('cols', 80)
|
|
339
|
+
self.rows = kwargs.pop('rows', 20)
|
|
340
|
+
|
|
341
|
+
self.process = None
|
|
342
|
+
self.is_running = False
|
|
343
|
+
self.await_response = False
|
|
344
|
+
self.stdin = None
|
|
345
|
+
self.stdout = []
|
|
346
|
+
self.stderr = []
|
|
347
|
+
self._process_output = None
|
|
348
|
+
self._command = None
|
|
349
|
+
self._raw = ''
|
|
350
|
+
|
|
351
|
+
self.callback = {
|
|
352
|
+
'on_character': [],
|
|
353
|
+
'on_linebreak': []
|
|
354
|
+
}
|
|
355
|
+
|
|
356
|
+
self.lock = threading.Lock()
|
|
357
|
+
|
|
358
|
+
if (command := kwargs.get('command')) is not None:
|
|
359
|
+
self.attach(command=command)
|
|
360
|
+
|
|
361
|
+
@classmethod
|
|
362
|
+
def trim_ansi(cls, text):
|
|
363
|
+
cis_pattern = re.compile(r'\x1B\[[^@-~]*[@-~]')
|
|
364
|
+
_ = cis_pattern.sub('', text)
|
|
365
|
+
|
|
366
|
+
osc_pattern = re.compile(r'\x1B][^\x07]*\x07')
|
|
367
|
+
_ = osc_pattern.sub('', _)
|
|
368
|
+
|
|
369
|
+
xterm_win_title_bel = '\x1b]0;'
|
|
370
|
+
_ = _.replace(xterm_win_title_bel, '')
|
|
371
|
+
|
|
372
|
+
bel = '\x07'
|
|
373
|
+
_ = _.replace(bel, ' ')
|
|
374
|
+
|
|
375
|
+
return _
|
|
376
|
+
|
|
377
|
+
def attach_local(self, command: list[str] = None):
|
|
378
|
+
if command is None:
|
|
379
|
+
if self.mode == 'cmd':
|
|
380
|
+
command = [r'C:\windows\system32\cmd.exe']
|
|
381
|
+
elif self.mode == 'ps':
|
|
382
|
+
command = [r'PowerShell.exe']
|
|
383
|
+
elif self.mode == 'posix':
|
|
384
|
+
command = ['bash']
|
|
385
|
+
else:
|
|
386
|
+
raise ValueError(f'Invalid mode {self.mode}')
|
|
387
|
+
|
|
388
|
+
self.attach(command=command)
|
|
389
|
+
self.logger.info('local shell connected')
|
|
390
|
+
|
|
391
|
+
def attach_remote(self, command: list[str] = None, host: str = None, user: str = None, password: str = None):
|
|
392
|
+
if command is None:
|
|
393
|
+
if self.mode == 'cmd':
|
|
394
|
+
command = [r'C:\windows\system32\cmd.exe', f'ssh {user}@{host}', password]
|
|
395
|
+
elif self.mode == 'ps':
|
|
396
|
+
command = [r'C:\Windows\System32\WindowsPowerShell\v1.0\powershell.exe', f'ssh {user}@{host}', password]
|
|
397
|
+
elif self.mode == 'posix':
|
|
398
|
+
command = ['sshpass', '-p', password, 'ssh', f'{user}@{host}']
|
|
399
|
+
elif self.mode == 'paramiko':
|
|
400
|
+
command = [user, host, password]
|
|
401
|
+
else:
|
|
402
|
+
raise ValueError(f'Invalid mode {self.mode}')
|
|
403
|
+
|
|
404
|
+
self.attach(command)
|
|
405
|
+
self.logger.info(f'remote shell {user}@{host} connected')
|
|
406
|
+
|
|
407
|
+
def attach(self, command: list[str]):
|
|
408
|
+
self._command = command
|
|
409
|
+
if self.process is not None:
|
|
410
|
+
self.terminate()
|
|
411
|
+
|
|
412
|
+
self.is_running = True
|
|
413
|
+
|
|
414
|
+
if self.mode == 'cmd':
|
|
415
|
+
return self._attach_cmd(command)
|
|
416
|
+
elif self.mode == 'ps':
|
|
417
|
+
return self._attach_ps(command)
|
|
418
|
+
elif self.mode == 'posix':
|
|
419
|
+
return self._attach_posix(command)
|
|
420
|
+
elif self.mode == 'paramiko':
|
|
421
|
+
return self._attach_paramiko(command)
|
|
422
|
+
else:
|
|
423
|
+
raise ValueError(f'Invalid mode {self.mode}')
|
|
424
|
+
|
|
425
|
+
def _attach_posix(self, command: list[str]):
|
|
426
|
+
import pty, fcntl, struct, termios
|
|
427
|
+
primary, replica = pty.openpty()
|
|
428
|
+
self.stdin = os.fdopen(primary, 'w')
|
|
429
|
+
fcntl.ioctl(self.stdin, termios.TIOCSWINSZ, struct.pack("HHHH", self.rows, self.cols, 0, 0))
|
|
430
|
+
|
|
431
|
+
self.process = subprocess.Popen(
|
|
432
|
+
command,
|
|
433
|
+
shell=False,
|
|
434
|
+
stdin=replica,
|
|
435
|
+
stdout=subprocess.PIPE,
|
|
436
|
+
stderr=subprocess.PIPE,
|
|
437
|
+
bufsize=0,
|
|
438
|
+
# preexec_fn=os.setsid,
|
|
439
|
+
# start_new_session=True,
|
|
440
|
+
text=True,
|
|
441
|
+
encoding=self.encoding,
|
|
442
|
+
# close_fds=True
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
self._process_output = (self.process.stdout, self.process.stderr)
|
|
446
|
+
threading.Thread(target=self.listen, name='shell.listen.stdout', args=('stdout',)).start()
|
|
447
|
+
threading.Thread(target=self.listen, name='shell.listen.stderr', args=('stderr',)).start()
|
|
448
|
+
|
|
449
|
+
def _attach_paramiko(self, command: list[str]):
|
|
450
|
+
import paramiko
|
|
451
|
+
|
|
452
|
+
user, host, password = command
|
|
453
|
+
client = paramiko.SSHClient()
|
|
454
|
+
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
|
455
|
+
client.connect(host, username=user, password=password)
|
|
456
|
+
channel = client.invoke_shell()
|
|
457
|
+
|
|
458
|
+
stdin = channel.makefile_stdin("wb", -1)
|
|
459
|
+
stdout = channel.makefile("r", -1)
|
|
460
|
+
stderr = channel.makefile_stderr("r", -1)
|
|
461
|
+
|
|
462
|
+
self.process = client
|
|
463
|
+
self.stdin = stdin
|
|
464
|
+
self._process_output = (stdout, stderr)
|
|
465
|
+
threading.Thread(target=self.listen, name='shell.listen.stdout', args=('stdout',)).start()
|
|
466
|
+
threading.Thread(target=self.listen, name='shell.listen.stderr', args=('stderr',)).start()
|
|
467
|
+
|
|
468
|
+
def _attach_cmd(self, command: list[str]):
|
|
469
|
+
from winpty import PTY
|
|
470
|
+
self.process = PTY(self.cols, self.rows)
|
|
471
|
+
# self.is_running = True
|
|
472
|
+
self.stdin = self.process
|
|
473
|
+
self._process_output = self.process
|
|
474
|
+
|
|
475
|
+
command = command[:]
|
|
476
|
+
console = command.pop(0)
|
|
477
|
+
remote = command.pop(0)
|
|
478
|
+
password = command.pop(0)
|
|
479
|
+
|
|
480
|
+
self.process.spawn(console.encode(self.encoding))
|
|
481
|
+
|
|
482
|
+
threading.Thread(target=self.listen, name='shell.listen.stdout', args=('stdout',)).start()
|
|
483
|
+
# threading.Thread(target=self.listen, name='shell.listen.stderr', args=('stderr',)).start()
|
|
484
|
+
|
|
485
|
+
while not self.stdout:
|
|
486
|
+
time.sleep(0.1)
|
|
487
|
+
|
|
488
|
+
self.write(remote)
|
|
489
|
+
|
|
490
|
+
while True:
|
|
491
|
+
last_output = self.stdout[-1] if self.stdout else ''
|
|
492
|
+
if 'password' in last_output:
|
|
493
|
+
break
|
|
494
|
+
time.sleep(0.1)
|
|
495
|
+
|
|
496
|
+
self.write(password)
|
|
497
|
+
|
|
498
|
+
for _ in command:
|
|
499
|
+
self.write(_)
|
|
500
|
+
|
|
501
|
+
def _attach_ps(self, command: list[str]):
|
|
502
|
+
from winpty import PTY
|
|
503
|
+
self.process = PTY(self.cols, self.rows)
|
|
504
|
+
# self.is_running = True
|
|
505
|
+
self.stdin = self.process
|
|
506
|
+
self._process_output = self.process
|
|
507
|
+
|
|
508
|
+
command = command[:]
|
|
509
|
+
console = command.pop(0)
|
|
510
|
+
remote = command.pop(0)
|
|
511
|
+
password = command.pop(0)
|
|
512
|
+
|
|
513
|
+
self.process.spawn(console.encode(self.encoding))
|
|
514
|
+
|
|
515
|
+
threading.Thread(target=self.listen, name='shell.listen.stdout', args=('stdout',)).start()
|
|
516
|
+
# threading.Thread(target=self.listen, name='shell.listen.stderr', args=('stderr',)).start()
|
|
517
|
+
|
|
518
|
+
while not self.stdout:
|
|
519
|
+
time.sleep(0.1)
|
|
520
|
+
|
|
521
|
+
self.write(remote)
|
|
522
|
+
|
|
523
|
+
while True:
|
|
524
|
+
last_output = self.stdout[-1] if self.stdout else ''
|
|
525
|
+
if 'password' in last_output:
|
|
526
|
+
break
|
|
527
|
+
time.sleep(0.1)
|
|
528
|
+
|
|
529
|
+
self.write(password)
|
|
530
|
+
|
|
531
|
+
for _ in command:
|
|
532
|
+
self.write(_)
|
|
533
|
+
|
|
534
|
+
def write(self, message: str | bytes, end=None):
|
|
535
|
+
self.lock.acquire()
|
|
536
|
+
end = os.linesep if end is None else end
|
|
537
|
+
|
|
538
|
+
if isinstance(message, str):
|
|
539
|
+
bytes_payload = f'{message}{end}'.encode(self.encoding)
|
|
540
|
+
elif isinstance(message, bytes):
|
|
541
|
+
bytes_payload = message + end.encode(self.encoding)
|
|
542
|
+
else:
|
|
543
|
+
raise TypeError(f'Invalid message type {type(message)}, must be str or bytes')
|
|
544
|
+
|
|
545
|
+
if self.mode == 'cmd':
|
|
546
|
+
self.stdin.write(bytes_payload)
|
|
547
|
+
elif self.mode == 'ps':
|
|
548
|
+
self.stdin.write(bytes_payload)
|
|
549
|
+
elif self.mode == 'posix':
|
|
550
|
+
self.stdin.write(bytes_payload.decode(self.encoding))
|
|
551
|
+
elif self.mode == 'paramiko':
|
|
552
|
+
self.stdin.write(bytes_payload.decode(self.encoding))
|
|
553
|
+
|
|
554
|
+
self.lock.release()
|
|
555
|
+
|
|
556
|
+
def on_character(self, character, flag):
|
|
557
|
+
if flag == 'stdout':
|
|
558
|
+
storage = self.stdout
|
|
559
|
+
else:
|
|
560
|
+
storage = self.stderr
|
|
561
|
+
|
|
562
|
+
if not storage:
|
|
563
|
+
storage.append(character)
|
|
564
|
+
else:
|
|
565
|
+
storage[-1] += character
|
|
566
|
+
|
|
567
|
+
self._raw += character
|
|
568
|
+
|
|
569
|
+
for callback in self.callback['on_character']:
|
|
570
|
+
callback(character, flag)
|
|
571
|
+
|
|
572
|
+
def on_linebreak(self, line: str, flag: str):
|
|
573
|
+
if flag == 'stdout':
|
|
574
|
+
storage = self.stdout
|
|
575
|
+
else:
|
|
576
|
+
storage = self.stderr
|
|
577
|
+
|
|
578
|
+
if self.strip_ansi:
|
|
579
|
+
continued_line = False
|
|
580
|
+
|
|
581
|
+
if os.name == 'nt':
|
|
582
|
+
if f'\x1b[{self.rows - 1};{self.cols}H' in line: # scroll screen and continue from the previous line
|
|
583
|
+
strip_line = self.trim_ansi(line)
|
|
584
|
+
strip_line = strip_line[1:]
|
|
585
|
+
continued_line = True
|
|
586
|
+
else:
|
|
587
|
+
strip_line = self.trim_ansi(line)
|
|
588
|
+
|
|
589
|
+
if strip_line.endswith('\r\n'):
|
|
590
|
+
content = strip_line.replace('\r\n', '').rstrip()
|
|
591
|
+
else:
|
|
592
|
+
content = strip_line.replace('\n', '').rstrip()
|
|
593
|
+
else:
|
|
594
|
+
strip_line = self.trim_ansi(line)
|
|
595
|
+
content = strip_line.replace('\n', '').rstrip()
|
|
596
|
+
|
|
597
|
+
if not content:
|
|
598
|
+
storage[-1] = ''
|
|
599
|
+
return
|
|
600
|
+
elif continued_line:
|
|
601
|
+
storage[-1] = ''
|
|
602
|
+
storage[-2] += content
|
|
603
|
+
else:
|
|
604
|
+
storage[-1] = content
|
|
605
|
+
storage.append('')
|
|
606
|
+
else:
|
|
607
|
+
content = line
|
|
608
|
+
storage[-1] = line
|
|
609
|
+
storage.append('')
|
|
610
|
+
|
|
611
|
+
if flag == 'stdout':
|
|
612
|
+
self.logger.debug(line)
|
|
613
|
+
else:
|
|
614
|
+
self.logger.error(line)
|
|
615
|
+
|
|
616
|
+
for callback in self.callback['on_linebreak']:
|
|
617
|
+
callback(line, flag)
|
|
618
|
+
|
|
619
|
+
return content
|
|
620
|
+
|
|
621
|
+
def listen(self, flag: str):
|
|
622
|
+
while self.is_running:
|
|
623
|
+
if flag == 'stdout':
|
|
624
|
+
storage = self.stdout
|
|
625
|
+
else:
|
|
626
|
+
storage = self.stderr
|
|
627
|
+
|
|
628
|
+
if self.mode == 'cmd':
|
|
629
|
+
if flag == 'stdout':
|
|
630
|
+
try:
|
|
631
|
+
outputs = self.process.read().decode(self.encoding)
|
|
632
|
+
except Exception as _:
|
|
633
|
+
break
|
|
634
|
+
else:
|
|
635
|
+
try:
|
|
636
|
+
outputs = self.process.read_stderr(1).decode(self.encoding)
|
|
637
|
+
except Exception as _:
|
|
638
|
+
outputs = ''
|
|
639
|
+
elif self.mode == 'ps':
|
|
640
|
+
if flag == 'stdout':
|
|
641
|
+
try:
|
|
642
|
+
outputs = self.process.read().decode(self.encoding)
|
|
643
|
+
except Exception as _:
|
|
644
|
+
break
|
|
645
|
+
else:
|
|
646
|
+
try:
|
|
647
|
+
outputs = self.process.read_stderr(1).decode(self.encoding)
|
|
648
|
+
except Exception as _:
|
|
649
|
+
outputs = ''
|
|
650
|
+
elif self.mode == 'posix':
|
|
651
|
+
if flag == 'stdout':
|
|
652
|
+
_in = self._process_output[0]
|
|
653
|
+
else:
|
|
654
|
+
_in = self._process_output[1]
|
|
655
|
+
|
|
656
|
+
_in.flush()
|
|
657
|
+
outputs = _in.read(1)
|
|
658
|
+
elif self.mode == 'paramiko':
|
|
659
|
+
if flag == 'stdout':
|
|
660
|
+
_in = self._process_output[0]
|
|
661
|
+
else:
|
|
662
|
+
_in = self._process_output[1]
|
|
663
|
+
|
|
664
|
+
_in.flush()
|
|
665
|
+
outputs = _in.read(1).decode(self.encoding, errors='ignore')
|
|
666
|
+
else:
|
|
667
|
+
raise ValueError(f'Invalid mode {self.mode}')
|
|
668
|
+
|
|
669
|
+
if outputs == '':
|
|
670
|
+
time.sleep(0.1)
|
|
671
|
+
continue
|
|
672
|
+
|
|
673
|
+
for output in outputs:
|
|
674
|
+
self.on_character(output, flag)
|
|
675
|
+
|
|
676
|
+
if output == '\n':
|
|
677
|
+
line = storage[-1]
|
|
678
|
+
self.on_linebreak(line, flag)
|
|
679
|
+
|
|
680
|
+
def execute(self, command: str | list[str], interval=0.1, timeout=None):
|
|
681
|
+
command_id = uuid.uuid4().hex[:8]
|
|
682
|
+
output = ([], [])
|
|
683
|
+
|
|
684
|
+
if self.await_response:
|
|
685
|
+
raise IOError('Blocked! Still waiting for response of the last command!')
|
|
686
|
+
|
|
687
|
+
start_time = time.time()
|
|
688
|
+
start_idx = len(self.stdout)
|
|
689
|
+
stderr_idx = len(self.stderr)
|
|
690
|
+
start_marker = f"<--- START of the command id={command_id} --->"
|
|
691
|
+
end_marker = f"<--- END of the command id={command_id} --->"
|
|
692
|
+
self.await_response = True
|
|
693
|
+
|
|
694
|
+
if isinstance(command, str):
|
|
695
|
+
command = [command]
|
|
696
|
+
|
|
697
|
+
self.write(f'echo "{start_marker}";{"&&".join(command)};echo "{end_marker}"', end='\n')
|
|
698
|
+
|
|
699
|
+
while True:
|
|
700
|
+
if timeout and time.time() - start_time > timeout:
|
|
701
|
+
LOGGER.error(output)
|
|
702
|
+
raise TimeoutError(f'No response within timeout {timeout}')
|
|
703
|
+
stdout_length = len(self.stdout)
|
|
704
|
+
|
|
705
|
+
for _ in range(start_idx, stdout_length):
|
|
706
|
+
line = self.stdout[_]
|
|
707
|
+
|
|
708
|
+
if start_marker in line and '"' not in line:
|
|
709
|
+
start_idx = _
|
|
710
|
+
|
|
711
|
+
if end_marker in line and '"' not in line:
|
|
712
|
+
self.await_response = False
|
|
713
|
+
end_idx = _
|
|
714
|
+
output[0].extend(self.stdout[start_idx + 1: end_idx])
|
|
715
|
+
output[1].extend(self.stderr[stderr_idx:])
|
|
716
|
+
break
|
|
717
|
+
|
|
718
|
+
if not self.await_response:
|
|
719
|
+
break
|
|
720
|
+
|
|
721
|
+
time.sleep(interval)
|
|
722
|
+
|
|
723
|
+
return output
|
|
724
|
+
|
|
725
|
+
def query(self, command: str, **kwargs) -> str:
|
|
726
|
+
timeout = kwargs.pop('timeout', None)
|
|
727
|
+
interval = kwargs.pop('interval', 0.1)
|
|
728
|
+
|
|
729
|
+
_ = self.execute([f'r=$({command})', 'echo "${r}"'], timeout=timeout, interval=interval)
|
|
730
|
+
output = _[0]
|
|
731
|
+
|
|
732
|
+
if len(output) > 1:
|
|
733
|
+
LOGGER.warning(f'Multi-line output received! {output}')
|
|
734
|
+
elif len(output) == 0:
|
|
735
|
+
LOGGER.warning(f'No output received!')
|
|
736
|
+
|
|
737
|
+
content = output[-1]
|
|
738
|
+
return content
|
|
739
|
+
|
|
740
|
+
def duplicate(self):
|
|
741
|
+
new_shell = self.__class__(
|
|
742
|
+
encoding=self.encoding,
|
|
743
|
+
logger=self.logger,
|
|
744
|
+
strip_ansi=self.strip_ansi,
|
|
745
|
+
external_fg=self.external_fg,
|
|
746
|
+
mode=self.mode,
|
|
747
|
+
cols=self.cols,
|
|
748
|
+
rows=self.rows
|
|
749
|
+
)
|
|
750
|
+
new_shell.attach(self._command)
|
|
751
|
+
return new_shell
|
|
752
|
+
|
|
753
|
+
def terminate(self):
|
|
754
|
+
if self.mode == 'cmd':
|
|
755
|
+
del self.process
|
|
756
|
+
self.process = None
|
|
757
|
+
elif self.mode == 'ps':
|
|
758
|
+
del self.process
|
|
759
|
+
self.process = None
|
|
760
|
+
elif self.mode == 'paramiko':
|
|
761
|
+
self.stdin.close()
|
|
762
|
+
self.process.close()
|
|
763
|
+
else:
|
|
764
|
+
self.process.terminate()
|
|
765
|
+
self.process = None
|
|
766
|
+
self.is_running = False
|
|
767
|
+
|
|
768
|
+
def disconnect(self):
|
|
769
|
+
self.write("exit")
|
|
770
|
+
|
|
771
|
+
|
|
772
|
+
class ShellTransfer(object):
|
|
773
|
+
def __init__(self, **kwargs):
|
|
774
|
+
|
|
775
|
+
if 'shell' in kwargs:
|
|
776
|
+
self.shell = kwargs.pop('shell')
|
|
777
|
+
else:
|
|
778
|
+
self.shell = InteractiveShell()
|
|
779
|
+
host = kwargs.pop('host')
|
|
780
|
+
user = kwargs.pop('user')
|
|
781
|
+
password = kwargs.pop('password')
|
|
782
|
+
self.shell.attach_remote(host=host, user=user, password=password)
|
|
783
|
+
|
|
784
|
+
def push(self, local_path: str | pathlib.Path, remote_path: str | pathlib.Path, **kwargs):
|
|
785
|
+
mode = kwargs.pop('mode', None)
|
|
786
|
+
|
|
787
|
+
if not mode:
|
|
788
|
+
try:
|
|
789
|
+
self._push_text(local_path, remote_path, **kwargs)
|
|
790
|
+
except Exception as _:
|
|
791
|
+
self._push_hex(local_path, remote_path, **kwargs)
|
|
792
|
+
elif mode == 'text':
|
|
793
|
+
self._push_text(local_path=local_path, remote_path=remote_path, **kwargs)
|
|
794
|
+
elif mode == 'hex':
|
|
795
|
+
self._push_hex(local_path=local_path, remote_path=remote_path, **kwargs)
|
|
796
|
+
elif mode == 'sftp':
|
|
797
|
+
self._push_sftp(local_path=local_path, remote_path=remote_path, **kwargs)
|
|
798
|
+
else:
|
|
799
|
+
raise ValueError(f'Invalid push mode {mode}')
|
|
800
|
+
|
|
801
|
+
def _push_text(self, local_path: str | pathlib.Path, remote_path: str | pathlib.Path, **kwargs):
|
|
802
|
+
chunk_size = kwargs.get('chunk_size', 1024)
|
|
803
|
+
encoding = kwargs.get('encoding', 'utf-8')
|
|
804
|
+
local_md5 = self.md5(local_path)
|
|
805
|
+
|
|
806
|
+
with open(local_path, 'rb') as f:
|
|
807
|
+
file_bytes = f.read()
|
|
808
|
+
|
|
809
|
+
file_size = len(file_bytes)
|
|
810
|
+
|
|
811
|
+
LOGGER.info(f'Push transfer size {file_size:,}, MD5 {local_md5}')
|
|
812
|
+
|
|
813
|
+
self.shell.execute(f'rm {remote_path}')
|
|
814
|
+
|
|
815
|
+
for _ in Progress(range(0, file_size, chunk_size), prompt='Push: '):
|
|
816
|
+
package = file_bytes[_:_ + chunk_size]
|
|
817
|
+
command = f'printf %b "{package.decode(encoding)}" >> "{remote_path}"'
|
|
818
|
+
self.shell.write(command.encode('unicode-escape').decode().replace('\\\\', '\\'))
|
|
819
|
+
|
|
820
|
+
remote_md5 = self.md5_remote(remote_path)
|
|
821
|
+
|
|
822
|
+
if remote_md5 == local_md5:
|
|
823
|
+
LOGGER.info(f'Transfer complete! MD5 = {remote_md5} match!')
|
|
824
|
+
else:
|
|
825
|
+
LOGGER.error(f'Fail transfer failed! local md5 {local_md5}, remote md5 {remote_md5}. Please reduce chunk_size and try again!')
|
|
826
|
+
|
|
827
|
+
def _push_hex(self, local_path: str | pathlib.Path, remote_path: str | pathlib.Path, **kwargs):
|
|
828
|
+
chunk_size = kwargs.get('chunk_size', 1024)
|
|
829
|
+
local_md5 = self.md5(local_path)
|
|
830
|
+
|
|
831
|
+
with open(local_path, 'rb') as f:
|
|
832
|
+
file_bytes = f.read()
|
|
833
|
+
|
|
834
|
+
file_size = len(file_bytes)
|
|
835
|
+
|
|
836
|
+
LOGGER.info(f'Push transfer size {file_size:,}, MD5 {local_md5}')
|
|
837
|
+
|
|
838
|
+
self.shell.execute(f'rm {remote_path}')
|
|
839
|
+
|
|
840
|
+
for _ in Progress(range(0, file_size, chunk_size), prompt='Push: '):
|
|
841
|
+
package = file_bytes[_:_ + chunk_size]
|
|
842
|
+
self.shell.write(f'echo -n {package.hex()} | xxd -r -p >> "{remote_path}"')
|
|
843
|
+
|
|
844
|
+
remote_md5 = self.md5_remote(remote_path)
|
|
845
|
+
|
|
846
|
+
if remote_md5 == local_md5:
|
|
847
|
+
LOGGER.info(f'Transfer complete! MD5 = {remote_md5} match!')
|
|
848
|
+
else:
|
|
849
|
+
LOGGER.error(f'Fail transfer failed! local md5 {local_md5}, remote md5 {remote_md5}. Please reduce chunk_size and try again!')
|
|
850
|
+
|
|
851
|
+
def _push_sftp(self, local_path: str | pathlib.Path, remote_path: str | pathlib.Path, **kwargs):
|
|
852
|
+
import paramiko
|
|
853
|
+
progress = None
|
|
854
|
+
|
|
855
|
+
def sftp_callback(transferred_bytes, total_bytes):
|
|
856
|
+
nonlocal progress
|
|
857
|
+
|
|
858
|
+
if progress is None:
|
|
859
|
+
progress = Progress(tasks=total_bytes)
|
|
860
|
+
|
|
861
|
+
progress.done_tasks = transferred_bytes
|
|
862
|
+
progress.output()
|
|
863
|
+
|
|
864
|
+
if self.shell.mode == 'paramiko':
|
|
865
|
+
ssh_client: paramiko.SSHClient = self.shell.process
|
|
866
|
+
sftp_client = ssh_client.open_sftp()
|
|
867
|
+
sftp_client.put(remotepath=str(remote_path), localpath=str(local_path), callback=sftp_callback)
|
|
868
|
+
else:
|
|
869
|
+
raise NotImplementedError(f'sftp mode not available in {self.shell.mode} mode')
|
|
870
|
+
|
|
871
|
+
local_md5 = self.md5(local_path)
|
|
872
|
+
remote_md5 = self.md5_remote(remote_path)
|
|
873
|
+
|
|
874
|
+
if remote_md5 == local_md5:
|
|
875
|
+
LOGGER.info(f'Transfer complete! MD5 = {remote_md5} match!')
|
|
876
|
+
else:
|
|
877
|
+
LOGGER.error(f'Fail transfer failed! local md5 {local_md5}, remote md5 {remote_md5}.')
|
|
878
|
+
|
|
879
|
+
def _pull_thread(self, remote_path: str | pathlib.Path, start_bytes: int, chunk_size: int, free_pool, occupied_pool):
|
|
880
|
+
command = f'xxd -s {start_bytes} -l {chunk_size} -c {chunk_size} -p "{remote_path}"'
|
|
881
|
+
result = None
|
|
882
|
+
i = 0
|
|
883
|
+
try:
|
|
884
|
+
new_shell = free_pool.pop(0)
|
|
885
|
+
except:
|
|
886
|
+
new_shell = self.shell.duplicate()
|
|
887
|
+
|
|
888
|
+
occupied_pool.append(new_shell)
|
|
889
|
+
|
|
890
|
+
while i < 10:
|
|
891
|
+
try:
|
|
892
|
+
hex_str = new_shell.query(command, timeout=6)
|
|
893
|
+
result = bytes.fromhex(hex_str)
|
|
894
|
+
break
|
|
895
|
+
except:
|
|
896
|
+
pass
|
|
897
|
+
|
|
898
|
+
i += 1
|
|
899
|
+
|
|
900
|
+
if i == 10:
|
|
901
|
+
new_shell.terminate()
|
|
902
|
+
else:
|
|
903
|
+
free_pool.append(new_shell)
|
|
904
|
+
occupied_pool.remove(new_shell)
|
|
905
|
+
return result
|
|
906
|
+
|
|
907
|
+
def pull(self, local_path: str | pathlib.Path, remote_path: str | pathlib.Path, **kwargs):
|
|
908
|
+
mode = kwargs.pop('mode', None)
|
|
909
|
+
|
|
910
|
+
if not mode or mode == 'hex':
|
|
911
|
+
self._pull_hex(local_path, remote_path, **kwargs)
|
|
912
|
+
elif mode == 'sftp':
|
|
913
|
+
self._pull_sftp(local_path=local_path, remote_path=remote_path, **kwargs)
|
|
914
|
+
else:
|
|
915
|
+
raise ValueError(f'Invalid push mode {mode}')
|
|
916
|
+
|
|
917
|
+
def _pull_hex(self, local_path: str | pathlib.Path, remote_path: str | pathlib.Path, **kwargs):
|
|
918
|
+
chunk_size = kwargs.pop('chunk_size', 1024)
|
|
919
|
+
remote_md5 = self.md5_remote(remote_path)
|
|
920
|
+
file_size = int(self.shell.query(f'wc -c < "{remote_path}"'))
|
|
921
|
+
hex_str = ''
|
|
922
|
+
|
|
923
|
+
LOGGER.info(f'Pull transfer size {file_size:,}, MD5 {remote_md5}')
|
|
924
|
+
|
|
925
|
+
for _ in Progress(range(0, file_size, chunk_size)):
|
|
926
|
+
hex_str += self.shell.query(f'xxd -s {_} -l {chunk_size} -c {chunk_size} -p "{remote_path}"')
|
|
927
|
+
|
|
928
|
+
bytes_data = bytes.fromhex(hex_str)
|
|
929
|
+
with open(local_path, 'wb') as f:
|
|
930
|
+
f.write(bytes_data)
|
|
931
|
+
|
|
932
|
+
local_md5 = self.md5(local_path)
|
|
933
|
+
|
|
934
|
+
if remote_md5 == local_md5:
|
|
935
|
+
LOGGER.info(f'Transfer complete! MD5 = {remote_md5} match!')
|
|
936
|
+
else:
|
|
937
|
+
LOGGER.error(f'Fail transfer failed! local md5 {local_md5}, remote md5 {remote_md5}. Please reduce chunk_size and try again!')
|
|
938
|
+
|
|
939
|
+
def _pull_sftp(self, local_path: str | pathlib.Path, remote_path: str | pathlib.Path, **kwargs):
|
|
940
|
+
import paramiko
|
|
941
|
+
progress = None
|
|
942
|
+
|
|
943
|
+
def sftp_callback(transferred_bytes, total_bytes):
|
|
944
|
+
nonlocal progress
|
|
945
|
+
|
|
946
|
+
if progress is None:
|
|
947
|
+
progress = Progress(tasks=total_bytes)
|
|
948
|
+
|
|
949
|
+
progress.done_tasks = transferred_bytes
|
|
950
|
+
progress.output()
|
|
951
|
+
|
|
952
|
+
if self.shell.mode == 'paramiko':
|
|
953
|
+
ssh_client: paramiko.SSHClient = self.shell.process
|
|
954
|
+
sftp_client = ssh_client.open_sftp()
|
|
955
|
+
sftp_client.get(remotepath=str(remote_path), localpath=str(local_path), callback=sftp_callback)
|
|
956
|
+
else:
|
|
957
|
+
raise NotImplementedError(f'sftp mode not available in {self.shell.mode} mode')
|
|
958
|
+
|
|
959
|
+
local_md5 = self.md5(local_path)
|
|
960
|
+
remote_md5 = self.md5_remote(remote_path)
|
|
961
|
+
|
|
962
|
+
if remote_md5 == local_md5:
|
|
963
|
+
LOGGER.info(f'Transfer complete! MD5 = {remote_md5} match!')
|
|
964
|
+
else:
|
|
965
|
+
LOGGER.error(f'Fail transfer failed! local md5 {local_md5}, remote md5 {remote_md5}')
|
|
966
|
+
|
|
967
|
+
def pull_multi_threads(self, local_path: str | pathlib.Path, remote_path: str | pathlib.Path, **kwargs):
|
|
968
|
+
chunk_size = kwargs.pop('chunk_size', 1024)
|
|
969
|
+
workers = kwargs.pop('workers', 8)
|
|
970
|
+
remote_md5 = self.md5_remote(remote_path)
|
|
971
|
+
file_size = int(self.shell.query(f'wc -c < "{remote_path}"'))
|
|
972
|
+
tasks = {}
|
|
973
|
+
result = {}
|
|
974
|
+
bytes_data = b''
|
|
975
|
+
free_pool, occupied_pool = [], []
|
|
976
|
+
|
|
977
|
+
LOGGER.info(f'Pull transfer size {file_size:,}, MD5 {remote_md5}')
|
|
978
|
+
import concurrent.futures
|
|
979
|
+
|
|
980
|
+
executor = concurrent.futures.ThreadPoolExecutor(max_workers=workers)
|
|
981
|
+
for start_bytes in range(0, file_size, chunk_size):
|
|
982
|
+
task = executor.submit(self._pull_thread, remote_path, start_bytes, chunk_size, free_pool, occupied_pool)
|
|
983
|
+
tasks[task] = start_bytes
|
|
984
|
+
|
|
985
|
+
progress = Progress(tasks=len(tasks))
|
|
986
|
+
for task in concurrent.futures.as_completed(tasks):
|
|
987
|
+
data = task.result()
|
|
988
|
+
result[tasks[task]] = data
|
|
989
|
+
progress.done_tasks += 1
|
|
990
|
+
progress.prompt = f'pulling with {len(free_pool) + len(occupied_pool)} workers... '
|
|
991
|
+
progress.output()
|
|
992
|
+
|
|
993
|
+
for _ in free_pool:
|
|
994
|
+
_.terminate()
|
|
995
|
+
|
|
996
|
+
for _ in sorted(result):
|
|
997
|
+
bytes_data += result[_]
|
|
998
|
+
|
|
999
|
+
with open(local_path, 'wb') as f:
|
|
1000
|
+
f.write(bytes_data)
|
|
1001
|
+
|
|
1002
|
+
local_md5 = self.md5(local_path)
|
|
1003
|
+
|
|
1004
|
+
if remote_md5 == local_md5:
|
|
1005
|
+
LOGGER.info(f'Transfer complete! MD5 = {remote_md5} match!')
|
|
1006
|
+
else:
|
|
1007
|
+
LOGGER.error(f'Fail transfer failed! local md5 {local_md5}, remote md5 {remote_md5}. Please reduce chunk_size and try again!')
|
|
1008
|
+
|
|
1009
|
+
def monitor(self, local_path: str | pathlib.Path, remote_path: str | pathlib.Path, **kwargs):
|
|
1010
|
+
chunk_size = kwargs.pop('chunk_size', 1024)
|
|
1011
|
+
interval = kwargs.pop('interval', 1)
|
|
1012
|
+
fetch_all = kwargs.pop('fetch_all', False)
|
|
1013
|
+
hex_str = ''
|
|
1014
|
+
|
|
1015
|
+
last_size = int(self.shell.query(f'wc -c < "{remote_path}"'))
|
|
1016
|
+
|
|
1017
|
+
if fetch_all:
|
|
1018
|
+
for _ in Progress(range(0, last_size, chunk_size), prompt='fetching... '):
|
|
1019
|
+
hex_str += self.shell.query(f'xxd -s {_} -l {chunk_size} -c {chunk_size} -p "{remote_path}"')
|
|
1020
|
+
|
|
1021
|
+
bytes_data = bytes.fromhex(hex_str)
|
|
1022
|
+
with open(local_path, 'wb') as f:
|
|
1023
|
+
f.write(bytes_data)
|
|
1024
|
+
|
|
1025
|
+
while True:
|
|
1026
|
+
hex_str = ''
|
|
1027
|
+
current_size = int(self.shell.query(f'wc -c < "{remote_path}"'))
|
|
1028
|
+
|
|
1029
|
+
if current_size <= last_size:
|
|
1030
|
+
last_size = current_size
|
|
1031
|
+
continue
|
|
1032
|
+
|
|
1033
|
+
for _ in Progress(range(last_size, current_size, chunk_size), prompt='updating... '):
|
|
1034
|
+
hex_str += self.shell.query(f'xxd -s {_} -l {chunk_size} -c {chunk_size} -p "{remote_path}"')
|
|
1035
|
+
|
|
1036
|
+
last_size = current_size
|
|
1037
|
+
bytes_data = bytes.fromhex(hex_str)
|
|
1038
|
+
with open(local_path, 'ab') as f:
|
|
1039
|
+
f.write(bytes_data)
|
|
1040
|
+
|
|
1041
|
+
time.sleep(interval)
|
|
1042
|
+
|
|
1043
|
+
@classmethod
|
|
1044
|
+
def md5(cls, file_path) -> str:
|
|
1045
|
+
hash_md5 = hashlib.md5()
|
|
1046
|
+
with open(file_path, "rb") as f:
|
|
1047
|
+
for chunk in iter(lambda: f.read(4096), b""):
|
|
1048
|
+
hash_md5.update(chunk)
|
|
1049
|
+
return hash_md5.hexdigest()
|
|
1050
|
+
|
|
1051
|
+
def md5_remote(self, remote_path) -> str:
|
|
1052
|
+
return self.shell.query(command=f'md5sum "{remote_path}" | cut -f 1 -d " "')
|
|
1053
|
+
|
|
1054
|
+
def terminate(self):
|
|
1055
|
+
self.shell.terminate()
|
|
1056
|
+
|
|
1057
|
+
|
|
1058
|
+
def count_ordinal(n: int) -> str:
|
|
1059
|
+
"""
|
|
1060
|
+
Convert an integer into its ordinal representation::
|
|
1061
|
+
make_ordinal(0) => '0th'
|
|
1062
|
+
make_ordinal(3) => '3rd'
|
|
1063
|
+
make_ordinal(122) => '122nd'
|
|
1064
|
+
make_ordinal(213) => '213th'
|
|
1065
|
+
"""
|
|
1066
|
+
n = int(n)
|
|
1067
|
+
suffix = ['th', 'st', 'nd', 'rd', 'th'][min(n % 10, 4)]
|
|
1068
|
+
if 11 <= (n % 100) <= 13:
|
|
1069
|
+
suffix = 'th'
|
|
1070
|
+
return str(n) + suffix
|