genlayer-test 0.10.1__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gltest/direct/vm.py ADDED
@@ -0,0 +1,432 @@
1
+ """
2
+ VMContext - Foundry-style test VM for GenLayer contracts.
3
+
4
+ Provides cheatcodes for:
5
+ - Setting sender/value (vm.sender, vm.value)
6
+ - Snapshots and reverts (vm.snapshot(), vm.revert())
7
+ - Mocking nondet operations (vm.mock_web(), vm.mock_llm())
8
+ - Expecting reverts (vm.expect_revert())
9
+ - Pranking (vm.prank(), vm.startPrank(), vm.stopPrank())
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import re
15
+ import sys
16
+ import hashlib
17
+ from contextlib import contextmanager, ExitStack
18
+ from dataclasses import dataclass, field
19
+ from typing import Any, Optional, Pattern, List, Tuple, Dict
20
+ from unittest.mock import patch
21
+
22
+ from ..types import MockedWebResponseData
23
+
24
+
25
+ @dataclass
26
+ class Snapshot:
27
+ """Storage snapshot for revert functionality."""
28
+ id: int
29
+ storage_data: Dict[bytes, bytes]
30
+ balances: Dict[bytes, int]
31
+
32
+
33
+ class InmemManager:
34
+ """
35
+ In-memory storage manager compatible with genlayer.py.storage.
36
+ """
37
+
38
+ def __init__(self):
39
+ self._parts: Dict[bytes, Tuple["Slot", bytearray]] = {}
40
+
41
+ def get_store_slot(self, slot_id: bytes) -> "Slot":
42
+ res = self._parts.get(slot_id)
43
+ if res is None:
44
+ slot = Slot(slot_id, self)
45
+ self._parts[slot_id] = (slot, bytearray())
46
+ return slot
47
+ return res[0]
48
+
49
+ def do_read(self, slot_id: bytes, off: int, length: int) -> bytes:
50
+ res = self._parts.get(slot_id)
51
+ if res is None:
52
+ slot = Slot(slot_id, self)
53
+ mem = bytearray()
54
+ self._parts[slot_id] = (slot, mem)
55
+ else:
56
+ _, mem = res
57
+
58
+ needed = off + length
59
+ if len(mem) < needed:
60
+ mem.extend(b'\x00' * (needed - len(mem)))
61
+
62
+ return bytes(memoryview(mem)[off:off + length])
63
+
64
+ def do_write(self, slot_id: bytes, off: int, what: bytes) -> None:
65
+ res = self._parts.get(slot_id)
66
+ if res is None:
67
+ slot = Slot(slot_id, self)
68
+ mem = bytearray()
69
+ self._parts[slot_id] = (slot, mem)
70
+ else:
71
+ _, mem = res
72
+
73
+ what_view = memoryview(what)
74
+ length = len(what_view)
75
+
76
+ needed = off + length
77
+ if len(mem) < needed:
78
+ mem.extend(b'\x00' * (needed - len(mem)))
79
+
80
+ memoryview(mem)[off:off + length] = what_view
81
+
82
+ def snapshot(self) -> Dict[bytes, bytes]:
83
+ return {
84
+ slot_id: bytes(mem)
85
+ for slot_id, (_, mem) in self._parts.items()
86
+ }
87
+
88
+ def restore(self, data: Dict[bytes, bytes]) -> None:
89
+ self._parts.clear()
90
+ for slot_id, mem_data in data.items():
91
+ slot = Slot(slot_id, self)
92
+ self._parts[slot_id] = (slot, bytearray(mem_data))
93
+
94
+
95
+ class Slot:
96
+ """Storage slot compatible with genlayer.py.storage."""
97
+
98
+ __slots__ = ('id', 'manager', '_indir_cache')
99
+
100
+ def __init__(self, slot_id: bytes, manager: InmemManager):
101
+ self.id = slot_id
102
+ self.manager = manager
103
+ self._indir_cache = hashlib.sha3_256(slot_id)
104
+
105
+ def read(self, off: int, length: int) -> bytes:
106
+ return self.manager.do_read(self.id, off, length)
107
+
108
+ def write(self, off: int, what: bytes) -> None:
109
+ self.manager.do_write(self.id, off, what)
110
+
111
+ def indirect(self, off: int) -> "Slot":
112
+ hasher = self._indir_cache.copy()
113
+ hasher.update(off.to_bytes(4, 'little'))
114
+ return self.manager.get_store_slot(hasher.digest())
115
+
116
+
117
+ ROOT_SLOT_ID = b'\x00' * 32
118
+
119
+
120
+ @dataclass
121
+ class VMContext:
122
+ """
123
+ Test VM context providing Foundry-style cheatcodes.
124
+
125
+ Usage:
126
+ vm = VMContext()
127
+ vm.sender = Address("0x" + "a" * 40)
128
+ vm.mock_web("api.example.com", {"status": 200, "body": "{}"})
129
+
130
+ with vm.activate():
131
+ contract = deploy_contract("Token.py", vm, owner)
132
+ contract.transfer(bob, 100)
133
+ """
134
+
135
+ # Message context
136
+ _sender: Optional[Any] = None
137
+ _origin: Optional[Any] = None
138
+ _contract_address: Optional[Any] = None
139
+ _value: int = 0
140
+ _chain_id: int = 1
141
+ _datetime: str = "2024-01-01T00:00:00Z"
142
+
143
+ # Storage
144
+ _storage: InmemManager = field(default_factory=InmemManager)
145
+ _balances: Dict[bytes, int] = field(default_factory=dict)
146
+
147
+ # Snapshots
148
+ _snapshots: Dict[int, Snapshot] = field(default_factory=dict)
149
+ _snapshot_counter: int = 0
150
+
151
+ # Mocks
152
+ _web_mocks: List[Tuple[Pattern, MockedWebResponseData]] = field(default_factory=list)
153
+ _llm_mocks: List[Tuple[Pattern, str]] = field(default_factory=list)
154
+
155
+ # Expect revert
156
+ _expect_revert: Optional[str] = None
157
+ _expect_revert_any: bool = False
158
+
159
+ # Prank stack
160
+ _prank_stack: List[Any] = field(default_factory=list)
161
+
162
+ # Return value capture
163
+ _return_value: Any = None
164
+ _returned: bool = False
165
+
166
+ # Debug tracing
167
+ _traces: List[str] = field(default_factory=list)
168
+ _trace_enabled: bool = True
169
+
170
+ @property
171
+ def sender(self) -> Any:
172
+ if self._prank_stack:
173
+ return self._prank_stack[-1]
174
+ return self._sender
175
+
176
+ @sender.setter
177
+ def sender(self, addr: Any) -> None:
178
+ self._sender = addr
179
+ self._refresh_gl_message()
180
+
181
+ @property
182
+ def value(self) -> int:
183
+ return self._value
184
+
185
+ @value.setter
186
+ def value(self, val: int) -> None:
187
+ self._value = val
188
+
189
+ @property
190
+ def origin(self) -> Any:
191
+ return self._origin or self._sender
192
+
193
+ @origin.setter
194
+ def origin(self, addr: Any) -> None:
195
+ self._origin = addr
196
+
197
+ def warp(self, timestamp: str) -> None:
198
+ """Set block timestamp (ISO format)."""
199
+ self._datetime = timestamp
200
+
201
+ def deal(self, address: Any, amount: int) -> None:
202
+ """Set balance for an address."""
203
+ addr_bytes = self._to_bytes(address)
204
+ self._balances[addr_bytes] = amount
205
+
206
+ def snapshot(self) -> int:
207
+ """Take a snapshot of current state. Returns snapshot ID."""
208
+ snap_id = self._snapshot_counter
209
+ self._snapshot_counter += 1
210
+
211
+ self._snapshots[snap_id] = Snapshot(
212
+ id=snap_id,
213
+ storage_data=self._storage.snapshot(),
214
+ balances=dict(self._balances),
215
+ )
216
+
217
+ return snap_id
218
+
219
+ def revert(self, snapshot_id: int) -> None:
220
+ """Revert to a previous snapshot."""
221
+ if snapshot_id not in self._snapshots:
222
+ raise ValueError(f"Snapshot {snapshot_id} not found")
223
+
224
+ snap = self._snapshots[snapshot_id]
225
+ self._storage.restore(snap.storage_data)
226
+ self._balances = dict(snap.balances)
227
+
228
+ self._snapshots = {
229
+ k: v for k, v in self._snapshots.items()
230
+ if k <= snapshot_id
231
+ }
232
+
233
+ def mock_web(
234
+ self,
235
+ url_pattern: str,
236
+ response: MockedWebResponseData,
237
+ ) -> None:
238
+ """Mock web requests matching URL pattern."""
239
+ pattern = re.compile(url_pattern)
240
+ self._web_mocks.append((pattern, response))
241
+
242
+ def mock_llm(self, prompt_pattern: str, response: str) -> None:
243
+ """Mock LLM prompts matching pattern."""
244
+ pattern = re.compile(prompt_pattern)
245
+ self._llm_mocks.append((pattern, response))
246
+
247
+ def clear_mocks(self) -> None:
248
+ """Clear all registered mocks."""
249
+ self._web_mocks.clear()
250
+ self._llm_mocks.clear()
251
+
252
+ @contextmanager
253
+ def expect_revert(self, message: Optional[str] = None):
254
+ """Context manager expecting the next call to revert."""
255
+ self._expect_revert = message
256
+ self._expect_revert_any = message is None
257
+
258
+ try:
259
+ yield
260
+ raise AssertionError(
261
+ f"Expected revert{f' with message: {message}' if message else ''}, but call succeeded"
262
+ )
263
+ except Exception as e:
264
+ from .wasi_mock import ContractRollback
265
+
266
+ if isinstance(e, ContractRollback):
267
+ if message is not None and message not in e.message:
268
+ raise AssertionError(
269
+ f"Expected revert with message '{message}', got '{e.message}'"
270
+ )
271
+ elif isinstance(e, AssertionError):
272
+ raise
273
+ else:
274
+ if message is not None and message not in str(e):
275
+ raise
276
+ finally:
277
+ self._expect_revert = None
278
+ self._expect_revert_any = False
279
+
280
+ @contextmanager
281
+ def prank(self, address: Any):
282
+ """Context manager to temporarily change sender."""
283
+ self._prank_stack.append(address)
284
+ self._refresh_gl_message()
285
+ try:
286
+ yield
287
+ finally:
288
+ self._prank_stack.pop()
289
+ self._refresh_gl_message()
290
+
291
+ def startPrank(self, address: Any) -> None:
292
+ """Start pranking as address (persists until stopPrank)."""
293
+ self._prank_stack.append(address)
294
+ self._refresh_gl_message()
295
+
296
+ def stopPrank(self) -> None:
297
+ """Stop the current prank."""
298
+ if self._prank_stack:
299
+ self._prank_stack.pop()
300
+ self._refresh_gl_message()
301
+ else:
302
+ raise RuntimeError("No active prank to stop")
303
+
304
+ @contextmanager
305
+ def activate(self):
306
+ """
307
+ Activate this VM context for contract execution.
308
+ Uses proper cleanup via ExitStack for resource management.
309
+ """
310
+ from . import wasi_mock
311
+
312
+ with ExitStack() as stack:
313
+ wasi_mock.set_vm(self)
314
+ sys.modules['_genlayer_wasi'] = wasi_mock
315
+
316
+ stack.enter_context(
317
+ patch('os.fdopen', wasi_mock.patched_fdopen)
318
+ )
319
+ stack.callback(self._cleanup_after_deactivate)
320
+
321
+ try:
322
+ yield self
323
+ finally:
324
+ if '_genlayer_wasi' in sys.modules:
325
+ del sys.modules['_genlayer_wasi']
326
+ wasi_mock.clear_vm()
327
+
328
+ def _cleanup_after_deactivate(self) -> None:
329
+ """Clean up resources after VM deactivation."""
330
+ modules_to_remove = [
331
+ key for key in sys.modules.keys()
332
+ if key.startswith('genlayer') or key.startswith('_contract_')
333
+ ]
334
+ for mod in modules_to_remove:
335
+ del sys.modules[mod]
336
+
337
+ def _match_web_mock(self, url: str, method: str = "GET") -> Optional[MockedWebResponseData]:
338
+ for pattern, response in self._web_mocks:
339
+ if pattern.search(url):
340
+ if response.get("method", "GET") == method:
341
+ return response
342
+ return None
343
+
344
+ def _match_llm_mock(self, prompt: str) -> Optional[str]:
345
+ for pattern, response in self._llm_mocks:
346
+ if pattern.search(prompt):
347
+ return response
348
+ return None
349
+
350
+ def _trace(self, message: str) -> None:
351
+ if self._trace_enabled:
352
+ self._traces.append(message)
353
+
354
+ def _to_bytes(self, addr: Any) -> bytes:
355
+ if isinstance(addr, bytes):
356
+ return addr
357
+ if hasattr(addr, 'as_bytes'):
358
+ return addr.as_bytes
359
+ if hasattr(addr, '__bytes__'):
360
+ return bytes(addr)
361
+ if isinstance(addr, str):
362
+ if addr.startswith("0x"):
363
+ return bytes.fromhex(addr[2:])
364
+ return bytes.fromhex(addr)
365
+ raise ValueError(f"Cannot convert {type(addr)} to bytes")
366
+
367
+ def _refresh_gl_message(self) -> None:
368
+ """
369
+ Refresh gl.message and gl.message_raw to reflect current sender.
370
+
371
+ GenLayer SDK caches gl.message at import time. This method updates
372
+ the cached values so contracts see the current vm.sender.
373
+
374
+ Only updates if genlayer.gl is already imported - we must not trigger
375
+ a fresh import as that would read from stdin before message is injected.
376
+ """
377
+ # Only proceed if genlayer.gl is already loaded
378
+ if 'genlayer.gl' not in sys.modules:
379
+ return
380
+
381
+ try:
382
+ gl = sys.modules['genlayer.gl']
383
+ from genlayer.py.types import Address, u256
384
+
385
+ # Convert sender to Address if needed
386
+ sender = self.sender
387
+ if sender is not None and not isinstance(sender, Address):
388
+ if isinstance(sender, bytes):
389
+ sender = Address(sender)
390
+ elif hasattr(sender, 'as_bytes'):
391
+ sender = Address(sender.as_bytes)
392
+
393
+ origin = self.origin
394
+ if origin is not None and not isinstance(origin, Address):
395
+ if isinstance(origin, bytes):
396
+ origin = Address(origin)
397
+ elif hasattr(origin, 'as_bytes'):
398
+ origin = Address(origin.as_bytes)
399
+
400
+ # Update message_raw dict (mutable)
401
+ if hasattr(gl, 'message_raw') and gl.message_raw is not None:
402
+ gl.message_raw['sender_address'] = sender
403
+ gl.message_raw['origin_address'] = origin
404
+
405
+ # Replace gl.message with new NamedTuple (immutable, must recreate)
406
+ if hasattr(gl, 'message') and gl.message is not None:
407
+ gl.message = gl.MessageType(
408
+ contract_address=gl.message.contract_address,
409
+ sender_address=sender,
410
+ origin_address=origin,
411
+ value=u256(self._value),
412
+ chain_id=u256(self._chain_id),
413
+ )
414
+ except ImportError:
415
+ # genlayer not loaded yet, nothing to update
416
+ pass
417
+
418
+ def get_message_raw(self) -> Dict[str, Any]:
419
+ """Get MessageRawType dict for stdin injection."""
420
+ return {
421
+ "contract_address": self._contract_address,
422
+ "sender_address": self.sender,
423
+ "origin_address": self.origin,
424
+ "stack": [],
425
+ "value": self._value,
426
+ "datetime": self._datetime,
427
+ "is_init": False,
428
+ "chain_id": self._chain_id,
429
+ "entry_kind": 0,
430
+ "entry_data": b"",
431
+ "entry_stage_data": None,
432
+ }
@@ -0,0 +1,219 @@
1
+ """
2
+ Mock implementation of _genlayer_wasi module.
3
+
4
+ Provides drop-in replacements for WASI functions that contracts use:
5
+ - storage_read / storage_write
6
+ - get_balance / get_self_balance
7
+ - gl_call
8
+
9
+ This module is injected into sys.modules before importing contracts.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import io
15
+ import os
16
+ import threading
17
+ import warnings
18
+ from typing import TYPE_CHECKING, Any
19
+
20
+ if TYPE_CHECKING:
21
+ from .vm import VMContext
22
+
23
+ # Thread-local VM context for parallel test safety
24
+ _local = threading.local()
25
+
26
+ # Original os.fdopen reference (saved once at module load)
27
+ _original_fdopen = os.fdopen
28
+
29
+
30
+ def set_vm(vm: "VMContext") -> None:
31
+ """Set the active VM context for WASI operations."""
32
+ _local.vm = vm
33
+ _local.fd_counter = 100
34
+ _local.fd_buffers = {}
35
+
36
+
37
+ def get_vm() -> "VMContext":
38
+ """Get the active VM context."""
39
+ vm = getattr(_local, 'vm', None)
40
+ if vm is None:
41
+ raise RuntimeError("No VM context active. Call VMContext.activate() first.")
42
+ return vm
43
+
44
+
45
+ def clear_vm() -> None:
46
+ """Clear the VM context and clean up resources."""
47
+ if hasattr(_local, 'fd_buffers'):
48
+ for buf in _local.fd_buffers.values():
49
+ buf.close()
50
+ _local.fd_buffers.clear()
51
+ _local.vm = None
52
+ _local.fd_counter = 100
53
+
54
+
55
+ def storage_read(slot: bytes, off: int, buf: bytearray, /) -> None:
56
+ """Read from storage slot into buffer."""
57
+ vm = get_vm()
58
+ data = vm._storage.do_read(slot, off, len(buf))
59
+ buf[:] = data
60
+
61
+
62
+ def storage_write(slot: bytes, off: int, what: bytes, /) -> None:
63
+ """Write to storage slot."""
64
+ vm = get_vm()
65
+ vm._storage.do_write(slot, off, what)
66
+
67
+
68
+ def get_balance(address: bytes, /) -> int:
69
+ """Get balance of an address."""
70
+ vm = get_vm()
71
+ return vm._balances.get(address, 0)
72
+
73
+
74
+ def get_self_balance() -> int:
75
+ """Get balance of current contract."""
76
+ vm = get_vm()
77
+ contract_addr = vm._contract_address
78
+ if contract_addr is None:
79
+ return 0
80
+ addr_bytes = bytes(contract_addr) if hasattr(contract_addr, '__bytes__') else contract_addr
81
+ return vm._balances.get(addr_bytes, 0)
82
+
83
+
84
+ def gl_call(data: bytes, /) -> int:
85
+ """
86
+ Execute a GenVM call operation.
87
+
88
+ Returns file descriptor for reading response, or 2^32-1 on failure.
89
+ """
90
+ vm = get_vm()
91
+ fd_buffers = getattr(_local, 'fd_buffers', {})
92
+
93
+ try:
94
+ from genlayer.py import calldata
95
+ request = calldata.decode(data)
96
+ except Exception as e:
97
+ vm._trace(f"gl_call decode error: {e}")
98
+ return 2**32 - 1
99
+
100
+ response = _handle_gl_call(vm, request)
101
+
102
+ if response is None:
103
+ return 2**32 - 1
104
+
105
+ try:
106
+ from genlayer.py import calldata
107
+ encoded = calldata.encode(response)
108
+ except Exception as e:
109
+ vm._trace(f"gl_call encode error: {e}")
110
+ return 2**32 - 1
111
+
112
+ fd_counter = getattr(_local, 'fd_counter', 100)
113
+ fd = fd_counter
114
+ _local.fd_counter = fd_counter + 1
115
+
116
+ buf = io.BytesIO(encoded)
117
+ fd_buffers[fd] = buf
118
+ _local.fd_buffers = fd_buffers
119
+
120
+ return fd
121
+
122
+
123
+ def _handle_gl_call(vm: "VMContext", request: Any) -> Any:
124
+ """Handle a gl_call request and return the response."""
125
+ if not isinstance(request, dict):
126
+ return None
127
+
128
+ if "Return" in request:
129
+ vm._return_value = request["Return"]
130
+ vm._returned = True
131
+ return None
132
+
133
+ if "Rollback" in request:
134
+ raise ContractRollback(request["Rollback"])
135
+
136
+ if "Trace" in request:
137
+ trace_data = request["Trace"]
138
+ if "Message" in trace_data:
139
+ vm._trace(trace_data["Message"])
140
+ return {"ok": None}
141
+
142
+ if "Sandbox" in request:
143
+ warnings.warn(
144
+ "gl.sandbox is not fully isolated in direct test mode.",
145
+ RuntimeWarning,
146
+ stacklevel=3,
147
+ )
148
+ return {"ok": None}
149
+
150
+ if "RunNondet" in request:
151
+ return {"ok": None}
152
+
153
+ if "GetWebsite" in request or "WebRequest" in request:
154
+ web_data = request.get("GetWebsite") or request.get("WebRequest", {})
155
+ return _handle_web_request(vm, web_data)
156
+
157
+ if "ExecPrompt" in request:
158
+ prompt_data = request["ExecPrompt"]
159
+ return _handle_llm_request(vm, prompt_data)
160
+
161
+ vm._trace(f"Unknown gl_call request type: {list(request.keys())}")
162
+ return None
163
+
164
+
165
+ def _handle_web_request(vm: "VMContext", data: Any) -> Any:
166
+ """Handle web request using mocks."""
167
+ url = data.get("url", "")
168
+ method = data.get("method", "GET")
169
+
170
+ response = vm._match_web_mock(url, method)
171
+ if response:
172
+ return {"ok": response}
173
+
174
+ raise MockNotFoundError(f"No web mock for {method} {url}")
175
+
176
+
177
+ def _handle_llm_request(vm: "VMContext", data: Any) -> Any:
178
+ """Handle LLM prompt request using mocks."""
179
+ prompt = data.get("prompt", "")
180
+
181
+ response = vm._match_llm_mock(prompt)
182
+ if response:
183
+ return {"ok": response}
184
+
185
+ raise MockNotFoundError(f"No LLM mock for prompt: {prompt[:100]}...")
186
+
187
+
188
+ def patched_fdopen(fd_arg: int, mode: str = "r", *args, **kwargs):
189
+ """Patched os.fdopen that intercepts mock file descriptors."""
190
+ fd_buffers = getattr(_local, 'fd_buffers', {})
191
+
192
+ if fd_arg in fd_buffers:
193
+ buf = fd_buffers.pop(fd_arg)
194
+ buf.seek(0)
195
+ return buf
196
+
197
+ return _original_fdopen(fd_arg, mode, *args, **kwargs)
198
+
199
+
200
+ class ContractRollback(Exception):
201
+ """Raised when a contract calls gl.rollback()."""
202
+
203
+ def __init__(self, message: str):
204
+ self.message = message
205
+ super().__init__(message)
206
+
207
+
208
+ class MockNotFoundError(Exception):
209
+ """Raised when no mock is found for a nondet operation."""
210
+ pass
211
+
212
+
213
+ __all__ = (
214
+ "storage_read",
215
+ "storage_write",
216
+ "get_balance",
217
+ "get_self_balance",
218
+ "gl_call",
219
+ )
gltest/types.py CHANGED
@@ -20,6 +20,20 @@ class MockedLLMResponse(TypedDict):
20
20
  eq_principle_prompt_non_comparative: Dict[str, bool]
21
21
 
22
22
 
23
+ class MockedWebResponseData(TypedDict):
24
+ """Mocked web response data with method for matching"""
25
+
26
+ method: str # GET, POST, PUT, DELETE, etc.
27
+ status: int # status code of the response
28
+ body: str # body of the response
29
+
30
+
31
+ class MockedWebResponse(TypedDict):
32
+ """Maps urls to responses"""
33
+
34
+ nondet_web_request: Dict[str, MockedWebResponseData]
35
+
36
+
23
37
  class ValidatorConfig(TypedDict):
24
38
  """Validator information."""
25
39