phantomrt 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atlas/__init__.py +3 -0
- atlas/agents/__init__.py +8 -0
- atlas/agents/command_space.py +227 -0
- atlas/analysis/__init__.py +3 -0
- atlas/analysis/binary_agent.py +488 -0
- atlas/analysis/binary_fuzz.py +389 -0
- atlas/analysis/frida_live.py +261 -0
- atlas/analysis/graph_annotator.py +147 -0
- atlas/analysis/spectrida_bridge.py +84 -0
- atlas/analysis/unicorn_harness.py +337 -0
- atlas/core/__init__.py +14 -0
- atlas/core/decoder.py +65 -0
- atlas/core/dynamics.py +217 -0
- atlas/core/encoder.py +120 -0
- atlas/core/surprise.py +145 -0
- atlas/core/world_model.py +334 -0
- atlas/environments/__init__.py +5 -0
- atlas/environments/base.py +51 -0
- atlas/environments/grid_world.py +219 -0
- atlas/environments/physics_2d.py +283 -0
- atlas/environments/vm_world.py +168 -0
- atlas/knowledge/__init__.py +3 -0
- atlas/knowledge/instruction_vocab.py +534 -0
- atlas/monitor/__init__.py +5 -0
- atlas/monitor/execution_monitor.py +518 -0
- atlas/optimization/__init__.py +6 -0
- atlas/optimization/speed.py +457 -0
- atlas/planning/__init__.py +4 -0
- atlas/planning/goal.py +100 -0
- atlas/planning/mcts.py +228 -0
- atlas/training/__init__.py +4 -0
- atlas/training/continual.py +392 -0
- atlas/training/growth.py +213 -0
- atlas/training/loop.py +306 -0
- atlas/training/losses.py +101 -0
- atlas/training/self_train.py +307 -0
- atlas/utils/__init__.py +4 -0
- atlas/utils/logging.py +33 -0
- atlas/utils/math_helpers.py +30 -0
- atlas/utils/viz.py +136 -0
- atlas/vm/__init__.py +4 -0
- atlas/vm/wsl_vm.py +249 -0
- phantomrt-0.1.0.dist-info/METADATA +75 -0
- phantomrt-0.1.0.dist-info/RECORD +48 -0
- phantomrt-0.1.0.dist-info/WHEEL +5 -0
- phantomrt-0.1.0.dist-info/entry_points.txt +3 -0
- phantomrt-0.1.0.dist-info/licenses/LICENSE +21 -0
- phantomrt-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,534 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Instruction Vocabulary — The Model's Dictionary
|
|
3
|
+
|
|
4
|
+
Gives the world model a BASE understanding of:
|
|
5
|
+
- What each CPU instruction does
|
|
6
|
+
- What registers are for
|
|
7
|
+
- What memory regions mean
|
|
8
|
+
- Common code patterns
|
|
9
|
+
- Vulnerability signatures
|
|
10
|
+
|
|
11
|
+
This is the "learning to read" foundation.
|
|
12
|
+
Without this, the model stares at raw bytes.
|
|
13
|
+
With this, it understands what it's seeing.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
import torch.nn as nn
|
|
18
|
+
import numpy as np
|
|
19
|
+
from typing import Optional
|
|
20
|
+
from dataclasses import dataclass
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# ═══════════════════════════════════════════════════════════════
|
|
24
|
+
# INSTRUCTION SEMANTICS — What each instruction DOES
|
|
25
|
+
# ═══════════════════════════════════════════════════════════════
|
|
26
|
+
|
|
27
|
+
INSTRUCTION_SEMANTICS = {
|
|
28
|
+
# Data Movement
|
|
29
|
+
"mov": {"category": "data_move", "risk": 0.1, "description": "Copy data from source to destination"},
|
|
30
|
+
"push": {"category": "data_move", "risk": 0.1, "description": "Push value onto stack"},
|
|
31
|
+
"pop": {"category": "data_move", "risk": 0.1, "description": "Pop value from stack"},
|
|
32
|
+
"lea": {"category": "data_move", "risk": 0.05, "description": "Load effective address (compute, don't access)"},
|
|
33
|
+
"xchg": {"category": "data_move", "risk": 0.05, "description": "Swap two values"},
|
|
34
|
+
"cmov": {"category": "data_move", "risk": 0.05, "description": "Conditional move"},
|
|
35
|
+
|
|
36
|
+
# Arithmetic
|
|
37
|
+
"add": {"category": "arithmetic", "risk": 0.15, "description": "Add source to destination"},
|
|
38
|
+
"sub": {"category": "arithmetic", "risk": 0.15, "description": "Subtract source from destination"},
|
|
39
|
+
"inc": {"category": "arithmetic", "risk": 0.05, "description": "Increment by 1"},
|
|
40
|
+
"dec": {"category": "arithmetic", "risk": 0.05, "description": "Decrement by 1"},
|
|
41
|
+
"imul": {"category": "arithmetic", "risk": 0.25, "description": "Signed multiply (can overflow)"},
|
|
42
|
+
"mul": {"category": "arithmetic", "risk": 0.25, "description": "Unsigned multiply (can overflow)"},
|
|
43
|
+
"idiv": {"category": "arithmetic", "risk": 0.3, "description": "Signed divide (can crash on zero)"},
|
|
44
|
+
"div": {"category": "arithmetic", "risk": 0.3, "description": "Unsigned divide (can crash on zero)"},
|
|
45
|
+
"neg": {"category": "arithmetic", "risk": 0.1, "description": "Negate value (two's complement)"},
|
|
46
|
+
|
|
47
|
+
# Logic
|
|
48
|
+
"and": {"category": "logic", "risk": 0.05, "description": "Bitwise AND"},
|
|
49
|
+
"or": {"category": "logic", "risk": 0.05, "description": "Bitwise OR"},
|
|
50
|
+
"xor": {"category": "logic", "risk": 0.05, "description": "Bitwise XOR"},
|
|
51
|
+
"not": {"category": "logic", "risk": 0.05, "description": "Bitwise NOT"},
|
|
52
|
+
"shl": {"category": "logic", "risk": 0.1, "description": "Shift left (multiply by 2^n)"},
|
|
53
|
+
"shr": {"category": "logic", "risk": 0.1, "description": "Shift right (divide by 2^n)"},
|
|
54
|
+
"sar": {"category": "logic", "risk": 0.1, "description": "Arithmetic shift right (preserves sign)"},
|
|
55
|
+
"rol": {"category": "logic", "risk": 0.05, "description": "Rotate left"},
|
|
56
|
+
"ror": {"category": "logic", "risk": 0.05, "description": "Rotate right"},
|
|
57
|
+
|
|
58
|
+
# Comparison
|
|
59
|
+
"cmp": {"category": "compare", "risk": 0.0, "description": "Compare two values (sets flags)"},
|
|
60
|
+
"test": {"category": "compare", "risk": 0.0, "description": "Test two values (AND, sets flags)"},
|
|
61
|
+
|
|
62
|
+
# Control Flow
|
|
63
|
+
"jmp": {"category": "control_flow", "risk": 0.05, "description": "Unconditional jump"},
|
|
64
|
+
"je": {"category": "control_flow", "risk": 0.05, "description": "Jump if equal"},
|
|
65
|
+
"jne": {"category": "control_flow", "risk": 0.05, "description": "Jump if not equal"},
|
|
66
|
+
"jg": {"category": "control_flow", "risk": 0.05, "description": "Jump if greater"},
|
|
67
|
+
"jge": {"category": "control_flow", "risk": 0.05, "description": "Jump if greater or equal"},
|
|
68
|
+
"jl": {"category": "control_flow", "risk": 0.05, "description": "Jump if less"},
|
|
69
|
+
"jle": {"category": "control_flow", "risk": 0.05, "description": "Jump if less or equal"},
|
|
70
|
+
"ja": {"category": "control_flow", "risk": 0.05, "description": "Jump if above (unsigned)"},
|
|
71
|
+
"jb": {"category": "control_flow", "risk": 0.05, "description": "Jump if below (unsigned)"},
|
|
72
|
+
"call": {"category": "control_flow", "risk": 0.1, "description": "Call function (pushes return address)"},
|
|
73
|
+
"ret": {"category": "control_flow", "risk": 0.15, "description": "Return from function (pops return address)"},
|
|
74
|
+
"loop": {"category": "control_flow", "risk": 0.05, "description": "Loop (dec ECX, jump if not zero)"},
|
|
75
|
+
|
|
76
|
+
# String Operations (HIGH RISK for buffer overflows)
|
|
77
|
+
"rep": {"category": "string", "risk": 0.2, "description": "Repeat next instruction ECX times"},
|
|
78
|
+
"movsb": {"category": "string", "risk": 0.2, "description": "Move byte string (DS:RSI -> ES:RDI)"},
|
|
79
|
+
"movsw": {"category": "string", "risk": 0.2, "description": "Move word string"},
|
|
80
|
+
"movsd": {"category": "string", "risk": 0.2, "description": "Move dword string"},
|
|
81
|
+
"movsq": {"category": "string", "risk": 0.2, "description": "Move qword string"},
|
|
82
|
+
"stosb": {"category": "string", "risk": 0.15, "description": "Store byte to string"},
|
|
83
|
+
"stosw": {"category": "string", "risk": 0.15, "description": "Store word to string"},
|
|
84
|
+
"stosd": {"category": "string", "risk": 0.15, "description": "Store dword to string"},
|
|
85
|
+
"stosq": {"category": "string", "risk": 0.15, "description": "Store qword to string"},
|
|
86
|
+
"cmpsb": {"category": "string", "risk": 0.05, "description": "Compare byte strings"},
|
|
87
|
+
"scasb": {"category": "string", "risk": 0.05, "description": "Scan byte string"},
|
|
88
|
+
|
|
89
|
+
# Stack Operations
|
|
90
|
+
"enter": {"category": "stack", "risk": 0.05, "description": "Create stack frame"},
|
|
91
|
+
"leave": {"category": "stack", "risk": 0.05, "description": "Destroy stack frame"},
|
|
92
|
+
|
|
93
|
+
# System
|
|
94
|
+
"syscall": {"category": "system", "risk": 0.3, "description": "System call (kernel transition)"},
|
|
95
|
+
"int": {"category": "system", "risk": 0.3, "description": "Software interrupt"},
|
|
96
|
+
"sysenter": {"category": "system", "risk": 0.3, "description": "Fast system call"},
|
|
97
|
+
"hlt": {"category": "system", "risk": 0.0, "description": "Halt processor"},
|
|
98
|
+
|
|
99
|
+
# SIMD (can be used for overflows)
|
|
100
|
+
"movdqu": {"category": "simd", "risk": 0.2, "description": "Move unaligned 128-bit data"},
|
|
101
|
+
"movdqa": {"category": "simd", "risk": 0.15, "description": "Move aligned 128-bit data"},
|
|
102
|
+
"paddd": {"category": "simd", "risk": 0.1, "description": "Add packed 32-bit integers"},
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
# ═══════════════════════════════════════════════════════════════
|
|
107
|
+
# REGISTER SEMANTICS — What each register is USED FOR
|
|
108
|
+
# ═══════════════════════════════════════════════════════════════
|
|
109
|
+
|
|
110
|
+
REGISTER_SEMANTICS = {
|
|
111
|
+
# General Purpose
|
|
112
|
+
"rax": {"role": "accumulator", "usage": "Return value, arithmetic", "volatility": "caller_saved"},
|
|
113
|
+
"rbx": {"role": "base", "usage": "General purpose, often preserved", "volatility": "callee_saved"},
|
|
114
|
+
"rcx": {"role": "counter", "usage": "Loop counter, 4th argument", "volatility": "caller_saved"},
|
|
115
|
+
"rdx": {"role": "data", "usage": "I/O operations, 3rd argument", "volatility": "caller_saved"},
|
|
116
|
+
"rsi": {"role": "source_index", "usage": "Source for string ops, 2nd argument", "volatility": "caller_saved"},
|
|
117
|
+
"rdi": {"role": "dest_index", "usage": "Destination for string ops, 1st argument", "volatility": "caller_saved"},
|
|
118
|
+
"rbp": {"role": "base_pointer", "usage": "Stack frame base", "volatility": "callee_saved"},
|
|
119
|
+
"rsp": {"role": "stack_pointer", "usage": "Top of stack", "volatility": "callee_saved"},
|
|
120
|
+
"r8": {"role": "general", "usage": "5th argument", "volatility": "caller_saved"},
|
|
121
|
+
"r9": {"role": "general", "usage": "6th argument", "volatility": "caller_saved"},
|
|
122
|
+
"r10": {"role": "general", "usage": "Temporary", "volatility": "caller_saved"},
|
|
123
|
+
"r11": {"role": "general", "usage": "Temporary, trashed by syscall", "volatility": "caller_saved"},
|
|
124
|
+
"r12": {"role": "general", "usage": "General purpose", "volatility": "callee_saved"},
|
|
125
|
+
"r13": {"role": "general", "usage": "General purpose", "volatility": "callee_saved"},
|
|
126
|
+
"r14": {"role": "general", "usage": "General purpose", "volatility": "callee_saved"},
|
|
127
|
+
"r15": {"role": "general", "usage": "General purpose", "volatility": "callee_saved"},
|
|
128
|
+
"rip": {"role": "instruction_pointer", "usage": "Next instruction to execute", "volatility": "special"},
|
|
129
|
+
"rflags": {"role": "flags", "usage": "Condition codes, control flags", "volatility": "special"},
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
# ═══════════════════════════════════════════════════════════════
|
|
134
|
+
# MEMORY REGIONS — What different addresses MEAN
|
|
135
|
+
# ═══════════════════════════════════════════════════════════════
|
|
136
|
+
|
|
137
|
+
MEMORY_REGIONS = {
|
|
138
|
+
"stack": {
|
|
139
|
+
"address_range": (0x7fff0000, 0x7fffffff),
|
|
140
|
+
"properties": ["grows_down", "local_variables", "function_returns"],
|
|
141
|
+
"risk": "buffer_overflow, stack_smash, return_oriented_programming",
|
|
142
|
+
},
|
|
143
|
+
"heap": {
|
|
144
|
+
"address_range": (0x60000000, 0x6fffffff),
|
|
145
|
+
"properties": ["grows_up", "dynamic_allocation", "free_list"],
|
|
146
|
+
"risk": "heap_overflow, use_after_free, double_free, heap_spray",
|
|
147
|
+
},
|
|
148
|
+
"code": {
|
|
149
|
+
"address_range": (0x00400000, 0x00600000),
|
|
150
|
+
"properties": ["read_only", "executable", "instructions"],
|
|
151
|
+
"risk": "code_injection, shellcode",
|
|
152
|
+
},
|
|
153
|
+
"data": {
|
|
154
|
+
"address_range": (0x00600000, 0x00700000),
|
|
155
|
+
"properties": ["read_write", "global_variables", "constants"],
|
|
156
|
+
"risk": "data_corruption",
|
|
157
|
+
},
|
|
158
|
+
"mmap": {
|
|
159
|
+
"address_range": (0x7f000000, 0x7fffffff),
|
|
160
|
+
"properties": ["dynamic", "libraries", "shared_memory"],
|
|
161
|
+
"risk": "mmap_exploitation",
|
|
162
|
+
},
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
# ═══════════════════════════════════════════════════════════════
|
|
167
|
+
# VULNERABILITY PATTERNS — Known dangerous patterns
|
|
168
|
+
# ═══════════════════════════════════════════════════════════════
|
|
169
|
+
|
|
170
|
+
VULNERABILITY_PATTERNS = {
|
|
171
|
+
"stack_buffer_overflow": {
|
|
172
|
+
"signatures": [
|
|
173
|
+
"rep movsb with large ECX and stack destination",
|
|
174
|
+
"mov with stack write past frame size",
|
|
175
|
+
"gets() call (no bounds checking)",
|
|
176
|
+
"strcpy() call (no bounds checking)",
|
|
177
|
+
"sprintf() call (no bounds checking)",
|
|
178
|
+
],
|
|
179
|
+
"indicators": [
|
|
180
|
+
"excessive stack growth",
|
|
181
|
+
"return address modification",
|
|
182
|
+
"saved RBP modification",
|
|
183
|
+
],
|
|
184
|
+
"severity": "critical",
|
|
185
|
+
},
|
|
186
|
+
"heap_buffer_overflow": {
|
|
187
|
+
"signatures": [
|
|
188
|
+
"write past allocated heap chunk",
|
|
189
|
+
"heap metadata corruption",
|
|
190
|
+
],
|
|
191
|
+
"indicators": [
|
|
192
|
+
"heap chunk header modification",
|
|
193
|
+
"adjacent chunk corruption",
|
|
194
|
+
],
|
|
195
|
+
"severity": "critical",
|
|
196
|
+
},
|
|
197
|
+
"use_after_free": {
|
|
198
|
+
"signatures": [
|
|
199
|
+
"accessing freed pointer",
|
|
200
|
+
"use after free() call",
|
|
201
|
+
],
|
|
202
|
+
"indicators": [
|
|
203
|
+
"pointer used after deallocation",
|
|
204
|
+
"double free detection",
|
|
205
|
+
],
|
|
206
|
+
"severity": "high",
|
|
207
|
+
},
|
|
208
|
+
"format_string": {
|
|
209
|
+
"signatures": [
|
|
210
|
+
"printf with user-controlled format",
|
|
211
|
+
"sprintf with user-controlled format",
|
|
212
|
+
"fprintf with user-controlled format",
|
|
213
|
+
],
|
|
214
|
+
"indicators": [
|
|
215
|
+
"format string without format specifier",
|
|
216
|
+
"user input directly in format position",
|
|
217
|
+
],
|
|
218
|
+
"severity": "high",
|
|
219
|
+
},
|
|
220
|
+
"integer_overflow": {
|
|
221
|
+
"signatures": [
|
|
222
|
+
"multiply without overflow check",
|
|
223
|
+
"add without bounds check before allocation",
|
|
224
|
+
"signed/unsigned confusion",
|
|
225
|
+
],
|
|
226
|
+
"indicators": [
|
|
227
|
+
"arithmetic before memory allocation",
|
|
228
|
+
"size calculation overflow",
|
|
229
|
+
],
|
|
230
|
+
"severity": "medium",
|
|
231
|
+
},
|
|
232
|
+
"race_condition": {
|
|
233
|
+
"signatures": [
|
|
234
|
+
"TOCTOU (time-of-check-to-time-of-use)",
|
|
235
|
+
"unsynchronized shared access",
|
|
236
|
+
],
|
|
237
|
+
"indicators": [
|
|
238
|
+
"check-then-act pattern without lock",
|
|
239
|
+
"shared memory access without synchronization",
|
|
240
|
+
],
|
|
241
|
+
"severity": "medium",
|
|
242
|
+
},
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
# ═══════════════════════════════════════════════════════════════
|
|
247
|
+
# INSTRUCTION ENCODER — Convert instructions to vectors
|
|
248
|
+
# ═══════════════════════════════════════════════════════════════
|
|
249
|
+
|
|
250
|
+
class InstructionEncoder(nn.Module):
|
|
251
|
+
"""
|
|
252
|
+
Encodes x86 instructions into dense vectors that the world model can process.
|
|
253
|
+
|
|
254
|
+
Converts:
|
|
255
|
+
"mov eax, dword [rbp-0x10]" → 128-dimensional vector
|
|
256
|
+
|
|
257
|
+
The vector captures:
|
|
258
|
+
- What the instruction does (mnemonic semantics)
|
|
259
|
+
- What it operates on (operand types)
|
|
260
|
+
- Risk level
|
|
261
|
+
- Memory access pattern
|
|
262
|
+
"""
|
|
263
|
+
|
|
264
|
+
VOCAB_SIZE = 200 # number of known mnemonics
|
|
265
|
+
EMBED_DIM = 32 # dimension per mnemonic
|
|
266
|
+
FEATURE_DIM = 16 # additional feature dimensions
|
|
267
|
+
|
|
268
|
+
def __init__(self, output_dim: int = 128):
|
|
269
|
+
super().__init__()
|
|
270
|
+
self.output_dim = output_dim
|
|
271
|
+
|
|
272
|
+
# Mnemonic embedding
|
|
273
|
+
self.mnemonic_embed = nn.Embedding(self.VOCAB_SIZE, self.EMBED_DIM)
|
|
274
|
+
|
|
275
|
+
# Category embedding
|
|
276
|
+
self.category_embed = nn.Embedding(10, 16) # ~10 categories
|
|
277
|
+
|
|
278
|
+
# Register embeddings
|
|
279
|
+
self.num_registers = 20 # common registers
|
|
280
|
+
self.register_embed = nn.Embedding(self.num_registers, 8)
|
|
281
|
+
|
|
282
|
+
# Feature projection
|
|
283
|
+
self.feature_proj = nn.Sequential(
|
|
284
|
+
nn.Linear(self.EMBED_DIM + 16 + 16 + self.FEATURE_DIM, 256),
|
|
285
|
+
nn.SiLU(),
|
|
286
|
+
nn.Linear(256, output_dim),
|
|
287
|
+
nn.LayerNorm(output_dim),
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
# Build vocabulary
|
|
291
|
+
self.mnemonic_to_idx = {}
|
|
292
|
+
self.register_to_idx = {}
|
|
293
|
+
self._build_vocab()
|
|
294
|
+
|
|
295
|
+
def _build_vocab(self):
|
|
296
|
+
"""Build lookup tables."""
|
|
297
|
+
for i, mnemonic in enumerate(INSTRUCTION_SEMANTICS.keys()):
|
|
298
|
+
if i < self.VOCAB_SIZE:
|
|
299
|
+
self.mnemonic_to_idx[mnemonic] = i
|
|
300
|
+
|
|
301
|
+
for i, reg in enumerate(REGISTER_SEMANTICS.keys()):
|
|
302
|
+
if i < self.num_registers:
|
|
303
|
+
self.register_to_idx[reg] = i
|
|
304
|
+
|
|
305
|
+
def encode_instruction(self, mnemonic: str, operands: str) -> torch.Tensor:
|
|
306
|
+
"""Encode a single instruction."""
|
|
307
|
+
# Mnemonic index
|
|
308
|
+
m_idx = self.mnemonic_to_idx.get(mnemonic, 0)
|
|
309
|
+
m_emb = self.mnemonic_embed(torch.tensor(m_idx))
|
|
310
|
+
|
|
311
|
+
# Category
|
|
312
|
+
sem = INSTRUCTION_SEMANTICS.get(mnemonic, {"category": "unknown", "risk": 0.0})
|
|
313
|
+
categories = list(set(s["category"] for s in INSTRUCTION_SEMANTICS.values()))
|
|
314
|
+
cat_idx = categories.index(sem["category"]) if sem["category"] in categories else 0
|
|
315
|
+
c_emb = self.category_embed(torch.tensor(cat_idx))
|
|
316
|
+
|
|
317
|
+
# Features
|
|
318
|
+
features = torch.zeros(self.FEATURE_DIM)
|
|
319
|
+
features[0] = sem.get("risk", 0.0) # risk level
|
|
320
|
+
features[1] = 1.0 if "rsp" in operands or "rbp" in operands else 0.0 # stack related
|
|
321
|
+
features[2] = 1.0 if "[" in operands else 0.0 # memory access
|
|
322
|
+
features[3] = 1.0 if "call" in mnemonic else 0.0 # function call
|
|
323
|
+
features[4] = 1.0 if "ret" in mnemonic else 0.0 # function return
|
|
324
|
+
features[5] = 1.0 if mnemonic in ("jmp", "je", "jne", "jg", "jl") else 0.0 # branch
|
|
325
|
+
features[6] = 1.0 if mnemonic in ("syscall", "int") else 0.0 # syscall
|
|
326
|
+
features[7] = 1.0 if mnemonic in ("rep", "movsb", "stosb") else 0.0 # string op
|
|
327
|
+
features[8] = 1.0 if "dword" in operands else 0.0 # 32-bit access
|
|
328
|
+
features[9] = 1.0 if "qword" in operands else 0.0 # 64-bit access
|
|
329
|
+
|
|
330
|
+
# Concatenate and project
|
|
331
|
+
combined = torch.cat([m_emb, c_emb, torch.zeros(16), features]) # register emb placeholder
|
|
332
|
+
return self.feature_proj(combined)
|
|
333
|
+
|
|
334
|
+
def encode_trace(self, trace: list) -> torch.Tensor:
|
|
335
|
+
"""Encode an execution trace into a sequence of vectors."""
|
|
336
|
+
encoded = []
|
|
337
|
+
for inst in trace:
|
|
338
|
+
vec = self.encode_instruction(inst.mnemonic, inst.operands)
|
|
339
|
+
encoded.append(vec)
|
|
340
|
+
|
|
341
|
+
if encoded:
|
|
342
|
+
return torch.stack(encoded) # [seq_len, output_dim]
|
|
343
|
+
return torch.zeros(1, self.output_dim)
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
# ═══════════════════════════════════════════════════════════════
|
|
347
|
+
# KNOWLEDGE BASE — Pre-trained understanding
|
|
348
|
+
# ═══════════════════════════════════════════════════════════════
|
|
349
|
+
|
|
350
|
+
class BinaryKnowledgeBase:
|
|
351
|
+
"""
|
|
352
|
+
Pre-loaded knowledge about binary analysis.
|
|
353
|
+
|
|
354
|
+
This is the "education" the model receives before
|
|
355
|
+
it starts analyzing real binaries.
|
|
356
|
+
|
|
357
|
+
Like teaching a child ABCs before they read books.
|
|
358
|
+
"""
|
|
359
|
+
|
|
360
|
+
def __init__(self):
|
|
361
|
+
self.instruction_encoder = InstructionEncoder()
|
|
362
|
+
self.vulnerability_patterns = VULNERABILITY_PATTERNS
|
|
363
|
+
self.memory_regions = MEMORY_REGIONS
|
|
364
|
+
self.register_info = REGISTER_SEMANTICS
|
|
365
|
+
|
|
366
|
+
def get_instruction_info(self, mnemonic: str) -> dict:
|
|
367
|
+
"""Get full information about an instruction."""
|
|
368
|
+
return INSTRUCTION_SEMANTICS.get(mnemonic, {
|
|
369
|
+
"category": "unknown",
|
|
370
|
+
"risk": 0.5,
|
|
371
|
+
"description": f"Unknown instruction: {mnemonic}"
|
|
372
|
+
})
|
|
373
|
+
|
|
374
|
+
def assess_risk(self, mnemonic: str, operands: str) -> float:
|
|
375
|
+
"""Assess the risk level of an instruction in context."""
|
|
376
|
+
base_risk = INSTRUCTION_SEMANTICS.get(mnemonic, {}).get("risk", 0.5)
|
|
377
|
+
|
|
378
|
+
# Increase risk for memory operations near stack
|
|
379
|
+
if "rsp" in operands or "rbp" in operands:
|
|
380
|
+
base_risk *= 1.5
|
|
381
|
+
|
|
382
|
+
# Increase risk for string operations
|
|
383
|
+
if mnemonic in ("rep", "movsb", "stosb", "movsd"):
|
|
384
|
+
base_risk *= 2.0
|
|
385
|
+
|
|
386
|
+
# Decrease risk for comparison/branch
|
|
387
|
+
if mnemonic in ("cmp", "test", "je", "jne"):
|
|
388
|
+
base_risk *= 0.3
|
|
389
|
+
|
|
390
|
+
return min(base_risk, 1.0)
|
|
391
|
+
|
|
392
|
+
def get_pattern_matches(self, trace: list) -> list:
|
|
393
|
+
"""Check execution trace against known vulnerability patterns."""
|
|
394
|
+
matches = []
|
|
395
|
+
|
|
396
|
+
for pattern_name, pattern_info in self.vulnerability_patterns.items():
|
|
397
|
+
for signature in pattern_info["signatures"]:
|
|
398
|
+
if self._check_signature(trace, signature):
|
|
399
|
+
matches.append({
|
|
400
|
+
"pattern": pattern_name,
|
|
401
|
+
"signature": signature,
|
|
402
|
+
"severity": pattern_info["severity"],
|
|
403
|
+
})
|
|
404
|
+
|
|
405
|
+
return matches
|
|
406
|
+
|
|
407
|
+
def _check_signature(self, trace: list, signature: str) -> bool:
|
|
408
|
+
"""Check if a trace matches a vulnerability signature."""
|
|
409
|
+
# Simplified pattern matching
|
|
410
|
+
sig_lower = signature.lower()
|
|
411
|
+
|
|
412
|
+
for inst in trace:
|
|
413
|
+
if "rep movsb" in sig_lower and inst.mnemonic in ("rep", "movsb"):
|
|
414
|
+
if "rsp" in inst.operands or "rbp" in inst.operands:
|
|
415
|
+
return True
|
|
416
|
+
|
|
417
|
+
if "gets()" in sig_lower and inst.mnemonic == "call" and "gets" in inst.operands:
|
|
418
|
+
return True
|
|
419
|
+
|
|
420
|
+
if "strcpy()" in sig_lower and inst.mnemonic == "call" and "strcpy" in inst.operands:
|
|
421
|
+
return True
|
|
422
|
+
|
|
423
|
+
if "free()" in sig_lower and inst.mnemonic == "call" and "free" in inst.operands:
|
|
424
|
+
return True
|
|
425
|
+
|
|
426
|
+
return False
|
|
427
|
+
|
|
428
|
+
def generate_training_examples(self, num_examples: int = 1000) -> list:
|
|
429
|
+
"""
|
|
430
|
+
Generate synthetic training examples for pre-training.
|
|
431
|
+
|
|
432
|
+
Creates labeled examples of:
|
|
433
|
+
- Normal execution patterns
|
|
434
|
+
- Vulnerable execution patterns
|
|
435
|
+
|
|
436
|
+
This gives the model a head start before real analysis.
|
|
437
|
+
"""
|
|
438
|
+
import random
|
|
439
|
+
|
|
440
|
+
examples = []
|
|
441
|
+
|
|
442
|
+
for _ in range(num_examples):
|
|
443
|
+
# Decide if this example is vulnerable
|
|
444
|
+
is_vulnerable = random.random() < 0.3 # 30% vulnerable
|
|
445
|
+
|
|
446
|
+
if is_vulnerable:
|
|
447
|
+
trace = self._generate_vulnerable_trace()
|
|
448
|
+
label = 1
|
|
449
|
+
else:
|
|
450
|
+
trace = self._generate_normal_trace()
|
|
451
|
+
label = 0
|
|
452
|
+
|
|
453
|
+
examples.append({
|
|
454
|
+
"trace": trace,
|
|
455
|
+
"label": label,
|
|
456
|
+
"vulnerability_type": trace.get("vuln_type", "none"),
|
|
457
|
+
})
|
|
458
|
+
|
|
459
|
+
return examples
|
|
460
|
+
|
|
461
|
+
def _generate_normal_trace(self) -> dict:
|
|
462
|
+
"""Generate a normal (non-vulnerable) execution trace."""
|
|
463
|
+
return {
|
|
464
|
+
"instructions": [
|
|
465
|
+
{"mnemonic": "push", "operands": "rbp", "risk": 0.1},
|
|
466
|
+
{"mnemonic": "mov", "operands": "rbp, rsp", "risk": 0.1},
|
|
467
|
+
{"mnemonic": "sub", "operands": "rsp, 0x20", "risk": 0.1},
|
|
468
|
+
{"mnemonic": "mov", "operands": "dword [rbp-0x14], edi", "risk": 0.15},
|
|
469
|
+
{"mnemonic": "mov", "operands": "dword [rbp-0x8], 0", "risk": 0.1},
|
|
470
|
+
{"mnemonic": "cmp", "operands": "dword [rbp-0x8], 10", "risk": 0.0},
|
|
471
|
+
{"mnemonic": "jge", "operands": ".end", "risk": 0.05},
|
|
472
|
+
{"mnemonic": "add", "operands": "dword [rbp-0x8], 1", "risk": 0.1},
|
|
473
|
+
{"mnemonic": "jmp", "operands": ".loop", "risk": 0.05},
|
|
474
|
+
{"mnemonic": ".end:", "operands": "", "risk": 0.0},
|
|
475
|
+
{"mnemonic": "mov", "operands": "eax, dword [rbp-0x8]", "risk": 0.1},
|
|
476
|
+
{"mnemonic": "add", "operands": "rsp, 0x20", "risk": 0.1},
|
|
477
|
+
{"mnemonic": "pop", "operands": "rbp", "risk": 0.1},
|
|
478
|
+
{"mnemonic": "ret", "operands": "", "risk": 0.15},
|
|
479
|
+
],
|
|
480
|
+
"vuln_type": "none",
|
|
481
|
+
}
|
|
482
|
+
|
|
483
|
+
def _generate_vulnerable_trace(self) -> dict:
|
|
484
|
+
"""Generate a vulnerable execution trace."""
|
|
485
|
+
vulns = ["stack_overflow", "format_string", "integer_overflow", "use_after_free"]
|
|
486
|
+
vuln_type = random.choice(vulns)
|
|
487
|
+
|
|
488
|
+
if vuln_type == "stack_overflow":
|
|
489
|
+
return {
|
|
490
|
+
"instructions": [
|
|
491
|
+
{"mnemonic": "push", "operands": "rbp", "risk": 0.1},
|
|
492
|
+
{"mnemonic": "mov", "operands": "rbp, rsp", "risk": 0.1},
|
|
493
|
+
{"mnemonic": "sub", "operands": "rsp, 0x40", "risk": 0.1},
|
|
494
|
+
{"mnemonic": "mov", "operands": "edi, dword [rbp-0x34]", "risk": 0.15},
|
|
495
|
+
{"mnemonic": "lea", "operands": "rax, [rbp-0x30]", "risk": 0.05},
|
|
496
|
+
{"mnemonic": "mov", "operands": "esi, eax", "risk": 0.1},
|
|
497
|
+
{"mnemonic": "call", "operands": "gets", "risk": 0.9}, # DANGEROUS
|
|
498
|
+
{"mnemonic": "nop", "operands": "", "risk": 0.0},
|
|
499
|
+
{"mnemonic": "leave", "operands": "", "risk": 0.05},
|
|
500
|
+
{"mnemonic": "ret", "operands": "", "risk": 0.15},
|
|
501
|
+
],
|
|
502
|
+
"vuln_type": "stack_overflow",
|
|
503
|
+
}
|
|
504
|
+
elif vuln_type == "format_string":
|
|
505
|
+
return {
|
|
506
|
+
"instructions": [
|
|
507
|
+
{"mnemonic": "push", "operands": "rbp", "risk": 0.1},
|
|
508
|
+
{"mnemonic": "mov", "operands": "rbp, rsp", "risk": 0.1},
|
|
509
|
+
{"mnemonic": "sub", "operands": "rsp, 0x10", "risk": 0.1},
|
|
510
|
+
{"mnemonic": "mov", "operands": "dword [rbp-0x8], edi", "risk": 0.15},
|
|
511
|
+
{"mnemonic": "mov", "operands": "eax, dword [rbp-0x8]", "risk": 0.1},
|
|
512
|
+
{"mnemonic": "mov", "operands": "esi, eax", "risk": 0.1},
|
|
513
|
+
{"mnemonic": "lea", "operands": "rdi, [rbp-0x4]", "risk": 0.05},
|
|
514
|
+
{"mnemonic": "call", "operands": "printf", "risk": 0.8}, # DANGEROUS
|
|
515
|
+
{"mnemonic": "nop", "operands": "", "risk": 0.0},
|
|
516
|
+
{"mnemonic": "leave", "operands": "", "risk": 0.05},
|
|
517
|
+
{"mnemonic": "ret", "operands": "", "risk": 0.15},
|
|
518
|
+
],
|
|
519
|
+
"vuln_type": "format_string",
|
|
520
|
+
}
|
|
521
|
+
else:
|
|
522
|
+
return {
|
|
523
|
+
"instructions": [
|
|
524
|
+
{"mnemonic": "push", "operands": "rbp", "risk": 0.1},
|
|
525
|
+
{"mnemonic": "mov", "operands": "rbp, rsp", "risk": 0.1},
|
|
526
|
+
{"mnemonic": "imul", "operands": "edi, esi", "risk": 0.4}, # DANGEROUS
|
|
527
|
+
{"mnemonic": "cdqe", "operands": "", "risk": 0.15},
|
|
528
|
+
{"mnemonic": "mov", "operands": "edi, eax", "risk": 0.1},
|
|
529
|
+
{"mnemonic": "call", "operands": "malloc", "risk": 0.15},
|
|
530
|
+
{"mnemonic": "leave", "operands": "", "risk": 0.05},
|
|
531
|
+
{"mnemonic": "ret", "operands": "", "risk": 0.15},
|
|
532
|
+
],
|
|
533
|
+
"vuln_type": "integer_overflow",
|
|
534
|
+
}
|