retracesoftware-proxy 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,172 @@
1
+ from __future__ import annotations
2
+
3
+ import retracesoftware.functional as functional
4
+ import retracesoftware_utils as utils
5
+
6
+ # from retracesoftware_functional import mapargs, walker, first_arg, if_then_else, compose, observer, anyargs, memoize_one_arg, side_effect, partial, threadwatcher, firstof, notinstance_test, instance_test,typeof, isinstanceof, always
7
+ # from retracesoftware.proxy.proxy import ProxyFactory, InternalProxy, ProxySpec, WrappingProxySpec, ExtendingProxySpec, ExtendingProxy
8
+ from retracesoftware_proxy import thread_id
9
+ # from retracesoftware_stream import ObjectWriter, ObjectReader
10
+ import retracesoftware.stream as stream
11
+
12
+ # from retracesoftware_utils import visitor
13
+ from retracesoftware.install.tracer import Tracer
14
+ from retracesoftware.proxy.record import RecordProxySystem
15
+ from retracesoftware.proxy.replay import ReplayProxySystem
16
+ from datetime import datetime
17
+
18
+ import os
19
+ import sys
20
+ import json
21
+ import enum
22
+ import _thread
23
+ import pickle
24
+ import weakref
25
+ import types
26
+ from pathlib import Path
27
+ import glob, os
28
+ import re
29
+
30
+ # from retracesoftware_proxy import *
31
+ # from retracesoftware_utils import *
32
+ # from retracesoftware.proxy import references
33
+ from retracesoftware.install import patcher
34
+ from retracesoftware.install.config import load_config
35
+ # from retracesoftware.proxy.immutabletypes import ImmutableTypes
36
+ # from retracesoftware.proxy import edgecases
37
+ from retracesoftware.install import globals
38
+ # from retracesoftware.proxy.record import RecordProxyFactory
39
+
40
+
41
+ class DebugWriter:
42
+ __slot__ = ['checkpoint']
43
+
44
+ def __init__(self, checkpoint):
45
+ self.checkpoint = checkpoint
46
+
47
+ def write_call(self, func, *args, **kwargs):
48
+ self.checkpoint({'type': 'call', 'func': ...})
49
+
50
+ def write_result(self, res):
51
+ self.checkpoint({'type': 'result', 'result': res})
52
+
53
+ def write_error(self, *args):
54
+ self.checkpoint({'type': 'error', 'error': tuple(args)})
55
+
56
+ class MethodDescriptor:
57
+ def __init__(self, descriptor):
58
+ self.cls = descriptor.__objclass__
59
+ self.name = descriptor.name
60
+
61
+ def once(*args): return functional.memoize_one_arg(functional.compose(*args))
62
+
63
+ any = functional.firstof
64
+
65
+
66
+ def compose(*args):
67
+ new_args = [item for item in args if item is not None]
68
+ if len(new_args) == 0:
69
+ raise Exception('TODO')
70
+ elif len(new_args) == 1:
71
+ return new_args[0]
72
+ else:
73
+ return functional.compose(*new_args)
74
+
75
+ class SerializedWrappedFunction:
76
+ def __init__(self, func):
77
+ if hasattr(func, '__objclass__'):
78
+ self.cls = func.__objclass__
79
+ elif hasattr(func, '__module__'):
80
+ self.module = func.__module__
81
+
82
+ if hasattr(func, '__name__'):
83
+ self.name = func.__name__
84
+
85
+
86
+ def replaying_proxy_factory(thread_state, is_immutable_type, tracer, next, bind, checkpoint):
87
+
88
+ # def on_new_ext_proxytype(proxytype):
89
+ # assert not issubclass(proxytype, DynamicProxy)
90
+ # bind(proxytype)
91
+ # writer.add_type_serializer(cls = proxytype, serializer = functional.typeof)
92
+
93
+ # bind_new_int_proxy = functional.if_then_else(functional.isinstanceof(InternalProxy), functional.memoize_one_arg(bind), None)
94
+
95
+ # on_ext_call = utils.visitor(from_arg = 1, function = bind_new_int_proxy)
96
+
97
+ def wrap_int_call(handler):
98
+ return functional.observer(
99
+ on_call = tracer('proxy.int.call'),
100
+ on_result = tracer('proxy.int.result'),
101
+ on_error = tracer('proxy.int.error'),
102
+ function = handler)
103
+
104
+ def foo(cls):
105
+ print(f"IN FOO {cls}")
106
+ inst = utils.create_stub_object(cls)
107
+ print(f'FOO {cls} {type(inst)}')
108
+ return inst
109
+
110
+ # def is_stub_type(obj):
111
+ # return type(obj) == type and issubclass(obj, (WrappingProxy, ExtendingProxy))
112
+
113
+ def is_stub_type(obj):
114
+ return type(obj) == type
115
+
116
+ create_stubs = functional.walker(functional.when(is_stub_type, foo))
117
+ # create_stubs = functional.walker(functional.when(is_stub_type, utils.create_stub_object))
118
+
119
+ def wrap_ext_call(handler):
120
+ return functional.observer(
121
+ on_call = tracer('proxy.ext.call'),
122
+ on_result = tracer('proxy.ext.result'),
123
+ on_error = tracer('proxy.ext.error'),
124
+ function = functional.compose(functional.always(next), create_stubs))
125
+
126
+ return ProxyFactory(thread_state = thread_state,
127
+ is_immutable_type = is_immutable_type,
128
+ tracer = tracer,
129
+ on_new_int_proxy = bind,
130
+ # on_new_ext_proxytype = on_new_ext_proxytype,
131
+ wrap_int_call = wrap_int_call,
132
+ wrap_ext_call = wrap_ext_call)
133
+
134
+ def latest_from_pattern(pattern: str) -> str | None:
135
+ """
136
+ Given a strftime-style filename pattern (e.g. "recordings/%Y%m%d_%H%M%S_%f"),
137
+ return the path to the most recent matching file, or None if no files exist.
138
+ """
139
+ # Turn strftime placeholders into '*' for globbing
140
+ # (very simple replacement: %... -> *)
141
+ glob_pattern = re.sub(r"%[a-zA-Z]", "*", pattern)
142
+
143
+ # Find all matching files
144
+ candidates = glob.glob(glob_pattern)
145
+ if not candidates:
146
+ return None
147
+
148
+ # Derive the datetime format from the pattern (basename only)
149
+ base_pattern = os.path.basename(pattern)
150
+
151
+ def parse_time(path: str):
152
+ name = os.path.basename(path)
153
+ return datetime.strptime(name, base_pattern)
154
+
155
+ # Find the latest by parsed timestamp
156
+ latest = max(candidates, key=parse_time)
157
+ return latest
158
+
159
+ # class Reader:
160
+
161
+ # def __init__(self, objectreader):
162
+ # self.objectreader = objectreader
163
+
164
+ # self.objectreader()
165
+
166
+ def tracing_config(config):
167
+ level = os.environ.get('RETRACE_DEBUG', config['default_tracing_level'])
168
+ return config['tracing_levels'].get(level, {})
169
+
170
+
171
+ def install(create_system):
172
+ return patcher.install(config = load_config('config.json'), create_system = create_system)
@@ -0,0 +1,494 @@
1
+ import inspect
2
+ # from retrace_utils import _intercept
3
+ # import _intercept
4
+ import types
5
+ import re
6
+ import sys
7
+ import pdb
8
+ import builtins
9
+ import types
10
+ import functools
11
+ import traceback
12
+ import functools
13
+ import importlib
14
+ import gc
15
+ import os
16
+ import atexit
17
+ import threading
18
+
19
+ from retracesoftware.install.typeutils import modify
20
+ # from retracesoftware.proxy.immutabletypes import ImmutableTypes
21
+ from retracesoftware.install.record import record_system
22
+ from retracesoftware.install.replay import replay_system
23
+ from retracesoftware.install.config import load_config
24
+
25
+ from functools import wraps
26
+
27
+ import retracesoftware_functional as functional
28
+ import retracesoftware_utils as utils
29
+
30
+ from retracesoftware.install import edgecases
31
+ from functools import partial
32
+ # from retrace_utils.intercept import proxy
33
+
34
+ # from retrace_utils.intercept.typeutils import *
35
+ # from retrace_utils.intercept.proxytype import DynamicProxyFactory, proxytype
36
+
37
+ from retracesoftware.install.predicate import PredicateBuilder
38
+
39
+ def find_attr(mro, name):
40
+ for cls in mro:
41
+ if name in cls.__dict__:
42
+ return cls.__dict__[name]
43
+
44
+ def is_descriptor(obj):
45
+ return hasattr(obj, '__get__') or hasattr(obj, '__set__') or hasattr(obj, '__delete__')
46
+
47
+ # default_exclude = ['__class__', '__getattribute__', '__init_subclass__', '__dict__', '__del__', '__new__']
48
+
49
+ def is_function_type(cls):
50
+ return issubclass(cls, types.BuiltinFunctionType) or issubclass(cls, types.FunctionType)
51
+
52
+ def select_keys(keys, dict):
53
+ return {key: dict[key] for key in keys if key in dict}
54
+
55
+ def map_values(f, dict):
56
+ return {key: f(value) for key,value in dict.items()}
57
+
58
+ def common_keys(dict, *dicts):
59
+ common_keys = utils.set(dict)
60
+ for d in dicts:
61
+ common_keys &= d.keys()
62
+
63
+ assert isinstance(common_keys, utils.set)
64
+
65
+ return common_keys
66
+
67
+ def intersection(*dicts):
68
+ return { key: tuple(d[key] for d in dicts) for key in common_keys(*dicts) }
69
+
70
+ def intersection_apply(f, *dicts):
71
+ return map_values(lambda vals: f(*vals), intersection(*dicts))
72
+
73
+ def resolve(path):
74
+ module, sep, name = path.rpartition('.')
75
+ if module == None: module = 'builtins'
76
+
77
+ return getattr(importlib.import_module(module), name)
78
+
79
+ # def sync(spec, module_dict):
80
+ # for name,properties in spec.items():
81
+
82
+ # cls = module_dict[name]
83
+
84
+ # orig_init = cls.__init__
85
+
86
+ # def __init__(inst, *args, **kwargs):
87
+ # orig_init(inst, *args, **kwargs)
88
+ # for prop in properties:
89
+ # self.system.add_sync(getattr(inst, prop))
90
+
91
+ def container_replace(container, old, new):
92
+ if isinstance(container, dict):
93
+ if old in container:
94
+ elem = container.pop(old)
95
+ container[new] = elem
96
+ container_replace(container, old, new)
97
+ else:
98
+ for key,value in container.items():
99
+ if value is old:
100
+ container[key] = new
101
+ return True
102
+ elif isinstance(container, list):
103
+ for i,value in enumerate(container):
104
+ if value is old:
105
+ container[i] = new
106
+ return True
107
+ elif isinstance(container, set):
108
+ container.remove(old)
109
+ container.add(new)
110
+ return True
111
+ else:
112
+ return False
113
+
114
+ def phase(func):
115
+ func.is_phase = True # add marker attribute
116
+ return func
117
+
118
+ def patch(func):
119
+ @wraps(func)
120
+ def wrapper(self, spec, mod_dict):
121
+ if isinstance(spec, str):
122
+ return wrapper(self, [spec], mod_dict)
123
+ elif isinstance(spec, list):
124
+ return {name: func(self, mod_dict[name]) for name in spec if name in mod_dict}
125
+ elif isinstance(spec, dict):
126
+ return {name: func(self, mod_dict[name], value) for name, value in spec.items() if name in mod_dict}
127
+ else:
128
+ raise Exception('TODO')
129
+
130
+ wrapper.is_phase = True
131
+ return wrapper
132
+
133
+ def superdict(cls):
134
+ result = {}
135
+ for cls in list(reversed(cls.__mro__))[1:]:
136
+ result.update(cls.__dict__)
137
+
138
+ return result
139
+
140
+ def is_method_descriptor(obj):
141
+ return isinstance(obj, types.FunctionType) or \
142
+ (isinstance(obj, (types.WrapperDescriptorType, types.MethodDescriptorType)) and obj.__objclass__ != object)
143
+
144
+ def wrap_method_descriptors(wrapper, prefix, base):
145
+ slots = {"__slots__": () }
146
+
147
+ extended = type(f'{prefix}.{base.__module__}.{base.__name__}', (base,), {"__slots__": () })
148
+
149
+ blacklist = ['__getattribute__', '__hash__', '__del__']
150
+
151
+ for name,value in superdict(base).items():
152
+ if name not in blacklist:
153
+ if is_method_descriptor(value):
154
+ setattr(extended, name, wrapper(value))
155
+
156
+ return extended
157
+
158
+ class Patcher:
159
+
160
+ def __init__(self, thread_state, config, system,
161
+ immutable_types,
162
+ on_function_proxy = None,
163
+ debug_level = 0,
164
+ post_commit = None):
165
+
166
+ # validate(config)
167
+ utils.set_thread_id(0)
168
+ self.thread_counter = system.sync(utils.counter(1))
169
+ # self.set_thread_number = set_thread_number
170
+
171
+ self.thread_state = thread_state
172
+ self.debug_level = debug_level
173
+ self.on_function_proxy = on_function_proxy
174
+ self.modules = config['modules']
175
+ self.immutable_types_set = immutable_types
176
+ self.predicate = PredicateBuilder()
177
+ self.system = system
178
+ self.type_attribute_filter = self.predicate(config['type_attribute_filter'])
179
+ self.post_commit = post_commit
180
+ self.exclude_paths = [re.compile(s) for s in config.get('exclude_paths', [])]
181
+
182
+ def is_phase(name): return getattr(getattr(self, name, None), "is_phase", False)
183
+
184
+ self.phases = [(name, getattr(self, name)) for name in Patcher.__dict__.keys() if is_phase(name)]
185
+
186
+ def log(self, *args):
187
+ self.system.tracer.log(*args)
188
+
189
+ def path_predicate(self, path):
190
+ for exclude in self.exclude_paths:
191
+ if exclude.match(str(path)) is not None:
192
+ # print(f'in path_predicate, excluding {path}')
193
+ return False
194
+ return True
195
+
196
+ @property
197
+ def disable(self):
198
+ return self.thread_state.select('disabled')
199
+
200
+ def proxyable(self, name, obj):
201
+ if name.startswith('__') and name.endswith('__'):
202
+ return False
203
+
204
+ if isinstance(obj, (str, int, dict, list, tuple)):
205
+ return False
206
+
207
+ if isinstance(obj, type):
208
+ return not issubclass(obj, BaseException) and obj not in self.immutable_types_set
209
+ else:
210
+ return type(obj) not in self.immutable_types_set
211
+
212
+ @phase
213
+ def immutable_types(self, spec, mod_dict):
214
+ if isinstance(spec, str):
215
+ return self.immutable_types([spec], mod_dict)
216
+
217
+ for name in spec:
218
+ if name in mod_dict:
219
+ if isinstance(mod_dict[name], type):
220
+ self.immutable_types_set.add(mod_dict[name])
221
+ else:
222
+ raise Exception(f'Tried to add "{name}" - {mod_dict[name]} which isn\'t a type to immutable')
223
+
224
+ @patch
225
+ def proxy(self, value):
226
+ return self.system(value)
227
+
228
+ @phase
229
+ def proxy_all_except(self, spec, mod_dict):
230
+
231
+ all_except = set(spec)
232
+
233
+ def proxyable(name, value):
234
+ return name not in all_except and self.proxyable(name, value)
235
+
236
+ return {key: self.system(value) for key,value in mod_dict.items() if proxyable(key, value)}
237
+
238
+ @phase
239
+ def proxy_type_attributes(self, spec, mod_dict):
240
+ for classname, attributes in spec.items():
241
+ if classname in mod_dict:
242
+ cls = mod_dict[classname]
243
+ if isinstance(cls, type):
244
+ for name in attributes:
245
+ attr = find_attr(cls.__mro__, name)
246
+ if attr is not None and (callable(attr) or is_descriptor(attr)):
247
+ proxied = self.system(attr)
248
+ # proxied = self.proxy(attr)
249
+
250
+ with modify(cls):
251
+ setattr(cls, name, proxied)
252
+ else:
253
+ raise Exception(f"Cannot patch attributes for {cls.__module__}.{cls.__name__} as object is: {cls} and not a type")
254
+
255
+ @phase
256
+ def replace(self, spec, mod_dict):
257
+ return {key: resolve(value) for key,value in spec.items()}
258
+
259
+ @patch
260
+ def patch_start_new_thread(self, value):
261
+ def start_new_thread(self, function, *args):
262
+ # synchronized, replay shoudl yield correct number
263
+ thread_id = self.thread_counter()
264
+
265
+ def threadrunner(*args, **kwargs):
266
+ nonlocal thread_id
267
+ utils.set_thread_id(thread_id)
268
+
269
+ with self.thread_state.select('internal'):
270
+ # if self.tracing:
271
+ # FrameTracer.install(self.thread_state.dispatch(noop, internal = self.checkpoint))
272
+ return function(*args, **kwargs)
273
+
274
+ return value(threadrunner, *args)
275
+
276
+ return self.thread_state.dispatch(value, internal = start_new_thread)
277
+
278
+ @phase
279
+ def wrappers(self, spec, mod_dict):
280
+ return intersection_apply(lambda path, value: resolve(path)(value), spec, mod_dict)
281
+
282
+ @patch
283
+ def patch_exec(self, exec):
284
+ def is_module_exec(source):
285
+ return isinstance(source, types.CodeType) and \
286
+ source.co_name == '<module>' and inspect.getmodule(source)
287
+
288
+ def exec_wrapper(source, *args,**kwargs):
289
+ if self.thread_state.value != 'boostrap' and is_module_exec(source):
290
+ with self.thread_state.select('disabled'):
291
+ module = inspect.getmodule(source)
292
+ with self.thread_state.select('internal'):
293
+ result = exec(source, *args, **kwargs)
294
+
295
+ self(module)
296
+
297
+ return result
298
+ else:
299
+ return exec(source, *args,**kwargs)
300
+
301
+ return exec_wrapper
302
+
303
+ @patch
304
+ def sync_types(self, value):
305
+ return wrap_method_descriptors(self.system.sync, "retrace", value)
306
+
307
+ @phase
308
+ def with_state(self, spec, mod_dict):
309
+
310
+ updates = {}
311
+
312
+ for state,elems in spec.items():
313
+
314
+ def wrap(obj): return self.thread_state.wrap(desired_state = state, function = obj)
315
+
316
+ updates.update(map_values(wrap, select_keys(elems, mod_dict)))
317
+
318
+ return updates
319
+
320
+ @patch
321
+ def patch_extension_exec(self, exec):
322
+
323
+ def wrapper(module):
324
+ res = exec(module)
325
+ self(module)
326
+ return res
327
+
328
+ return wrapper
329
+
330
+ @patch
331
+ def path_predicates(self, func, param):
332
+ signature = inspect.signature(func).parameters
333
+
334
+ try:
335
+ index = list(signature.keys()).index(param)
336
+ except ValueError:
337
+ print(f'parameter {param} not in: {signature.keys()} {type(func)} {func}')
338
+ raise
339
+
340
+ param = functional.param(name = param, index = index)
341
+
342
+ assert callable(param)
343
+
344
+ return functional.if_then_else(
345
+ test = functional.compose(param, self.path_predicate),
346
+ then = func,
347
+ otherwise = self.thread_state.wrap('disabled', func))
348
+
349
+ @phase
350
+ def wrap(self, spec, mod_dict):
351
+ updates = {}
352
+
353
+ for path, wrapper_name in spec.items():
354
+
355
+ parts = path.split('.')
356
+ name = parts[0]
357
+ if name in mod_dict:
358
+ value = mod_dict[name]
359
+
360
+ if len(parts) == 1:
361
+ updates[name] = resolve(wrapper_name)(value)
362
+ elif len(parts) == 2:
363
+ member = getattr(value, parts[1], None)
364
+ if member:
365
+ setattr(value, parts[1], resolve(wrapper_name)(member))
366
+ else:
367
+ raise Exception('TODO')
368
+
369
+ return updates
370
+
371
+ def updates(self, spec, mod_dict):
372
+ updates = {}
373
+
374
+ for phase,func in self.phases:
375
+ if phase in spec:
376
+ self.log('install.module.phase', phase)
377
+
378
+ phase_updates = func(spec[phase], mod_dict | updates)
379
+
380
+ if phase_updates:
381
+ self.log('install.module.phase.results', list(phase_updates.keys()))
382
+ updates |= phase_updates
383
+
384
+ return updates
385
+
386
+ def configs(self, module):
387
+ for name,value in sys.modules.items():
388
+ if value is module and name in self.modules:
389
+ yield name
390
+
391
+ def patch_module_with_name(self, mod_name, module):
392
+ with self.disable:
393
+ self.log('install.module', mod_name)
394
+
395
+ # self.system.log(f'patching module: {mod_name}')
396
+
397
+ spec = self.modules.get(mod_name)
398
+
399
+ updates = self.updates(spec = spec, mod_dict = module.__dict__)
400
+
401
+ originals = select_keys(updates.keys(), module.__dict__)
402
+
403
+ module.__dict__.update(updates)
404
+
405
+ for name, value in originals.items():
406
+ for ref in gc.get_referrers(value):
407
+ if ref is not originals:
408
+ container_replace(container = ref, old = value, new = updates[name])
409
+
410
+ module.__retrace__ = originals
411
+
412
+ if self.post_commit:
413
+ self.post_commit(mod_name, updates)
414
+
415
+ def __call__(self, module):
416
+
417
+ if not hasattr(module, '__retrace__'):
418
+ configs = list(self.configs(module))
419
+
420
+ if len(configs) > 0:
421
+ if len(configs) > 1:
422
+ raise Exception(f'TODO')
423
+ else:
424
+ try:
425
+ self.patch_module_with_name(configs[0], module)
426
+ except Exception as error:
427
+ raise Exception(f'Error patching module: {configs[0]}') from error
428
+
429
+ return module
430
+
431
+ def env_truthy(key, default=False):
432
+ value = os.getenv(key)
433
+ if value is None:
434
+ return default
435
+ return value.strip().lower() in ("1", "true", "yes", "on")
436
+
437
+ def install(mode):
438
+
439
+ create_system = None
440
+
441
+ if mode == 'record':
442
+ create_system = record_system
443
+ elif mode == 'replay':
444
+ create_system = replay_system
445
+ else:
446
+ raise Exception(f'mode: {mode} unsupported')
447
+
448
+ config = load_config('config.json')
449
+
450
+ states = [x for x in config['states'] if isinstance(x, str)]
451
+
452
+ thread_state = utils.ThreadState(*states)
453
+
454
+ immutable_types = set()
455
+
456
+ if 'RETRACE_RECORDING_PATH' in os.environ:
457
+ config['recording_path'] = os.environ['RETRACE_RECORDING_PATH']
458
+
459
+ # immutable_types = ImmutableTypes(config)
460
+ config['verbose'] = env_truthy('RETRACE_VERBOSE')
461
+
462
+ system = create_system(thread_state = thread_state,
463
+ immutable_types = immutable_types,
464
+ config = config)
465
+
466
+ patcher = Patcher(thread_state = thread_state,
467
+ config = config,
468
+ system = system,
469
+ post_commit = getattr(system, 'on_patched', None),
470
+ immutable_types = immutable_types)
471
+
472
+ def at_exit(): thread_state.value = 'disabled'
473
+
474
+ with thread_state.select('internal'):
475
+
476
+ atexit.register(lambda: at_exit)
477
+
478
+ for module in sys.modules.values():
479
+ patcher(module)
480
+
481
+ for library in config.get('preload', []):
482
+ importlib.import_module(library)
483
+
484
+ # # if env_truthy('RETRACE_TRACE_CALLS'):
485
+ # # trace_calls(system = retracesystem)
486
+
487
+ thread_state.value = 'internal'
488
+
489
+ # import threading
490
+ # threading.current_thread().retrace = system
491
+
492
+ threading.current_thread().__retrace__ = system
493
+
494
+ return system