man-spider 1.1.1__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- man_spider/lib/errors.py +8 -4
- man_spider/lib/file.py +16 -19
- man_spider/lib/logger.py +27 -32
- man_spider/lib/parser/__init__.py +1 -1
- man_spider/lib/parser/parser.py +102 -57
- man_spider/lib/processpool.py +24 -31
- man_spider/lib/smb.py +71 -63
- man_spider/lib/spider.py +69 -70
- man_spider/lib/spiderling.py +188 -141
- man_spider/lib/util.py +95 -29
- man_spider/manspider.py +170 -55
- {man_spider-1.1.1.dist-info → man_spider-2.0.0.dist-info}/METADATA +101 -44
- man_spider-2.0.0.dist-info/RECORD +18 -0
- {man_spider-1.1.1.dist-info → man_spider-2.0.0.dist-info}/WHEEL +1 -1
- man_spider-2.0.0.dist-info/entry_points.txt +2 -0
- man_spider-1.1.1.dist-info/RECORD +0 -18
- man_spider-1.1.1.dist-info/entry_points.txt +0 -3
- {man_spider-1.1.1.dist-info → man_spider-2.0.0.dist-info/licenses}/LICENSE +0 -0
man_spider/lib/errors.py
CHANGED
|
@@ -5,21 +5,25 @@ from impacket.smb import SessionError, UnsupportedFeature
|
|
|
5
5
|
from impacket.smbconnection import SessionError as CSessionError
|
|
6
6
|
|
|
7
7
|
# set up logging
|
|
8
|
-
log = logging.getLogger(
|
|
8
|
+
log = logging.getLogger("manspider")
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class MANSPIDERError(Exception):
|
|
12
12
|
pass
|
|
13
13
|
|
|
14
|
+
|
|
14
15
|
class FileRetrievalError(MANSPIDERError):
|
|
15
16
|
pass
|
|
16
17
|
|
|
18
|
+
|
|
17
19
|
class ShareListError(MANSPIDERError):
|
|
18
20
|
pass
|
|
19
21
|
|
|
22
|
+
|
|
20
23
|
class FileListError(MANSPIDERError):
|
|
21
24
|
pass
|
|
22
25
|
|
|
26
|
+
|
|
23
27
|
class LogonFailure(MANSPIDERError):
|
|
24
28
|
pass
|
|
25
29
|
|
|
@@ -41,9 +45,9 @@ impacket_errors = (
|
|
|
41
45
|
|
|
42
46
|
|
|
43
47
|
def impacket_error(e):
|
|
44
|
-
|
|
48
|
+
"""
|
|
45
49
|
Tries to format impacket exceptions nicely
|
|
46
|
-
|
|
50
|
+
"""
|
|
47
51
|
if type(e) in (SessionError, CSessionError):
|
|
48
52
|
try:
|
|
49
53
|
error_str = e.getErrorString()[0]
|
|
@@ -51,5 +55,5 @@ def impacket_error(e):
|
|
|
51
55
|
except (IndexError,):
|
|
52
56
|
pass
|
|
53
57
|
if not e.args:
|
|
54
|
-
e.args = (
|
|
58
|
+
e.args = ("",)
|
|
55
59
|
return e
|
man_spider/lib/file.py
CHANGED
|
@@ -1,14 +1,14 @@
|
|
|
1
|
-
import io
|
|
2
|
-
from .util import *
|
|
3
|
-
from .errors import *
|
|
4
1
|
from pathlib import Path
|
|
5
2
|
|
|
3
|
+
from man_spider.lib.util import *
|
|
4
|
+
from man_spider.lib.errors import *
|
|
6
5
|
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
6
|
+
|
|
7
|
+
class RemoteFile:
|
|
8
|
+
"""
|
|
9
|
+
Represents a file on an SMB share
|
|
10
|
+
Passed from a spiderling up to its parent spide
|
|
11
|
+
r"""
|
|
12
12
|
|
|
13
13
|
def __init__(self, name, share, target, size=0):
|
|
14
14
|
|
|
@@ -19,23 +19,21 @@ r '''
|
|
|
19
19
|
self.smb_client = None
|
|
20
20
|
|
|
21
21
|
file_suffix = Path(name).suffix.lower()
|
|
22
|
-
self.tmp_filename = Path(
|
|
23
|
-
|
|
22
|
+
self.tmp_filename = Path("/tmp/.manspider") / (random_string(15) + file_suffix)
|
|
24
23
|
|
|
25
24
|
def get(self, smb_client=None):
|
|
26
|
-
|
|
25
|
+
"""
|
|
27
26
|
Downloads file to self.tmp_filename
|
|
28
27
|
|
|
29
28
|
NOTE: SMBConnection() can't be passed through a multiprocessing queue
|
|
30
29
|
This means that smb_client must be set after the file arrives at Spider()
|
|
31
|
-
|
|
30
|
+
"""
|
|
32
31
|
|
|
33
32
|
if smb_client is None and self.smb_client is None:
|
|
34
|
-
raise FileRetrievalError(
|
|
35
|
-
|
|
36
|
-
#memfile = io.BytesIO()
|
|
37
|
-
with open(str(self.tmp_filename), 'wb') as f:
|
|
33
|
+
raise FileRetrievalError("Please specify smb_client")
|
|
38
34
|
|
|
35
|
+
# memfile = io.BytesIO()
|
|
36
|
+
with open(str(self.tmp_filename), "wb") as f:
|
|
39
37
|
try:
|
|
40
38
|
smb_client.conn.getFile(self.share, self.name, f.write)
|
|
41
39
|
except Exception as e:
|
|
@@ -43,9 +41,8 @@ r '''
|
|
|
43
41
|
raise FileRetrievalError(f'Error retrieving file "{str(self)}": {str(e)[:150]}')
|
|
44
42
|
|
|
45
43
|
# reset cursor back to zero so .read() will return the whole file
|
|
46
|
-
#memfile.seek(0)
|
|
47
|
-
|
|
44
|
+
# memfile.seek(0)
|
|
48
45
|
|
|
49
46
|
def __str__(self):
|
|
50
47
|
|
|
51
|
-
return f
|
|
48
|
+
return f"{self.target}\\{self.share}\\{self.name}"
|
man_spider/lib/logger.py
CHANGED
|
@@ -11,67 +11,62 @@ from logging.handlers import QueueHandler, QueueListener
|
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class ColoredFormatter(logging.Formatter):
|
|
14
|
-
|
|
15
14
|
color_mapping = {
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
15
|
+
"DEBUG": 69, # blue
|
|
16
|
+
"INFO": 118, # green
|
|
17
|
+
"WARNING": 208, # orange
|
|
18
|
+
"ERROR": 196, # red
|
|
19
|
+
"CRITICAL": 196, # red
|
|
21
20
|
}
|
|
22
21
|
|
|
23
22
|
char_mapping = {
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
23
|
+
"DEBUG": "*",
|
|
24
|
+
"INFO": "+",
|
|
25
|
+
"WARNING": "-",
|
|
26
|
+
"ERROR": "!",
|
|
27
|
+
"CRITICAL": "!!!",
|
|
29
28
|
}
|
|
30
29
|
|
|
31
|
-
prefix =
|
|
32
|
-
suffix =
|
|
30
|
+
prefix = "\033[1;38;5;"
|
|
31
|
+
suffix = "\033[0m"
|
|
33
32
|
|
|
34
33
|
def __init__(self, pattern):
|
|
35
34
|
|
|
36
35
|
super().__init__(pattern)
|
|
37
36
|
|
|
38
|
-
|
|
39
37
|
def format(self, record):
|
|
40
38
|
|
|
41
39
|
colored_record = copy(record)
|
|
42
40
|
levelname = colored_record.levelname
|
|
43
|
-
levelchar = self.char_mapping.get(levelname,
|
|
44
|
-
seq = self.color_mapping.get(levelname, 15)
|
|
45
|
-
colored_levelname = f
|
|
41
|
+
levelchar = self.char_mapping.get(levelname, "+")
|
|
42
|
+
seq = self.color_mapping.get(levelname, 15) # default white
|
|
43
|
+
colored_levelname = f"{self.prefix}{seq}m[{levelchar}]{self.suffix}"
|
|
46
44
|
colored_record.levelname = colored_levelname
|
|
47
45
|
|
|
48
46
|
return logging.Formatter.format(self, colored_record)
|
|
49
47
|
|
|
50
|
-
|
|
51
48
|
@classmethod
|
|
52
49
|
def green(cls, s):
|
|
53
50
|
|
|
54
51
|
return cls.color(s)
|
|
55
52
|
|
|
56
|
-
|
|
57
53
|
@classmethod
|
|
58
54
|
def red(cls, s):
|
|
59
55
|
|
|
60
|
-
return cls.color(s, level=
|
|
61
|
-
|
|
56
|
+
return cls.color(s, level="ERROR")
|
|
62
57
|
|
|
63
58
|
@classmethod
|
|
64
|
-
def color(cls, s, level=
|
|
59
|
+
def color(cls, s, level="INFO"):
|
|
65
60
|
|
|
66
61
|
color = cls.color_mapping.get(level)
|
|
67
|
-
return f
|
|
68
|
-
|
|
62
|
+
return f"{cls.prefix}{color}m{s}{cls.suffix}"
|
|
69
63
|
|
|
70
64
|
|
|
71
65
|
class CustomQueueListener(QueueListener):
|
|
72
|
-
|
|
66
|
+
"""
|
|
73
67
|
Ignore errors in the monitor thread that result from a race condition when the program exits
|
|
74
|
-
|
|
68
|
+
"""
|
|
69
|
+
|
|
75
70
|
def _monitor(self):
|
|
76
71
|
try:
|
|
77
72
|
super()._monitor()
|
|
@@ -83,18 +78,18 @@ class CustomQueueListener(QueueListener):
|
|
|
83
78
|
|
|
84
79
|
console = logging.StreamHandler(stdout)
|
|
85
80
|
# tell the handler to use this format
|
|
86
|
-
console.setFormatter(ColoredFormatter(
|
|
81
|
+
console.setFormatter(ColoredFormatter("%(levelname)s %(message)s"))
|
|
87
82
|
|
|
88
83
|
### LOG TO FILE ###
|
|
89
84
|
|
|
90
85
|
log_queue = Queue()
|
|
91
86
|
listener = CustomQueueListener(log_queue, console)
|
|
92
87
|
sender = QueueHandler(log_queue)
|
|
93
|
-
logging.getLogger(
|
|
88
|
+
logging.getLogger("manspider").handlers = [sender]
|
|
94
89
|
|
|
95
|
-
logdir = Path.home() /
|
|
90
|
+
logdir = Path.home() / ".manspider" / "logs"
|
|
96
91
|
logdir.mkdir(parents=True, exist_ok=True)
|
|
97
|
-
logfile = f
|
|
92
|
+
logfile = f"manspider_{datetime.now().strftime('%m-%d-%Y')}.log"
|
|
98
93
|
handler = logging.FileHandler(str(logdir / logfile))
|
|
99
|
-
handler.setFormatter(logging.Formatter(
|
|
100
|
-
logging.getLogger(
|
|
94
|
+
handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s"))
|
|
95
|
+
logging.getLogger("manspider").addHandler(handler)
|
|
@@ -1 +1 @@
|
|
|
1
|
-
from .parser import *
|
|
1
|
+
from .parser import *
|
man_spider/lib/parser/parser.py
CHANGED
|
@@ -1,39 +1,80 @@
|
|
|
1
1
|
import re
|
|
2
|
-
import magic
|
|
3
2
|
import logging
|
|
4
3
|
from time import sleep
|
|
5
4
|
import subprocess as sp
|
|
6
|
-
from
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from kreuzberg import extract_file_sync
|
|
7
|
+
from charset_normalizer import from_path
|
|
7
8
|
|
|
8
9
|
from man_spider.lib.util import *
|
|
9
10
|
from man_spider.lib.logger import *
|
|
10
11
|
|
|
11
|
-
log = logging.getLogger(
|
|
12
|
+
log = logging.getLogger("manspider.parser")
|
|
12
13
|
|
|
13
14
|
|
|
14
|
-
|
|
15
|
+
def is_text_file(filepath):
|
|
16
|
+
"""Detect if file is plain text using charset-normalizer."""
|
|
17
|
+
result = from_path(filepath)
|
|
18
|
+
best = result.best()
|
|
19
|
+
# Only consider it a text file if we have high confidence
|
|
20
|
+
# and the encoding is detected (not binary)
|
|
21
|
+
return best is not None and best.encoding is not None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def extract_text_file(filepath):
|
|
25
|
+
"""Extract text from plain text file, auto-detecting encoding."""
|
|
26
|
+
result = from_path(filepath)
|
|
27
|
+
best = result.best()
|
|
28
|
+
return str(best) if best else None
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def extract_strings_from_binary(filepath, min_length=4):
|
|
32
|
+
"""
|
|
33
|
+
Extract printable ASCII strings from a binary file.
|
|
34
|
+
Similar to the Unix 'strings' command.
|
|
35
|
+
"""
|
|
36
|
+
import string
|
|
37
|
+
|
|
38
|
+
printable = set(string.printable) - set("\x0b\x0c") # Exclude vertical tab and form feed
|
|
15
39
|
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
# PNG, JPEG, etc.
|
|
19
|
-
# 'image data',
|
|
20
|
-
# ZIP, GZ, etc.
|
|
21
|
-
'archive data',
|
|
22
|
-
# encrypted data
|
|
23
|
-
'encrypted'
|
|
24
|
-
]
|
|
40
|
+
with open(filepath, "rb") as f:
|
|
41
|
+
data = f.read()
|
|
25
42
|
|
|
43
|
+
result = []
|
|
44
|
+
current = []
|
|
45
|
+
for byte in data:
|
|
46
|
+
char = chr(byte) if byte < 128 else None
|
|
47
|
+
if char and char in printable:
|
|
48
|
+
current.append(char)
|
|
49
|
+
else:
|
|
50
|
+
if len(current) >= min_length:
|
|
51
|
+
result.append("".join(current))
|
|
52
|
+
current = []
|
|
53
|
+
if len(current) >= min_length:
|
|
54
|
+
result.append("".join(current))
|
|
55
|
+
|
|
56
|
+
return "\n".join(result)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class FileParser:
|
|
60
|
+
# don't parse files with these extensions
|
|
61
|
+
extension_blacklist = {
|
|
62
|
+
# Archive formats
|
|
63
|
+
'.zip', '.gz', '.tar', '.bz2', '.7z', '.rar', '.xz', '.tgz', '.tbz2',
|
|
64
|
+
# Encrypted/protected formats
|
|
65
|
+
'.enc', '.gpg', '.pgp', '.asc',
|
|
66
|
+
# Compiled/binary formats that are rarely useful to parse
|
|
67
|
+
'.exe', '.dll', '.so', '.dylib',
|
|
68
|
+
}
|
|
26
69
|
|
|
27
70
|
def __init__(self, filters, quiet=False):
|
|
28
71
|
self.init_content_filters(filters)
|
|
29
|
-
self.extractor = Extractor()
|
|
30
72
|
self.quiet = quiet
|
|
31
73
|
|
|
32
|
-
|
|
33
74
|
def init_content_filters(self, file_content):
|
|
34
|
-
|
|
75
|
+
"""
|
|
35
76
|
Get ready to search by file content
|
|
36
|
-
|
|
77
|
+
"""
|
|
37
78
|
|
|
38
79
|
# strings to look for in file content
|
|
39
80
|
# if empty, content is ignored
|
|
@@ -46,41 +87,36 @@ class FileParser:
|
|
|
46
87
|
sleep(1)
|
|
47
88
|
if self.content_filters:
|
|
48
89
|
content_filter_str = '"' + '", "'.join([f.pattern for f in self.content_filters]) + '"'
|
|
49
|
-
log.info(f
|
|
50
|
-
|
|
51
|
-
|
|
90
|
+
log.info(f"Searching by file content: {content_filter_str}")
|
|
52
91
|
|
|
53
92
|
def match(self, file_content):
|
|
54
|
-
|
|
93
|
+
"""
|
|
55
94
|
Finds all regex matches in file content
|
|
56
|
-
|
|
95
|
+
"""
|
|
57
96
|
|
|
58
97
|
for _filter in self.content_filters:
|
|
59
98
|
for match in _filter.finditer(file_content):
|
|
60
99
|
# ( filter, (match_start_index, match_end_index) )
|
|
61
100
|
yield (_filter, match.span())
|
|
62
101
|
|
|
63
|
-
|
|
64
102
|
def match_magic(self, file):
|
|
65
|
-
|
|
103
|
+
"""
|
|
66
104
|
Returns True if the file isn't of a blacklisted file type
|
|
67
|
-
|
|
105
|
+
"""
|
|
106
|
+
file_path = Path(file)
|
|
107
|
+
extension = file_path.suffix.lower()
|
|
68
108
|
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
if keyword.lower() in magic_type:
|
|
73
|
-
log.debug(f'Not parsing {file}: blacklisted magic type: "{keyword}"')
|
|
74
|
-
return False
|
|
109
|
+
if extension in self.extension_blacklist:
|
|
110
|
+
log.debug(f'Not parsing {file}: blacklisted extension: "{extension}"')
|
|
111
|
+
return False
|
|
75
112
|
|
|
76
113
|
return True
|
|
77
114
|
|
|
78
|
-
|
|
79
115
|
def grep(self, content, pattern):
|
|
80
116
|
|
|
81
117
|
if not self.quiet:
|
|
82
118
|
try:
|
|
83
|
-
|
|
119
|
+
"""
|
|
84
120
|
GREP(1)
|
|
85
121
|
-E, --extended-regexp
|
|
86
122
|
Interpret PATTERN as an extended regular expression
|
|
@@ -90,11 +126,9 @@ class FileParser:
|
|
|
90
126
|
Process a binary file as if it were text
|
|
91
127
|
-m NUM, --max-count=NUM
|
|
92
128
|
Stop reading a file after NUM matching lines
|
|
93
|
-
|
|
129
|
+
"""
|
|
94
130
|
grep_process = sp.Popen(
|
|
95
|
-
[
|
|
96
|
-
stdin=sp.PIPE,
|
|
97
|
-
stdout=sp.PIPE
|
|
131
|
+
["grep", "-Eiam", "5", "--color=always", pattern], stdin=sp.PIPE, stdout=sp.PIPE
|
|
98
132
|
)
|
|
99
133
|
grep_output = grep_process.communicate(content)[0]
|
|
100
134
|
for line in grep_output.splitlines():
|
|
@@ -102,52 +136,63 @@ class FileParser:
|
|
|
102
136
|
except (sp.SubprocessError, OSError, IndexError):
|
|
103
137
|
pass
|
|
104
138
|
|
|
105
|
-
|
|
106
139
|
def parse_file(self, file, pretty_filename=None):
|
|
107
|
-
|
|
140
|
+
"""
|
|
108
141
|
Parse a file on the local filesystem
|
|
109
|
-
|
|
142
|
+
"""
|
|
110
143
|
|
|
111
144
|
if pretty_filename is None:
|
|
112
145
|
pretty_filename = str(file)
|
|
113
146
|
|
|
114
|
-
log.debug(f
|
|
147
|
+
log.debug(f"Parsing file: {pretty_filename}")
|
|
115
148
|
|
|
116
149
|
matches = dict()
|
|
117
150
|
|
|
118
151
|
try:
|
|
119
|
-
|
|
120
|
-
matches = self.textract(file, pretty_filename=pretty_filename)
|
|
152
|
+
matches = self.extract_text(file, pretty_filename=pretty_filename)
|
|
121
153
|
|
|
122
154
|
except Exception as e:
|
|
123
|
-
#except (BadZipFile, textract.exceptions.CommandLineError) as e:
|
|
124
155
|
if log.level <= logging.DEBUG:
|
|
125
|
-
log.warning(f
|
|
156
|
+
log.warning(f"Error extracting text from {pretty_filename}: {e}")
|
|
126
157
|
else:
|
|
127
|
-
log.warning(f
|
|
128
|
-
|
|
129
|
-
return matches
|
|
158
|
+
log.warning(f"Error extracting text from {pretty_filename} (-v to debug)")
|
|
130
159
|
|
|
160
|
+
return matches
|
|
131
161
|
|
|
132
|
-
def
|
|
133
|
-
|
|
134
|
-
Extracts text from a file
|
|
135
|
-
Uses
|
|
136
|
-
|
|
162
|
+
def extract_text(self, file, pretty_filename):
|
|
163
|
+
"""
|
|
164
|
+
Extracts text from a file.
|
|
165
|
+
Uses charset-normalizer for plain text files (handles UTF-16, etc.)
|
|
166
|
+
Falls back to kreuzberg for binary formats (docx, pdf, xlsx, etc.)
|
|
167
|
+
"""
|
|
137
168
|
|
|
138
169
|
matches = dict()
|
|
139
170
|
|
|
140
|
-
suffix = Path(str(file)).suffix.lower()
|
|
141
|
-
|
|
142
171
|
# blacklist certain mime types
|
|
143
172
|
if not self.match_magic(file):
|
|
144
173
|
return matches
|
|
145
174
|
|
|
146
|
-
|
|
175
|
+
# Try charset-normalizer first for text files (handles UTF-16, etc.)
|
|
176
|
+
if is_text_file(str(file)):
|
|
177
|
+
text_content = extract_text_file(str(file))
|
|
178
|
+
log.debug(f"Extracted text from {pretty_filename} using charset-normalizer")
|
|
179
|
+
else:
|
|
180
|
+
# Try kreuzberg for document formats (docx, pdf, xlsx, etc.)
|
|
181
|
+
try:
|
|
182
|
+
result = extract_file_sync(str(file))
|
|
183
|
+
text_content = result.content
|
|
184
|
+
except Exception as e:
|
|
185
|
+
# Kreuzberg doesn't support this file type, try extracting raw strings
|
|
186
|
+
log.debug(f"Kreuzberg failed for {pretty_filename}: {e}, trying string extraction")
|
|
187
|
+
text_content = extract_strings_from_binary(str(file))
|
|
188
|
+
|
|
189
|
+
# Guard against None content
|
|
190
|
+
if text_content is None:
|
|
191
|
+
return matches
|
|
147
192
|
|
|
148
193
|
# try to convert to UTF-8 for grep-friendliness
|
|
149
194
|
try:
|
|
150
|
-
binary_content = text_content.encode(
|
|
195
|
+
binary_content = text_content.encode("utf-8", errors="ignore")
|
|
151
196
|
except Exception:
|
|
152
197
|
pass
|
|
153
198
|
|
man_spider/lib/processpool.py
CHANGED
|
@@ -1,28 +1,27 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
import traceback
|
|
3
2
|
from time import sleep
|
|
4
3
|
import multiprocessing as mp
|
|
5
|
-
from queue import Empty
|
|
4
|
+
from queue import Empty
|
|
5
|
+
from traceback import format_exc
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
# set up logging
|
|
9
|
-
log = logging.getLogger(
|
|
10
|
-
|
|
9
|
+
log = logging.getLogger("manspider.processpool")
|
|
11
10
|
|
|
12
11
|
|
|
13
12
|
class ProcessPool:
|
|
14
|
-
|
|
13
|
+
"""
|
|
15
14
|
usage:
|
|
16
15
|
with ProcessPool(2) as pool:
|
|
17
16
|
for i in pool.map(target, iterable):
|
|
18
17
|
yield i
|
|
19
|
-
|
|
18
|
+
"""
|
|
20
19
|
|
|
21
|
-
def __init__(self, processes=None, daemon=False, name=
|
|
20
|
+
def __init__(self, processes=None, daemon=False, name=""):
|
|
22
21
|
|
|
23
|
-
self.name =
|
|
22
|
+
self.name = "ProcessPool"
|
|
24
23
|
if name:
|
|
25
|
-
self.name += f
|
|
24
|
+
self.name += f"-{name}"
|
|
26
25
|
|
|
27
26
|
if processes is None:
|
|
28
27
|
processes = mp.cpu_count()
|
|
@@ -38,16 +37,12 @@ class ProcessPool:
|
|
|
38
37
|
# make the result queue
|
|
39
38
|
self.result_queue = mp.Manager().Queue()
|
|
40
39
|
|
|
41
|
-
|
|
42
40
|
def map(self, func, iterable, args=(), kwargs={}):
|
|
43
41
|
|
|
44
42
|
# loop until we're out of work
|
|
45
43
|
for entry in iterable:
|
|
46
|
-
|
|
47
44
|
try:
|
|
48
|
-
|
|
49
45
|
while 1:
|
|
50
|
-
|
|
51
46
|
for result in self.results:
|
|
52
47
|
yield result
|
|
53
48
|
|
|
@@ -55,11 +50,15 @@ class ProcessPool:
|
|
|
55
50
|
for i in range(len(self.pool)):
|
|
56
51
|
process = self.pool[i]
|
|
57
52
|
if process is None or not process.is_alive():
|
|
58
|
-
self.pool[i] = mp.Process(
|
|
59
|
-
|
|
53
|
+
self.pool[i] = mp.Process(
|
|
54
|
+
target=self.execute,
|
|
55
|
+
args=(func, self.result_queue, (entry,) + args),
|
|
56
|
+
kwargs=kwargs,
|
|
57
|
+
daemon=self.daemon,
|
|
58
|
+
)
|
|
60
59
|
self.pool[i].start()
|
|
61
60
|
self.started_counter += 1
|
|
62
|
-
log.debug(f
|
|
61
|
+
log.debug(f"{self.name}: {self.started_counter:,} processes started")
|
|
63
62
|
# success, move on to next
|
|
64
63
|
assert False
|
|
65
64
|
|
|
@@ -67,26 +66,24 @@ class ProcessPool:
|
|
|
67
66
|
yield result
|
|
68
67
|
|
|
69
68
|
# prevent unnecessary CPU usage
|
|
70
|
-
sleep(.1)
|
|
69
|
+
sleep(0.1)
|
|
71
70
|
|
|
72
71
|
except AssertionError:
|
|
73
72
|
continue
|
|
74
73
|
|
|
75
74
|
# wait for processes to finish
|
|
76
75
|
while 1:
|
|
77
|
-
|
|
78
76
|
finished_threads = [p is None or not p.is_alive() for p in self.pool]
|
|
79
77
|
if all(finished_threads):
|
|
80
78
|
self.finished_counter += len([p for p in self.pool if p is not None and not p.is_alive()])
|
|
81
79
|
break
|
|
82
80
|
else:
|
|
83
|
-
log.debug(f
|
|
81
|
+
log.debug(f"{self.name}: Waiting for {finished_threads.count(False):,} threads to finish")
|
|
84
82
|
sleep(1)
|
|
85
83
|
|
|
86
84
|
for result in self.results:
|
|
87
85
|
yield result
|
|
88
86
|
|
|
89
|
-
|
|
90
87
|
@property
|
|
91
88
|
def results(self):
|
|
92
89
|
|
|
@@ -94,27 +91,25 @@ class ProcessPool:
|
|
|
94
91
|
try:
|
|
95
92
|
result = self.result_queue.get_nowait()
|
|
96
93
|
self.finished_counter += 1
|
|
97
|
-
log.debug(f
|
|
94
|
+
log.debug(f"{self.name}: {self.finished_counter:,} processes finished")
|
|
98
95
|
yield result
|
|
99
96
|
except Empty:
|
|
100
|
-
sleep(.1)
|
|
97
|
+
sleep(0.1)
|
|
101
98
|
break
|
|
102
99
|
|
|
103
|
-
|
|
104
100
|
@staticmethod
|
|
105
101
|
def execute(func, result_queue, args=(), kwargs={}):
|
|
106
|
-
|
|
102
|
+
"""
|
|
107
103
|
Executes given function and places return value in result queue
|
|
108
|
-
|
|
104
|
+
"""
|
|
109
105
|
|
|
110
106
|
try:
|
|
111
107
|
result_queue.put(func(*args, **kwargs))
|
|
112
108
|
except Exception as e:
|
|
113
109
|
if type(e) not in [FileNotFoundError]:
|
|
114
110
|
log.critical(format_exc())
|
|
115
|
-
except KeyboardInterrupt
|
|
116
|
-
log.critical(
|
|
117
|
-
|
|
111
|
+
except KeyboardInterrupt:
|
|
112
|
+
log.critical("ProcessPool Interrupted")
|
|
118
113
|
|
|
119
114
|
@staticmethod
|
|
120
115
|
def _close_queue(q):
|
|
@@ -126,15 +121,13 @@ class ProcessPool:
|
|
|
126
121
|
break
|
|
127
122
|
q.close()
|
|
128
123
|
|
|
129
|
-
|
|
130
124
|
def __enter__(self):
|
|
131
125
|
|
|
132
126
|
return self
|
|
133
127
|
|
|
134
|
-
|
|
135
128
|
def __exit__(self, exception_type, exception_value, traceback):
|
|
136
129
|
|
|
137
130
|
try:
|
|
138
131
|
self._close_queue(self.result_queue)
|
|
139
132
|
except Exception:
|
|
140
|
-
pass
|
|
133
|
+
pass
|