ato 2.0.4__py3-none-any.whl → 2.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ato/__init__.py CHANGED
@@ -1 +1 @@
1
- __version__ = '1.12.0'
1
+ __version__ = '2.1.1'
ato/scope.py CHANGED
@@ -1,8 +1,10 @@
1
1
  import argparse
2
+ import ast
2
3
  import hashlib
3
4
  import inspect
4
5
  import pickle
5
6
  import sys
7
+ import textwrap
6
8
  import uuid
7
9
  import warnings
8
10
  from contextlib import contextmanager
@@ -12,6 +14,7 @@ from ato.adict import ADict
12
14
  from inspect import currentframe, getframeinfo
13
15
 
14
16
  from ato.parser import parse_command
17
+ from ato.trace import Canon
15
18
 
16
19
 
17
20
  # safe compile
@@ -72,6 +75,7 @@ def add_func_to_scope(scope, field=None, priority=0, lazy=False, default=False,
72
75
  def decorator(func):
73
76
  _add_func_to_scope(scope, func, field, priority, lazy, default, chain_with)
74
77
  return func
78
+
75
79
  return decorator
76
80
 
77
81
 
@@ -99,6 +103,7 @@ def add_func_to_multi_scope(scopes, field=None, priority=0, lazy=False, default=
99
103
  for scope in scopes:
100
104
  _add_func_to_scope(scope, func, field, priority, lazy, default, chain_with)
101
105
  return func
106
+
102
107
  return decorator
103
108
 
104
109
 
@@ -177,19 +182,37 @@ def _get_func_trace_id(func):
177
182
  return f'{func.__module__}.{func.__qualname__}'
178
183
 
179
184
 
185
+ def _fine_line_numbers(frame_info, filename, line):
186
+ with open(filename, 'r', encoding='utf-8') as f:
187
+ src = f.read()
188
+ tree = ast.parse(src, filename=filename)
189
+ target_node = None
190
+ target_start_line = None
191
+ target_end_line = None
192
+ for node in ast.walk(tree):
193
+ if isinstance(node, (ast.With, ast.AsyncWith)):
194
+ start_line = getattr(node, 'lineno', None)
195
+ end_line = getattr(node, 'end_lineno', None)
196
+ if start_line is not None and end_line is not None and start_line <= line <= end_line:
197
+ if target_node is None or (end_line-start_line < target_end_line-target_start_line):
198
+ target_node = node
199
+ target_start_line = start_line
200
+ target_end_line = end_line
201
+ if target_node is None:
202
+ target_start_line = getattr(frame_info.positions, 'lineno', line)
203
+ target_end_line = getattr(frame_info.positions, 'end_lineno', line)
204
+ return target_start_line, target_end_line
205
+
206
+
180
207
  def _generate_func_fingerprint(func, trace_id=None):
181
- code_obj = func.__code__
182
- code_info = (
183
- code_obj.co_code,
184
- code_obj.co_consts,
185
- code_obj.co_names,
186
- code_obj.co_varnames,
187
- code_obj.co_filename,
188
- code_obj.co_name,
189
- code_obj.co_firstlineno
190
- )
208
+ src = inspect.getsource(func)
209
+ src = textwrap.dedent(src)
210
+ tree = ast.parse(src)
211
+ tree = Canon().visit(tree)
212
+ ast.fix_missing_locations(tree)
213
+ canonical = ast.dump(tree, annotate_fields=True, include_attributes=False)
191
214
  trace_id = _get_func_trace_id(func) if trace_id is None else trace_id
192
- code_hash = hashlib.sha256(repr(code_info).encode()).hexdigest()
215
+ code_hash = hashlib.sha256(canonical.encode('utf-8')).hexdigest()
193
216
  return ADict(**{trace_id: code_hash})
194
217
 
195
218
 
@@ -240,7 +263,8 @@ class Scope:
240
263
  def trace(self, trace_id=None):
241
264
  def decorator(func):
242
265
  self._traced_data.fingerprints.update(_generate_func_fingerprint(func, trace_id=trace_id))
243
- return self(func)
266
+ return func
267
+
244
268
  return decorator
245
269
 
246
270
  def runtime_trace(self, init_fn=None, inspect_fn=None, trace_id=None):
@@ -249,13 +273,15 @@ class Scope:
249
273
  nonlocal trace_id
250
274
  if init_fn is not None:
251
275
  init_fn()
252
- results = inspect_results = self(func)(*args, **kwargs)
276
+ results = inspect_results = func(*args, **kwargs)
253
277
  if inspect_fn is not None:
254
278
  inspect_results = inspect_fn(results)
255
- trace_id = _get_func_trace_id(func) if trace_id is not None else trace_id
279
+ trace_id = _get_func_trace_id(func) if trace_id is None else trace_id
256
280
  inspect_hash = hashlib.sha256(pickle.dumps(inspect_results)).hexdigest()
257
281
  self._traced_data.fingerprints.update({trace_id: inspect_hash})
282
+
258
283
  return inner
284
+
259
285
  return decorator
260
286
 
261
287
  def register(self):
@@ -420,6 +446,7 @@ class Scope:
420
446
  if self.mode == 'ON':
421
447
  args, kwargs = self.get_config_updated_arguments(func, *args, **kwargs)
422
448
  return func(*args, **kwargs)
449
+
423
450
  return inner
424
451
 
425
452
  def __call__(self, func):
@@ -431,10 +458,12 @@ class Scope:
431
458
  def convert_argparse_to_scope(self):
432
459
  args = self.views['_argparse'].config
433
460
  code = f"def argparse({self.name}):\n"
434
- code += '\n'.join([
435
- f' {self.name}.{key} = '+(f"'{value}'" if isinstance(value, str) else f'{value}')
436
- for key, value in args.items()
437
- ])
461
+ code += '\n'.join(
462
+ [
463
+ f' {self.name}.{key} = '+(f"'{value}'" if isinstance(value, str) else f'{value}')
464
+ for key, value in args.items()
465
+ ]
466
+ )
438
467
  return code
439
468
 
440
469
  @classmethod
@@ -456,12 +485,12 @@ class Scope:
456
485
  frame = currentframe().f_back.f_back
457
486
  frame_info = getframeinfo(frame)
458
487
  file_name = frame_info.filename
459
- start = frame_info.positions.lineno
460
- end = frame_info.positions.end_lineno
461
- with open(file_name, 'r') as f:
462
- inner_ctx_lines = list(f.readlines())[start:end]
488
+ line = frame.f_lineno
489
+ start_line, end_line = _fine_line_numbers(frame_info, file_name, line)
490
+ with open(file_name, 'r', encoding='utf-8') as f:
491
+ inner_ctx_lines = list(f.readlines())[start_line:end_line]
463
492
  ctx_name = f"_lazy_context_{str(uuid.uuid4()).replace('-', '_')}"
464
- inner_ctx_lines = [f'def {ctx_name}({scope.name}):']+inner_ctx_lines
493
+ inner_ctx_lines = [f'def {ctx_name}({scope.name}):\n']+inner_ctx_lines
465
494
  global_vars = frame.f_globals
466
495
  local_vars = frame.f_locals
467
496
  exec(compile('\n'.join(inner_ctx_lines), '<string>', 'exec'), global_vars, local_vars)
@@ -489,4 +518,5 @@ class MultiScope:
489
518
  args, kwargs = scope.get_config_updated_arguments(func, *args, **kwargs)
490
519
  scope.apply()
491
520
  return func(*args, **kwargs)
521
+
492
522
  return decorator
ato/trace.py ADDED
@@ -0,0 +1,27 @@
1
+ import ast
2
+
3
+
4
+ class Canon(ast.NodeTransformer):
5
+ def visit_Constant(self, node):
6
+ if isinstance(node.value, (int, float, complex)):
7
+ return ast.copy_location(ast.Constant(node.value), node)
8
+ if isinstance(node.value, str):
9
+ return node
10
+ return node
11
+
12
+ def visit_Attribute(self, node):
13
+ self.generic_visit(node)
14
+ return node
15
+
16
+ def visit_ImportFrom(self, node):
17
+ names = node.names
18
+ if names:
19
+ node.names = sorted(names, key=lambda name: (name.name, name.asname or ''))
20
+ return node
21
+
22
+ def visit_Dict(self, node):
23
+ self.generic_visit(node)
24
+ if all(isinstance(k, ast.Constant) for k in node.keys if k is not None):
25
+ pairs = sorted(zip(node.keys, node.values), key=lambda kv: kv[0].value)
26
+ node.keys, node.values = map(list, zip(*pairs)) if pairs else ([], [])
27
+ return node