traffic-taffy 0.8.5__py3-none-any.whl → 0.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- traffic_taffy/__init__.py +1 -1
- traffic_taffy/algorithms/__init__.py +14 -7
- traffic_taffy/algorithms/comparecorrelation.py +164 -0
- traffic_taffy/algorithms/comparecorrelationchanges.py +210 -0
- traffic_taffy/algorithms/compareseries.py +117 -0
- traffic_taffy/algorithms/compareslices.py +116 -0
- traffic_taffy/algorithms/statistical.py +9 -9
- traffic_taffy/compare.py +149 -159
- traffic_taffy/comparison.py +18 -4
- traffic_taffy/config.py +133 -0
- traffic_taffy/dissection.py +78 -6
- traffic_taffy/dissectmany.py +26 -16
- traffic_taffy/dissector.py +189 -77
- traffic_taffy/dissector_engine/scapy.py +41 -8
- traffic_taffy/graph.py +54 -53
- traffic_taffy/graphdata.py +13 -2
- traffic_taffy/hooks/ip2asn.py +20 -7
- traffic_taffy/hooks/labels.py +45 -0
- traffic_taffy/hooks/psl.py +21 -3
- traffic_taffy/output/__init__.py +8 -48
- traffic_taffy/output/console.py +37 -25
- traffic_taffy/output/fsdb.py +24 -18
- traffic_taffy/reports/__init__.py +5 -0
- traffic_taffy/reports/compareslicesreport.py +85 -0
- traffic_taffy/reports/correlationchangereport.py +54 -0
- traffic_taffy/reports/correlationreport.py +42 -0
- traffic_taffy/taffy_config.py +44 -0
- traffic_taffy/tests/test_compare_results.py +22 -7
- traffic_taffy/tests/test_config.py +149 -0
- traffic_taffy/tests/test_global_config.py +33 -0
- traffic_taffy/tests/test_normalize.py +1 -0
- traffic_taffy/tests/test_pcap_dissector.py +12 -2
- traffic_taffy/tests/test_pcap_splitter.py +21 -10
- traffic_taffy/tools/cache_info.py +3 -2
- traffic_taffy/tools/compare.py +32 -24
- traffic_taffy/tools/config.py +83 -0
- traffic_taffy/tools/dissect.py +51 -59
- traffic_taffy/tools/explore.py +5 -4
- traffic_taffy/tools/export.py +28 -17
- traffic_taffy/tools/graph.py +25 -27
- {traffic_taffy-0.8.5.dist-info → traffic_taffy-0.9.dist-info}/METADATA +4 -1
- traffic_taffy-0.9.dist-info/RECORD +56 -0
- {traffic_taffy-0.8.5.dist-info → traffic_taffy-0.9.dist-info}/entry_points.txt +1 -0
- traffic_taffy/report.py +0 -12
- traffic_taffy-0.8.5.dist-info/RECORD +0 -43
- {traffic_taffy-0.8.5.dist-info → traffic_taffy-0.9.dist-info}/WHEEL +0 -0
- {traffic_taffy-0.8.5.dist-info → traffic_taffy-0.9.dist-info}/licenses/LICENSE.txt +0 -0
traffic_taffy/compare.py
CHANGED
@@ -1,61 +1,105 @@
|
|
1
1
|
"""The primary statistical packet comparison engine."""
|
2
2
|
|
3
3
|
from __future__ import annotations
|
4
|
-
from
|
5
|
-
from
|
6
|
-
from datetime import datetime
|
7
|
-
import datetime as dt
|
8
|
-
import itertools
|
4
|
+
from typing import List, TYPE_CHECKING, Any
|
5
|
+
from logging import error
|
9
6
|
|
10
7
|
if TYPE_CHECKING:
|
11
|
-
from
|
8
|
+
from traffic_taffy.dissection import Dissection
|
9
|
+
from traffic_taffy.comparison import Comparison
|
10
|
+
from argparse_with_config import ArgumentParserWithConfig
|
12
11
|
|
13
|
-
from traffic_taffy.comparison import Comparison
|
14
12
|
from traffic_taffy.dissectmany import PCAPDissectMany
|
15
|
-
from traffic_taffy.dissector import PCAPDissectorLevel
|
16
|
-
from traffic_taffy.dissection import Dissection
|
17
13
|
from traffic_taffy.algorithms.statistical import ComparisonStatistical
|
14
|
+
from traffic_taffy.algorithms.comparecorrelation import CompareCorrelation
|
15
|
+
from traffic_taffy.algorithms.comparecorrelationchanges import CompareCorrelationChanges
|
16
|
+
from traffic_taffy.taffy_config import TaffyConfig, taffy_default
|
17
|
+
from traffic_taffy.dissector import TTD_CFG, TTL_CFG
|
18
|
+
|
19
|
+
|
20
|
+
class TTC_CFG:
|
21
|
+
KEY_COMPARE: str = "compare"
|
22
|
+
ONLY_POSITIVE: str = "only_positive"
|
23
|
+
ONLY_NEGATIVE: str = "only_negative"
|
24
|
+
PRINT_THRESHOLD: str = "print_threshold"
|
25
|
+
TOP_RECORDS: str = "top_records"
|
26
|
+
REVERSE_SORT: str = "reverse_sort"
|
27
|
+
SORT_BY: str = "sort_by"
|
28
|
+
ALGORITHM: str = "algorithm"
|
29
|
+
|
30
|
+
|
31
|
+
def compare_default(name: str, value: Any) -> None:
|
32
|
+
taffy_default(TTC_CFG.KEY_COMPARE + "." + name, value)
|
33
|
+
|
34
|
+
|
35
|
+
compare_default(TTC_CFG.ONLY_POSITIVE, False)
|
36
|
+
compare_default(TTC_CFG.ONLY_NEGATIVE, False)
|
37
|
+
compare_default(TTC_CFG.PRINT_THRESHOLD, 0.0)
|
38
|
+
compare_default(TTC_CFG.TOP_RECORDS, None)
|
39
|
+
compare_default(TTC_CFG.REVERSE_SORT, False)
|
40
|
+
compare_default(TTC_CFG.SORT_BY, "delta%")
|
41
|
+
compare_default(TTC_CFG.ALGORITHM, "statistical")
|
42
|
+
compare_default(TTC_CFG.PRINT_THRESHOLD, 0.0)
|
18
43
|
|
19
44
|
|
20
45
|
class PcapCompare:
|
21
46
|
"""Take a set of PCAPs to then perform various comparisons upon."""
|
22
47
|
|
23
|
-
REPORT_VERSION: int = 2
|
24
|
-
|
25
48
|
def __init__(
|
26
49
|
self,
|
27
50
|
pcap_files: List[str],
|
28
|
-
|
29
|
-
deep: bool = True,
|
30
|
-
pcap_filter: str | None = None,
|
31
|
-
cache_results: bool = False,
|
32
|
-
cache_file_suffix: str = "taffy",
|
33
|
-
bin_size: int | None = None,
|
34
|
-
dissection_level: PCAPDissectorLevel = PCAPDissectorLevel.COUNT_ONLY,
|
35
|
-
between_times: List[int] | None = None,
|
36
|
-
ignore_list: List[str] | None = None,
|
37
|
-
layers: List[str] | None = None,
|
38
|
-
force_load: bool = False,
|
39
|
-
force_overwrite: bool = False,
|
40
|
-
merge_files: bool = False,
|
51
|
+
config: TaffyConfig | None = None,
|
41
52
|
) -> None:
|
42
53
|
"""Create a compare object."""
|
43
|
-
self.
|
44
|
-
self.
|
45
|
-
self.
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
self.
|
51
|
-
self.
|
52
|
-
self.
|
53
|
-
self.
|
54
|
-
self.
|
55
|
-
self.
|
56
|
-
self.
|
57
|
-
|
58
|
-
self.
|
54
|
+
self.config = config
|
55
|
+
self._pcap_files = pcap_files
|
56
|
+
if not self.config:
|
57
|
+
config = TaffyConfig()
|
58
|
+
|
59
|
+
dissector_config = config[TTD_CFG.KEY_DISSECTOR]
|
60
|
+
|
61
|
+
self.maximum_count = dissector_config[TTD_CFG.PACKET_COUNT]
|
62
|
+
self.pcap_filter = dissector_config[TTD_CFG.CACHE_PCAP_RESULTS]
|
63
|
+
self.dissection_level = dissector_config[TTD_CFG.DISSECTION_LEVEL]
|
64
|
+
# self.between_times = config[TTC_CFG.BETWEEN_TIMES]
|
65
|
+
self.bin_size = dissector_config[TTD_CFG.BIN_SIZE]
|
66
|
+
self.cache_file_suffix = dissector_config[TTD_CFG.CACHE_FILE_SUFFIX]
|
67
|
+
if self.cache_file_suffix[0] != ".":
|
68
|
+
self.cache_file_suffix = "." + self.cache_file_suffix
|
69
|
+
self.ignore_list = dissector_config[TTD_CFG.IGNORE_LIST]
|
70
|
+
self.layers = dissector_config[TTD_CFG.LAYERS]
|
71
|
+
self.force_overwrite = dissector_config[TTD_CFG.FORCE_OVERWRITE]
|
72
|
+
self.force_load = dissector_config[TTD_CFG.FORCE_LOAD]
|
73
|
+
self.filter_arguments = dissector_config[TTD_CFG.FILTER_ARGUMENTS]
|
74
|
+
self.merge_files = dissector_config[TTD_CFG.MERGE]
|
75
|
+
|
76
|
+
compare_config = config[TTC_CFG.KEY_COMPARE]
|
77
|
+
algorithm = compare_config[TTC_CFG.ALGORITHM]
|
78
|
+
|
79
|
+
algorithm_arguments = {
|
80
|
+
"timestamps": None,
|
81
|
+
"match_string": self.filter_arguments["match_string"],
|
82
|
+
"match_value": self.filter_arguments["match_value"],
|
83
|
+
"minimum_count": self.filter_arguments["minimum_count"],
|
84
|
+
"make_printable": True,
|
85
|
+
"match_expression": self.filter_arguments["match_expression"],
|
86
|
+
}
|
87
|
+
|
88
|
+
if algorithm == "statistical":
|
89
|
+
self.algorithm = ComparisonStatistical(
|
90
|
+
**algorithm_arguments,
|
91
|
+
)
|
92
|
+
elif algorithm == "correlation":
|
93
|
+
self.algorithm = CompareCorrelation(
|
94
|
+
**algorithm_arguments,
|
95
|
+
)
|
96
|
+
elif algorithm == "correlationchanges":
|
97
|
+
self.algorithm = CompareCorrelationChanges(
|
98
|
+
**algorithm_arguments,
|
99
|
+
)
|
100
|
+
else:
|
101
|
+
error(f"unknown algorithm: {algorithm}")
|
102
|
+
raise ValueError()
|
59
103
|
|
60
104
|
@property
|
61
105
|
def pcap_files(self) -> List[str]:
|
@@ -75,145 +119,76 @@ class PcapCompare:
|
|
75
119
|
def reports(self, newvalue: List[dict]) -> None:
|
76
120
|
self._reports = newvalue
|
77
121
|
|
78
|
-
def load_pcaps(self) -> None:
|
122
|
+
def load_pcaps(self, config: TaffyConfig) -> None:
|
79
123
|
"""Load all pcaps into memory and dissect them."""
|
80
|
-
# load the first as a reference
|
124
|
+
# load the first as a reference pap
|
81
125
|
pdm = PCAPDissectMany(
|
82
126
|
self.pcap_files,
|
83
|
-
|
84
|
-
maximum_count=self.maximum_count,
|
85
|
-
pcap_filter=self.pcap_filter,
|
86
|
-
cache_results=self.cache_results,
|
87
|
-
cache_file_suffix=self.cache_file_suffix,
|
88
|
-
dissector_level=self.dissection_level,
|
89
|
-
ignore_list=self.ignore_list,
|
90
|
-
layers=self.layers,
|
91
|
-
force_load=self.force_load,
|
92
|
-
force_overwrite=self.force_overwrite,
|
93
|
-
merge_files=self.merge_files,
|
127
|
+
config,
|
94
128
|
)
|
95
129
|
return pdm.load_all()
|
96
130
|
|
97
131
|
def compare(self) -> List[Comparison]:
|
98
132
|
"""Compare each pcap as requested."""
|
99
|
-
dissections = self.load_pcaps()
|
133
|
+
dissections = self.load_pcaps(self.config)
|
100
134
|
self.compare_all(dissections)
|
101
135
|
return self.reports
|
102
136
|
|
103
137
|
def compare_all(self, dissections: List[Dissection]) -> List[Comparison]:
|
104
138
|
"""Compare all loaded pcaps."""
|
105
|
-
reports = []
|
106
|
-
|
107
|
-
# hack to figure out if there is at least two instances of a generator
|
108
|
-
# without actually extracting them all
|
109
|
-
# (since it could be memory expensive)
|
110
|
-
reference = next(dissections)
|
111
|
-
other = None
|
112
|
-
multiple = True
|
113
|
-
try:
|
114
|
-
other = next(dissections)
|
115
|
-
dissections = itertools.chain([other], dissections)
|
116
|
-
except Exception as e:
|
117
|
-
print(e)
|
118
|
-
multiple = False
|
119
|
-
|
120
|
-
if multiple:
|
121
|
-
# multiple file comparison
|
122
|
-
for other in dissections:
|
123
|
-
# compare the two global summaries
|
124
|
-
|
125
|
-
report = self.algorithm.compare_dissections(
|
126
|
-
reference.data[0], other.data[0]
|
127
|
-
)
|
128
|
-
report.title = f"{reference.pcap_file} vs {other.pcap_file}"
|
129
|
-
|
130
|
-
reports.append(report)
|
131
|
-
else:
|
132
|
-
# deal with timestamps within a single file
|
133
|
-
reference = reference.data
|
134
|
-
timestamps = list(reference.keys())
|
135
|
-
if len(timestamps) <= 2: # just 0-summary plus a single stamp
|
136
|
-
error(
|
137
|
-
"the requested pcap data was not long enough to compare against itself"
|
138
|
-
)
|
139
|
-
errorstr: str = "not large enough pcap file"
|
140
|
-
raise ValueError(errorstr)
|
141
|
-
debug(
|
142
|
-
f"found {len(timestamps)} timestamps from {timestamps[2]} to {timestamps[-1]}"
|
143
|
-
)
|
144
|
-
|
145
|
-
for timestamp in range(
|
146
|
-
2, len(timestamps)
|
147
|
-
): # second real non-zero timestamp to last
|
148
|
-
time_left = timestamps[timestamp - 1]
|
149
|
-
time_right = timestamps[timestamp]
|
150
|
-
|
151
|
-
# see if we were asked to only use particular time ranges
|
152
|
-
if self.between_times and (
|
153
|
-
time_left < self.between_times[0]
|
154
|
-
or time_right > self.between_times[1]
|
155
|
-
):
|
156
|
-
# debug(f"skipping timestamps {time_left} and {time_right}")
|
157
|
-
continue
|
158
|
-
|
159
|
-
debug(f"comparing timestamps {time_left} and {time_right}")
|
160
139
|
|
161
|
-
|
162
|
-
|
163
|
-
reference[time_right],
|
164
|
-
)
|
165
|
-
|
166
|
-
title_left = datetime.fromtimestamp(time_left, dt.UTC).strftime(
|
167
|
-
"%Y-%m-%d %H:%M:%S"
|
168
|
-
)
|
169
|
-
title_right = datetime.fromtimestamp(time_right, dt.UTC).strftime(
|
170
|
-
"%Y-%m-%d %H:%M:%S"
|
171
|
-
)
|
172
|
-
|
173
|
-
report.title = f"time {title_left} vs time {title_right}"
|
174
|
-
reports.append(report)
|
175
|
-
|
176
|
-
continue
|
177
|
-
|
178
|
-
# takes way too much memory to do it "right"
|
179
|
-
# reports.append(
|
180
|
-
# {
|
181
|
-
# "report": report,
|
182
|
-
# "title": f"time {time_left} vs time {time_right}",
|
183
|
-
# }
|
184
|
-
# )
|
185
|
-
|
186
|
-
self.reports = reports
|
187
|
-
return reports
|
140
|
+
self.reports = self.algorithm.compare_dissections(dissections)
|
141
|
+
return self.reports
|
188
142
|
|
189
143
|
|
190
144
|
def compare_add_parseargs(
|
191
|
-
compare_parser:
|
192
|
-
|
145
|
+
compare_parser: ArgumentParserWithConfig,
|
146
|
+
config: TaffyConfig | None = None,
|
147
|
+
add_subgroup: bool = True,
|
148
|
+
) -> ArgumentParserWithConfig:
|
193
149
|
"""Add common comparison arguments."""
|
150
|
+
|
151
|
+
if not config:
|
152
|
+
config = TaffyConfig()
|
153
|
+
compare_config = config[TTC_CFG.KEY_COMPARE]
|
154
|
+
|
194
155
|
if add_subgroup:
|
195
|
-
compare_parser = compare_parser.add_argument_group(
|
156
|
+
compare_parser = compare_parser.add_argument_group(
|
157
|
+
"Comparison result options", config_path=TTC_CFG.KEY_COMPARE
|
158
|
+
)
|
196
159
|
|
197
160
|
compare_parser.add_argument(
|
198
161
|
"-t",
|
199
162
|
"--print-threshold",
|
200
|
-
default=
|
163
|
+
default=compare_config[TTC_CFG.PRINT_THRESHOLD],
|
164
|
+
config_path=TTC_CFG.PRINT_THRESHOLD,
|
201
165
|
type=float,
|
202
166
|
help="Don't print results with abs(percent) less than this threshold",
|
203
167
|
)
|
204
168
|
|
205
169
|
compare_parser.add_argument(
|
206
|
-
"-P",
|
170
|
+
"-P",
|
171
|
+
"--only-positive",
|
172
|
+
action="store_true",
|
173
|
+
help="Only show positive entries",
|
174
|
+
default=compare_config[TTC_CFG.ONLY_POSITIVE],
|
175
|
+
config_path=TTC_CFG.ONLY_POSITIVE,
|
207
176
|
)
|
208
177
|
|
209
178
|
compare_parser.add_argument(
|
210
|
-
"-N",
|
179
|
+
"-N",
|
180
|
+
"--only-negative",
|
181
|
+
action="store_true",
|
182
|
+
help="Only show negative entries",
|
183
|
+
default=compare_config[TTC_CFG.ONLY_NEGATIVE],
|
184
|
+
config_path=TTC_CFG.ONLY_NEGATIVE,
|
211
185
|
)
|
212
186
|
|
213
187
|
compare_parser.add_argument(
|
214
188
|
"-R",
|
215
189
|
"--top-records",
|
216
|
-
default=
|
190
|
+
default=compare_config[TTC_CFG.TOP_RECORDS],
|
191
|
+
config_path=TTC_CFG.TOP_RECORDS,
|
217
192
|
type=int,
|
218
193
|
help="Show the top N records from each section.",
|
219
194
|
)
|
@@ -222,39 +197,54 @@ def compare_add_parseargs(
|
|
222
197
|
"-r",
|
223
198
|
"--reverse_sort",
|
224
199
|
action="store_true",
|
200
|
+
default=compare_config[TTC_CFG.REVERSE_SORT],
|
201
|
+
config_path=TTC_CFG.REVERSE_SORT,
|
225
202
|
help="Reverse the sort order of reports",
|
226
203
|
)
|
227
204
|
|
228
205
|
compare_parser.add_argument(
|
229
206
|
"-s",
|
230
207
|
"--sort-by",
|
231
|
-
default=
|
208
|
+
default=compare_config[TTC_CFG.SORT_BY],
|
209
|
+
config_path=TTC_CFG.SORT_BY,
|
232
210
|
type=str,
|
233
211
|
help="Sort report entries by this column",
|
234
212
|
)
|
235
213
|
|
214
|
+
compare_parser.add_argument(
|
215
|
+
"-A",
|
216
|
+
"--algorithm",
|
217
|
+
default=compare_config[TTC_CFG.ALGORITHM],
|
218
|
+
config_path=TTC_CFG.ALGORITHM,
|
219
|
+
type=str,
|
220
|
+
help="The algorithm to apply for data comparison (statistical, correlation)",
|
221
|
+
)
|
222
|
+
|
236
223
|
# compare_parser.add_argument(
|
237
224
|
# "-T",
|
238
225
|
# "--between-times",
|
239
|
-
# nargs=2,
|
240
|
-
# type=int,
|
241
|
-
# help="For single files, only display results between these timestamps",
|
242
|
-
# )
|
243
226
|
|
244
227
|
return compare_parser
|
245
228
|
|
246
229
|
|
247
|
-
def get_comparison_args(
|
230
|
+
def get_comparison_args(config: dict) -> dict:
|
248
231
|
"""Return a dict of comparison parameters from arguments."""
|
232
|
+
dissect_config = config[TTD_CFG.KEY_DISSECTOR]
|
233
|
+
compare_config = config[TTC_CFG.KEY_COMPARE]
|
234
|
+
limitor_config = config[TTL_CFG.KEY_LIMITOR]
|
235
|
+
|
249
236
|
return {
|
250
|
-
"maximum_count":
|
251
|
-
"
|
252
|
-
"
|
253
|
-
"
|
254
|
-
"
|
255
|
-
"
|
256
|
-
"
|
257
|
-
"
|
258
|
-
"
|
259
|
-
"
|
237
|
+
"maximum_count": dissect_config[TTD_CFG.PACKET_COUNT] or 0,
|
238
|
+
"match_string": limitor_config[TTL_CFG.MATCH_STRING],
|
239
|
+
"match_value": limitor_config[TTL_CFG.MATCH_VALUE],
|
240
|
+
"match_expression": limitor_config[TTL_CFG.MATCH_EXPRESSION],
|
241
|
+
"minimum_count": limitor_config[TTL_CFG.MINIMUM_COUNT],
|
242
|
+
"print_threshold": float(compare_config[TTC_CFG.PRINT_THRESHOLD]) / 100.0,
|
243
|
+
"only_positive": compare_config[TTC_CFG.ONLY_POSITIVE],
|
244
|
+
"only_negative": compare_config[TTC_CFG.ONLY_NEGATIVE],
|
245
|
+
"top_records": compare_config[TTC_CFG.TOP_RECORDS],
|
246
|
+
"reverse_sort": compare_config[TTC_CFG.REVERSE_SORT],
|
247
|
+
"sort_by": compare_config[TTC_CFG.SORT_BY],
|
248
|
+
"merge_files": dissect_config[TTD_CFG.MERGE],
|
249
|
+
"algorithm": compare_config[TTC_CFG.ALGORITHM],
|
260
250
|
}
|
traffic_taffy/comparison.py
CHANGED
@@ -3,15 +3,29 @@
|
|
3
3
|
from __future__ import annotations
|
4
4
|
from typing import Dict, Any
|
5
5
|
|
6
|
+
from traffic_taffy.reports import Report
|
7
|
+
|
8
|
+
# Organized reports are dicts containing a primary key that is being
|
9
|
+
# compared to (left hand side), and a secondary key that is the right
|
10
|
+
# hand thing being compared. Each key/subkey combination should point
|
11
|
+
# to a Report containing the results of that comparison.
|
12
|
+
OrganizedReports = Dict[str, Dict[Any, Report]]
|
13
|
+
|
6
14
|
|
7
15
|
class Comparison:
|
8
16
|
"""A simple data storage class to hold comparison data."""
|
9
17
|
|
10
|
-
def __init__(
|
18
|
+
def __init__(
|
19
|
+
self,
|
20
|
+
contents: OrganizedReports,
|
21
|
+
title: str = "",
|
22
|
+
sort_by: str = "delta_percentage",
|
23
|
+
):
|
11
24
|
"""Create a Comparison class from contents."""
|
12
|
-
self.contents = contents
|
25
|
+
self.contents: OrganizedReports = contents
|
13
26
|
self.title: str = title
|
14
27
|
self.printing_arguments: Dict[str, Any] = {}
|
28
|
+
self.sort_by = sort_by
|
15
29
|
|
16
30
|
# title
|
17
31
|
@property
|
@@ -25,10 +39,10 @@ class Comparison:
|
|
25
39
|
|
26
40
|
# report contents -- actual data
|
27
41
|
@property
|
28
|
-
def contents(self) ->
|
42
|
+
def contents(self) -> OrganizedReports:
|
29
43
|
"""The contents of this comparison."""
|
30
44
|
return self._contents
|
31
45
|
|
32
46
|
@contents.setter
|
33
|
-
def contents(self, new_contents:
|
47
|
+
def contents(self, new_contents: OrganizedReports) -> None:
|
34
48
|
self._contents = new_contents
|
traffic_taffy/config.py
ADDED
@@ -0,0 +1,133 @@
|
|
1
|
+
"""A helper class to store a generic set of configuration as a dict."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
from enum import Enum
|
5
|
+
from typing import TextIO, Dict, List, Any
|
6
|
+
from pathlib import Path
|
7
|
+
from logging import error
|
8
|
+
from argparse import Namespace
|
9
|
+
from dotnest import DotNest
|
10
|
+
|
11
|
+
|
12
|
+
class ConfigStyles(Enum):
|
13
|
+
"""A set of configuration types."""
|
14
|
+
|
15
|
+
YAML = "yaml"
|
16
|
+
TOML = "toml"
|
17
|
+
# TODO(hardaker): support "any" at some point to determine type at run-time
|
18
|
+
|
19
|
+
|
20
|
+
class Config(dict):
|
21
|
+
"""A generic configuration storage class."""
|
22
|
+
|
23
|
+
def __init__(self, *args, **kwargs):
|
24
|
+
"""Create an configuration object to store collected data in."""
|
25
|
+
self._config_option_names = ["--config"]
|
26
|
+
self.dotnest = DotNest(self, allow_creation=True)
|
27
|
+
super().__init__(*args, **kwargs)
|
28
|
+
|
29
|
+
@property
|
30
|
+
def config_option_names(self) -> List[str]:
|
31
|
+
"""The list of configuration file arguments to use/look for."""
|
32
|
+
return self._config_option_names
|
33
|
+
|
34
|
+
@config_option_names.setter
|
35
|
+
def config_option_names(self, newlist: str | List[str]) -> None:
|
36
|
+
if isinstance(newlist, str):
|
37
|
+
newlist = [newlist]
|
38
|
+
|
39
|
+
self._config_option_names = newlist
|
40
|
+
|
41
|
+
def deep_update(self, ref: dict, new_content: dict):
|
42
|
+
for key in new_content:
|
43
|
+
if key in ref and isinstance(ref[key], dict):
|
44
|
+
self.deep_update(ref[key], new_content[key])
|
45
|
+
else:
|
46
|
+
ref[key] = new_content[key]
|
47
|
+
|
48
|
+
def load_stream(
|
49
|
+
self, config_handle: TextIO, style: ConfigStyles = ConfigStyles.YAML
|
50
|
+
) -> None:
|
51
|
+
"""Import a set of configuration from a IO stream."""
|
52
|
+
if style == ConfigStyles.YAML:
|
53
|
+
import yaml
|
54
|
+
|
55
|
+
contents = yaml.safe_load(config_handle)
|
56
|
+
|
57
|
+
# TODO(hardaker): support TOML
|
58
|
+
self.deep_update(self, contents)
|
59
|
+
|
60
|
+
def load_file(
|
61
|
+
self, config_file: str, style: ConfigStyles = ConfigStyles.YAML
|
62
|
+
) -> None:
|
63
|
+
"""Load a configuration file from a filename."""
|
64
|
+
self.load_stream(Path.open(config_file), style=style)
|
65
|
+
|
66
|
+
def load_namespace(
|
67
|
+
self, namespace: Namespace, mapping: Dict[str, Any] | None = None
|
68
|
+
) -> None:
|
69
|
+
"""Load the contents of an argparse Namespace into configuration."""
|
70
|
+
values = vars(namespace)
|
71
|
+
if mapping:
|
72
|
+
values = {mapping.get(key, key): value for key, value in values.items()}
|
73
|
+
self.update(values)
|
74
|
+
|
75
|
+
def read_configfile_from_arguments(
|
76
|
+
self,
|
77
|
+
argv: List[str],
|
78
|
+
) -> None:
|
79
|
+
"""Scan an list of arguments for configuration file(s) and load them."""
|
80
|
+
# TODO(hardaker): convert this to argparse's parse known feature
|
81
|
+
# aka replace using stackoverflow answer to 3609852
|
82
|
+
|
83
|
+
for n, item in enumerate(argv):
|
84
|
+
if item in self.config_option_names:
|
85
|
+
if len(argv) == n:
|
86
|
+
error(f"no argument supplied after '{item}'")
|
87
|
+
raise ValueError
|
88
|
+
|
89
|
+
if argv[n + 1].startswith("-"):
|
90
|
+
error(f"The argument after '{item}' seems to be another argument")
|
91
|
+
raise ValueError
|
92
|
+
|
93
|
+
filename = argv[n + 1]
|
94
|
+
|
95
|
+
if "=" in filename:
|
96
|
+
(left, right) = filename.split("=")
|
97
|
+
left = left.strip()
|
98
|
+
right = right.strip()
|
99
|
+
self.set_dotnest(left, right)
|
100
|
+
continue
|
101
|
+
|
102
|
+
if not Path(filename).is_file():
|
103
|
+
error(
|
104
|
+
f"The filename after '{item}' does not exist or is not a file"
|
105
|
+
)
|
106
|
+
raise ValueError
|
107
|
+
|
108
|
+
self.load_file(filename)
|
109
|
+
|
110
|
+
def as_namespace(self) -> Namespace:
|
111
|
+
"""Convert the configuration (back) into a argparse Namespace."""
|
112
|
+
namespace = Namespace()
|
113
|
+
for item, value in self.items():
|
114
|
+
setattr(namespace, item, value)
|
115
|
+
|
116
|
+
return namespace
|
117
|
+
|
118
|
+
def dump(self):
|
119
|
+
"""Dumps the current configuration into a YAML format."""
|
120
|
+
import yaml
|
121
|
+
|
122
|
+
print(yaml.dump(self))
|
123
|
+
|
124
|
+
def set_dotnest(self, parameter: str, value: Any):
|
125
|
+
self.dotnest.set(parameter, value)
|
126
|
+
|
127
|
+
def get_dotnest(
|
128
|
+
self, parameter: str, default: Any = None, return_none: bool = True
|
129
|
+
):
|
130
|
+
result = self.dotnest.get(parameter, return_none=return_none)
|
131
|
+
if result is not None:
|
132
|
+
return result
|
133
|
+
return default
|