py-adtools 0.1.2__tar.gz → 0.1.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of py-adtools might be problematic. Click here for more details.
- {py_adtools-0.1.2 → py_adtools-0.1.4}/PKG-INFO +11 -9
- {py_adtools-0.1.2 → py_adtools-0.1.4}/README.md +9 -7
- py_adtools-0.1.4/adtools/__init__.py +3 -0
- {py_adtools-0.1.2 → py_adtools-0.1.4}/adtools/evaluator.py +81 -47
- py_adtools-0.1.4/adtools/evaluator_pool.py +82 -0
- py_adtools-0.1.4/adtools/lm_base.py +403 -0
- {py_adtools-0.1.2 → py_adtools-0.1.4}/adtools/py_code.py +69 -47
- {py_adtools-0.1.2 → py_adtools-0.1.4}/py_adtools.egg-info/PKG-INFO +11 -9
- {py_adtools-0.1.2 → py_adtools-0.1.4}/py_adtools.egg-info/SOURCES.txt +2 -0
- {py_adtools-0.1.2 → py_adtools-0.1.4}/setup.py +2 -2
- py_adtools-0.1.2/adtools/__init__.py +0 -2
- {py_adtools-0.1.2 → py_adtools-0.1.4}/LICENSE +0 -0
- {py_adtools-0.1.2 → py_adtools-0.1.4}/py_adtools.egg-info/dependency_links.txt +0 -0
- {py_adtools-0.1.2 → py_adtools-0.1.4}/py_adtools.egg-info/requires.txt +0 -0
- {py_adtools-0.1.2 → py_adtools-0.1.4}/py_adtools.egg-info/top_level.txt +0 -0
- {py_adtools-0.1.2 → py_adtools-0.1.4}/setup.cfg +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: py-adtools
|
|
3
|
-
Version: 0.1.
|
|
4
|
-
Summary: Useful tools for parsing and evaluating Python programs for algorithm design.
|
|
3
|
+
Version: 0.1.4
|
|
4
|
+
Summary: Useful tools for parsing and evaluating Python programs for LLM-based algorithm design.
|
|
5
5
|
Home-page: https://github.com/RayZhhh/py-adtools
|
|
6
6
|
Author: Rui Zhang
|
|
7
7
|
Author-email: rzhang.cs@gmail.com
|
|
@@ -34,7 +34,7 @@ Dynamic: summary
|
|
|
34
34
|
|
|
35
35
|
------
|
|
36
36
|
|
|
37
|
-
The figure demonstrates how a Python program is parsed into `
|
|
37
|
+
The figure demonstrates how a Python program is parsed into `PyCodeBlock`, `PyFunction`, `PyClass,` and `PyProgram` via `adtools`.
|
|
38
38
|
|
|
39
39
|

|
|
40
40
|
|
|
@@ -68,7 +68,7 @@ Parse your code (in string) into Python code instances, so that you can check ea
|
|
|
68
68
|
from adtools import PyProgram
|
|
69
69
|
|
|
70
70
|
code = r'''
|
|
71
|
-
import ast, numba # This part will be parsed into
|
|
71
|
+
import ast, numba # This part will be parsed into PyCodeBlock
|
|
72
72
|
import numpy as np
|
|
73
73
|
|
|
74
74
|
@numba.jit() # This part will be parsed into PyFunction
|
|
@@ -80,8 +80,9 @@ def function(arg1, arg2=True):
|
|
|
80
80
|
|
|
81
81
|
@some.decorators() # This part will be parsed into PyClass
|
|
82
82
|
class PythonClass(BaseClass):
|
|
83
|
-
|
|
84
|
-
|
|
83
|
+
|
|
84
|
+
class_var1 = 1 # This part will be parsed into PyCodeBlock
|
|
85
|
+
class_var2 = 2 # and placed in PyClass.class_vars_and_code
|
|
85
86
|
|
|
86
87
|
def __init__(self, x): # This part will be parsed into PyFunction
|
|
87
88
|
self.x = x # and placed in PyClass.functions
|
|
@@ -93,11 +94,11 @@ class PythonClass(BaseClass):
|
|
|
93
94
|
def method2(self, x, y):
|
|
94
95
|
return x + y + self.method1(x)
|
|
95
96
|
|
|
96
|
-
class InnerClass: # This part will be parsed into
|
|
97
|
+
class InnerClass: # This part will be parsed into PyCodeBlock
|
|
97
98
|
def __init__(self): # and placed in PyClass.class_vars_and_code
|
|
98
99
|
...
|
|
99
100
|
|
|
100
|
-
if __name__ == '__main__': # This part will be parsed into
|
|
101
|
+
if __name__ == '__main__': # This part will be parsed into PyCodeBlock
|
|
101
102
|
res = function(1)
|
|
102
103
|
print(res)
|
|
103
104
|
res = PythonClass().method2(1, 2)
|
|
@@ -116,7 +117,7 @@ print(p.functions[0].name)
|
|
|
116
117
|
Evaluate Python programs in a secure process to avoid the abortation of the main process. Two steps:
|
|
117
118
|
|
|
118
119
|
- Extend the `PyEvaluator` class and override the `evaluate_program` method.
|
|
119
|
-
- Evaluate the program (in str) by calling the `evaluate` (directly evaluate without executing in a sandbox process) or the `secure_evaluate` (evaluate in a sandbox process) methods.
|
|
120
|
+
- Evaluate the program (in str) by calling the `evaluate` (directly evaluate without executing in a sandbox process) or the `secure_evaluate` (evaluate in a sandbox process) methods.
|
|
120
121
|
|
|
121
122
|
```python
|
|
122
123
|
import time
|
|
@@ -190,6 +191,7 @@ def merge(left, right):
|
|
|
190
191
|
|
|
191
192
|
harmful_code_generated_by_llm = '''
|
|
192
193
|
def merge_sort(arr):
|
|
194
|
+
print('I am harmful') # There will be no output since we redirect STDOUT to /dev/null by default.
|
|
193
195
|
while True:
|
|
194
196
|
pass
|
|
195
197
|
'''
|
|
@@ -8,7 +8,7 @@
|
|
|
8
8
|
|
|
9
9
|
------
|
|
10
10
|
|
|
11
|
-
The figure demonstrates how a Python program is parsed into `
|
|
11
|
+
The figure demonstrates how a Python program is parsed into `PyCodeBlock`, `PyFunction`, `PyClass,` and `PyProgram` via `adtools`.
|
|
12
12
|
|
|
13
13
|

|
|
14
14
|
|
|
@@ -42,7 +42,7 @@ Parse your code (in string) into Python code instances, so that you can check ea
|
|
|
42
42
|
from adtools import PyProgram
|
|
43
43
|
|
|
44
44
|
code = r'''
|
|
45
|
-
import ast, numba # This part will be parsed into
|
|
45
|
+
import ast, numba # This part will be parsed into PyCodeBlock
|
|
46
46
|
import numpy as np
|
|
47
47
|
|
|
48
48
|
@numba.jit() # This part will be parsed into PyFunction
|
|
@@ -54,8 +54,9 @@ def function(arg1, arg2=True):
|
|
|
54
54
|
|
|
55
55
|
@some.decorators() # This part will be parsed into PyClass
|
|
56
56
|
class PythonClass(BaseClass):
|
|
57
|
-
|
|
58
|
-
|
|
57
|
+
|
|
58
|
+
class_var1 = 1 # This part will be parsed into PyCodeBlock
|
|
59
|
+
class_var2 = 2 # and placed in PyClass.class_vars_and_code
|
|
59
60
|
|
|
60
61
|
def __init__(self, x): # This part will be parsed into PyFunction
|
|
61
62
|
self.x = x # and placed in PyClass.functions
|
|
@@ -67,11 +68,11 @@ class PythonClass(BaseClass):
|
|
|
67
68
|
def method2(self, x, y):
|
|
68
69
|
return x + y + self.method1(x)
|
|
69
70
|
|
|
70
|
-
class InnerClass: # This part will be parsed into
|
|
71
|
+
class InnerClass: # This part will be parsed into PyCodeBlock
|
|
71
72
|
def __init__(self): # and placed in PyClass.class_vars_and_code
|
|
72
73
|
...
|
|
73
74
|
|
|
74
|
-
if __name__ == '__main__': # This part will be parsed into
|
|
75
|
+
if __name__ == '__main__': # This part will be parsed into PyCodeBlock
|
|
75
76
|
res = function(1)
|
|
76
77
|
print(res)
|
|
77
78
|
res = PythonClass().method2(1, 2)
|
|
@@ -90,7 +91,7 @@ print(p.functions[0].name)
|
|
|
90
91
|
Evaluate Python programs in a secure process to avoid the abortation of the main process. Two steps:
|
|
91
92
|
|
|
92
93
|
- Extend the `PyEvaluator` class and override the `evaluate_program` method.
|
|
93
|
-
- Evaluate the program (in str) by calling the `evaluate` (directly evaluate without executing in a sandbox process) or the `secure_evaluate` (evaluate in a sandbox process) methods.
|
|
94
|
+
- Evaluate the program (in str) by calling the `evaluate` (directly evaluate without executing in a sandbox process) or the `secure_evaluate` (evaluate in a sandbox process) methods.
|
|
94
95
|
|
|
95
96
|
```python
|
|
96
97
|
import time
|
|
@@ -164,6 +165,7 @@ def merge(left, right):
|
|
|
164
165
|
|
|
165
166
|
harmful_code_generated_by_llm = '''
|
|
166
167
|
def merge_sort(arr):
|
|
168
|
+
print('I am harmful') # There will be no output since we redirect STDOUT to /dev/null by default.
|
|
167
169
|
while True:
|
|
168
170
|
pass
|
|
169
171
|
'''
|
|
@@ -1,7 +1,13 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright (c) 2025 Rui Zhang <rzhang.cs@gmail.com>
|
|
3
|
+
|
|
4
|
+
NOTICE: This code is under MIT license. This code is intended for academic/research purposes only.
|
|
5
|
+
Commercial use of this software or its derivatives requires prior written permission.
|
|
6
|
+
"""
|
|
7
|
+
|
|
1
8
|
import multiprocessing
|
|
2
9
|
import os
|
|
3
10
|
import sys
|
|
4
|
-
import time
|
|
5
11
|
from abc import ABC, abstractmethod
|
|
6
12
|
from queue import Empty
|
|
7
13
|
from typing import Any, Literal, Dict, Callable, List
|
|
@@ -12,15 +18,26 @@ from .py_code import PyProgram
|
|
|
12
18
|
|
|
13
19
|
class PyEvaluator(ABC):
|
|
14
20
|
|
|
15
|
-
def __init__(
|
|
16
|
-
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
exec_code: bool = True,
|
|
24
|
+
debug_mode: bool = False,
|
|
25
|
+
*,
|
|
26
|
+
join_timeout_seconds: int = 10
|
|
27
|
+
):
|
|
28
|
+
"""Evaluator interface for evaluating the Python algorithm program. Override this class and implement
|
|
29
|
+
'evaluate_program' method, then invoke 'self.evaluate()' or 'self.secure_evaluate()' for evaluation.
|
|
17
30
|
Args:
|
|
31
|
+
exec_code: Using 'exec()' to execute the program code and obtain the callable functions and classes,
|
|
32
|
+
which will be passed to 'self.evaluate_program()'. Set this parameter to 'False' if you are going to
|
|
33
|
+
evaluate a Python scripy. Note that if the parameter is set to 'False', the arguments 'callable_...'
|
|
34
|
+
in 'self.evaluate_program()' will no longer be affective.
|
|
18
35
|
debug_mode: Debug mode.
|
|
19
|
-
|
|
36
|
+
join_timeout_seconds: Timeout in seconds to wait for the process to finish. Kill the process if timeout.
|
|
20
37
|
"""
|
|
21
|
-
self.
|
|
22
|
-
self.
|
|
23
|
-
self.
|
|
38
|
+
self.debug_mode = debug_mode
|
|
39
|
+
self.exec_code = exec_code
|
|
40
|
+
self.join_timeout_seconds = join_timeout_seconds
|
|
24
41
|
|
|
25
42
|
@abstractmethod
|
|
26
43
|
def evaluate_program(
|
|
@@ -31,19 +48,21 @@ class PyEvaluator(ABC):
|
|
|
31
48
|
callable_classes_dict: Dict[str, Callable] | None,
|
|
32
49
|
callable_classes_list: List[Callable] | None,
|
|
33
50
|
**kwargs
|
|
34
|
-
) -> Any
|
|
51
|
+
) -> Any:
|
|
35
52
|
"""Evaluate a given program.
|
|
36
53
|
Args:
|
|
37
|
-
program_str
|
|
54
|
+
program_str: The raw program text.
|
|
38
55
|
callable_functions_dict: A dict maps function name to callable function.
|
|
39
56
|
callable_functions_list: A list of callable functions.
|
|
40
|
-
callable_classes_dict
|
|
41
|
-
callable_classes_list
|
|
42
|
-
|
|
57
|
+
callable_classes_dict: A dict maps class name to callable class.
|
|
58
|
+
callable_classes_list: A list of callable classes.
|
|
59
|
+
Returns:
|
|
43
60
|
Returns the evaluation result.
|
|
44
61
|
"""
|
|
45
|
-
raise NotImplementedError(
|
|
46
|
-
|
|
62
|
+
raise NotImplementedError(
|
|
63
|
+
'Must provide an evaluator for a python program. '
|
|
64
|
+
'Override this method in a subclass.'
|
|
65
|
+
)
|
|
47
66
|
|
|
48
67
|
def _kill_process_and_its_children(self, process: multiprocessing.Process):
|
|
49
68
|
# Find all children processes
|
|
@@ -54,51 +73,57 @@ class PyEvaluator(ABC):
|
|
|
54
73
|
children_processes = []
|
|
55
74
|
# Terminate parent process
|
|
56
75
|
process.terminate()
|
|
57
|
-
process.join(timeout=self.
|
|
76
|
+
process.join(timeout=self.join_timeout_seconds)
|
|
58
77
|
if process.is_alive():
|
|
59
78
|
process.kill()
|
|
60
79
|
process.join()
|
|
61
80
|
# Kill all children processes
|
|
62
81
|
for child in children_processes:
|
|
63
|
-
if self.
|
|
82
|
+
if self.debug_mode:
|
|
64
83
|
print(f"Killing process {process.pid}'s children process {child.pid}")
|
|
65
84
|
child.terminate()
|
|
66
85
|
|
|
67
|
-
def evaluate(self,
|
|
86
|
+
def evaluate(self, program: str | PyProgram, **kwargs):
|
|
87
|
+
"""Evaluate a program.
|
|
88
|
+
Args:
|
|
89
|
+
program: the program to be evaluated.
|
|
90
|
+
**kwargs: additional keyword arguments to pass to 'evaluate_program'.
|
|
91
|
+
"""
|
|
68
92
|
try:
|
|
69
93
|
# Parse to program instance
|
|
70
|
-
program
|
|
94
|
+
if isinstance(program, str):
|
|
95
|
+
program = PyProgram.from_text(program)
|
|
71
96
|
function_names = [f.name for f in program.functions]
|
|
72
97
|
class_names = [c.name for c in program.classes]
|
|
73
|
-
|
|
74
|
-
|
|
98
|
+
|
|
99
|
+
# Execute the code and get callable instances
|
|
100
|
+
if self.exec_code:
|
|
75
101
|
all_globals_namespace = {}
|
|
76
102
|
# Execute the program, map func/var/class to global namespace
|
|
77
|
-
exec(
|
|
103
|
+
exec(str(program), all_globals_namespace)
|
|
78
104
|
# Get callable functions
|
|
79
|
-
|
|
80
|
-
|
|
105
|
+
callable_funcs_list = [all_globals_namespace[f_name] for f_name in function_names]
|
|
106
|
+
callable_funcs_dict = dict(zip(function_names, callable_funcs_list))
|
|
81
107
|
# Get callable classes
|
|
82
|
-
|
|
83
|
-
|
|
108
|
+
callable_cls_list = [all_globals_namespace[c_name] for c_name in class_names]
|
|
109
|
+
callable_cls_dict = dict(zip(class_names, callable_cls_list))
|
|
84
110
|
else:
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
callable_classes_dict = None
|
|
111
|
+
callable_funcs_list, callable_funcs_dict, callable_cls_list, callable_cls_dict = (
|
|
112
|
+
None, None, None, None
|
|
113
|
+
)
|
|
89
114
|
|
|
90
115
|
# Get evaluate result
|
|
91
116
|
res = self.evaluate_program(
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
117
|
+
str(program),
|
|
118
|
+
callable_funcs_dict,
|
|
119
|
+
callable_funcs_list,
|
|
120
|
+
callable_cls_dict,
|
|
121
|
+
callable_cls_list,
|
|
97
122
|
**kwargs
|
|
98
123
|
)
|
|
99
124
|
return res
|
|
100
125
|
except Exception as e:
|
|
101
|
-
if self.
|
|
126
|
+
if self.debug_mode:
|
|
102
127
|
print(e)
|
|
103
128
|
return None
|
|
104
129
|
|
|
@@ -109,10 +134,13 @@ class PyEvaluator(ABC):
|
|
|
109
134
|
redirect_to_devnull: bool,
|
|
110
135
|
**kwargs
|
|
111
136
|
):
|
|
137
|
+
# Redirect STDOUT and STDERR to '/dev/null'
|
|
112
138
|
if redirect_to_devnull:
|
|
113
|
-
with open(
|
|
139
|
+
with open(os.devnull, 'w') as devnull:
|
|
114
140
|
os.dup2(devnull.fileno(), sys.stdout.fileno())
|
|
115
141
|
os.dup2(devnull.fileno(), sys.stderr.fileno())
|
|
142
|
+
|
|
143
|
+
# Evaluate and put the results to the queue
|
|
116
144
|
res = self.evaluate(program_str, **kwargs)
|
|
117
145
|
result_queue.put(res)
|
|
118
146
|
|
|
@@ -120,24 +148,27 @@ class PyEvaluator(ABC):
|
|
|
120
148
|
self,
|
|
121
149
|
program: str | PyProgram,
|
|
122
150
|
timeout_seconds: int | float = None,
|
|
123
|
-
redirect_to_devnull: bool =
|
|
124
|
-
multiprocessing_start_method
|
|
151
|
+
redirect_to_devnull: bool = False,
|
|
152
|
+
multiprocessing_start_method: Literal['default', 'auto', 'fork', 'spawn'] = 'auto',
|
|
125
153
|
**kwargs
|
|
126
154
|
):
|
|
127
|
-
"""
|
|
155
|
+
"""Evaluate program in a new process. This enables timeout restriction and output redirection.
|
|
128
156
|
Args:
|
|
129
157
|
program: the program to be evaluated.
|
|
130
158
|
timeout_seconds: return 'None' if the execution time exceeds 'timeout_seconds'.
|
|
131
159
|
redirect_to_devnull: redirect any output to '/dev/null'.
|
|
132
|
-
multiprocessing_start_method: start a process using 'fork' or 'spawn'.
|
|
160
|
+
multiprocessing_start_method: start a process using 'fork' or 'spawn'. If set to 'auto',
|
|
161
|
+
the process will be started using 'fork' with Linux/macOS and 'spawn' with Windows.
|
|
162
|
+
If set to 'default', there will be no changes to system default.
|
|
163
|
+
**kwargs: additional keyword arguments to pass to 'evaluate_program'.
|
|
133
164
|
"""
|
|
134
165
|
if multiprocessing_start_method == 'auto':
|
|
135
|
-
# Force
|
|
166
|
+
# Force macOS and Linux use 'fork' to generate new process
|
|
136
167
|
if sys.platform.startswith('darwin') or sys.platform.startswith('linux'):
|
|
137
168
|
multiprocessing.set_start_method('fork', force=True)
|
|
138
169
|
elif multiprocessing_start_method == 'fork':
|
|
139
170
|
multiprocessing.set_start_method('fork', force=True)
|
|
140
|
-
|
|
171
|
+
elif multiprocessing_start_method == 'spawn':
|
|
141
172
|
multiprocessing.set_start_method('spawn', force=True)
|
|
142
173
|
|
|
143
174
|
try:
|
|
@@ -156,22 +187,25 @@ class PyEvaluator(ABC):
|
|
|
156
187
|
result = result_queue.get(timeout=timeout_seconds)
|
|
157
188
|
# After getting the result, terminate/kill the process
|
|
158
189
|
self._kill_process_and_its_children(process)
|
|
159
|
-
except Empty:
|
|
160
|
-
|
|
161
|
-
if self._debug_mode:
|
|
190
|
+
except Empty: # The queue is empty indicates a timeout
|
|
191
|
+
if self.debug_mode:
|
|
162
192
|
print(f'DEBUG: the evaluation time exceeds {timeout_seconds}s.')
|
|
193
|
+
# Terminate/kill all processes if timeout happens
|
|
163
194
|
self._kill_process_and_its_children(process)
|
|
164
195
|
result = None
|
|
165
196
|
except Exception as e:
|
|
166
|
-
if self.
|
|
197
|
+
if self.debug_mode:
|
|
167
198
|
print(f'DEBUG: evaluation failed with exception:\n{e}')
|
|
199
|
+
# Terminate/kill all processes if meet exceptions
|
|
168
200
|
self._kill_process_and_its_children(process)
|
|
169
201
|
result = None
|
|
170
202
|
else:
|
|
203
|
+
# If there is no timeout limit, wait execution to finish
|
|
171
204
|
result = result_queue.get()
|
|
205
|
+
# Terminate/kill all processes after evaluation
|
|
172
206
|
self._kill_process_and_its_children(process)
|
|
173
207
|
return result
|
|
174
208
|
except Exception as e:
|
|
175
|
-
if self.
|
|
209
|
+
if self.debug_mode:
|
|
176
210
|
print(e)
|
|
177
211
|
return None
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright (c) 2025 Rui Zhang <rzhang.cs@gmail.com>
|
|
3
|
+
|
|
4
|
+
NOTICE: This code is under MIT license. This code is intended for academic/research purposes only.
|
|
5
|
+
Commercial use of this software or its derivatives requires prior written permission.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import time
|
|
9
|
+
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
|
|
10
|
+
from typing import Literal, Optional
|
|
11
|
+
|
|
12
|
+
from .evaluator import PyEvaluator
|
|
13
|
+
from .py_code import PyProgram
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class EvaluatorExecutorPool:
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
evaluator: PyEvaluator,
|
|
20
|
+
max_workers: int,
|
|
21
|
+
pool_type: Literal['thread', 'process'] = 'thread'
|
|
22
|
+
):
|
|
23
|
+
"""Multi-thread/process executor pool for parallel evaluation.
|
|
24
|
+
Args:
|
|
25
|
+
evaluator: The PyEvaluator instance.
|
|
26
|
+
max_workers: The maximum number of workers.
|
|
27
|
+
pool_type: Type of the executor pool.
|
|
28
|
+
"""
|
|
29
|
+
self.evaluator = evaluator
|
|
30
|
+
self.max_workers = max_workers
|
|
31
|
+
if pool_type == 'thread':
|
|
32
|
+
self.pool = ThreadPoolExecutor(max_workers=self.max_workers)
|
|
33
|
+
else:
|
|
34
|
+
self.pool = ProcessPoolExecutor(max_workers=self.max_workers)
|
|
35
|
+
|
|
36
|
+
def evaluate(self, program: str | PyProgram, return_time=True, **kwargs):
|
|
37
|
+
"""Evaluate program.
|
|
38
|
+
Args:
|
|
39
|
+
program: the program to be evaluated.
|
|
40
|
+
**kwargs: additional keyword arguments to pass to 'evaluate_program'.
|
|
41
|
+
"""
|
|
42
|
+
start_time = time.time()
|
|
43
|
+
future = self.pool.submit(self.evaluator.evaluate, program, **kwargs)
|
|
44
|
+
res = future.result()
|
|
45
|
+
duration = time.time() - start_time
|
|
46
|
+
if return_time:
|
|
47
|
+
return res, duration
|
|
48
|
+
else:
|
|
49
|
+
return res
|
|
50
|
+
|
|
51
|
+
def secure_evaluate(
|
|
52
|
+
self,
|
|
53
|
+
program: str | PyProgram,
|
|
54
|
+
timeout_seconds: Optional[float],
|
|
55
|
+
redirect_to_devnull: bool = False,
|
|
56
|
+
multiprocessing_start_method: Literal['default', 'auto', 'fork', 'spawn'] = 'auto',
|
|
57
|
+
return_time=True,
|
|
58
|
+
**kwargs
|
|
59
|
+
):
|
|
60
|
+
"""Evaluate program in a new process. This enables timeout restriction and output redirection.
|
|
61
|
+
Args:
|
|
62
|
+
program: the program to be evaluated.
|
|
63
|
+
timeout_seconds: return 'None' if the execution time exceeds 'timeout_seconds'.
|
|
64
|
+
redirect_to_devnull: redirect any output to '/dev/null'.
|
|
65
|
+
multiprocessing_start_method: start a process using 'fork' or 'spawn'.
|
|
66
|
+
**kwargs: additional keyword arguments to pass to 'evaluate_program'.
|
|
67
|
+
"""
|
|
68
|
+
start_time = time.time()
|
|
69
|
+
future = self.pool.submit(
|
|
70
|
+
self.evaluator.secure_evaluate,
|
|
71
|
+
program,
|
|
72
|
+
timeout_seconds,
|
|
73
|
+
redirect_to_devnull,
|
|
74
|
+
multiprocessing_start_method,
|
|
75
|
+
**kwargs
|
|
76
|
+
)
|
|
77
|
+
res = future.result()
|
|
78
|
+
duration = time.time() - start_time
|
|
79
|
+
if return_time:
|
|
80
|
+
return res, duration
|
|
81
|
+
else:
|
|
82
|
+
return res
|
|
@@ -0,0 +1,403 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright (c) 2025 Rui Zhang <rzhang.cs@gmail.com>
|
|
3
|
+
|
|
4
|
+
NOTICE: This code is under MIT license. This code is intended for academic/research purposes only.
|
|
5
|
+
Commercial use of this software or its derivatives requires prior written permission.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from abc import abstractmethod
|
|
9
|
+
from typing import Optional, List, Literal, Dict, Any
|
|
10
|
+
import os
|
|
11
|
+
import subprocess
|
|
12
|
+
import sys
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
import psutil
|
|
15
|
+
import requests
|
|
16
|
+
import time
|
|
17
|
+
|
|
18
|
+
import openai.types.chat
|
|
19
|
+
|
|
20
|
+
__all__ = ['LanguageModel', 'OpenAIAPI', 'VLLMServer']
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class LanguageModel:
|
|
24
|
+
"""Base class for language model interface."""
|
|
25
|
+
|
|
26
|
+
@abstractmethod
|
|
27
|
+
def chat_completion(
|
|
28
|
+
self,
|
|
29
|
+
message: str | List[openai.types.chat.ChatCompletionMessageParam],
|
|
30
|
+
max_tokens: int,
|
|
31
|
+
timeout_seconds: float,
|
|
32
|
+
*args,
|
|
33
|
+
**kwargs
|
|
34
|
+
):
|
|
35
|
+
"""Send a chat completion query with OpenAI format to the vLLM server. Return the response content.
|
|
36
|
+
Args:
|
|
37
|
+
message: The message in str or openai format.
|
|
38
|
+
max_tokens: The maximum number of tokens to generate.
|
|
39
|
+
timeout_seconds: The timeout seconds.
|
|
40
|
+
"""
|
|
41
|
+
pass
|
|
42
|
+
|
|
43
|
+
def close(self):
|
|
44
|
+
"""Release resources (if necessary)."""
|
|
45
|
+
pass
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class OpenAIAPI(LanguageModel):
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
model: str,
|
|
52
|
+
base_url: str = None,
|
|
53
|
+
api_key: str = None,
|
|
54
|
+
**openai_init_kwargs
|
|
55
|
+
):
|
|
56
|
+
super().__init__()
|
|
57
|
+
# If base_url is set to None, find 'OPENAI_BASE_URL' in environment variables
|
|
58
|
+
if base_url is None:
|
|
59
|
+
if 'OPENAI_BASE_URL' not in os.environ:
|
|
60
|
+
raise RuntimeError('If "base_url" is None, the environment variable OPENAI_BASE_URL must be set.')
|
|
61
|
+
else:
|
|
62
|
+
base_url = os.environ['OPENAI_BASE_URL']
|
|
63
|
+
|
|
64
|
+
# If api_key is set to None, find 'OPENAI_API_KEY' in environment variables
|
|
65
|
+
if api_key is None:
|
|
66
|
+
if 'OPENAI_API_KEY' not in os.environ:
|
|
67
|
+
raise RuntimeError('If "api_key" is None, OPENAI_API_KEY must be set.')
|
|
68
|
+
else:
|
|
69
|
+
api_key = os.environ['OPENAI_API_KEY']
|
|
70
|
+
|
|
71
|
+
self._model = model
|
|
72
|
+
self._client = openai.OpenAI(
|
|
73
|
+
api_key=api_key,
|
|
74
|
+
base_url=base_url,
|
|
75
|
+
**openai_init_kwargs
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
def chat_completion(
|
|
79
|
+
self,
|
|
80
|
+
message: str | List[openai.types.chat.ChatCompletionMessageParam],
|
|
81
|
+
max_tokens: int,
|
|
82
|
+
timeout_seconds: float,
|
|
83
|
+
*args,
|
|
84
|
+
**kwargs
|
|
85
|
+
):
|
|
86
|
+
"""Send a chat completion query with OpenAI format to the vLLM server. Return the response content.
|
|
87
|
+
Args:
|
|
88
|
+
message: The message in str or openai format.
|
|
89
|
+
max_tokens: The maximum number of tokens to generate.
|
|
90
|
+
timeout_seconds: The timeout seconds.
|
|
91
|
+
"""
|
|
92
|
+
if isinstance(message, str):
|
|
93
|
+
message = [{'role': 'user', 'content': message.strip()}]
|
|
94
|
+
|
|
95
|
+
response = self._client.chat.completions.create(
|
|
96
|
+
model=self._model,
|
|
97
|
+
messages=message,
|
|
98
|
+
stream=False,
|
|
99
|
+
max_tokens=max_tokens,
|
|
100
|
+
timeout=timeout_seconds,
|
|
101
|
+
*args,
|
|
102
|
+
**kwargs,
|
|
103
|
+
)
|
|
104
|
+
return response.choices[0].message.content
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def _print_cmd_list(cmd_list, gpus, host, port):
|
|
108
|
+
print('\n' + '=' * 80)
|
|
109
|
+
print(f'[vLLM] Launching vLLM on GPU:{gpus}; URL: https://{host}:{port}')
|
|
110
|
+
print('=' * 80)
|
|
111
|
+
cmd = cmd_list[0] + ' \\\n'
|
|
112
|
+
for c in cmd_list[1:]:
|
|
113
|
+
cmd += ' ' + c + ' \\\n'
|
|
114
|
+
print(cmd.strip())
|
|
115
|
+
print('=' * 80 + '\n', flush=True)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class VLLMServer:
|
|
119
|
+
def __init__(self,
|
|
120
|
+
model_path: str,
|
|
121
|
+
port: int,
|
|
122
|
+
gpus: int | list[int],
|
|
123
|
+
tokenizer_path: Optional[str] = None,
|
|
124
|
+
max_model_len: int = 16384,
|
|
125
|
+
max_lora_rank: Optional[int] = None,
|
|
126
|
+
host: str = '0.0.0.0',
|
|
127
|
+
mem_util: float = 0.85,
|
|
128
|
+
deploy_timeout_seconds: int = 600,
|
|
129
|
+
enforce_eager: bool = False,
|
|
130
|
+
vllm_log_level: Literal['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'] = 'INFO',
|
|
131
|
+
silent_mode: bool = False,
|
|
132
|
+
env_variable_dict: Optional[Dict[str, str]] = None,
|
|
133
|
+
vllm_serve_args: Optional[List[str]] = None,
|
|
134
|
+
vllm_serve_kwargs: Optional[Dict[str, str]] = None,
|
|
135
|
+
chat_template_kwargs: Optional[Dict[str, Any]] = None):
|
|
136
|
+
"""Deploy an LLM on specified GPUs.
|
|
137
|
+
Args:
|
|
138
|
+
model_path: Path to the model to deploy.
|
|
139
|
+
tokenizer_path: Path to the tokenizer to use.
|
|
140
|
+
port: List of ports to deploy.
|
|
141
|
+
gpus: List of GPUs to deploy.
|
|
142
|
+
max_lora_rank: Max rank of LoRA adapter. Defaults to `None` which disables LoRA adapter.
|
|
143
|
+
host: Host address for vLLM server.
|
|
144
|
+
mem_util: Memory utility for each vLLM deployment.
|
|
145
|
+
deploy_timeout_seconds: Timeout to deploy (in seconds).
|
|
146
|
+
enforce_eager: Enforce eager mode.
|
|
147
|
+
vllm_log_level: Log level of vLLM server.
|
|
148
|
+
silent_mode: Silent mode.
|
|
149
|
+
env_variable_dict: Environment variables to use for vLLM server, e.g., {'KEY': 'VALUE'}.
|
|
150
|
+
vllm_serve_args: Arguments to pass to vLLM server, e.g., ['--enable-reasoning'].
|
|
151
|
+
vllm_serve_kwargs: Keyword arguments to pass to vLLM server, e.g., {'--reasoning-parser': 'deepseek-r1'}.
|
|
152
|
+
|
|
153
|
+
Example:
|
|
154
|
+
# deploy a model on GPU 0 and 1
|
|
155
|
+
llm = VLLMServer(
|
|
156
|
+
model_path='path/to/model',
|
|
157
|
+
tokenizer_path='path/to/tokenizer',
|
|
158
|
+
gpus=[0, 1], # set gpus=0 or gpus=[0] if you only use one GPU
|
|
159
|
+
port=12001,
|
|
160
|
+
mem_util=0.8
|
|
161
|
+
)
|
|
162
|
+
# draw sample using base model
|
|
163
|
+
llm.draw_sample('hello')
|
|
164
|
+
|
|
165
|
+
# load adapter and draw sample
|
|
166
|
+
llm.load_lora_adapter('adapter_1', '/path/to/adapter')
|
|
167
|
+
llm.draw_sample('hello', lora_name='adapter_1')
|
|
168
|
+
|
|
169
|
+
# unload adapter
|
|
170
|
+
llm.unload_lora_adapter('adapter_1')
|
|
171
|
+
|
|
172
|
+
# release resources
|
|
173
|
+
llm.close()
|
|
174
|
+
"""
|
|
175
|
+
self._model_path = model_path
|
|
176
|
+
self._port = port
|
|
177
|
+
self._gpus = gpus
|
|
178
|
+
self._tokenizer_path = tokenizer_path if tokenizer_path is not None else model_path
|
|
179
|
+
self._max_model_len = max_model_len
|
|
180
|
+
self._max_lora_rank = max_lora_rank
|
|
181
|
+
self._host = host
|
|
182
|
+
self._mem_util = mem_util
|
|
183
|
+
self._deploy_timeout_seconds = deploy_timeout_seconds
|
|
184
|
+
self._enforce_eager = enforce_eager
|
|
185
|
+
self._vllm_log_level = vllm_log_level
|
|
186
|
+
self._silent_mode = silent_mode
|
|
187
|
+
self._env_variable_dict = env_variable_dict
|
|
188
|
+
self._vllm_serve_args = vllm_serve_args
|
|
189
|
+
self._vllm_serve_kwargs = vllm_serve_kwargs
|
|
190
|
+
self._chat_template_kwargs = chat_template_kwargs
|
|
191
|
+
|
|
192
|
+
# Deploy vLLMs
|
|
193
|
+
self._process = self._launch_vllm()
|
|
194
|
+
self._wait_for_vllm()
|
|
195
|
+
|
|
196
|
+
def _launch_vllm(self):
|
|
197
|
+
"""Launch a vLLM server and return the subprocess.
|
|
198
|
+
"""
|
|
199
|
+
if isinstance(self._gpus, int):
|
|
200
|
+
gpus = str(self._gpus)
|
|
201
|
+
else:
|
|
202
|
+
gpus = ','.join([str(g) for g in self._gpus])
|
|
203
|
+
|
|
204
|
+
executable_path = sys.executable
|
|
205
|
+
cmd = [
|
|
206
|
+
executable_path, '-m',
|
|
207
|
+
'vllm.entrypoints.openai.api_server',
|
|
208
|
+
'--model', self._model_path,
|
|
209
|
+
'--tokenizer', self._tokenizer_path,
|
|
210
|
+
'--max_model_len', str(self._max_model_len),
|
|
211
|
+
'--host', self._host,
|
|
212
|
+
'--port', str(self._port),
|
|
213
|
+
'--gpu-memory-utilization', str(self._mem_util),
|
|
214
|
+
'--tensor-parallel-size', str(len(self._gpus)) if isinstance(self._gpus, list) else '1',
|
|
215
|
+
'--trust-remote-code',
|
|
216
|
+
'--chat-template-content-format', 'string',
|
|
217
|
+
]
|
|
218
|
+
|
|
219
|
+
if self._enforce_eager:
|
|
220
|
+
cmd.append('--enforce_eager')
|
|
221
|
+
|
|
222
|
+
# Other args for vllm serve
|
|
223
|
+
if self._vllm_serve_args is not None:
|
|
224
|
+
for arg in self._vllm_serve_args:
|
|
225
|
+
cmd.append(arg)
|
|
226
|
+
|
|
227
|
+
# Other kwargs for vllm serve
|
|
228
|
+
if self._vllm_serve_kwargs is not None:
|
|
229
|
+
for kwarg, value in self._vllm_serve_kwargs.items():
|
|
230
|
+
cmd.extend([kwarg, value])
|
|
231
|
+
|
|
232
|
+
# Environmental variables
|
|
233
|
+
env = os.environ.copy()
|
|
234
|
+
env['CUDA_VISIBLE_DEVICES'] = gpus
|
|
235
|
+
env['VLLM_LOGGING_LEVEL'] = self._vllm_log_level
|
|
236
|
+
|
|
237
|
+
# FIXME: These code are required for my machine :-(
|
|
238
|
+
# FIXME: This may due to the bad NCCL environment configuration :-(
|
|
239
|
+
if isinstance(self._gpus, list) and len(self._gpus) > 1:
|
|
240
|
+
# set NCCL environment variable
|
|
241
|
+
env['NCCL_P2P_DISABLE'] = '1'
|
|
242
|
+
# disable custom all reduce
|
|
243
|
+
cmd.append('--disable-custom-all-reduce')
|
|
244
|
+
|
|
245
|
+
# Enable LoRA dynamic loading
|
|
246
|
+
if self._max_lora_rank is not None:
|
|
247
|
+
cmd.extend([
|
|
248
|
+
'--enable-lora',
|
|
249
|
+
'--max-lora-rank', str(self._max_lora_rank),
|
|
250
|
+
])
|
|
251
|
+
env['VLLM_ALLOW_RUNTIME_LORA_UPDATING'] = 'True'
|
|
252
|
+
|
|
253
|
+
# Other env variables
|
|
254
|
+
if self._env_variable_dict is not None:
|
|
255
|
+
for k, v in self._env_variable_dict.items():
|
|
256
|
+
env[k] = v
|
|
257
|
+
|
|
258
|
+
_print_cmd_list(cmd, gpus=self._gpus, host=self._host, port=self._port)
|
|
259
|
+
|
|
260
|
+
# Launch vllm using subprocess
|
|
261
|
+
stdout = Path(os.devnull).open('w') if self._silent_mode else None
|
|
262
|
+
proc = subprocess.Popen(cmd, env=env, stdout=stdout, stderr=subprocess.STDOUT)
|
|
263
|
+
return proc
|
|
264
|
+
|
|
265
|
+
def _kill_vllm_process(self):
|
|
266
|
+
try:
|
|
267
|
+
# Get child processes before terminating parent
|
|
268
|
+
try:
|
|
269
|
+
parent = psutil.Process(self._process.pid)
|
|
270
|
+
children = parent.children(recursive=True)
|
|
271
|
+
except psutil.NoSuchProcess:
|
|
272
|
+
children = []
|
|
273
|
+
|
|
274
|
+
# Terminate parent process
|
|
275
|
+
self._process.terminate()
|
|
276
|
+
self._process.wait(timeout=5)
|
|
277
|
+
print(f'[vLLM] terminated process: {self._process.pid}')
|
|
278
|
+
|
|
279
|
+
# Kill any remaining children
|
|
280
|
+
for child in children:
|
|
281
|
+
try:
|
|
282
|
+
child.terminate()
|
|
283
|
+
child.wait(timeout=2)
|
|
284
|
+
except (psutil.NoSuchProcess, psutil.TimeoutExpired):
|
|
285
|
+
try:
|
|
286
|
+
child.kill()
|
|
287
|
+
except psutil.NoSuchProcess:
|
|
288
|
+
pass
|
|
289
|
+
except subprocess.TimeoutExpired:
|
|
290
|
+
self._process.kill()
|
|
291
|
+
print(f'[vLLM] killed process: {self._process.pid}')
|
|
292
|
+
|
|
293
|
+
def _wait_for_vllm(self):
|
|
294
|
+
"""Check each vLLM server's state and check /health. Kill all vLLM server processes if timeout.
|
|
295
|
+
"""
|
|
296
|
+
for _ in range(self._deploy_timeout_seconds):
|
|
297
|
+
# check process status
|
|
298
|
+
if self._process.poll() is not None:
|
|
299
|
+
sys.exit(f'[vLLM] crashed (exit {self._process.returncode})')
|
|
300
|
+
|
|
301
|
+
# check server status
|
|
302
|
+
health = f'http://{self._host}:{self._port}/health'
|
|
303
|
+
try:
|
|
304
|
+
if requests.get(health, timeout=1).status_code == 200:
|
|
305
|
+
return
|
|
306
|
+
except Exception:
|
|
307
|
+
pass
|
|
308
|
+
time.sleep(1)
|
|
309
|
+
|
|
310
|
+
# Servers fail to initialize
|
|
311
|
+
print('[vLLM] failed to start within timeout')
|
|
312
|
+
self._kill_vllm_process()
|
|
313
|
+
sys.exit('[vLLM] failed to start within timeout')
|
|
314
|
+
|
|
315
|
+
def unload_lora_adapter(self, lora_name: str):
|
|
316
|
+
"""Unload lora adapter given the lora name.
|
|
317
|
+
Args:
|
|
318
|
+
lora_name: Lora adapter name.
|
|
319
|
+
"""
|
|
320
|
+
lora_api_url = f'http://{self._host}:{self._port}/v1/unload_lora_adapter'
|
|
321
|
+
headers = {'Content-Type': 'application/json'}
|
|
322
|
+
try:
|
|
323
|
+
payload = {'lora_name': lora_name}
|
|
324
|
+
requests.post(lora_api_url, json=payload, headers=headers, timeout=10)
|
|
325
|
+
except requests.exceptions.RequestException:
|
|
326
|
+
pass
|
|
327
|
+
|
|
328
|
+
def load_lora_adapter(self, lora_name: str, new_adapter_path: str, num_trails: int = 5):
|
|
329
|
+
"""Dynamically load a LoRA adapter.
|
|
330
|
+
Args:
|
|
331
|
+
lora_name: LoRA adapter name.
|
|
332
|
+
new_adapter_path: Path to the new LoRA adapter weights.
|
|
333
|
+
"""
|
|
334
|
+
# First unload lora adapter
|
|
335
|
+
self.unload_lora_adapter(lora_name)
|
|
336
|
+
|
|
337
|
+
if self._max_lora_rank is None:
|
|
338
|
+
raise ValueError('LoRA is not enabled for this VLLMServer instance, since "max_lora_rank" is not set.')
|
|
339
|
+
|
|
340
|
+
# Prepare the payload for LoRA update
|
|
341
|
+
payload = {'lora_name': lora_name, 'lora_path': new_adapter_path}
|
|
342
|
+
headers = {'Content-Type': 'application/json'}
|
|
343
|
+
lora_api_url = f'http://{self._host}:{self._port}/v1/load_lora_adapter'
|
|
344
|
+
|
|
345
|
+
# Repeatedly trying to load lora adapters
|
|
346
|
+
for i in range(num_trails):
|
|
347
|
+
try:
|
|
348
|
+
response = requests.post(lora_api_url, json=payload, headers=headers, timeout=60)
|
|
349
|
+
if response.status_code == 200:
|
|
350
|
+
print(f'[vLLM] Successfully load LoRA adapter: {lora_name} from {new_adapter_path}')
|
|
351
|
+
else:
|
|
352
|
+
print(f'[vLLM] Failed to load LoRA adapter. '
|
|
353
|
+
f'Status code: {response.status_code}, Response: {response.text}')
|
|
354
|
+
return True
|
|
355
|
+
except requests.exceptions.RequestException as e:
|
|
356
|
+
continue
|
|
357
|
+
|
|
358
|
+
print(f'[vLLM] Error loading LoRA adapter: {str(e)}')
|
|
359
|
+
return False
|
|
360
|
+
|
|
361
|
+
def close(self):
|
|
362
|
+
"""Shut down vLLM server and kill all vLLM processes."""
|
|
363
|
+
self._kill_vllm_process()
|
|
364
|
+
|
|
365
|
+
def chat_completion(self,
|
|
366
|
+
message: str | List[openai.types.chat.ChatCompletionMessageParam],
|
|
367
|
+
max_tokens: Optional[int] = None,
|
|
368
|
+
timeout_seconds: Optional[int] = None,
|
|
369
|
+
lora_name: Optional[str] = None,
|
|
370
|
+
temperature: float = 0.9,
|
|
371
|
+
top_p: float = 0.9,
|
|
372
|
+
chat_template_kwargs: Optional[Dict[str, Any]] = None) -> str:
|
|
373
|
+
"""Send a chat completion query with OpenAI format to the vLLM server. Return the response content.
|
|
374
|
+
Args:
|
|
375
|
+
message: The message in str or openai format.
|
|
376
|
+
max_tokens: The maximum number of tokens to generate.
|
|
377
|
+
timeout_seconds: The timeout seconds.
|
|
378
|
+
lora_name: Lora adapter name. Defaults to None which uses base model.
|
|
379
|
+
temperature: The temperature parameter.
|
|
380
|
+
top_p: The top p parameter.
|
|
381
|
+
chat_template_kwargs: The chat template kwargs, e.g., {'enable_thinking': False}.
|
|
382
|
+
"""
|
|
383
|
+
data = {
|
|
384
|
+
'messages': [
|
|
385
|
+
{'role': 'user', 'content': message.strip()} if isinstance(message, str) else message
|
|
386
|
+
],
|
|
387
|
+
'temperature': temperature,
|
|
388
|
+
'top_p': top_p,
|
|
389
|
+
'max_tokens': max_tokens,
|
|
390
|
+
}
|
|
391
|
+
# Use the specified lora adapter
|
|
392
|
+
if lora_name is not None:
|
|
393
|
+
data['model'] = lora_name
|
|
394
|
+
# Chat template keyword args
|
|
395
|
+
if self._chat_template_kwargs is not None:
|
|
396
|
+
data['chat_template_kwargs'] = self._chat_template_kwargs
|
|
397
|
+
elif chat_template_kwargs is not None:
|
|
398
|
+
data['chat_template_kwargs'] = chat_template_kwargs
|
|
399
|
+
# Request
|
|
400
|
+
url = f'http://{self._host}:{self._port}/v1/chat/completions'
|
|
401
|
+
headers = {'Content-Type': 'application/json'}
|
|
402
|
+
response = requests.post(url, headers=headers, json=data, timeout=timeout_seconds)
|
|
403
|
+
return response.json()['choices'][0]['message']['content']
|
|
@@ -10,24 +10,27 @@ import dataclasses
|
|
|
10
10
|
import textwrap
|
|
11
11
|
from typing import List, Optional
|
|
12
12
|
|
|
13
|
-
__all__ = ['
|
|
13
|
+
__all__ = ['PyCodeBlock', 'PyFunction', 'PyClass', 'PyProgram']
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
@dataclasses.dataclass
|
|
17
|
-
class
|
|
18
|
-
"""A parsed Python
|
|
17
|
+
class PyCodeBlock:
|
|
18
|
+
"""A parsed Python code block (e.g., top-level code that's not in classes/functions).
|
|
19
19
|
"""
|
|
20
20
|
code: str
|
|
21
21
|
|
|
22
22
|
def __str__(self) -> str:
|
|
23
|
-
return self.code
|
|
23
|
+
return self.code
|
|
24
|
+
|
|
25
|
+
def __repr__(self) -> str:
|
|
26
|
+
return self.__str__() + '\n'
|
|
24
27
|
|
|
25
28
|
|
|
26
29
|
@dataclasses.dataclass
|
|
27
30
|
class PyFunction:
|
|
28
31
|
"""A parsed Python function.
|
|
29
|
-
|
|
30
|
-
|
|
32
|
+
Part of this class is referenced from:
|
|
33
|
+
https://github.com/google-deepmind/funsearch/blob/main/implementation/code_manipulation.py
|
|
31
34
|
"""
|
|
32
35
|
decorator: str
|
|
33
36
|
name: str
|
|
@@ -38,16 +41,20 @@ class PyFunction:
|
|
|
38
41
|
|
|
39
42
|
def __str__(self) -> str:
|
|
40
43
|
return_type = f' -> {self.return_type}' if self.return_type else ''
|
|
41
|
-
function = f'{self.decorator}\
|
|
44
|
+
function = f'{self.decorator}\n' if self.decorator else ''
|
|
45
|
+
function += f'def {self.name}({self.args}){return_type}:\n'
|
|
42
46
|
if self.docstring:
|
|
43
47
|
# The self.docstring is already indented on every line except the first one.
|
|
44
|
-
# Here, we assume the indentation is always
|
|
48
|
+
# Here, we assume the indentation is always 4 spaces.
|
|
45
49
|
new_line = '\n' if self.body else ''
|
|
46
50
|
function += f' """{self.docstring}"""{new_line}'
|
|
47
51
|
# The self.body is already indented.
|
|
48
|
-
function += self.body
|
|
52
|
+
function += self.body
|
|
49
53
|
return function
|
|
50
54
|
|
|
55
|
+
def __repr__(self) -> str:
|
|
56
|
+
return self.__str__() + '\n\n'
|
|
57
|
+
|
|
51
58
|
def __setattr__(self, name: str, value: str) -> None:
|
|
52
59
|
# Ensure there aren't leading & trailing new lines in `body`
|
|
53
60
|
if name == 'body':
|
|
@@ -82,13 +89,14 @@ class PyClass:
|
|
|
82
89
|
decorator: str
|
|
83
90
|
name: str
|
|
84
91
|
bases: str
|
|
85
|
-
class_vars_and_code: List[
|
|
92
|
+
class_vars_and_code: List[PyCodeBlock] = None
|
|
86
93
|
docstring: str | None = None
|
|
87
94
|
functions: list[PyFunction] = dataclasses.field(default_factory=list)
|
|
88
|
-
functions_class_vars_and_code: List[
|
|
95
|
+
functions_class_vars_and_code: List[PyCodeBlock | PyFunction] | None = None
|
|
89
96
|
|
|
90
97
|
def __str__(self) -> str:
|
|
91
|
-
class_def = f'{self.decorator}\
|
|
98
|
+
class_def = f'{self.decorator}\n' if self.decorator else ''
|
|
99
|
+
class_def += f'class {self.name}'
|
|
92
100
|
if self.bases:
|
|
93
101
|
class_def += f'({self.bases})'
|
|
94
102
|
class_def += ':\n'
|
|
@@ -96,15 +104,21 @@ class PyClass:
|
|
|
96
104
|
if self.docstring:
|
|
97
105
|
class_def += f' """{self.docstring}"""\n'
|
|
98
106
|
|
|
99
|
-
for item in self.functions_class_vars_and_code:
|
|
100
|
-
if isinstance(item,
|
|
101
|
-
|
|
107
|
+
for i, item in enumerate(self.functions_class_vars_and_code):
|
|
108
|
+
if isinstance(item, PyCodeBlock):
|
|
109
|
+
# The PyCodeBlock has already indented
|
|
110
|
+
class_def += f'{str(item)}'
|
|
102
111
|
else:
|
|
103
112
|
# Add functions with an extra level of indentation
|
|
104
113
|
class_def += textwrap.indent(str(item).strip(), ' ')
|
|
114
|
+
# Add '\n\n' if this is not the last element
|
|
115
|
+
if i != len(self.functions_class_vars_and_code) - 1:
|
|
105
116
|
class_def += '\n\n'
|
|
106
117
|
return class_def
|
|
107
118
|
|
|
119
|
+
def __repr__(self):
|
|
120
|
+
return self.__str__() + '\n\n'
|
|
121
|
+
|
|
108
122
|
def __setattr__(self, name: str, value: str) -> None:
|
|
109
123
|
# Ensure there aren't leading & trailing new lines in `body`
|
|
110
124
|
if name == 'body':
|
|
@@ -137,19 +151,19 @@ class PyClass:
|
|
|
137
151
|
class PyProgram:
|
|
138
152
|
"""A parsed Python program."""
|
|
139
153
|
|
|
140
|
-
scripts: list[
|
|
154
|
+
scripts: list[PyCodeBlock] # Top-level code that's not in classes/functions
|
|
141
155
|
functions: list[PyFunction] # Top-level functions in the code
|
|
142
156
|
classes: list[PyClass] # Top-level classes in the code
|
|
143
|
-
classes_functions_scripts: list[PyFunction | PyClass |
|
|
157
|
+
classes_functions_scripts: list[PyFunction | PyClass | PyCodeBlock]
|
|
144
158
|
|
|
145
159
|
def __str__(self) -> str:
|
|
146
160
|
program = ''
|
|
147
161
|
for class_or_func_or_script in self.classes_functions_scripts:
|
|
148
|
-
program += str(class_or_func_or_script) + '\n'
|
|
162
|
+
program += str(class_or_func_or_script) + '\n\n'
|
|
149
163
|
return program
|
|
150
164
|
|
|
151
165
|
@classmethod
|
|
152
|
-
def from_text(cls, text: str) ->
|
|
166
|
+
def from_text(cls, text: str) -> 'PyProgram':
|
|
153
167
|
tree = ast.parse(text)
|
|
154
168
|
visitor = _ProgramVisitor(text)
|
|
155
169
|
visitor.visit(tree)
|
|
@@ -163,19 +177,31 @@ class _ProgramVisitor(ast.NodeVisitor):
|
|
|
163
177
|
|
|
164
178
|
def __init__(self, sourcecode: str):
|
|
165
179
|
self._codelines: list[str] = sourcecode.splitlines()
|
|
166
|
-
self._scripts: list[
|
|
180
|
+
self._scripts: list[PyCodeBlock] = []
|
|
167
181
|
self._functions: list[PyFunction] = []
|
|
168
182
|
self._classes: list[PyClass] = []
|
|
169
|
-
self._classes_functions_scripts: list[PyFunction | PyClass |
|
|
183
|
+
self._classes_functions_scripts: list[PyFunction | PyClass | PyCodeBlock] = []
|
|
170
184
|
self._last_script_end = 0
|
|
171
185
|
|
|
186
|
+
def _get_code(self, start_line: int, end_line: int, dedent=False):
|
|
187
|
+
"""Get code between start_line and end_line in 'self._codelines'.
|
|
188
|
+
"""
|
|
189
|
+
code = []
|
|
190
|
+
for line in self._codelines[start_line: end_line]:
|
|
191
|
+
if dedent:
|
|
192
|
+
code.append(line[4:])
|
|
193
|
+
else:
|
|
194
|
+
code.append(line)
|
|
195
|
+
return '\n'.join(code).rstrip()
|
|
196
|
+
|
|
172
197
|
def _add_script(self, start_line: int, end_line: int):
|
|
173
|
-
"""Add a script segment from the code.
|
|
198
|
+
"""Add a script segment from the code.
|
|
199
|
+
"""
|
|
174
200
|
if start_line >= end_line:
|
|
175
201
|
return
|
|
176
|
-
script_code =
|
|
202
|
+
script_code = self._get_code(start_line, end_line).strip()
|
|
177
203
|
if script_code:
|
|
178
|
-
script =
|
|
204
|
+
script = PyCodeBlock(code=script_code)
|
|
179
205
|
self._scripts.append(script)
|
|
180
206
|
self._classes_functions_scripts.append(script)
|
|
181
207
|
|
|
@@ -189,11 +215,11 @@ class _ProgramVisitor(ast.NodeVisitor):
|
|
|
189
215
|
if has_decorators:
|
|
190
216
|
# Find the minimum line number and retain the code above
|
|
191
217
|
decorator_start_line = min(decorator.lineno for decorator in node.decorator_list)
|
|
192
|
-
decorator =
|
|
218
|
+
decorator = self._get_code(decorator_start_line - 1, node.lineno - 1)
|
|
193
219
|
# Update script end line
|
|
194
220
|
script_end_line = decorator_start_line - 1
|
|
195
221
|
else:
|
|
196
|
-
decorator =
|
|
222
|
+
decorator = None
|
|
197
223
|
script_end_line = node.lineno - 1
|
|
198
224
|
|
|
199
225
|
# Add any script code before this function
|
|
@@ -204,6 +230,7 @@ class _ProgramVisitor(ast.NodeVisitor):
|
|
|
204
230
|
body_start_line = node.body[0].lineno - 1
|
|
205
231
|
docstring = None
|
|
206
232
|
|
|
233
|
+
# If the first node is ast.Expr, we regard it as a docstring
|
|
207
234
|
if isinstance(node.body[0], ast.Expr) and isinstance(node.body[0].value, ast.Constant):
|
|
208
235
|
docstring = ast.literal_eval(ast.unparse(node.body[0])).strip()
|
|
209
236
|
if len(node.body) > 1:
|
|
@@ -211,13 +238,14 @@ class _ProgramVisitor(ast.NodeVisitor):
|
|
|
211
238
|
else:
|
|
212
239
|
body_start_line = function_end_line
|
|
213
240
|
|
|
241
|
+
# Return a PyFunction instance
|
|
214
242
|
func = PyFunction(
|
|
215
243
|
decorator=decorator,
|
|
216
244
|
name=node.name,
|
|
217
245
|
args=ast.unparse(node.args),
|
|
218
246
|
return_type=ast.unparse(node.returns) if node.returns else None,
|
|
219
247
|
docstring=docstring,
|
|
220
|
-
body=
|
|
248
|
+
body=self._get_code(body_start_line, function_end_line)
|
|
221
249
|
)
|
|
222
250
|
|
|
223
251
|
self._functions.append(func)
|
|
@@ -234,11 +262,11 @@ class _ProgramVisitor(ast.NodeVisitor):
|
|
|
234
262
|
if has_decorators:
|
|
235
263
|
# Find the minimum line number and retain the code above
|
|
236
264
|
decorator_start_line = min(decorator.lineno for decorator in node.decorator_list)
|
|
237
|
-
class_decorator =
|
|
265
|
+
class_decorator = self._get_code(decorator_start_line - 1, node.lineno - 1)
|
|
238
266
|
# Update script end line
|
|
239
267
|
script_end_line = decorator_start_line - 1
|
|
240
268
|
else:
|
|
241
|
-
class_decorator =
|
|
269
|
+
class_decorator = None
|
|
242
270
|
script_end_line = node.lineno - 1
|
|
243
271
|
|
|
244
272
|
# Add any script code before this class
|
|
@@ -250,9 +278,12 @@ class _ProgramVisitor(ast.NodeVisitor):
|
|
|
250
278
|
if isinstance(node.body[0], ast.Expr) and isinstance(node.body[0].value, ast.Constant):
|
|
251
279
|
docstring = ast.literal_eval(ast.unparse(node.body[0]))
|
|
252
280
|
|
|
281
|
+
# Record methods
|
|
253
282
|
methods = []
|
|
283
|
+
# Record class variables or code that are not methods
|
|
254
284
|
class_vars_and_code = []
|
|
255
|
-
|
|
285
|
+
# Record the order of function and class vars and code
|
|
286
|
+
function_class_vars_and_code = []
|
|
256
287
|
|
|
257
288
|
# Traverse each body, if there is a docstring, skip body[0]
|
|
258
289
|
for item in node.body if docstring is None else node.body[1:]:
|
|
@@ -263,13 +294,9 @@ class _ProgramVisitor(ast.NodeVisitor):
|
|
|
263
294
|
# Find the minimum line number and retain the code above
|
|
264
295
|
decorator_start_line = min(decorator.lineno for decorator in item.decorator_list)
|
|
265
296
|
# Dedent decorator code
|
|
266
|
-
decorator =
|
|
267
|
-
for line in range(decorator_start_line - 1, item.lineno - 1):
|
|
268
|
-
dedented_decorator = self._codelines[line].strip()
|
|
269
|
-
decorator.append(dedented_decorator)
|
|
270
|
-
decorator = '\n'.join(decorator)
|
|
297
|
+
decorator = self._get_code(decorator_start_line - 1, item.lineno - 1, dedent=True)
|
|
271
298
|
else:
|
|
272
|
-
decorator =
|
|
299
|
+
decorator = None
|
|
273
300
|
|
|
274
301
|
method_end_line = item.end_lineno
|
|
275
302
|
method_body_start_line = item.body[0].lineno - 1
|
|
@@ -284,11 +311,7 @@ class _ProgramVisitor(ast.NodeVisitor):
|
|
|
284
311
|
method_body_start_line = method_end_line
|
|
285
312
|
|
|
286
313
|
# Extract function body and dedent for 4 spaces
|
|
287
|
-
body =
|
|
288
|
-
for line in range(method_body_start_line, method_end_line):
|
|
289
|
-
dedented_body = self._codelines[line][4:]
|
|
290
|
-
body.append(dedented_body)
|
|
291
|
-
body = '\n'.join(body)
|
|
314
|
+
body = self._get_code(method_body_start_line, method_end_line, dedent=True)
|
|
292
315
|
|
|
293
316
|
py_func = PyFunction(
|
|
294
317
|
decorator=decorator,
|
|
@@ -300,17 +323,16 @@ class _ProgramVisitor(ast.NodeVisitor):
|
|
|
300
323
|
)
|
|
301
324
|
methods.append(py_func)
|
|
302
325
|
function_class_vars_and_code.append(py_func)
|
|
303
|
-
else: # If the item is not a function definition,add to class variables and code
|
|
304
|
-
code =
|
|
305
|
-
|
|
306
|
-
code.append(self._codelines[i])
|
|
307
|
-
py_script = PyScript(code='\n'.join(code))
|
|
326
|
+
else: # If the item is not a function definition, add to class variables and code
|
|
327
|
+
code = self._get_code(item.lineno - 1, item.end_lineno)
|
|
328
|
+
py_script = PyCodeBlock(code=code)
|
|
308
329
|
class_vars_and_code.append(py_script)
|
|
309
330
|
function_class_vars_and_code.append(py_script)
|
|
310
331
|
|
|
311
332
|
# Get base classes
|
|
312
|
-
bases = ', '.join([ast.unparse(base) for base in node.bases]) if node.bases else
|
|
333
|
+
bases = ', '.join([ast.unparse(base) for base in node.bases]) if node.bases else None
|
|
313
334
|
|
|
335
|
+
# Return a PyClass instance
|
|
314
336
|
class_ = PyClass(
|
|
315
337
|
decorator=class_decorator,
|
|
316
338
|
name=node.name,
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: py-adtools
|
|
3
|
-
Version: 0.1.
|
|
4
|
-
Summary: Useful tools for parsing and evaluating Python programs for algorithm design.
|
|
3
|
+
Version: 0.1.4
|
|
4
|
+
Summary: Useful tools for parsing and evaluating Python programs for LLM-based algorithm design.
|
|
5
5
|
Home-page: https://github.com/RayZhhh/py-adtools
|
|
6
6
|
Author: Rui Zhang
|
|
7
7
|
Author-email: rzhang.cs@gmail.com
|
|
@@ -34,7 +34,7 @@ Dynamic: summary
|
|
|
34
34
|
|
|
35
35
|
------
|
|
36
36
|
|
|
37
|
-
The figure demonstrates how a Python program is parsed into `
|
|
37
|
+
The figure demonstrates how a Python program is parsed into `PyCodeBlock`, `PyFunction`, `PyClass,` and `PyProgram` via `adtools`.
|
|
38
38
|
|
|
39
39
|

|
|
40
40
|
|
|
@@ -68,7 +68,7 @@ Parse your code (in string) into Python code instances, so that you can check ea
|
|
|
68
68
|
from adtools import PyProgram
|
|
69
69
|
|
|
70
70
|
code = r'''
|
|
71
|
-
import ast, numba # This part will be parsed into
|
|
71
|
+
import ast, numba # This part will be parsed into PyCodeBlock
|
|
72
72
|
import numpy as np
|
|
73
73
|
|
|
74
74
|
@numba.jit() # This part will be parsed into PyFunction
|
|
@@ -80,8 +80,9 @@ def function(arg1, arg2=True):
|
|
|
80
80
|
|
|
81
81
|
@some.decorators() # This part will be parsed into PyClass
|
|
82
82
|
class PythonClass(BaseClass):
|
|
83
|
-
|
|
84
|
-
|
|
83
|
+
|
|
84
|
+
class_var1 = 1 # This part will be parsed into PyCodeBlock
|
|
85
|
+
class_var2 = 2 # and placed in PyClass.class_vars_and_code
|
|
85
86
|
|
|
86
87
|
def __init__(self, x): # This part will be parsed into PyFunction
|
|
87
88
|
self.x = x # and placed in PyClass.functions
|
|
@@ -93,11 +94,11 @@ class PythonClass(BaseClass):
|
|
|
93
94
|
def method2(self, x, y):
|
|
94
95
|
return x + y + self.method1(x)
|
|
95
96
|
|
|
96
|
-
class InnerClass: # This part will be parsed into
|
|
97
|
+
class InnerClass: # This part will be parsed into PyCodeBlock
|
|
97
98
|
def __init__(self): # and placed in PyClass.class_vars_and_code
|
|
98
99
|
...
|
|
99
100
|
|
|
100
|
-
if __name__ == '__main__': # This part will be parsed into
|
|
101
|
+
if __name__ == '__main__': # This part will be parsed into PyCodeBlock
|
|
101
102
|
res = function(1)
|
|
102
103
|
print(res)
|
|
103
104
|
res = PythonClass().method2(1, 2)
|
|
@@ -116,7 +117,7 @@ print(p.functions[0].name)
|
|
|
116
117
|
Evaluate Python programs in a secure process to avoid the abortation of the main process. Two steps:
|
|
117
118
|
|
|
118
119
|
- Extend the `PyEvaluator` class and override the `evaluate_program` method.
|
|
119
|
-
- Evaluate the program (in str) by calling the `evaluate` (directly evaluate without executing in a sandbox process) or the `secure_evaluate` (evaluate in a sandbox process) methods.
|
|
120
|
+
- Evaluate the program (in str) by calling the `evaluate` (directly evaluate without executing in a sandbox process) or the `secure_evaluate` (evaluate in a sandbox process) methods.
|
|
120
121
|
|
|
121
122
|
```python
|
|
122
123
|
import time
|
|
@@ -190,6 +191,7 @@ def merge(left, right):
|
|
|
190
191
|
|
|
191
192
|
harmful_code_generated_by_llm = '''
|
|
192
193
|
def merge_sort(arr):
|
|
194
|
+
print('I am harmful') # There will be no output since we redirect STDOUT to /dev/null by default.
|
|
193
195
|
while True:
|
|
194
196
|
pass
|
|
195
197
|
'''
|
|
@@ -5,10 +5,10 @@ with open('README.md', 'r', encoding='utf-8') as fh:
|
|
|
5
5
|
|
|
6
6
|
setup(
|
|
7
7
|
name='py-adtools',
|
|
8
|
-
version='0.1.
|
|
8
|
+
version='0.1.4',
|
|
9
9
|
author='Rui Zhang',
|
|
10
10
|
author_email='rzhang.cs@gmail.com',
|
|
11
|
-
description='Useful tools for parsing and evaluating Python programs for algorithm design.',
|
|
11
|
+
description='Useful tools for parsing and evaluating Python programs for LLM-based algorithm design.',
|
|
12
12
|
long_description=long_description,
|
|
13
13
|
long_description_content_type='text/markdown',
|
|
14
14
|
url='https://github.com/RayZhhh/py-adtools',
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|