PyAlgoEngine 0.7.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. PyAlgoEngine-0.7.4.dist-info/LICENSE +21 -0
  2. PyAlgoEngine-0.7.4.dist-info/METADATA +27 -0
  3. PyAlgoEngine-0.7.4.dist-info/RECORD +43 -0
  4. PyAlgoEngine-0.7.4.dist-info/WHEEL +5 -0
  5. PyAlgoEngine-0.7.4.dist-info/top_level.txt +1 -0
  6. algo_engine/__init__.py +41 -0
  7. algo_engine/apps/__init__.py +17 -0
  8. algo_engine/apps/backtest/__init__.py +20 -0
  9. algo_engine/apps/backtest/doc_server.py +331 -0
  10. algo_engine/apps/backtest/tester.py +254 -0
  11. algo_engine/apps/backtest/web_app.py +127 -0
  12. algo_engine/apps/bokeh_server.py +205 -0
  13. algo_engine/apps/demo/__init__.py +0 -0
  14. algo_engine/apps/demo/test.py +39 -0
  15. algo_engine/backtest/__init__.py +19 -0
  16. algo_engine/backtest/__main__.py +51 -0
  17. algo_engine/backtest/metrics.py +179 -0
  18. algo_engine/backtest/replay.py +261 -0
  19. algo_engine/backtest/sim_match.py +295 -0
  20. algo_engine/base/__init__.py +40 -0
  21. algo_engine/base/console_utils.py +1070 -0
  22. algo_engine/base/finance_decimal.py +258 -0
  23. algo_engine/base/market_buffer.py +571 -0
  24. algo_engine/base/market_utils.py +3092 -0
  25. algo_engine/base/market_utils_nt.py +188 -0
  26. algo_engine/base/market_utils_posix.py +3004 -0
  27. algo_engine/base/technical_analysis.py +406 -0
  28. algo_engine/base/telemetrics.py +78 -0
  29. algo_engine/base/trade_utils.py +709 -0
  30. algo_engine/engine/__init__.py +28 -0
  31. algo_engine/engine/algo_engine.py +901 -0
  32. algo_engine/engine/event_engine.py +53 -0
  33. algo_engine/engine/market_engine.py +370 -0
  34. algo_engine/engine/trade_engine.py +2037 -0
  35. algo_engine/monitor/__init__.py +15 -0
  36. algo_engine/monitor/advanced_data_interface.py +239 -0
  37. algo_engine/profile/__init__.py +121 -0
  38. algo_engine/profile/cn.py +175 -0
  39. algo_engine/strategy/__init__.py +44 -0
  40. algo_engine/strategy/strategy_engine.py +440 -0
  41. algo_engine/utils/__init__.py +3 -0
  42. algo_engine/utils/commit_regularizer.py +49 -0
  43. algo_engine/utils/data_utils.py +251 -0
