adtl 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
adtl/__init__.py ADDED
@@ -0,0 +1,1085 @@
1
+ import argparse
2
+ import csv
3
+ import hashlib
4
+ import io
5
+ import json
6
+ import logging
7
+ import itertools
8
+ import copy
9
+ import re
10
+ import importlib.metadata
11
+ from collections import defaultdict, Counter
12
+ from datetime import datetime
13
+ from pathlib import Path
14
+ from functools import lru_cache
15
+ from typing import Any, Dict, Iterable, List, Optional, Union, Callable
16
+ from more_itertools import unique_everseen
17
+
18
+ import pint
19
+ import tomli
20
+ import requests
21
+ import fastjsonschema
22
+ from tqdm import tqdm
23
+ import warnings
24
+
25
+ import adtl.transformations as tf
26
+ from adtl.transformations import AdtlTransformationWarning
27
+
28
+ SUPPORTED_FORMATS = {"json": json.load, "toml": tomli.load}
29
+ DEFAULT_DATE_FORMAT = "%Y-%m-%d"
30
+
31
+ StrDict = Dict[str, Any]
32
+ Rule = Union[str, StrDict]
33
+ Context = Optional[Dict[str, Union[bool, int, str, List[str]]]]
34
+
35
+ __version__ = importlib.metadata.version("adtl")
36
+
37
+
38
+ def get_value(row: StrDict, rule: Rule, ctx: Context = None) -> Any:
39
+ """Gets value from row using rule
40
+
41
+ Same as get_value_unhashed(), except it hashes if sensitive = True in rule.
42
+ This function should be used instead of get_value_unhashed() for
43
+ application code.
44
+ """
45
+ value = get_value_unhashed(row, rule, ctx)
46
+ if isinstance(rule, dict) and rule.get("sensitive") and value is not None:
47
+ return hash_sensitive(value)
48
+ if not isinstance(value, str):
49
+ return value
50
+ try:
51
+ return int(value)
52
+ except ValueError:
53
+ try:
54
+ return float(value)
55
+ except ValueError:
56
+ return value
57
+
58
+
59
+ def get_value_unhashed(row: StrDict, rule: Rule, ctx: Context = None) -> Any:
60
+ """Gets value from row using rule (unhashed)
61
+
62
+ Unlike get_value() this function does NOT hash sensitive data
63
+ and should not be called directly, except for debugging. Use
64
+ get_value() instead.
65
+ """
66
+ if not isinstance(rule, dict) or isinstance(
67
+ rule, list
68
+ ): # not a container, is constant
69
+ return rule
70
+ # Check whether field is present if it's allowed to be passed over
71
+ if "field" in rule:
72
+ # do not check for condition if field is missing
73
+ if skip_field(row, rule, ctx):
74
+ return None
75
+ # do not parse field if condition is not met
76
+ if "if" in rule and not parse_if(row, rule["if"]):
77
+ return None
78
+ value = row[rule["field"]]
79
+ if "apply" in rule:
80
+ # apply data transformations.
81
+ transformation = rule["apply"]["function"]
82
+ params = None
83
+ if "params" in rule["apply"]:
84
+ params = []
85
+ for i in range(len(rule["apply"]["params"])):
86
+ if isinstance(rule["apply"]["params"][i], str) and rule["apply"][
87
+ "params"
88
+ ][i].startswith("$"):
89
+ params.append(row[rule["apply"]["params"][i][1:]])
90
+ elif isinstance(rule["apply"]["params"][i], list):
91
+ param = [
92
+ (
93
+ row[rule["apply"]["params"][i][j][1:]]
94
+ if (
95
+ isinstance(rule["apply"]["params"][i][j], str)
96
+ and rule["apply"]["params"][i][j].startswith("$")
97
+ )
98
+ else rule["apply"]["params"][i][j]
99
+ )
100
+ for j in range(len(rule["apply"]["params"][i]))
101
+ ]
102
+ params.append(param)
103
+ else:
104
+ params.append(rule["apply"]["params"][i])
105
+
106
+ try:
107
+ with warnings.catch_warnings():
108
+ warnings.simplefilter("error", category=AdtlTransformationWarning)
109
+ if params:
110
+ value = getattr(tf, transformation)(value, *params)
111
+ else:
112
+ value = getattr(tf, transformation)(value)
113
+ except AttributeError:
114
+ raise AttributeError(
115
+ f"Error using a data transformation: Function {transformation} "
116
+ "has not been defined."
117
+ )
118
+ except AdtlTransformationWarning as e:
119
+ if ctx and ctx.get("returnUnmatched"):
120
+ warnings.warn(str(e), AdtlTransformationWarning)
121
+ return value
122
+ else:
123
+ logging.error(str(e))
124
+ return None
125
+ return value
126
+ if value == "":
127
+ return None
128
+ if "values" in rule:
129
+ if rule.get("caseInsensitive") and isinstance(value, str):
130
+ value = value.lower().lstrip(" ").rstrip(" ")
131
+ rule["values"] = {k.lower(): v for k, v in rule["values"].items()}
132
+
133
+ if rule.get("ignoreMissingKey") or (ctx and ctx.get("returnUnmatched")):
134
+ value = rule["values"].get(value, value)
135
+ else:
136
+ value = rule["values"].get(value)
137
+
138
+ # recheck if value is empty after mapping (use to map values to None)
139
+ if value == "":
140
+ return None
141
+ # Either source_unit / unit OR source_date / date triggers conversion
142
+ # do not parse units if value is empty
143
+ if "source_unit" in rule and "unit" in rule:
144
+ assert "source_date" not in rule and "date" not in rule
145
+ source_unit = get_value(row, rule["source_unit"])
146
+ unit = rule["unit"]
147
+ if not isinstance(source_unit, str):
148
+ logging.debug(
149
+ f"Error converting source_unit {source_unit} to {unit!r} with "
150
+ "rule: {rule}, defaulting to assume source_unit is {unit}"
151
+ )
152
+ return float(value)
153
+ try:
154
+ value = pint.Quantity(float(value), source_unit).to(unit).m
155
+ except ValueError:
156
+ if ctx and ctx.get("returnUnmatched"):
157
+ logging.debug(f"Could not convert {value} to a floating point")
158
+ return value
159
+ raise ValueError(f"Could not convert {value} to a floating point")
160
+ if "source_date" in rule or (ctx and ctx.get("is_date")):
161
+ assert "source_unit" not in rule and "unit" not in rule
162
+ target_date = rule.get("date", "%Y-%m-%d")
163
+ source_date = (
164
+ get_value(row, rule["source_date"])
165
+ if "source_date" in rule
166
+ else ctx["defaultDateFormat"]
167
+ )
168
+ if source_date != target_date:
169
+ try:
170
+ value = datetime.strptime(value, source_date).strftime(target_date)
171
+ except (TypeError, ValueError):
172
+ logging.info(f"Could not parse date: {value}")
173
+ if ctx and ctx.get("returnUnmatched"):
174
+ return value
175
+ return None
176
+ return value
177
+ elif "combinedType" in rule:
178
+ return get_combined_type(row, rule, ctx)
179
+ else:
180
+ raise ValueError(f"Could not return value for {rule}")
181
+
182
+
183
+ def matching_fields(fields: List[str], pattern: str) -> List[str]:
184
+ "Returns fields matching pattern"
185
+ compiled_pattern = re.compile(pattern)
186
+ return [f for f in fields if compiled_pattern.match(f)]
187
+
188
+
189
+ def parse_if(
190
+ row: StrDict, rule: StrDict, ctx: Callable[[str], dict] = None, can_skip=False
191
+ ) -> bool:
192
+ "Parse conditional statements and return a boolean"
193
+
194
+ n_keys = len(rule.keys())
195
+ assert n_keys == 1 or n_keys == 2
196
+ if n_keys == 2:
197
+ assert "can_skip" in rule
198
+ can_skip = True
199
+ key = next(iter(rule.keys()))
200
+ if key == "not" and isinstance(rule[key], dict):
201
+ return not parse_if(row, rule[key], ctx, can_skip)
202
+ elif key == "any" and isinstance(rule[key], list):
203
+ return any(parse_if(row, r, ctx, can_skip) for r in rule[key])
204
+ elif key == "all" and isinstance(rule[key], list):
205
+ return all(parse_if(row, r, ctx, can_skip) for r in rule[key])
206
+ try:
207
+ attr_value = row[key]
208
+ except KeyError:
209
+ if can_skip is True:
210
+ return False
211
+ elif ctx:
212
+ if skip_field(row, {"field": key}, ctx(key)):
213
+ return False
214
+ raise
215
+
216
+ if isinstance(rule[key], dict):
217
+ cmp = next(iter(rule[key]))
218
+ value = rule[key][cmp]
219
+ try:
220
+ cast_value = type(value)(attr_value)
221
+ except ValueError:
222
+ logging.debug(
223
+ f"Error when casting value {attr_value!r} with rule: {rule}, defaulting"
224
+ " to False"
225
+ )
226
+ return False
227
+ if cmp == ">":
228
+ return cast_value > value
229
+ elif cmp == ">=":
230
+ return cast_value >= value
231
+ elif cmp == "<":
232
+ return cast_value < value
233
+ elif cmp == "<=":
234
+ return cast_value <= value
235
+ elif cmp == "!=":
236
+ return cast_value != value
237
+ elif cmp in ["=", "=="]:
238
+ return cast_value == value
239
+ elif cmp == "=~":
240
+ return bool(re.match(value, cast_value, re.IGNORECASE))
241
+ else:
242
+ raise ValueError(f"Unrecognized operand: {cmp}")
243
+ elif isinstance(rule[key], set): # common error, missed colon to make it a dict
244
+ raise ValueError(
245
+ f"if-subexpressions should be a dictionary, is a set: {rule[key]}"
246
+ )
247
+ else:
248
+ value = rule[key]
249
+ try:
250
+ cast_value = type(value)(attr_value)
251
+ except ValueError:
252
+ logging.debug(
253
+ f"Error when casting value {attr_value!r} with rule: {rule}, defaulting"
254
+ " to False"
255
+ )
256
+ return False
257
+ return cast_value == value
258
+
259
+
260
+ def get_combined_type(row: StrDict, rule: StrDict, ctx: Context = None):
261
+ """Gets value from row for a combinedType rule
262
+
263
+ A rule with the combinedType key combines multiple fields in the row
264
+ to get the value. Thus this rule assumes that the combinedType fields
265
+ do NOT have repeated (possibly different) values across the dataset.
266
+
267
+ Example of dataset that will be handled correctly, with modliv and
268
+ mildliver being the categorical indicators for moderate and mild
269
+ liver disease respectively:
270
+
271
+ subjid,modliv,mildliver,otherfield
272
+ 1,0,1,NA
273
+ 1,,,
274
+
275
+ Example of dataset that will not be handled correctly:
276
+
277
+ subjid,modliv,mildliver,otherfield
278
+ 1,0,,
279
+ 1,,1,
280
+
281
+ For a combinedType rule to successfully run, all the field values should
282
+ be present in the same row.
283
+ """
284
+ assert "combinedType" in rule
285
+ combined_type = rule["combinedType"]
286
+ rules = []
287
+ # expand fieldPattern rules
288
+ for r in rule["fields"]:
289
+ if "fieldPattern" in r:
290
+ for match in matching_fields(list(row.keys()), r.get("fieldPattern")):
291
+ rules.append({"field": match, **r})
292
+ else:
293
+ rules.append(r)
294
+ if combined_type in ["all", "any", "min", "max"]:
295
+ values = [get_value(row, r, ctx) for r in rules]
296
+ values = [v for v in values if v not in [None, ""]]
297
+ # normally calling eval() is a bad idea, but here values are restricted, so okay
298
+ return eval(combined_type)(values) if values else None
299
+ elif combined_type == "firstNonNull":
300
+ try:
301
+ return next(
302
+ filter(
303
+ lambda item: item is not None,
304
+ flatten([get_value(row, r, ctx) for r in rules]),
305
+ )
306
+ )
307
+ except StopIteration:
308
+ return None
309
+ elif combined_type == "list" or combined_type == "set":
310
+ excludeWhen = rule.get("excludeWhen")
311
+ if excludeWhen not in [None, "false-like", "none"] and not isinstance(
312
+ excludeWhen, list
313
+ ):
314
+ raise ValueError(
315
+ "excludeWhen rule should be 'none', 'false-like', or a list of values"
316
+ )
317
+
318
+ values = flatten([get_value(row, r, ctx) for r in rules])
319
+ if combined_type == "set":
320
+ values = [*set(values)]
321
+ if excludeWhen is None:
322
+ return list(values)
323
+ if excludeWhen == "none":
324
+ return [v for v in values if v is not None]
325
+ elif excludeWhen == "false-like":
326
+ return [v for v in values if v]
327
+ else:
328
+ return [v for v in values if v not in excludeWhen]
329
+ else:
330
+ raise ValueError(f"Unknown {combined_type} in {rule}")
331
+
332
+
333
+ def flatten(xs):
334
+ """
335
+ Flatten a list of lists +-/ non-list items
336
+ e.g.
337
+ [None, ['Dexamethasone']] -> [None, 'Dexamethasome']
338
+ """
339
+ for x in xs:
340
+ if isinstance(x, Iterable) and not isinstance(x, (str, bytes)):
341
+ yield from flatten(x)
342
+ else:
343
+ yield x
344
+
345
+
346
+ def expand_refs(spec_fragment: StrDict, defs: StrDict) -> Union[StrDict, List[StrDict]]:
347
+ "Expand all references (ref) with definitions (defs)"
348
+
349
+ if spec_fragment == {}:
350
+ return {}
351
+ if isinstance(spec_fragment, dict):
352
+ if "ref" in spec_fragment:
353
+ reference_expanded = defs[spec_fragment["ref"]]
354
+ del spec_fragment["ref"]
355
+ spec_fragment = {**reference_expanded, **spec_fragment}
356
+ return {k: expand_refs(spec_fragment[k], defs) for k in spec_fragment}
357
+ elif isinstance(spec_fragment, list):
358
+ return [expand_refs(m, defs) for m in spec_fragment]
359
+ else:
360
+ return spec_fragment
361
+
362
+
363
+ def expand_for(spec: List[StrDict]) -> List[StrDict]:
364
+ "Expands for expressions in oneToMany table blocks"
365
+
366
+ out = []
367
+
368
+ def replace_val(
369
+ item: Union[str, float, Dict[str, Any]], replace: Dict[str, Any]
370
+ ) -> Dict[str, Any]:
371
+ block = {}
372
+ if isinstance(item, str):
373
+ return item.format(**replace)
374
+ elif isinstance(item, (float, int)):
375
+ return item
376
+ for k, v in item.items():
377
+ if not isinstance(k, str):
378
+ block[k] = v
379
+ continue
380
+ rk = k.format(**replace)
381
+ if isinstance(v, dict):
382
+ block[rk] = replace_val(v, replace)
383
+ elif isinstance(v, str):
384
+ block[rk] = v.format(**replace)
385
+ elif isinstance(v, list):
386
+ block[rk] = [replace_val(it, replace) for it in v]
387
+ else:
388
+ block[rk] = v
389
+ return block
390
+
391
+ for match in spec:
392
+ if "for" not in match:
393
+ out.append(match)
394
+ continue
395
+ for_expr = match.pop("for")
396
+ if not isinstance(for_expr, dict):
397
+ raise ValueError(
398
+ f"for expression {for_expr!r} is not a dictionary of variables to list "
399
+ "of values or a range"
400
+ )
401
+
402
+ # Expand ranges when available
403
+ for var in for_expr:
404
+ if (
405
+ "range" in for_expr[var]
406
+ and isinstance(for_expr[var]["range"], list)
407
+ and len(for_expr[var]["range"]) == 2
408
+ and isinstance(for_expr[var]["range"][0], int)
409
+ and isinstance(for_expr[var]["range"][1], int)
410
+ and for_expr[var]["range"][1] > for_expr[var]["range"][0]
411
+ ):
412
+ start, end = for_expr[var]["range"]
413
+ for_expr[var] = range(start, end + 1) # add one to include end in list
414
+ elif isinstance(for_expr[var], list):
415
+ pass
416
+ else:
417
+ raise ValueError(
418
+ f"for expression {for_expr!r} can only have lists or ranges for "
419
+ "variables"
420
+ )
421
+ loop_vars = sorted(for_expr.keys())
422
+ loop_assignments = [
423
+ dict(zip(loop_vars, vals))
424
+ for vals in itertools.product(*(for_expr[var] for var in loop_vars))
425
+ ]
426
+ for replacement in loop_assignments:
427
+ out.append(replace_val(match, replacement))
428
+ return out
429
+
430
+
431
+ def hash_sensitive(value: str) -> str:
432
+ """Hashes sensitive values. This is not generally sufficient for
433
+ anonymisation, as the value still serves as a unique identifier,
434
+ but is better than storing the value unprocessed."""
435
+ return hashlib.sha256(str(value).encode("utf-8")).hexdigest()
436
+
437
+
438
+ def remove_null_keys(d: Dict[str, Any]) -> Dict[str, Any]:
439
+ "Removes keys which map to null - but not empty strings or 'unknown' etc types"
440
+ return {k: v for k, v in d.items() if v is not None}
441
+
442
+
443
+ def get_date_fields(schema: Dict[str, Any]) -> List[str]:
444
+ "Returns list of date fields from schema"
445
+ fields = [
446
+ field
447
+ for field in schema["properties"]
448
+ if field == "date" or "date_" in field or "_date" in field
449
+ ]
450
+ format_date_fields = [
451
+ field
452
+ for field in schema["properties"]
453
+ if schema["properties"][field].get("format") == "date"
454
+ ]
455
+ return sorted(set(fields + format_date_fields))
456
+
457
+
458
+ def make_fields_optional(
459
+ schema: Dict[str, Any], optional_fields: List[str]
460
+ ) -> Dict[str, Any]:
461
+ "Returns JSON schema with required fields modified to drop optional fields"
462
+ if optional_fields is None:
463
+ return schema
464
+ _schema = copy.deepcopy(schema)
465
+ _schema["required"] = sorted(set(schema["required"]) - set(optional_fields))
466
+ for opt in ["oneOf", "anyOf"]:
467
+ if opt in _schema:
468
+ if any("required" in _schema[opt][x] for x in range(len(_schema[opt]))):
469
+ for x in range(len(_schema[opt])):
470
+ _schema[opt][x]["required"] = list(
471
+ set(_schema[opt][x]["required"]) - set(optional_fields or [])
472
+ )
473
+ if all(
474
+ all(bool(v) is False for v in _schema[opt][x].values())
475
+ for x in range(len(_schema[opt]))
476
+ ):
477
+ _schema.pop(opt)
478
+ else:
479
+ _schema[opt] = list(unique_everseen(_schema[opt]))
480
+ return _schema
481
+
482
+
483
+ def relative_path(source_file, target_file):
484
+ return Path(source_file).parent / target_file
485
+
486
+
487
+ def read_definition(file: Path) -> Dict[str, Any]:
488
+ "Reads definition from file into a dictionary"
489
+ if isinstance(file, str):
490
+ file = Path(file)
491
+ if file.suffix == ".json":
492
+ with file.open() as fp:
493
+ return json.load(fp)
494
+ elif file.suffix == ".toml":
495
+ with file.open("rb") as fp:
496
+ return tomli.load(fp)
497
+ else:
498
+ raise ValueError(f"Unsupported file format: {file}")
499
+
500
+
501
+ def skip_field(row: StrDict, rule: StrDict, ctx: Context = None):
502
+ "Returns True if the field is missing and allowed to be skipped"
503
+ if rule.get("can_skip"):
504
+ return rule["field"] not in row
505
+ if ctx and ctx.get("skip_pattern") and ctx.get("skip_pattern").match(rule["field"]):
506
+ return rule["field"] not in row
507
+ return False
508
+
509
+
510
+ class Parser:
511
+ """Main parser class that loads a specification
512
+
513
+ Typical use of this within Python code::
514
+
515
+ import adtl
516
+
517
+ parser = adtl.Parser(specification)
518
+ print(parser.tables) # list of tables created
519
+
520
+ for row in parser.parse().read_table(table):
521
+ print(row)
522
+ """
523
+
524
+ def __init__(
525
+ self,
526
+ spec: Union[str, Path, StrDict],
527
+ include_defs: List[str] = [],
528
+ quiet: bool = False,
529
+ ):
530
+ """Loads specification from spec in format (default json)
531
+
532
+ Args:
533
+ spec: Either the specification file to read (as Path or str), or
534
+ the specification loaded into a dictionary
535
+ include_defs: Definition files to include. These are spliced
536
+ directly into the adtl.defs section of the :ref:`specification`.
537
+ quiet: Boolean that switches on the verbosity of the parser, default False
538
+ """
539
+
540
+ self.data: StrDict = {}
541
+ self.defs: StrDict = {}
542
+ self.fieldnames: Dict[str, List[str]] = {}
543
+ self.specfile = None
544
+ self.include_defs = include_defs
545
+ self.validators: StrDict = {}
546
+ self.schemas: StrDict = {}
547
+ self.quiet = quiet
548
+ self.date_fields = []
549
+ self.report = {
550
+ "validation_errors": defaultdict(Counter),
551
+ "total_valid": defaultdict(int),
552
+ "total": defaultdict(int),
553
+ }
554
+ self.report_available = False
555
+ if isinstance(spec, str):
556
+ spec = Path(spec)
557
+ if isinstance(spec, Path):
558
+ self.specfile = spec
559
+ fmt = spec.suffix[1:]
560
+ if fmt not in SUPPORTED_FORMATS:
561
+ raise ValueError(f"adtl specification format not supported: {fmt}")
562
+ with spec.open("rb") as fp:
563
+ self.spec = SUPPORTED_FORMATS[fmt](fp)
564
+ else:
565
+ self.spec = spec
566
+ self.header = self.spec.get("adtl", {})
567
+ if self.specfile:
568
+ self.include_defs = [
569
+ relative_path(self.specfile, definition_file)
570
+ for definition_file in self.header.get("include-def", [])
571
+ ] + self.include_defs
572
+ self.defs = self.header.get("defs", {})
573
+ if self.include_defs:
574
+ for definition_file in self.include_defs:
575
+ self.defs.update(read_definition(definition_file))
576
+ self.spec = expand_refs(self.spec, self.defs)
577
+
578
+ self.validate_spec()
579
+ for table in (t for t in self.tables if self.tables[t]["kind"] == "oneToMany"):
580
+ self.spec[table] = expand_for(self.spec[table])
581
+ for table in self.tables:
582
+ if self.tables[table].get("groupBy"):
583
+ self.data[table] = defaultdict(dict)
584
+ else:
585
+ self.data[table] = []
586
+ if schema := self.tables[table].get("schema"):
587
+ optional_fields = self.tables[table].get("optional-fields")
588
+ if schema.startswith("http"):
589
+ try:
590
+ res = requests.get(schema)
591
+ if res.status_code != 200:
592
+ logging.warning(
593
+ f"Could not fetch schema for table {table!r}, will not "
594
+ "validate"
595
+ )
596
+ continue
597
+ except ConnectionError: # pragma: no cover
598
+ logging.warning(
599
+ f"Could not fetch schema for table {table!r}, will not "
600
+ "validate"
601
+ )
602
+ continue
603
+ self.schemas[table] = make_fields_optional(
604
+ res.json(), optional_fields
605
+ )
606
+ else: # local file
607
+ with (self.specfile.parent / schema).open() as fp:
608
+ self.schemas[table] = make_fields_optional(
609
+ json.load(fp), optional_fields
610
+ )
611
+ self.date_fields.extend(get_date_fields(self.schemas[table]))
612
+ self.validators[table] = fastjsonschema.compile(self.schemas[table])
613
+
614
+ self._set_field_names()
615
+
616
+ @lru_cache
617
+ def ctx(self, attribute: str):
618
+ return {
619
+ "is_date": attribute in self.date_fields,
620
+ "defaultDateFormat": self.header.get(
621
+ "defaultDateFormat", DEFAULT_DATE_FORMAT
622
+ ),
623
+ "skip_pattern": (
624
+ re.compile(self.header.get("skipFieldPattern"))
625
+ if self.header.get("skipFieldPattern")
626
+ else False
627
+ ),
628
+ "returnUnmatched": self.header.get("returnUnmatched", False),
629
+ }
630
+
631
+ def validate_spec(self):
632
+ "Raises exceptions if specification is invalid"
633
+ for required in ["tables", "name", "description"]:
634
+ if required not in self.header:
635
+ raise ValueError(f"Specification header requires key: {required}")
636
+ self.tables = self.header["tables"]
637
+ self.name = self.header["name"]
638
+ self.description = self.header["description"]
639
+
640
+ for table in self.tables:
641
+ aggregation = self.tables[table].get("aggregation")
642
+ group_field = self.tables[table].get("groupBy")
643
+ kind = self.tables[table].get("kind")
644
+ if kind is None:
645
+ raise ValueError(
646
+ f"Required 'kind' attribute within 'tables' not present for {table}"
647
+ )
648
+ if group_field is not None and aggregation != "lastNotNull":
649
+ raise ValueError(
650
+ "groupBy needs aggregation=lastNotNull to be set for table: "
651
+ f"{table}"
652
+ )
653
+
654
+ def _set_field_names(self):
655
+ for table in self.tables:
656
+ if table not in self.spec:
657
+ raise ValueError(
658
+ f"Parser specification missing required '{table}' element"
659
+ )
660
+ if self.tables[table].get("kind") != "oneToMany":
661
+ self.fieldnames[table] = sorted(list(self.spec[table].keys()))
662
+ else:
663
+ if table not in self.schemas:
664
+ print(
665
+ f"Warning: no schema found for {table!r}, field names may be "
666
+ "incomplete!"
667
+ )
668
+ self.fieldnames[table] = list(
669
+ self.tables[table].get("common", {}).keys()
670
+ ) + sorted(
671
+ list(set(sum([list(m.keys()) for m in self.spec[table]], [])))
672
+ )
673
+ else:
674
+ self.fieldnames[table] = sorted(self.schemas[table]["properties"])
675
+ if commonMappings := self.tables[table].get("common", {}):
676
+ for match in self.spec[table]:
677
+ match.update(commonMappings)
678
+
679
+ def _default_if(self, table: str, rule: StrDict):
680
+ """
681
+ Default behaviour for oneToMany table, row not displayed if there's an empty
682
+ string or values not mapped in the rule.
683
+ """
684
+
685
+ data_options = [
686
+ option["required"][0] for option in self.schemas[table]["oneOf"]
687
+ ]
688
+
689
+ option = set(data_options).intersection(rule.keys()).pop()
690
+
691
+ if "combinedType" not in rule[option]:
692
+ field = rule[option]["field"]
693
+ if "values" in rule[option]:
694
+ values = rule[option]["values"]
695
+ if "can_skip" in rule[option]:
696
+ if_rule = {"any": [{field: v, "can_skip": True} for v in values]}
697
+ else:
698
+ if_rule = {"any": [{field: v} for v in values]}
699
+ elif "can_skip" in rule[option]:
700
+ if_rule = {field: {"!=": ""}, "can_skip": True}
701
+ else:
702
+ if_rule = {field: {"!=": ""}}
703
+ else:
704
+ assert rule[option]["combinedType"] in [
705
+ "any",
706
+ "all",
707
+ "firstNonNull",
708
+ "set",
709
+ "list",
710
+ "min",
711
+ "max",
712
+ ], f"Invalid combinedType: {rule[option]['combinedType']}"
713
+ rules = rule[option]["fields"]
714
+
715
+ def create_if_rule(rule):
716
+ field = rule["field"]
717
+ values = rule.get("values", [])
718
+ can_skip = rule.get("can_skip", False)
719
+
720
+ if_condition = {}
721
+
722
+ if values and can_skip:
723
+ if_condition = [{field: v, "can_skip": True} for v in values]
724
+ elif values:
725
+ if_condition = [{field: v} for v in values]
726
+ elif can_skip:
727
+ if_condition[field] = {"!=": ""}
728
+ if_condition["can_skip"] = True
729
+ if_condition = [if_condition]
730
+ else:
731
+ if_condition[field] = {"!=": ""}
732
+ if_condition = [if_condition]
733
+
734
+ return if_condition
735
+
736
+ if_rule = {"any": sum(map(create_if_rule, rules), [])}
737
+
738
+ rule["if"] = if_rule
739
+ return rule
740
+
741
+ def update_table(self, table: str, row: StrDict):
742
+ """Updates table with a new row
743
+
744
+ Args:
745
+ table: Table to update
746
+ row: Dictionary with keys as field names and values as field values
747
+ """
748
+
749
+ group_field = self.tables[table].get("groupBy")
750
+ kind = self.tables[table].get("kind")
751
+ if group_field:
752
+ group_key = get_value(row, self.spec[table][group_field])
753
+ for attr in self.spec[table]:
754
+ value = get_value(row, self.spec[table][attr], self.ctx(attr))
755
+ # Check against all null elements, for combinedType=set/list, null is []
756
+ if value is not None and value != []:
757
+ if attr not in self.data[table][group_key].keys():
758
+ # if data for this field hasn't already been captured
759
+ self.data[table][group_key][attr] = value
760
+
761
+ else:
762
+ if "combinedType" in self.spec[table][attr]:
763
+ combined_type = self.spec[table][attr]["combinedType"]
764
+ existing_value = self.data[table][group_key][attr]
765
+
766
+ if combined_type in ["all", "any", "min", "max"]:
767
+ values = [existing_value, value]
768
+ # normally calling eval() is a bad idea, but here
769
+ # values are restricted, so okay
770
+ self.data[table][group_key][attr] = eval(combined_type)(
771
+ values
772
+ )
773
+ elif combined_type in ["list", "set"]:
774
+ if combined_type == "set":
775
+ self.data[table][group_key][attr] = list(
776
+ set(existing_value + value)
777
+ )
778
+ else:
779
+ self.data[table][group_key][attr] = (
780
+ existing_value + value
781
+ )
782
+ elif combined_type == "firstNonNull":
783
+ # only use the first value found
784
+ pass
785
+ else:
786
+ # otherwise overwrite?
787
+ logging.debug(
788
+ f"Multiple rows of data found for {attr} without a"
789
+ " combinedType listed. Data being overwritten."
790
+ )
791
+ self.data[table][group_key][attr] = value
792
+
793
+ elif kind == "oneToMany":
794
+ for match in self.spec[table]:
795
+ if "if" not in match:
796
+ match = self._default_if(table, match)
797
+ if parse_if(row, match["if"], self.ctx):
798
+ self.data[table].append(
799
+ remove_null_keys(
800
+ {
801
+ attr: get_value(row, match[attr], self.ctx(attr))
802
+ for attr in set(match.keys()) - {"if"}
803
+ }
804
+ )
805
+ )
806
+ elif kind == "constant": # only one row
807
+ self.data[table] = [self.spec[table]]
808
+ else:
809
+ self.data[table].append(
810
+ remove_null_keys(
811
+ {
812
+ attr: get_value(row, self.spec[table][attr], self.ctx(attr))
813
+ for attr in self.spec[table]
814
+ }
815
+ )
816
+ )
817
+
818
+ def parse(self, file: str, encoding: str = "utf-8", skip_validation=False):
819
+ """Transform file according to specification
820
+
821
+ Args:
822
+ file: Source file to transform
823
+ encoding: Source file encoding
824
+ skip_validation: Whether to skip validation, default off
825
+
826
+ Returns:
827
+ adtl.Parser: Returns an instance of itself, updated with the parsed tables
828
+ """
829
+ with open(file, encoding=encoding) as fp:
830
+ reader = csv.DictReader(fp)
831
+ return self.parse_rows(
832
+ (
833
+ tqdm(
834
+ reader,
835
+ desc=f"[{self.name}] parsing {Path(file).name}",
836
+ )
837
+ if not self.quiet
838
+ else reader
839
+ ),
840
+ skip_validation=skip_validation,
841
+ )
842
+
843
+ def parse_rows(self, rows: Iterable[StrDict], skip_validation=False):
844
+ """Transform rows from an iterable according to specification
845
+
846
+ Args:
847
+ rows: Iterable of rows, specified as a dictionary of
848
+ (field name, field value) pairs
849
+ skip_validation: Whether to skip validation, default off
850
+
851
+ Returns:
852
+ adtl.Parser: Returns an instance of itself, updated with the parsed tables
853
+ """
854
+ for row in rows:
855
+ for table in self.tables:
856
+ try:
857
+ self.update_table(table, row)
858
+ except ValueError: # pragma: no cover
859
+ print(
860
+ "\n".join(
861
+ [
862
+ f"{key} = {value}"
863
+ for key, value in row.items()
864
+ if value not in ["", None]
865
+ ]
866
+ )
867
+ )
868
+ raise
869
+ self.report_available = not skip_validation
870
+ if not skip_validation:
871
+ for table in self.validators:
872
+ for row in self.read_table(table):
873
+ self.report["total"][table] += 1
874
+ try:
875
+ self.validators[table](row)
876
+ row["adtl_valid"] = True
877
+ self.report["total_valid"][table] += 1
878
+ except fastjsonschema.exceptions.JsonSchemaValueException as e:
879
+ row["adtl_valid"] = False
880
+ row["adtl_error"] = e.message
881
+ self.report["validation_errors"][table].update([e.message])
882
+ return self
883
+
884
+ def clear(self):
885
+ "Clears parser state"
886
+ self.data = {}
887
+
888
+ def read_table(self, table: str) -> Iterable[StrDict]:
889
+ """Returns parsed table
890
+
891
+ Args:
892
+ table: Table to read
893
+
894
+ Returns:
895
+ Iterable of transformed rows in table
896
+ """
897
+ if table not in self.tables:
898
+ raise ValueError(f"Invalid table: {table}")
899
+ if "groupBy" in self.tables[table]:
900
+ for i in self.data[table]:
901
+ yield self.data[table][i]
902
+ else:
903
+ for row in self.data[table]:
904
+ yield row
905
+
906
+ def write_csv(
907
+ self,
908
+ table: str,
909
+ output: Optional[str] = None,
910
+ ) -> Optional[str]:
911
+ """Writes to output as CSV a particular table
912
+
913
+ Args:
914
+ table: Table that should be written to CSV
915
+ output: (optional) Output file name. If not specified, defaults to parser
916
+ name + table name
917
+ with a csv suffix.
918
+ """
919
+
920
+ def writerows(fp, table):
921
+ writer = csv.DictWriter(
922
+ fp,
923
+ fieldnames=(
924
+ ["adtl_valid", "adtl_error"] if table in self.validators else []
925
+ )
926
+ + self.fieldnames[table],
927
+ )
928
+ writer.writeheader()
929
+ for row in self.read_table(table):
930
+ writer.writerow(row)
931
+ return fp
932
+
933
+ if output:
934
+ with open(output, "w") as fp:
935
+ writerows(fp, table)
936
+ return None
937
+ else:
938
+ buf = io.StringIO()
939
+ return writerows(buf, table).getvalue()
940
+
941
+ def write_parquet(
942
+ self,
943
+ table: str,
944
+ output: Optional[str] = None,
945
+ ) -> Optional[str]:
946
+ """Writes to output as parquet a particular table
947
+
948
+ Args:
949
+ table: Table that should be written to parquet
950
+ output: (optional) Output file name. If not specified, defaults to parser
951
+ name + table name with a parquet suffix.
952
+ """
953
+
954
+ try:
955
+ import polars as pl
956
+ except ImportError:
957
+ raise ImportError(
958
+ "Parquet output requires the polars library. "
959
+ "Install with 'pip install polars'"
960
+ )
961
+
962
+ # Read the table data
963
+ data = list(self.read_table(table))
964
+
965
+ # Convert data to Polars DataFrame
966
+ df = pl.DataFrame(data, infer_schema_length=len(data))
967
+
968
+ if table in self.validators:
969
+ valid_cols = [c for c in ["adtl_valid", "adtl_error"] if c in df.columns]
970
+ df_validated = df.select(
971
+ valid_cols
972
+ + [
973
+ *[
974
+ col
975
+ for col in df.columns
976
+ if (col != "adtl_valid" and col != "adtl_error")
977
+ ], # All other columns, in their original order
978
+ ]
979
+ )
980
+ else:
981
+ df_validated = df
982
+
983
+ if output:
984
+ df_validated.write_parquet(output)
985
+ else:
986
+ buf = io.BytesIO()
987
+ df_validated.write_parquet(buf)
988
+ return buf.getvalue()
989
+
990
+ def show_report(self):
991
+ "Shows report with validation errors"
992
+ if self.report_available:
993
+ print("\n|table \t|valid\t|total\t|percentage_valid|")
994
+ print("|---------------|-------|-------|----------------|")
995
+ for table in self.report["total"]:
996
+ print(
997
+ f"|{table:14s}\t|{self.report['total_valid'][table]}\t"
998
+ f"|{self.report['total'][table]}\t"
999
+ f"|{self.report['total_valid'][table]/self.report['total'][table]:%} |" # noqa:E501
1000
+ )
1001
+ print()
1002
+ for table in self.report["validation_errors"]:
1003
+ print(f"## {table}\n")
1004
+ for message, count in self.report["validation_errors"][
1005
+ table
1006
+ ].most_common():
1007
+ print(f"* {count}: {message}")
1008
+ print()
1009
+
1010
+ def save(self, output: Optional[str] = None, parquet=False):
1011
+ """Saves all tables to CSV
1012
+
1013
+ Args:
1014
+ output: (optional) Filename prefix that is used for all tables
1015
+ """
1016
+
1017
+ if parquet:
1018
+ for table in self.tables:
1019
+ self.write_parquet(table, f"{output}-{table}.parquet")
1020
+
1021
+ else:
1022
+ for table in self.tables:
1023
+ self.write_csv(table, f"{output}-{table}.csv")
1024
+
1025
+
1026
+ def main(argv=None):
1027
+ cmd = argparse.ArgumentParser(
1028
+ prog="adtl",
1029
+ description="Transforms and validates data into CSV given a specification",
1030
+ )
1031
+ cmd.add_argument(
1032
+ "spec",
1033
+ help="specification file to use",
1034
+ )
1035
+ cmd.add_argument("file", help="file to read in")
1036
+ cmd.add_argument(
1037
+ "-o", "--output", help="output file, if blank, writes to standard output"
1038
+ )
1039
+ cmd.add_argument(
1040
+ "--encoding", help="encoding input file is in", default="utf-8-sig"
1041
+ )
1042
+ cmd.add_argument(
1043
+ "--parquet", help="output file is in parquet format", action="store_true"
1044
+ )
1045
+ cmd.add_argument(
1046
+ "-q",
1047
+ "--quiet",
1048
+ help="quiet mode - decrease verbosity, disable progress bar",
1049
+ action="store_true",
1050
+ )
1051
+ cmd.add_argument("--save-report", help="save report in JSON format")
1052
+ cmd.add_argument(
1053
+ "--include-def",
1054
+ action="append",
1055
+ help="include external definition (TOML or JSON)",
1056
+ )
1057
+ cmd.add_argument("--version", action="version", version="%(prog)s " + __version__)
1058
+ args = cmd.parse_args(argv)
1059
+ include_defs = args.include_def or []
1060
+ spec = Parser(args.spec, include_defs=include_defs, quiet=args.quiet)
1061
+
1062
+ # check for incompatible options
1063
+ if spec.header.get("returnUnmatched") and args.parquet:
1064
+ raise ValueError("returnUnmatched and parquet options are incompatible")
1065
+
1066
+ # run adtl
1067
+ adtl_output = spec.parse(args.file, encoding=args.encoding)
1068
+ adtl_output.save(args.output or spec.name, args.parquet)
1069
+ if args.save_report:
1070
+ adtl_output.report.update(
1071
+ dict(
1072
+ encoding=args.encoding,
1073
+ include_defs=include_defs,
1074
+ file=args.file,
1075
+ parser=args.spec,
1076
+ )
1077
+ )
1078
+ with open(args.save_report, "w") as fp:
1079
+ json.dump(adtl_output.report, fp, sort_keys=True, indent=2)
1080
+ else:
1081
+ adtl_output.show_report()
1082
+
1083
+
1084
+ if __name__ == "__main__":
1085
+ main()