codesuture 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,607 @@
1
+ """
2
+ Synthesises guard + original bytecode for all deterministic strategies.
3
+ """
4
+ from bytecode import Bytecode, Instr, Label, Compare
5
+ from codesuture.pattern_matcher import PatchSpec
6
+
7
+ class PatchValidationError(Exception):
8
+ pass
9
+
10
+ class PatchRejectedError(Exception):
11
+ pass
12
+
13
+ def validate_patch(original_code, patched_code):
14
+ import dis
15
+
16
+ _SYNTH_INTERNAL_NAMES = frozenset({
17
+ '_codesuture_cont', '_codesuture_key', '_lp_chain',
18
+ })
19
+ allowed = set(original_code.co_varnames) | _SYNTH_INTERNAL_NAMES
20
+ for instr in dis.get_instructions(patched_code):
21
+ if instr.opname == 'LOAD_FAST':
22
+ name = instr.argval
23
+ if name not in allowed:
24
+ raise PatchValidationError(f"Patch rejected: LOAD_FAST '{name}' not in co_varnames — bytecode would corrupt frame. Patch was not applied.")
25
+
26
+ def propagate_patch(original_func, patched_code) -> int:
27
+ import gc
28
+ original_code = original_func.__code__
29
+ propagated = 0
30
+
31
+ for ref in gc.get_referrers(original_code):
32
+ if ref is original_func:
33
+ continue
34
+
35
+ if hasattr(ref, '__func__') and hasattr(ref.__func__, '__code__'):
36
+ if ref.__func__.__code__ is original_code:
37
+ ref.__func__.__code__ = patched_code
38
+ propagated += 1
39
+
40
+ elif hasattr(ref, '__code__') and ref.__code__ is original_code:
41
+ ref.__code__ = patched_code
42
+ propagated += 1
43
+
44
+ original_func.__code__ = patched_code
45
+
46
+ if propagated > 0:
47
+ print(f"[CodeSuture] Propagated patch to {propagated} additional "
48
+ f"live reference(s) of {original_func.__qualname__}.")
49
+ return propagated
50
+
51
+ def synthesize_guarded_code(original_code, spec: PatchSpec) -> Bytecode:
52
+ if spec.strategy in ('subscript_guard', 'key_guard', 'dict_get_guard'):
53
+ res = _build_subscript_guarded_code(original_code, spec.var_name, spec.key_name, spec.default_value)
54
+ elif spec.strategy == 'chain_subscript_guard':
55
+ res = _build_chain_subscript_guarded_code(original_code, spec.var_name, spec.key_name, spec.default_value)
56
+ elif spec.strategy == 'division_guard':
57
+ res = _build_division_guarded_code(original_code, spec.var_name, spec.default_value)
58
+ elif spec.strategy == 'null_guard':
59
+ if spec.key_name is not None:
60
+ res = _build_attr_null_guarded_code(original_code, spec.var_name, spec.key_name, spec.default_value)
61
+ else:
62
+ res = _build_null_guarded_code(original_code, spec.var_name, spec.default_value)
63
+ elif spec.strategy in ('index_guard', 'list_bound_guard'):
64
+ res = _build_index_guarded_code(original_code, spec.var_name, spec.list_len_var, spec.default_value)
65
+ elif spec.strategy == 'file_guard':
66
+ res = _build_file_guarded_code(original_code, spec.var_name, spec.default_value)
67
+ elif spec.strategy == 'str_coerce_guard':
68
+ res = _build_str_coerce_guarded_code(original_code, spec.var_name)
69
+ elif spec.strategy == 'callable_guard':
70
+ res = _build_callable_guarded_code(original_code, spec.var_name, spec.default_value)
71
+ elif spec.strategy == 'type_coercion_guard':
72
+ res = _build_type_coercion_guarded_code(original_code, spec.var_name, spec.default_value)
73
+ elif spec.strategy == 'return_guard':
74
+ res = _build_return_guarded_code(original_code, spec.default_value)
75
+ elif spec.strategy == 'autonomous_rule':
76
+ new_module_code = compile(spec.default_value, "<autonomous>", "exec")
77
+ found = False
78
+ for const in new_module_code.co_consts:
79
+ if type(const).__name__ == 'code' and const.co_name == original_code.co_name:
80
+ res = Bytecode.from_code(const)
81
+ found = True
82
+ break
83
+ if not found:
84
+ raise ValueError("Could not find replacement function code in autonomous rule.")
85
+ else:
86
+ raise ValueError(f"Unknown strategy: {spec.strategy}")
87
+
88
+ if getattr(spec, 'is_async', False):
89
+ _ensure_resume_first(res)
90
+
91
+ patched_code = res.to_code()
92
+ validate_patch(original_code, patched_code)
93
+
94
+ from codesuture.diff_guard import semantic_diff
95
+ diff = semantic_diff(original_code, patched_code, spec.strategy)
96
+ if diff.rejected:
97
+ print(f"[CodeSuture] {diff.reason}")
98
+ raise PatchRejectedError(diff.reason)
99
+
100
+ return res
101
+
102
+ def _ensure_resume_first(bc: Bytecode):
103
+
104
+ instrs = list(bc)
105
+
106
+ resume_idx = None
107
+ for i, instr in enumerate(instrs):
108
+ if isinstance(instr, Instr) and instr.name == 'RESUME' and instr.arg == 0:
109
+ resume_idx = i
110
+ break
111
+
112
+ if resume_idx is None:
113
+
114
+ bc.insert(0, Instr('RESUME', 0))
115
+ return
116
+
117
+ if resume_idx == 0:
118
+
119
+ return
120
+
121
+ resume_instr = instrs.pop(resume_idx)
122
+ instrs.insert(0, resume_instr)
123
+ bc.clear()
124
+ bc.extend(instrs)
125
+
126
+ def _build_null_guarded_code(original_code, var_name, default):
127
+ bc = Bytecode.from_code(original_code)
128
+ instrs = list(bc)
129
+
130
+ for idx in range(len(instrs) - 1):
131
+ instr = instrs[idx]
132
+ next_instr = instrs[idx + 1]
133
+ if (
134
+ isinstance(instr, Instr)
135
+ and isinstance(next_instr, Instr)
136
+ and instr.name == 'LOAD_CONST'
137
+ and instr.arg is None
138
+ and next_instr.name == 'STORE_FAST'
139
+ and next_instr.arg == var_name
140
+ ):
141
+ bc[idx] = Instr('LOAD_CONST', default, lineno=instr.lineno)
142
+ return bc
143
+
144
+ crash_idx = None
145
+ for idx in range(len(instrs) - 1):
146
+ instr = instrs[idx]
147
+ next_instr = instrs[idx + 1]
148
+ if (isinstance(instr, Instr) and instr.name == 'LOAD_FAST' and instr.arg == var_name
149
+ and isinstance(next_instr, Instr) and next_instr.name in ('LOAD_ATTR', 'LOAD_METHOD')):
150
+ crash_idx = idx
151
+ break
152
+
153
+ insert_after_idx = None
154
+ search_end = crash_idx if crash_idx is not None else len(instrs)
155
+ for idx in range(search_end - 1, -1, -1):
156
+ instr = instrs[idx]
157
+ if isinstance(instr, Instr) and instr.name == 'STORE_FAST' and instr.arg == var_name:
158
+ insert_after_idx = idx
159
+ break
160
+
161
+ skip = Label()
162
+ patch = [
163
+ Instr('LOAD_FAST', var_name),
164
+ Instr('LOAD_CONST', None),
165
+ Instr('IS_OP', 0),
166
+ Instr('POP_JUMP_FORWARD_IF_FALSE', skip),
167
+ Instr('LOAD_CONST', default),
168
+ Instr('STORE_FAST', var_name),
169
+ skip
170
+ ]
171
+
172
+ if insert_after_idx is not None:
173
+
174
+ pos = insert_after_idx + 1
175
+ else:
176
+
177
+ pos = 0
178
+ for i, instr in enumerate(bc):
179
+ if isinstance(instr, Instr) and instr.name == 'RESUME':
180
+ pos = i + 1
181
+ break
182
+
183
+ for instr in reversed(patch):
184
+ bc.insert(pos, instr)
185
+ return bc
186
+
187
+ def _build_attr_null_guarded_code(original_code, local_var, attr_chain, default):
188
+
189
+ bc = Bytecode.from_code(original_code)
190
+ instrs = list(bc)
191
+
192
+ has_store = any(
193
+ isinstance(instr, Instr) and instr.name == 'STORE_FAST' and instr.arg == local_var
194
+ for instr in instrs
195
+ )
196
+
197
+ if has_store:
198
+
199
+ crash_idx = None
200
+ for idx in range(len(instrs) - 1):
201
+ instr = instrs[idx]
202
+ next_instr = instrs[idx + 1]
203
+ if (isinstance(instr, Instr) and instr.name == 'LOAD_FAST' and instr.arg == local_var
204
+ and isinstance(next_instr, Instr) and next_instr.name in ('LOAD_ATTR', 'LOAD_METHOD')):
205
+ crash_idx = idx
206
+ break
207
+
208
+ insert_after_idx = None
209
+ search_end = crash_idx if crash_idx is not None else len(instrs)
210
+ for idx in range(search_end - 1, -1, -1):
211
+ instr = instrs[idx]
212
+ if isinstance(instr, Instr) and instr.name == 'STORE_FAST' and instr.arg == local_var:
213
+ insert_after_idx = idx
214
+ break
215
+
216
+ skip = Label()
217
+ patch = [
218
+ Instr('LOAD_FAST', local_var),
219
+ Instr('LOAD_CONST', None),
220
+ Instr('IS_OP', 0),
221
+ Instr('POP_JUMP_FORWARD_IF_FALSE', skip),
222
+ Instr('LOAD_CONST', default),
223
+ Instr('RETURN_VALUE'),
224
+ skip
225
+ ]
226
+
227
+ pos = (insert_after_idx + 1) if insert_after_idx is not None else 0
228
+ for instr in reversed(patch):
229
+ bc.insert(pos, instr)
230
+ return bc
231
+
232
+ return_default = Label()
233
+ end_guard = Label()
234
+
235
+ patch = [Instr('LOAD_FAST', local_var)]
236
+ for attr in attr_chain:
237
+ patch.extend([
238
+ Instr('COPY', 1),
239
+ Instr('LOAD_CONST', None),
240
+ Instr('IS_OP', 0),
241
+ Instr('POP_JUMP_FORWARD_IF_TRUE', return_default),
242
+ Instr('LOAD_ATTR', attr)
243
+ ])
244
+
245
+ patch.extend([
246
+ Instr('COPY', 1),
247
+ Instr('LOAD_CONST', None),
248
+ Instr('IS_OP', 0),
249
+ Instr('POP_JUMP_FORWARD_IF_TRUE', return_default),
250
+
251
+ Instr('POP_TOP'),
252
+ Instr('JUMP_FORWARD', end_guard),
253
+
254
+ return_default,
255
+
256
+ Instr('POP_TOP'),
257
+ Instr('LOAD_CONST', default),
258
+ Instr('RETURN_VALUE'),
259
+
260
+ end_guard
261
+ ])
262
+
263
+ idx = 0
264
+ for i, instr in enumerate(bc):
265
+ if isinstance(instr, Instr) and instr.name == 'RESUME':
266
+ idx = i + 1
267
+ break
268
+ for instr in reversed(patch):
269
+ bc.insert(idx, instr)
270
+ return bc
271
+
272
+ def _build_division_guarded_code(original_code, var_name, default):
273
+ bc = Bytecode.from_code(original_code)
274
+ new_instrs = []
275
+ replaced_count = 0
276
+ for instr in bc:
277
+ if isinstance(instr, Instr) and (instr.name == 'BINARY_TRUE_DIVIDE' or (instr.name == 'BINARY_OP' and instr.arg == 11)):
278
+ skip = Label()
279
+ new_instrs.append(Instr('COPY', 1))
280
+ new_instrs.append(Instr('LOAD_CONST', 0))
281
+ new_instrs.append(Instr('COMPARE_OP', Compare.GT))
282
+ new_instrs.append(Instr('POP_JUMP_FORWARD_IF_TRUE', skip))
283
+ new_instrs.append(Instr('POP_TOP'))
284
+ new_instrs.append(Instr('LOAD_CONST', default))
285
+ new_instrs.append(skip)
286
+ new_instrs.append(instr)
287
+ replaced_count += 1
288
+ else:
289
+ new_instrs.append(instr)
290
+ if replaced_count > 0:
291
+ print(f"[CodeSuture] Patched {replaced_count} occurrences of the failing expression pattern in {original_code.co_name}.")
292
+ bc.clear()
293
+ bc.extend(new_instrs)
294
+ return bc
295
+
296
+ def _build_subscript_guarded_code(original_code, container_var, key_name_or_var, default):
297
+ bc = Bytecode.from_code(original_code)
298
+ new_instrs = []
299
+ replaced_count = 0
300
+ for instr in bc:
301
+ if isinstance(instr, Instr) and instr.name == 'BINARY_SUBSCR' and replaced_count == 0:
302
+ skip_none = Label()
303
+ end = Label()
304
+ new_instrs.append(Instr('STORE_FAST', '_codesuture_key'))
305
+ new_instrs.append(Instr('STORE_FAST', '_codesuture_cont'))
306
+ new_instrs.append(Instr('LOAD_FAST', '_codesuture_cont'))
307
+ new_instrs.append(Instr('LOAD_CONST', None))
308
+ new_instrs.append(Instr('COMPARE_OP', Compare.EQ))
309
+ new_instrs.append(Instr('POP_JUMP_FORWARD_IF_FALSE', skip_none))
310
+ new_instrs.append(Instr('LOAD_CONST', default))
311
+ new_instrs.append(Instr('JUMP_FORWARD', end))
312
+ new_instrs.append(skip_none)
313
+ new_instrs.append(Instr('LOAD_FAST', '_codesuture_cont'))
314
+ new_instrs.append(Instr('LOAD_METHOD', 'get'))
315
+ new_instrs.append(Instr('LOAD_FAST', '_codesuture_key'))
316
+ new_instrs.append(Instr('LOAD_CONST', default))
317
+ new_instrs.append(Instr('PRECALL', 2))
318
+ new_instrs.append(Instr('CALL', 2))
319
+ new_instrs.append(end)
320
+ replaced_count += 1
321
+ else:
322
+ new_instrs.append(instr)
323
+ if replaced_count > 0:
324
+ print(f"[CodeSuture] Patched {replaced_count} occurrences of the failing expression pattern in {original_code.co_name}.")
325
+ bc.clear()
326
+ bc.extend(new_instrs)
327
+ return bc
328
+
329
+ def _build_chain_subscript_guarded_code(original_code, root_var, keys, default):
330
+
331
+ bc = Bytecode.from_code(original_code)
332
+ instrs = list(bc)
333
+ new_instrs = []
334
+ num_keys = len(keys)
335
+ pattern_len = 1 + num_keys * 2
336
+
337
+ i = 0
338
+ replaced_count = 0
339
+ while i < len(instrs):
340
+ if _match_chain(instrs, i, root_var, keys):
341
+ new_instrs.extend(_gen_chain_get(root_var, keys, default))
342
+ i += pattern_len
343
+ replaced_count += 1
344
+ continue
345
+ new_instrs.append(instrs[i])
346
+ i += 1
347
+
348
+ if replaced_count > 0:
349
+ print(f"[CodeSuture] Patched {replaced_count} occurrences of the failing expression pattern in {original_code.co_name}.")
350
+ bc.clear()
351
+ bc.extend(new_instrs)
352
+ return bc
353
+
354
+ def _match_chain(instrs, start, root_var, keys):
355
+
356
+ pos = start
357
+ if pos >= len(instrs):
358
+ return False
359
+ i0 = instrs[pos]
360
+ if not (isinstance(i0, Instr) and i0.name == 'LOAD_FAST' and i0.arg == root_var):
361
+ return False
362
+ pos += 1
363
+ for key in keys:
364
+ if pos + 1 >= len(instrs):
365
+ return False
366
+ ld = instrs[pos]
367
+ if not isinstance(ld, Instr):
368
+ return False
369
+ if not ((ld.name == 'LOAD_CONST' and ld.arg == key) or
370
+ (ld.name == 'LOAD_FAST' and ld.arg == key)):
371
+ return False
372
+ pos += 1
373
+ bs = instrs[pos]
374
+ if not (isinstance(bs, Instr) and bs.name == 'BINARY_SUBSCR'):
375
+ return False
376
+ pos += 1
377
+ return True
378
+
379
+ def _gen_chain_get(root_var, keys, default):
380
+
381
+ out = []
382
+ out.append(Instr('LOAD_FAST', root_var))
383
+ out.append(Instr('STORE_FAST', '_lp_chain'))
384
+
385
+ for key in keys[:-1]:
386
+ skip = Label()
387
+ out.append(Instr('LOAD_FAST', '_lp_chain'))
388
+ out.append(Instr('LOAD_CONST', None))
389
+ out.append(Instr('COMPARE_OP', Compare.EQ))
390
+ out.append(Instr('POP_JUMP_FORWARD_IF_TRUE', skip))
391
+ out.append(Instr('LOAD_FAST', '_lp_chain'))
392
+ out.append(Instr('LOAD_METHOD', 'get'))
393
+ out.append(Instr('LOAD_CONST', key))
394
+ out.append(Instr('LOAD_CONST', None))
395
+ out.append(Instr('PRECALL', 2))
396
+ out.append(Instr('CALL', 2))
397
+ out.append(Instr('STORE_FAST', '_lp_chain'))
398
+ out.append(skip)
399
+
400
+ last = keys[-1]
401
+ skip_last = Label()
402
+ end = Label()
403
+ out.append(Instr('LOAD_FAST', '_lp_chain'))
404
+ out.append(Instr('LOAD_CONST', None))
405
+ out.append(Instr('COMPARE_OP', Compare.EQ))
406
+ out.append(Instr('POP_JUMP_FORWARD_IF_TRUE', skip_last))
407
+ out.append(Instr('LOAD_FAST', '_lp_chain'))
408
+ out.append(Instr('LOAD_METHOD', 'get'))
409
+ out.append(Instr('LOAD_CONST', last))
410
+ out.append(Instr('LOAD_CONST', default))
411
+ out.append(Instr('PRECALL', 2))
412
+ out.append(Instr('CALL', 2))
413
+ out.append(Instr('JUMP_FORWARD', end))
414
+ out.append(skip_last)
415
+ out.append(Instr('LOAD_CONST', default))
416
+ out.append(end)
417
+ return out
418
+
419
+ def _build_index_guarded_code(original_code, idx_var, list_var, default):
420
+ bc = Bytecode.from_code(original_code)
421
+ skip = Label()
422
+ patch = [
423
+ Instr('LOAD_FAST', idx_var),
424
+ Instr('LOAD_GLOBAL', (True, 'len')),
425
+ Instr('LOAD_FAST', list_var),
426
+ Instr('PRECALL', 1),
427
+ Instr('CALL', 1),
428
+ Instr('COMPARE_OP', Compare.GE),
429
+ Instr('POP_JUMP_FORWARD_IF_FALSE', skip),
430
+ Instr('LOAD_CONST', 0),
431
+ Instr('STORE_FAST', idx_var),
432
+ skip
433
+ ]
434
+ idx = 0
435
+ for i, instr in enumerate(bc):
436
+ if isinstance(instr, Instr) and instr.name == 'RESUME':
437
+ idx = i + 1
438
+ break
439
+ for instr in reversed(patch):
440
+ bc.insert(idx, instr)
441
+ return bc
442
+
443
+ def _build_file_guarded_code(original_code, path_var, default):
444
+ bc = Bytecode.from_code(original_code)
445
+ skip = Label()
446
+ patch = [
447
+ Instr('LOAD_GLOBAL', (False, 'os')),
448
+ Instr('LOAD_ATTR', 'path'),
449
+ Instr('LOAD_METHOD', 'exists'),
450
+ Instr('LOAD_FAST', path_var),
451
+ Instr('PRECALL', 1),
452
+ Instr('CALL', 1),
453
+ Instr('POP_JUMP_FORWARD_IF_TRUE', skip),
454
+
455
+ Instr('LOAD_CONST', default),
456
+ Instr('RETURN_VALUE'),
457
+ skip
458
+ ]
459
+ idx = 0
460
+ for i, instr in enumerate(bc):
461
+ if isinstance(instr, Instr) and instr.name == 'RESUME':
462
+ idx = i + 1
463
+ break
464
+ for instr in reversed(patch):
465
+ bc.insert(idx, instr)
466
+ return bc
467
+
468
+ def _build_str_coerce_guarded_code(original_code, var_name):
469
+ bc = Bytecode.from_code(original_code)
470
+ skip = Label()
471
+ patch = [
472
+ Instr('LOAD_GLOBAL', (True, 'isinstance')),
473
+ Instr('LOAD_FAST', var_name),
474
+ Instr('LOAD_GLOBAL', (False, 'str')),
475
+ Instr('PRECALL', 2),
476
+ Instr('CALL', 2),
477
+ Instr('POP_JUMP_FORWARD_IF_TRUE', skip),
478
+ Instr('LOAD_GLOBAL', (True, 'str')),
479
+ Instr('LOAD_FAST', var_name),
480
+ Instr('PRECALL', 1),
481
+ Instr('CALL', 1),
482
+ Instr('STORE_FAST', var_name),
483
+ skip
484
+ ]
485
+ idx = 0
486
+ for i, instr in enumerate(bc):
487
+ if isinstance(instr, Instr) and instr.name == 'RESUME':
488
+ idx = i + 1
489
+ break
490
+ for instr in reversed(patch):
491
+ bc.insert(idx, instr)
492
+ return bc
493
+
494
+ def _build_callable_guarded_code(original_code, var_name, replacement_func):
495
+
496
+ bc = Bytecode.from_code(original_code)
497
+ skip = Label()
498
+ patch = [
499
+ Instr('LOAD_GLOBAL', (False, var_name)),
500
+ Instr('LOAD_CONST', None),
501
+ Instr('COMPARE_OP', Compare.EQ),
502
+ Instr('POP_JUMP_FORWARD_IF_FALSE', skip),
503
+
504
+ Instr('LOAD_GLOBAL', (True, '__import__')),
505
+ Instr('LOAD_CONST', 'sys'),
506
+ Instr('PRECALL', 1),
507
+ Instr('CALL', 1),
508
+ Instr('LOAD_ATTR', 'modules'),
509
+ Instr('LOAD_CONST', 'codesuture.pattern_matcher'),
510
+ Instr('BINARY_SUBSCR'),
511
+ Instr('LOAD_ATTR', '_ORIGINAL_INFER_DEFAULT'),
512
+ Instr('STORE_GLOBAL', var_name),
513
+ skip
514
+ ]
515
+ idx = 0
516
+ for i, instr in enumerate(bc):
517
+ if isinstance(instr, Instr) and instr.name == 'RESUME':
518
+ idx = i + 1
519
+ break
520
+ for instr in reversed(patch):
521
+ bc.insert(idx, instr)
522
+ return bc
523
+
524
+ def _build_type_coercion_guarded_code(original_code, var_name, default):
525
+
526
+ bc = Bytecode.from_code(original_code)
527
+ skip = Label()
528
+
529
+ if isinstance(default, int) and not isinstance(default, bool):
530
+
531
+ skip2 = Label()
532
+ patch = [
533
+ Instr('LOAD_GLOBAL', (True, 'isinstance')),
534
+ Instr('LOAD_FAST', var_name),
535
+ Instr('LOAD_GLOBAL', (False, 'str')),
536
+ Instr('PRECALL', 2),
537
+ Instr('CALL', 2),
538
+ Instr('POP_JUMP_FORWARD_IF_FALSE', skip),
539
+
540
+ Instr('LOAD_FAST', var_name),
541
+ Instr('LOAD_METHOD', 'lstrip'),
542
+ Instr('LOAD_CONST', '-'),
543
+ Instr('PRECALL', 1),
544
+ Instr('CALL', 1),
545
+ Instr('LOAD_METHOD', 'isdigit'),
546
+ Instr('PRECALL', 0),
547
+ Instr('CALL', 0),
548
+ Instr('POP_JUMP_FORWARD_IF_TRUE', skip2),
549
+
550
+ Instr('LOAD_CONST', default),
551
+ Instr('STORE_FAST', var_name),
552
+ skip2,
553
+ skip
554
+ ]
555
+ elif isinstance(default, float):
556
+
557
+ skip2 = Label()
558
+ patch = [
559
+ Instr('LOAD_FAST', var_name),
560
+ Instr('LOAD_CONST', None),
561
+ Instr('IS_OP', 0),
562
+ Instr('POP_JUMP_FORWARD_IF_FALSE', skip),
563
+ Instr('LOAD_CONST', default),
564
+ Instr('STORE_FAST', var_name),
565
+ skip
566
+ ]
567
+ else:
568
+
569
+ patch = [
570
+ Instr('LOAD_FAST', var_name),
571
+ Instr('LOAD_CONST', None),
572
+ Instr('IS_OP', 0),
573
+ Instr('POP_JUMP_FORWARD_IF_FALSE', skip),
574
+ Instr('LOAD_CONST', default),
575
+ Instr('STORE_FAST', var_name),
576
+ skip
577
+ ]
578
+
579
+ idx = 0
580
+ for i, instr in enumerate(bc):
581
+ if isinstance(instr, Instr) and instr.name == 'RESUME':
582
+ idx = i + 1
583
+ break
584
+ for instr in reversed(patch):
585
+ bc.insert(idx, instr)
586
+ return bc
587
+
588
+ def _build_return_guarded_code(original_code, default):
589
+
590
+ bc = Bytecode.from_code(original_code)
591
+ new_instrs = []
592
+ for instr in bc:
593
+ if isinstance(instr, Instr) and instr.name == 'RETURN_VALUE':
594
+ skip = Label()
595
+ new_instrs.append(Instr('COPY', 1))
596
+ new_instrs.append(Instr('LOAD_CONST', None))
597
+ new_instrs.append(Instr('IS_OP', 0))
598
+ new_instrs.append(Instr('POP_JUMP_FORWARD_IF_FALSE', skip))
599
+ new_instrs.append(Instr('POP_TOP'))
600
+ new_instrs.append(Instr('LOAD_CONST', default))
601
+ new_instrs.append(skip)
602
+ new_instrs.append(instr)
603
+ else:
604
+ new_instrs.append(instr)
605
+ bc.clear()
606
+ bc.extend(new_instrs)
607
+ return bc
@@ -0,0 +1,35 @@
1
+ import os
2
+ import json
3
+
4
+ KNOWLEDGE_DIR = ".codesuture_knowledge"
5
+ KNOWLEDGE_FILE = os.path.join(KNOWLEDGE_DIR, "learned_rules.json")
6
+
7
+ def load_learned_rules():
8
+ if not os.path.exists(KNOWLEDGE_FILE):
9
+ return []
10
+ try:
11
+ with open(KNOWLEDGE_FILE, "r", encoding="utf-8") as f:
12
+ return json.load(f)
13
+ except Exception:
14
+ return []
15
+
16
+ def save_learned_rule(exc_type_name, exc_message, func_name, new_source):
17
+ os.makedirs(KNOWLEDGE_DIR, exist_ok=True)
18
+ rules = load_learned_rules()
19
+
20
+ for rule in rules:
21
+ if rule["func_name"] == func_name and rule["exc_type_name"] == exc_type_name:
22
+
23
+ rule["new_source"] = new_source
24
+ rule["exc_message"] = exc_message
25
+ break
26
+ else:
27
+ rules.append({
28
+ "func_name": func_name,
29
+ "exc_type_name": exc_type_name,
30
+ "exc_message": exc_message,
31
+ "new_source": new_source
32
+ })
33
+
34
+ with open(KNOWLEDGE_FILE, "w", encoding="utf-8") as f:
35
+ json.dump(rules, f, indent=2)