@@ -0,0 +1,1070 @@
1
+ import argparse
2
+ import hashlib
3
+ import io
4
+ import logging
5
+ import os
6
+ import pathlib
7
+ import re
8
+ import shutil
9
+ import subprocess
10
+ import sys
11
+ import threading
12
+ import time
13
+ import uuid
14
+ from enum import Enum
15
+ from typing import Iterable, Sized
16
+
17
+ from . import LOGGER
18
+
19
+ LOGGER = LOGGER.getChild('Console')
20
+ __all__ = ['Progress', 'GetInput', 'GetArgs', 'count_ordinal', 'TerminalStyle', 'InteractiveShell', 'ShellTransfer']
21
+
22
+
23
+ # noinspection SpellCheckingInspection
24
+ class TerminalStyle(Enum):
25
+ CEND = '\33[0m'
26
+ CBOLD = '\33[1m'
27
+ CITALIC = '\33[3m'
28
+ CURL = '\33[4m'
29
+ CBLINK = '\33[5m'
30
+ CBLINK2 = '\33[6m'
31
+ CSELECTED = '\33[7m'
32
+
33
+ CBLACK = '\33[30m'
34
+ CRED = '\33[31m'
35
+ CGREEN = '\33[32m'
36
+ CYELLOW = '\33[33m'
37
+ CBLUE = '\33[34m'
38
+ CVIOLET = '\33[35m'
39
+ CBEIGE = '\33[36m'
40
+ CWHITE = '\33[37m'
41
+
42
+ CBLACKBG = '\33[40m'
43
+ CREDBG = '\33[41m'
44
+ CGREENBG = '\33[42m'
45
+ CYELLOWBG = '\33[43m'
46
+ CBLUEBG = '\33[44m'
47
+ CVIOLETBG = '\33[45m'
48
+ CBEIGEBG = '\33[46m'
49
+ CWHITEBG = '\33[47m'
50
+
51
+ CGREY = '\33[90m'
52
+ CRED2 = '\33[91m'
53
+ CGREEN2 = '\33[92m'
54
+ CYELLOW2 = '\33[93m'
55
+ CBLUE2 = '\33[94m'
56
+ CVIOLET2 = '\33[95m'
57
+ CBEIGE2 = '\33[96m'
58
+ CWHITE2 = '\33[97m'
59
+
60
+ CGREYBG = '\33[100m'
61
+ CREDBG2 = '\33[101m'
62
+ CGREENBG2 = '\33[102m'
63
+ CYELLOWBG2 = '\33[103m'
64
+ CBLUEBG2 = '\33[104m'
65
+ CVIOLETBG2 = '\33[105m'
66
+ CBEIGEBG2 = '\33[106m'
67
+ CWHITEBG2 = '\33[107m'
68
+
69
+ @staticmethod
70
+ def color_table():
71
+ """
72
+ prints table of formatted text format options
73
+ """
74
+ for style in range(8):
75
+ for fg in range(30, 38):
76
+ s1 = ''
77
+ for bg in range(40, 48):
78
+ _format = ';'.join([str(style), str(fg), str(bg)])
79
+ s1 += '\x1b[%sm %s \x1b[0m' % (_format, _format)
80
+ print(s1)
81
+ print('\n')
82
+
83
+
84
+ class Progress(object):
85
+ DEFAULT = '{prompt} [{bar}] {progress:>7.2%} {eta}{done}'
86
+ MINI = '{prompt} {progress:.2%}'
87
+ FULL = '{prompt} [{bar}] {done_tasks}/{total_tasks} {progress:>7.2%}, {remaining} to go {eta}{done}'
88
+
89
+ def __init__(self, tasks: int | Iterable, prompt: str = 'Progress:', format_spec: str = DEFAULT, **kwargs):
90
+ self.prompt = prompt
91
+ self.format_spec = format_spec
92
+ self._width = kwargs.pop('width', None)
93
+ self.tick_size = kwargs.pop('tick_size', 0.0001)
94
+ self.progress_symbol = kwargs.pop('progress_symbol', '=')
95
+ self.blank_symbol = kwargs.pop('blank_symbol', ' ')
96
+
97
+ if isinstance(tasks, int):
98
+ self.total_tasks = tasks
99
+ self.tasks = range(self.total_tasks)
100
+ elif isinstance(tasks, (Sized, Iterable)):
101
+ self.total_tasks = len(tasks)
102
+ self.tasks = tasks
103
+
104
+ if 'outputs' not in kwargs:
105
+ self.outputs = [sys.stdout]
106
+ else:
107
+ outputs = kwargs.pop('outputs')
108
+ if outputs is None:
109
+ self.outputs = []
110
+ elif isinstance(outputs, Iterable):
111
+ self.outputs = outputs
112
+ else:
113
+ self.outputs = [outputs]
114
+
115
+ self.start_time = time.time()
116
+ self.done_tasks = 0
117
+ self.done_time = None
118
+ self.iter_task = None
119
+ self.last_output = -1
120
+
121
+ @property
122
+ def eta(self):
123
+ remaining = self.total_tasks - self.done_tasks
124
+ time_cost = time.time() - self.start_time
125
+
126
+ if self.done_tasks == 0:
127
+ eta = float('inf')
128
+ else:
129
+ eta = time_cost / self.done_tasks * remaining
130
+
131
+ return eta
132
+
133
+ @property
134
+ def work_time(self):
135
+ if self.done_time:
136
+ work_time = self.done_time - self.start_time
137
+ else:
138
+ work_time = time.time() - self.start_time
139
+
140
+ return work_time
141
+
142
+ @property
143
+ def is_done(self):
144
+ return self.done_tasks == self.total_tasks
145
+
146
+ @property
147
+ def progress(self):
148
+ if self.total_tasks:
149
+ return self.done_tasks / self.total_tasks
150
+ else:
151
+ return 1.
152
+
153
+ @property
154
+ def remaining(self):
155
+ return self.total_tasks - self.done_tasks
156
+
157
+ @property
158
+ def width(self):
159
+ if self._width:
160
+ width = self._width
161
+ else:
162
+ width = shutil.get_terminal_size().columns
163
+
164
+ return width
165
+
166
+ def format_progress(self):
167
+
168
+ if self.is_done:
169
+ eta = ''
170
+ done = f'All done in {self.work_time:,.2f} seconds'
171
+ else:
172
+ eta = f'ETA: {self.eta:,.2f} seconds'
173
+ done = ''
174
+
175
+ args = {
176
+ 'total_tasks': self.total_tasks,
177
+ 'done_tasks': self.done_tasks,
178
+ 'progress': self.progress,
179
+ 'remaining': self.remaining,
180
+ 'work_time': self.work_time,
181
+ 'eta': eta,
182
+ 'done': done,
183
+ 'prompt': self.prompt,
184
+ 'bar': '',
185
+ }
186
+
187
+ bar_size = max(10, self.width - len(self.format_spec.format_map(args)))
188
+ progress_size = round(bar_size * self.progress)
189
+ args['bar'] = self.progress_symbol * progress_size + self.blank_symbol * (bar_size - progress_size)
190
+ progress_str = self.format_spec.format_map(args)
191
+
192
+ if self.is_done:
193
+ progress_str += '\n'
194
+
195
+ return progress_str
196
+
197
+ def reset(self):
198
+ self.done_tasks = 0
199
+ self.done_time = None
200
+ self.last_output = -1
201
+
202
+ def output(self):
203
+ progress_str = self.format_progress()
204
+ self.last_output = self.progress
205
+
206
+ for output in self.outputs:
207
+ if callable(output):
208
+ output(progress_str)
209
+ elif isinstance(output, (io.TextIOBase, logging.Logger)):
210
+ print('\r' + progress_str, file=output, end='')
211
+ else:
212
+ pass
213
+
214
+ def __call__(self, *args, **kwargs):
215
+ return self.format_progress()
216
+
217
+ def __next__(self):
218
+ try:
219
+ if (not self.tick_size) or self.progress >= self.tick_size + self.last_output:
220
+ self.output()
221
+ self.done_tasks += 1
222
+ return self.iter_task.__next__()
223
+ except StopIteration:
224
+ self.done_tasks = self.total_tasks
225
+ self.output()
226
+ raise StopIteration()
227
+
228
+ def __iter__(self):
229
+ self.reset()
230
+ self.start_time = time.time()
231
+ self.iter_task = self.tasks.__iter__()
232
+ return self
233
+
234
+
235
+ class GetInput(object):
236
+ def __init__(self, timeout=5, prompt_message: str = None, default_value: str = None):
237
+
238
+ if prompt_message is None:
239
+ prompt_message = f'Please respond in {timeout} seconds: '
240
+
241
+ self.timeout = timeout
242
+ self.default_value = default_value
243
+ self.prompt_message = prompt_message
244
+ self._input = None
245
+ self.input_thread: threading.Thread | None = None
246
+ self.show()
247
+
248
+ def show(self):
249
+ self.input_thread = threading.Thread(target=self.get_input)
250
+ self.input_thread.daemon = True
251
+ self.input_thread.start()
252
+ self.input_thread.join(timeout=self.timeout)
253
+ # input_thread.terminate()
254
+
255
+ if self._input is None:
256
+ print(f"No input was given within {self.timeout} seconds. Use {self.default_value} as default value.")
257
+ self._input = self.default_value
258
+
259
+ def get_input(self):
260
+ self._input = None
261
+ self._input = input(self.prompt_message)
262
+ return
263
+
264
+ @property
265
+ def input(self):
266
+ return self._input
267
+
268
+
269
+ class GetArgs(object):
270
+ class ExpectedArgument(object):
271
+ def __init__(self, name: str, **kwargs):
272
+ self.name = name
273
+ self.kwargs = kwargs
274
+
275
+ self.kwargs['dest'] = name
276
+
277
+ def __init__(self, parser: argparse.ArgumentParser = None, required_args: list[ExpectedArgument] = None, optional_args: list[ExpectedArgument] = None, identifier="--"):
278
+ self.parser = argparse.ArgumentParser() if parser is None else parser
279
+ self.identifier = identifier
280
+
281
+ self.required_args: dict[str, GetArgs.ExpectedArgument] = {}
282
+ self.optional_args: dict[str, GetArgs.ExpectedArgument] = {}
283
+
284
+ if required_args:
285
+ for argument in required_args:
286
+ self.add_argument(argument, optional=False)
287
+
288
+ if optional_args:
289
+ for argument in optional_args:
290
+ self.add_argument(argument, optional=True)
291
+
292
+ def add_flag(self, name: str, flag_value=True):
293
+ if flag_value:
294
+ action = 'store_true'
295
+ else:
296
+ action = 'store_false'
297
+
298
+ self.add_argument(argument=self.ExpectedArgument(name=name, action=action), optional=False)
299
+
300
+ def add_name(self, name: str, optional=False, **kwargs):
301
+ self.add_argument(argument=self.ExpectedArgument(name=name, **kwargs), optional=optional)
302
+
303
+ def add_argument(self, argument: ExpectedArgument, optional=False):
304
+ name = argument.name.lstrip(self.identifier)
305
+
306
+ if optional:
307
+ self.optional_args[name] = argument
308
+ else:
309
+ self.required_args[name] = argument
310
+
311
+ def parse(self):
312
+ for name in self.required_args:
313
+ self.parser.add_argument(f'{self.identifier}{name}', **self.required_args[name].kwargs)
314
+
315
+ parsed, unknown = self.parser.parse_known_args()
316
+
317
+ for arg_str in unknown:
318
+ if arg_str.startswith(self.identifier):
319
+ arg = arg_str.split('=')[0]
320
+ name = arg.strip(self.identifier)
321
+
322
+ if name in self.optional_args:
323
+ self.parser.add_argument(arg, **self.optional_args[name].kwargs)
324
+ else:
325
+ self.parser.add_argument(arg, dest=name)
326
+
327
+ args = self.parser.parse_args()
328
+ return args
329
+
330
+
331
+ class InteractiveShell(object):
332
+ def __init__(self, **kwargs):
333
+ self.encoding = kwargs.pop('encoding', 'utf-8')
334
+ self.logger = kwargs.pop('logger', LOGGER)
335
+ self.strip_ansi = kwargs.pop('strip_ansi', True)
336
+ self.external_fg = kwargs.pop('external_fg', True)
337
+ self.mode = kwargs.pop('mode', 'posix' if os.name == 'posix' else 'cmd')
338
+ self.cols = kwargs.pop('cols', 80)
339
+ self.rows = kwargs.pop('rows', 20)
340
+
341
+ self.process = None
342
+ self.is_running = False
343
+ self.await_response = False
344
+ self.stdin = None
345
+ self.stdout = []
346
+ self.stderr = []
347
+ self._process_output = None
348
+ self._command = None
349
+ self._raw = ''
350
+
351
+ self.callback = {
352
+ 'on_character': [],
353
+ 'on_linebreak': []
354
+ }
355
+
356
+ self.lock = threading.Lock()
357
+
358
+ if (command := kwargs.get('command')) is not None:
359
+ self.attach(command=command)
360
+
361
+ @classmethod
362
+ def trim_ansi(cls, text):
363
+ cis_pattern = re.compile(r'\x1B\[[^@-~]*[@-~]')
364
+ _ = cis_pattern.sub('', text)
365
+
366
+ osc_pattern = re.compile(r'\x1B][^\x07]*\x07')
367
+ _ = osc_pattern.sub('', _)
368
+
369
+ xterm_win_title_bel = '\x1b]0;'
370
+ _ = _.replace(xterm_win_title_bel, '')
371
+
372
+ bel = '\x07'
373
+ _ = _.replace(bel, ' ')
374
+
375
+ return _
376
+
377
+ def attach_local(self, command: list[str] = None):
378
+ if command is None:
379
+ if self.mode == 'cmd':
380
+ command = [r'C:\windows\system32\cmd.exe']
381
+ elif self.mode == 'ps':
382
+ command = [r'PowerShell.exe']
383
+ elif self.mode == 'posix':
384
+ command = ['bash']
385
+ else:
386
+ raise ValueError(f'Invalid mode {self.mode}')
387
+
388
+ self.attach(command=command)
389
+ self.logger.info('local shell connected')
390
+
391
+ def attach_remote(self, command: list[str] = None, host: str = None, user: str = None, password: str = None):
392
+ if command is None:
393
+ if self.mode == 'cmd':
394
+ command = [r'C:\windows\system32\cmd.exe', f'ssh {user}@{host}', password]
395
+ elif self.mode == 'ps':
396
+ command = [r'C:\Windows\System32\WindowsPowerShell\v1.0\powershell.exe', f'ssh {user}@{host}', password]
397
+ elif self.mode == 'posix':
398
+ command = ['sshpass', '-p', password, 'ssh', f'{user}@{host}']
399
+ elif self.mode == 'paramiko':
400
+ command = [user, host, password]
401
+ else:
402
+ raise ValueError(f'Invalid mode {self.mode}')
403
+
404
+ self.attach(command)
405
+ self.logger.info(f'remote shell {user}@{host} connected')
406
+
407
+ def attach(self, command: list[str]):
408
+ self._command = command
409
+ if self.process is not None:
410
+ self.terminate()
411
+
412
+ self.is_running = True
413
+
414
+ if self.mode == 'cmd':
415
+ return self._attach_cmd(command)
416
+ elif self.mode == 'ps':
417
+ return self._attach_ps(command)
418
+ elif self.mode == 'posix':
419
+ return self._attach_posix(command)
420
+ elif self.mode == 'paramiko':
421
+ return self._attach_paramiko(command)
422
+ else:
423
+ raise ValueError(f'Invalid mode {self.mode}')
424
+
425
+ def _attach_posix(self, command: list[str]):
426
+ import pty, fcntl, struct, termios
427
+ primary, replica = pty.openpty()
428
+ self.stdin = os.fdopen(primary, 'w')
429
+ fcntl.ioctl(self.stdin, termios.TIOCSWINSZ, struct.pack("HHHH", self.rows, self.cols, 0, 0))
430
+
431
+ self.process = subprocess.Popen(
432
+ command,
433
+ shell=False,
434
+ stdin=replica,
435
+ stdout=subprocess.PIPE,
436
+ stderr=subprocess.PIPE,
437
+ bufsize=0,
438
+ # preexec_fn=os.setsid,
439
+ # start_new_session=True,
440
+ text=True,
441
+ encoding=self.encoding,
442
+ # close_fds=True
443
+ )
444
+
445
+ self._process_output = (self.process.stdout, self.process.stderr)
446
+ threading.Thread(target=self.listen, name='shell.listen.stdout', args=('stdout',)).start()
447
+ threading.Thread(target=self.listen, name='shell.listen.stderr', args=('stderr',)).start()
448
+
449
+ def _attach_paramiko(self, command: list[str]):
450
+ import paramiko
451
+
452
+ user, host, password = command
453
+ client = paramiko.SSHClient()
454
+ client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
455
+ client.connect(host, username=user, password=password)
456
+ channel = client.invoke_shell()
457
+
458
+ stdin = channel.makefile_stdin("wb", -1)
459
+ stdout = channel.makefile("r", -1)
460
+ stderr = channel.makefile_stderr("r", -1)
461
+
462
+ self.process = client
463
+ self.stdin = stdin
464
+ self._process_output = (stdout, stderr)
465
+ threading.Thread(target=self.listen, name='shell.listen.stdout', args=('stdout',)).start()
466
+ threading.Thread(target=self.listen, name='shell.listen.stderr', args=('stderr',)).start()
467
+
468
+ def _attach_cmd(self, command: list[str]):
469
+ from winpty import PTY
470
+ self.process = PTY(self.cols, self.rows)
471
+ # self.is_running = True
472
+ self.stdin = self.process
473
+ self._process_output = self.process
474
+
475
+ command = command[:]
476
+ console = command.pop(0)
477
+ remote = command.pop(0)
478
+ password = command.pop(0)
479
+
480
+ self.process.spawn(console.encode(self.encoding))
481
+
482
+ threading.Thread(target=self.listen, name='shell.listen.stdout', args=('stdout',)).start()
483
+ # threading.Thread(target=self.listen, name='shell.listen.stderr', args=('stderr',)).start()
484
+
485
+ while not self.stdout:
486
+ time.sleep(0.1)
487
+
488
+ self.write(remote)
489
+
490
+ while True:
491
+ last_output = self.stdout[-1] if self.stdout else ''
492
+ if 'password' in last_output:
493
+ break
494
+ time.sleep(0.1)
495
+
496
+ self.write(password)
497
+
498
+ for _ in command:
499
+ self.write(_)
500
+
501
+ def _attach_ps(self, command: list[str]):
502
+ from winpty import PTY
503
+ self.process = PTY(self.cols, self.rows)
504
+ # self.is_running = True
505
+ self.stdin = self.process
506
+ self._process_output = self.process
507
+
508
+ command = command[:]
509
+ console = command.pop(0)
510
+ remote = command.pop(0)
511
+ password = command.pop(0)
512
+
513
+ self.process.spawn(console.encode(self.encoding))
514
+
515
+ threading.Thread(target=self.listen, name='shell.listen.stdout', args=('stdout',)).start()
516
+ # threading.Thread(target=self.listen, name='shell.listen.stderr', args=('stderr',)).start()
517
+
518
+ while not self.stdout:
519
+ time.sleep(0.1)
520
+
521
+ self.write(remote)
522
+
523
+ while True:
524
+ last_output = self.stdout[-1] if self.stdout else ''
525
+ if 'password' in last_output:
526
+ break
527
+ time.sleep(0.1)
528
+
529
+ self.write(password)
530
+
531
+ for _ in command:
532
+ self.write(_)
533
+
534
+ def write(self, message: str | bytes, end=None):
535
+ self.lock.acquire()
536
+ end = os.linesep if end is None else end
537
+
538
+ if isinstance(message, str):
539
+ bytes_payload = f'{message}{end}'.encode(self.encoding)
540
+ elif isinstance(message, bytes):
541
+ bytes_payload = message + end.encode(self.encoding)
542
+ else:
543
+ raise TypeError(f'Invalid message type {type(message)}, must be str or bytes')
544
+
545
+ if self.mode == 'cmd':
546
+ self.stdin.write(bytes_payload)
547
+ elif self.mode == 'ps':
548
+ self.stdin.write(bytes_payload)
549
+ elif self.mode == 'posix':
550
+ self.stdin.write(bytes_payload.decode(self.encoding))
551
+ elif self.mode == 'paramiko':
552
+ self.stdin.write(bytes_payload.decode(self.encoding))
553
+
554
+ self.lock.release()
555
+
556
+ def on_character(self, character, flag):
557
+ if flag == 'stdout':
558
+ storage = self.stdout
559
+ else:
560
+ storage = self.stderr
561
+
562
+ if not storage:
563
+ storage.append(character)
564
+ else:
565
+ storage[-1] += character
566
+
567
+ self._raw += character
568
+
569
+ for callback in self.callback['on_character']:
570
+ callback(character, flag)
571
+
572
+ def on_linebreak(self, line: str, flag: str):
573
+ if flag == 'stdout':
574
+ storage = self.stdout
575
+ else:
576
+ storage = self.stderr
577
+
578
+ if self.strip_ansi:
579
+ continued_line = False
580
+
581
+ if os.name == 'nt':
582
+ if f'\x1b[{self.rows - 1};{self.cols}H' in line: # scroll screen and continue from the previous line
583
+ strip_line = self.trim_ansi(line)
584
+ strip_line = strip_line[1:]
585
+ continued_line = True
586
+ else:
587
+ strip_line = self.trim_ansi(line)
588
+
589
+ if strip_line.endswith('\r\n'):
590
+ content = strip_line.replace('\r\n', '').rstrip()
591
+ else:
592
+ content = strip_line.replace('\n', '').rstrip()
593
+ else:
594
+ strip_line = self.trim_ansi(line)
595
+ content = strip_line.replace('\n', '').rstrip()
596
+
597
+ if not content:
598
+ storage[-1] = ''
599
+ return
600
+ elif continued_line:
601
+ storage[-1] = ''
602
+ storage[-2] += content
603
+ else:
604
+ storage[-1] = content
605
+ storage.append('')
606
+ else:
607
+ content = line
608
+ storage[-1] = line
609
+ storage.append('')
610
+
611
+ if flag == 'stdout':
612
+ self.logger.debug(line)
613
+ else:
614
+ self.logger.error(line)
615
+
616
+ for callback in self.callback['on_linebreak']:
617
+ callback(line, flag)
618
+
619
+ return content
620
+
621
+ def listen(self, flag: str):
622
+ while self.is_running:
623
+ if flag == 'stdout':
624
+ storage = self.stdout
625
+ else:
626
+ storage = self.stderr
627
+
628
+ if self.mode == 'cmd':
629
+ if flag == 'stdout':
630
+ try:
631
+ outputs = self.process.read().decode(self.encoding)
632
+ except Exception as _:
633
+ break
634
+ else:
635
+ try:
636
+ outputs = self.process.read_stderr(1).decode(self.encoding)
637
+ except Exception as _:
638
+ outputs = ''
639
+ elif self.mode == 'ps':
640
+ if flag == 'stdout':
641
+ try:
642
+ outputs = self.process.read().decode(self.encoding)
643
+ except Exception as _:
644
+ break
645
+ else:
646
+ try:
647
+ outputs = self.process.read_stderr(1).decode(self.encoding)
648
+ except Exception as _:
649
+ outputs = ''
650
+ elif self.mode == 'posix':
651
+ if flag == 'stdout':
652
+ _in = self._process_output[0]
653
+ else:
654
+ _in = self._process_output[1]
655
+
656
+ _in.flush()
657
+ outputs = _in.read(1)
658
+ elif self.mode == 'paramiko':
659
+ if flag == 'stdout':
660
+ _in = self._process_output[0]
661
+ else:
662
+ _in = self._process_output[1]
663
+
664
+ _in.flush()
665
+ outputs = _in.read(1).decode(self.encoding, errors='ignore')
666
+ else:
667
+ raise ValueError(f'Invalid mode {self.mode}')
668
+
669
+ if outputs == '':
670
+ time.sleep(0.1)
671
+ continue
672
+
673
+ for output in outputs:
674
+ self.on_character(output, flag)
675
+
676
+ if output == '\n':
677
+ line = storage[-1]
678
+ self.on_linebreak(line, flag)
679
+
680
+ def execute(self, command: str | list[str], interval=0.1, timeout=None):
681
+ command_id = uuid.uuid4().hex[:8]
682
+ output = ([], [])
683
+
684
+ if self.await_response:
685
+ raise IOError('Blocked! Still waiting for response of the last command!')
686
+
687
+ start_time = time.time()
688
+ start_idx = len(self.stdout)
689
+ stderr_idx = len(self.stderr)
690
+ start_marker = f"<--- START of the command id={command_id} --->"
691
+ end_marker = f"<--- END of the command id={command_id} --->"
692
+ self.await_response = True
693
+
694
+ if isinstance(command, str):
695
+ command = [command]
696
+
697
+ self.write(f'echo "{start_marker}";{"&&".join(command)};echo "{end_marker}"', end='\n')
698
+
699
+ while True:
700
+ if timeout and time.time() - start_time > timeout:
701
+ LOGGER.error(output)
702
+ raise TimeoutError(f'No response within timeout {timeout}')
703
+ stdout_length = len(self.stdout)
704
+
705
+ for _ in range(start_idx, stdout_length):
706
+ line = self.stdout[_]
707
+
708
+ if start_marker in line and '"' not in line:
709
+ start_idx = _
710
+
711
+ if end_marker in line and '"' not in line:
712
+ self.await_response = False
713
+ end_idx = _
714
+ output[0].extend(self.stdout[start_idx + 1: end_idx])
715
+ output[1].extend(self.stderr[stderr_idx:])
716
+ break
717
+
718
+ if not self.await_response:
719
+ break
720
+
721
+ time.sleep(interval)
722
+
723
+ return output
724
+
725
+ def query(self, command: str, **kwargs) -> str:
726
+ timeout = kwargs.pop('timeout', None)
727
+ interval = kwargs.pop('interval', 0.1)
728
+
729
+ _ = self.execute([f'r=$({command})', 'echo "${r}"'], timeout=timeout, interval=interval)
730
+ output = _[0]
731
+
732
+ if len(output) > 1:
733
+ LOGGER.warning(f'Multi-line output received! {output}')
734
+ elif len(output) == 0:
735
+ LOGGER.warning(f'No output received!')
736
+
737
+ content = output[-1]
738
+ return content
739
+
740
+ def duplicate(self):
741
+ new_shell = self.__class__(
742
+ encoding=self.encoding,
743
+ logger=self.logger,
744
+ strip_ansi=self.strip_ansi,
745
+ external_fg=self.external_fg,
746
+ mode=self.mode,
747
+ cols=self.cols,
748
+ rows=self.rows
749
+ )
750
+ new_shell.attach(self._command)
751
+ return new_shell
752
+
753
+ def terminate(self):
754
+ if self.mode == 'cmd':
755
+ del self.process
756
+ self.process = None
757
+ elif self.mode == 'ps':
758
+ del self.process
759
+ self.process = None
760
+ elif self.mode == 'paramiko':
761
+ self.stdin.close()
762
+ self.process.close()
763
+ else:
764
+ self.process.terminate()
765
+ self.process = None
766
+ self.is_running = False
767
+
768
+ def disconnect(self):
769
+ self.write("exit")
770
+
771
+
772
+ class ShellTransfer(object):
773
+ def __init__(self, **kwargs):
774
+
775
+ if 'shell' in kwargs:
776
+ self.shell = kwargs.pop('shell')
777
+ else:
778
+ self.shell = InteractiveShell()
779
+ host = kwargs.pop('host')
780
+ user = kwargs.pop('user')
781
+ password = kwargs.pop('password')
782
+ self.shell.attach_remote(host=host, user=user, password=password)
783
+
784
+ def push(self, local_path: str | pathlib.Path, remote_path: str | pathlib.Path, **kwargs):
785
+ mode = kwargs.pop('mode', None)
786
+
787
+ if not mode:
788
+ try:
789
+ self._push_text(local_path, remote_path, **kwargs)
790
+ except Exception as _:
791
+ self._push_hex(local_path, remote_path, **kwargs)
792
+ elif mode == 'text':
793
+ self._push_text(local_path=local_path, remote_path=remote_path, **kwargs)
794
+ elif mode == 'hex':
795
+ self._push_hex(local_path=local_path, remote_path=remote_path, **kwargs)
796
+ elif mode == 'sftp':
797
+ self._push_sftp(local_path=local_path, remote_path=remote_path, **kwargs)
798
+ else:
799
+ raise ValueError(f'Invalid push mode {mode}')
800
+
801
+ def _push_text(self, local_path: str | pathlib.Path, remote_path: str | pathlib.Path, **kwargs):
802
+ chunk_size = kwargs.get('chunk_size', 1024)
803
+ encoding = kwargs.get('encoding', 'utf-8')
804
+ local_md5 = self.md5(local_path)
805
+
806
+ with open(local_path, 'rb') as f:
807
+ file_bytes = f.read()
808
+
809
+ file_size = len(file_bytes)
810
+
811
+ LOGGER.info(f'Push transfer size {file_size:,}, MD5 {local_md5}')
812
+
813
+ self.shell.execute(f'rm {remote_path}')
814
+
815
+ for _ in Progress(range(0, file_size, chunk_size), prompt='Push: '):
816
+ package = file_bytes[_:_ + chunk_size]
817
+ command = f'printf %b "{package.decode(encoding)}" >> "{remote_path}"'
818
+ self.shell.write(command.encode('unicode-escape').decode().replace('\\\\', '\\'))
819
+
820
+ remote_md5 = self.md5_remote(remote_path)
821
+
822
+ if remote_md5 == local_md5:
823
+ LOGGER.info(f'Transfer complete! MD5 = {remote_md5} match!')
824
+ else:
825
+ LOGGER.error(f'Fail transfer failed! local md5 {local_md5}, remote md5 {remote_md5}. Please reduce chunk_size and try again!')
826
+
827
+ def _push_hex(self, local_path: str | pathlib.Path, remote_path: str | pathlib.Path, **kwargs):
828
+ chunk_size = kwargs.get('chunk_size', 1024)
829
+ local_md5 = self.md5(local_path)
830
+
831
+ with open(local_path, 'rb') as f:
832
+ file_bytes = f.read()
833
+
834
+ file_size = len(file_bytes)
835
+
836
+ LOGGER.info(f'Push transfer size {file_size:,}, MD5 {local_md5}')
837
+
838
+ self.shell.execute(f'rm {remote_path}')
839
+
840
+ for _ in Progress(range(0, file_size, chunk_size), prompt='Push: '):
841
+ package = file_bytes[_:_ + chunk_size]
842
+ self.shell.write(f'echo -n {package.hex()} | xxd -r -p >> "{remote_path}"')
843
+
844
+ remote_md5 = self.md5_remote(remote_path)
845
+
846
+ if remote_md5 == local_md5:
847
+ LOGGER.info(f'Transfer complete! MD5 = {remote_md5} match!')
848
+ else:
849
+ LOGGER.error(f'Fail transfer failed! local md5 {local_md5}, remote md5 {remote_md5}. Please reduce chunk_size and try again!')
850
+
851
+ def _push_sftp(self, local_path: str | pathlib.Path, remote_path: str | pathlib.Path, **kwargs):
852
+ import paramiko
853
+ progress = None
854
+
855
+ def sftp_callback(transferred_bytes, total_bytes):
856
+ nonlocal progress
857
+
858
+ if progress is None:
859
+ progress = Progress(tasks=total_bytes)
860
+
861
+ progress.done_tasks = transferred_bytes
862
+ progress.output()
863
+
864
+ if self.shell.mode == 'paramiko':
865
+ ssh_client: paramiko.SSHClient = self.shell.process
866
+ sftp_client = ssh_client.open_sftp()
867
+ sftp_client.put(remotepath=str(remote_path), localpath=str(local_path), callback=sftp_callback)
868
+ else:
869
+ raise NotImplementedError(f'sftp mode not available in {self.shell.mode} mode')
870
+
871
+ local_md5 = self.md5(local_path)
872
+ remote_md5 = self.md5_remote(remote_path)
873
+
874
+ if remote_md5 == local_md5:
875
+ LOGGER.info(f'Transfer complete! MD5 = {remote_md5} match!')
876
+ else:
877
+ LOGGER.error(f'Fail transfer failed! local md5 {local_md5}, remote md5 {remote_md5}.')
878
+
879
+ def _pull_thread(self, remote_path: str | pathlib.Path, start_bytes: int, chunk_size: int, free_pool, occupied_pool):
880
+ command = f'xxd -s {start_bytes} -l {chunk_size} -c {chunk_size} -p "{remote_path}"'
881
+ result = None
882
+ i = 0
883
+ try:
884
+ new_shell = free_pool.pop(0)
885
+ except:
886
+ new_shell = self.shell.duplicate()
887
+
888
+ occupied_pool.append(new_shell)
889
+
890
+ while i < 10:
891
+ try:
892
+ hex_str = new_shell.query(command, timeout=6)
893
+ result = bytes.fromhex(hex_str)
894
+ break
895
+ except:
896
+ pass
897
+
898
+ i += 1
899
+
900
+ if i == 10:
901
+ new_shell.terminate()
902
+ else:
903
+ free_pool.append(new_shell)
904
+ occupied_pool.remove(new_shell)
905
+ return result
906
+
907
+ def pull(self, local_path: str | pathlib.Path, remote_path: str | pathlib.Path, **kwargs):
908
+ mode = kwargs.pop('mode', None)
909
+
910
+ if not mode or mode == 'hex':
911
+ self._pull_hex(local_path, remote_path, **kwargs)
912
+ elif mode == 'sftp':
913
+ self._pull_sftp(local_path=local_path, remote_path=remote_path, **kwargs)
914
+ else:
915
+ raise ValueError(f'Invalid push mode {mode}')
916
+
917
+ def _pull_hex(self, local_path: str | pathlib.Path, remote_path: str | pathlib.Path, **kwargs):
918
+ chunk_size = kwargs.pop('chunk_size', 1024)
919
+ remote_md5 = self.md5_remote(remote_path)
920
+ file_size = int(self.shell.query(f'wc -c < "{remote_path}"'))
921
+ hex_str = ''
922
+
923
+ LOGGER.info(f'Pull transfer size {file_size:,}, MD5 {remote_md5}')
924
+
925
+ for _ in Progress(range(0, file_size, chunk_size)):
926
+ hex_str += self.shell.query(f'xxd -s {_} -l {chunk_size} -c {chunk_size} -p "{remote_path}"')
927
+
928
+ bytes_data = bytes.fromhex(hex_str)
929
+ with open(local_path, 'wb') as f:
930
+ f.write(bytes_data)
931
+
932
+ local_md5 = self.md5(local_path)
933
+
934
+ if remote_md5 == local_md5:
935
+ LOGGER.info(f'Transfer complete! MD5 = {remote_md5} match!')
936
+ else:
937
+ LOGGER.error(f'Fail transfer failed! local md5 {local_md5}, remote md5 {remote_md5}. Please reduce chunk_size and try again!')
938
+
939
+ def _pull_sftp(self, local_path: str | pathlib.Path, remote_path: str | pathlib.Path, **kwargs):
940
+ import paramiko
941
+ progress = None
942
+
943
+ def sftp_callback(transferred_bytes, total_bytes):
944
+ nonlocal progress
945
+
946
+ if progress is None:
947
+ progress = Progress(tasks=total_bytes)
948
+
949
+ progress.done_tasks = transferred_bytes
950
+ progress.output()
951
+
952
+ if self.shell.mode == 'paramiko':
953
+ ssh_client: paramiko.SSHClient = self.shell.process
954
+ sftp_client = ssh_client.open_sftp()
955
+ sftp_client.get(remotepath=str(remote_path), localpath=str(local_path), callback=sftp_callback)
956
+ else:
957
+ raise NotImplementedError(f'sftp mode not available in {self.shell.mode} mode')
958
+
959
+ local_md5 = self.md5(local_path)
960
+ remote_md5 = self.md5_remote(remote_path)
961
+
962
+ if remote_md5 == local_md5:
963
+ LOGGER.info(f'Transfer complete! MD5 = {remote_md5} match!')
964
+ else:
965
+ LOGGER.error(f'Fail transfer failed! local md5 {local_md5}, remote md5 {remote_md5}')
966
+
967
+ def pull_multi_threads(self, local_path: str | pathlib.Path, remote_path: str | pathlib.Path, **kwargs):
968
+ chunk_size = kwargs.pop('chunk_size', 1024)
969
+ workers = kwargs.pop('workers', 8)
970
+ remote_md5 = self.md5_remote(remote_path)
971
+ file_size = int(self.shell.query(f'wc -c < "{remote_path}"'))
972
+ tasks = {}
973
+ result = {}
974
+ bytes_data = b''
975
+ free_pool, occupied_pool = [], []
976
+
977
+ LOGGER.info(f'Pull transfer size {file_size:,}, MD5 {remote_md5}')
978
+ import concurrent.futures
979
+
980
+ executor = concurrent.futures.ThreadPoolExecutor(max_workers=workers)
981
+ for start_bytes in range(0, file_size, chunk_size):
982
+ task = executor.submit(self._pull_thread, remote_path, start_bytes, chunk_size, free_pool, occupied_pool)
983
+ tasks[task] = start_bytes
984
+
985
+ progress = Progress(tasks=len(tasks))
986
+ for task in concurrent.futures.as_completed(tasks):
987
+ data = task.result()
988
+ result[tasks[task]] = data
989
+ progress.done_tasks += 1
990
+ progress.prompt = f'pulling with {len(free_pool) + len(occupied_pool)} workers... '
991
+ progress.output()
992
+
993
+ for _ in free_pool:
994
+ _.terminate()
995
+
996
+ for _ in sorted(result):
997
+ bytes_data += result[_]
998
+
999
+ with open(local_path, 'wb') as f:
1000
+ f.write(bytes_data)
1001
+
1002
+ local_md5 = self.md5(local_path)
1003
+
1004
+ if remote_md5 == local_md5:
1005
+ LOGGER.info(f'Transfer complete! MD5 = {remote_md5} match!')
1006
+ else:
1007
+ LOGGER.error(f'Fail transfer failed! local md5 {local_md5}, remote md5 {remote_md5}. Please reduce chunk_size and try again!')
1008
+
1009
+ def monitor(self, local_path: str | pathlib.Path, remote_path: str | pathlib.Path, **kwargs):
1010
+ chunk_size = kwargs.pop('chunk_size', 1024)
1011
+ interval = kwargs.pop('interval', 1)
1012
+ fetch_all = kwargs.pop('fetch_all', False)
1013
+ hex_str = ''
1014
+
1015
+ last_size = int(self.shell.query(f'wc -c < "{remote_path}"'))
1016
+
1017
+ if fetch_all:
1018
+ for _ in Progress(range(0, last_size, chunk_size), prompt='fetching... '):
1019
+ hex_str += self.shell.query(f'xxd -s {_} -l {chunk_size} -c {chunk_size} -p "{remote_path}"')
1020
+
1021
+ bytes_data = bytes.fromhex(hex_str)
1022
+ with open(local_path, 'wb') as f:
1023
+ f.write(bytes_data)
1024
+
1025
+ while True:
1026
+ hex_str = ''
1027
+ current_size = int(self.shell.query(f'wc -c < "{remote_path}"'))
1028
+
1029
+ if current_size <= last_size:
1030
+ last_size = current_size
1031
+ continue
1032
+
1033
+ for _ in Progress(range(last_size, current_size, chunk_size), prompt='updating... '):
1034
+ hex_str += self.shell.query(f'xxd -s {_} -l {chunk_size} -c {chunk_size} -p "{remote_path}"')
1035
+
1036
+ last_size = current_size
1037
+ bytes_data = bytes.fromhex(hex_str)
1038
+ with open(local_path, 'ab') as f:
1039
+ f.write(bytes_data)
1040
+
1041
+ time.sleep(interval)
1042
+
1043
+ @classmethod
1044
+ def md5(cls, file_path) -> str:
1045
+ hash_md5 = hashlib.md5()
1046
+ with open(file_path, "rb") as f:
1047
+ for chunk in iter(lambda: f.read(4096), b""):
1048
+ hash_md5.update(chunk)
1049
+ return hash_md5.hexdigest()
1050
+
1051
+ def md5_remote(self, remote_path) -> str:
1052
+ return self.shell.query(command=f'md5sum "{remote_path}" | cut -f 1 -d " "')
1053
+
1054
+ def terminate(self):
1055
+ self.shell.terminate()
1056
+
1057
+
1058
+ def count_ordinal(n: int) -> str:
1059
+ """
1060
+ Convert an integer into its ordinal representation::
1061
+ make_ordinal(0) => '0th'
1062
+ make_ordinal(3) => '3rd'
1063
+ make_ordinal(122) => '122nd'
1064
+ make_ordinal(213) => '213th'
1065
+ """
1066
+ n = int(n)
1067
+ suffix = ['th', 'st', 'nd', 'rd', 'th'][min(n % 10, 4)]
1068
+ if 11 <= (n % 100) <= 13:
1069
+ suffix = 'th'
1070
+ return str(n) + suffix