klongpy 0.6.9__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- klongpy/__init__.py +17 -1
- klongpy/adverbs.py +84 -82
- klongpy/autograd.py +299 -0
- klongpy/backend.py +38 -103
- klongpy/backends/__init__.py +26 -0
- klongpy/backends/base.py +469 -0
- klongpy/backends/numpy_backend.py +123 -0
- klongpy/backends/registry.py +76 -0
- klongpy/backends/torch_backend.py +1047 -0
- klongpy-0.6.9.data/scripts/kgpy → klongpy/cli.py +110 -90
- klongpy/core.py +113 -974
- klongpy/db/sys_fn_db.py +7 -6
- klongpy/db/sys_fn_kvs.py +2 -4
- klongpy/dyads.py +332 -160
- klongpy/interpreter.py +60 -15
- klongpy/monads.py +121 -75
- klongpy/parser.py +328 -0
- klongpy/repl.py +23 -5
- klongpy/sys_fn.py +170 -21
- klongpy/sys_fn_autograd.py +290 -0
- klongpy/sys_fn_ipc.py +22 -15
- klongpy/sys_fn_timer.py +13 -3
- klongpy/types.py +503 -0
- klongpy/web/sys_fn_web.py +14 -4
- klongpy/writer.py +122 -0
- klongpy/ws/sys_fn_ws.py +5 -8
- klongpy-0.7.1.dist-info/METADATA +544 -0
- klongpy-0.7.1.dist-info/RECORD +52 -0
- {klongpy-0.6.9.dist-info → klongpy-0.7.1.dist-info}/WHEEL +1 -1
- klongpy-0.7.1.dist-info/entry_points.txt +2 -0
- {klongpy-0.6.9.dist-info → klongpy-0.7.1.dist-info}/top_level.txt +0 -1
- klongpy-0.6.9.dist-info/METADATA +0 -448
- klongpy-0.6.9.dist-info/RECORD +0 -77
- tests/__init__.py +0 -6
- tests/gen_join_over.py +0 -119
- tests/gen_py_suite.py +0 -77
- tests/gen_test_fn.py +0 -259
- tests/perf_async.py +0 -25
- tests/perf_avg.py +0 -18
- tests/perf_duckdb.py +0 -32
- tests/perf_gen.py +0 -38
- tests/perf_ipc_overhead.py +0 -34
- tests/perf_join.py +0 -53
- tests/perf_load.py +0 -17
- tests/perf_prog.py +0 -18
- tests/perf_serdes.py +0 -52
- tests/perf_sys_fn_db.py +0 -263
- tests/perf_vector.py +0 -40
- tests/test_accel.py +0 -227
- tests/test_df_cache.py +0 -85
- tests/test_eval_monad_list.py +0 -34
- tests/test_examples.py +0 -64
- tests/test_extra_suite.py +0 -382
- tests/test_file_cache.py +0 -185
- tests/test_interop.py +0 -180
- tests/test_kg_asarray.py +0 -94
- tests/test_kgtests.py +0 -65
- tests/test_known_bugs.py +0 -206
- tests/test_prog.py +0 -107
- tests/test_reshape_strings.py +0 -33
- tests/test_suite.py +0 -1480
- tests/test_suite_file.py +0 -153
- tests/test_sys_fn.py +0 -420
- tests/test_sys_fn_db.py +0 -88
- tests/test_sys_fn_ipc.py +0 -587
- tests/test_sys_fn_timer.py +0 -133
- tests/test_sys_fn_web.py +0 -50
- tests/test_util.py +0 -233
- tests/utils.py +0 -126
- {klongpy-0.6.9.dist-info → klongpy-0.7.1.dist-info}/licenses/LICENSE +0 -0
klongpy/__init__.py
CHANGED
|
@@ -1,2 +1,18 @@
|
|
|
1
1
|
from .interpreter import KlongInterpreter, KlongException
|
|
2
|
-
|
|
2
|
+
from .backends import (
|
|
3
|
+
get_backend,
|
|
4
|
+
register_backend,
|
|
5
|
+
list_backends,
|
|
6
|
+
BackendProvider,
|
|
7
|
+
UnsupportedDtypeError,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"KlongInterpreter",
|
|
12
|
+
"KlongException",
|
|
13
|
+
"UnsupportedDtypeError",
|
|
14
|
+
"get_backend",
|
|
15
|
+
"register_backend",
|
|
16
|
+
"list_backends",
|
|
17
|
+
"BackendProvider",
|
|
18
|
+
]
|
klongpy/adverbs.py
CHANGED
|
@@ -1,34 +1,9 @@
|
|
|
1
1
|
from .core import *
|
|
2
|
-
from .dyads import eval_dyad_add, eval_dyad_subtract, eval_dyad_multiply, eval_dyad_divide
|
|
3
2
|
import functools
|
|
4
3
|
import itertools
|
|
5
4
|
|
|
6
5
|
|
|
7
|
-
def
|
|
8
|
-
if s == "'":
|
|
9
|
-
return eval_adverb_each2 if arity == 2 else eval_adverb_each
|
|
10
|
-
elif s == '/':
|
|
11
|
-
return eval_adverb_over_neutral if arity == 2 else eval_adverb_over
|
|
12
|
-
elif s == '\\':
|
|
13
|
-
return eval_adverb_scan_over_neutral if arity == 2 else eval_adverb_scan_over
|
|
14
|
-
elif s == '\\~':
|
|
15
|
-
return (lambda f,a,b,k=klong: eval_adverb_scan_while(k,f,a,b)) if arity == 2 else eval_adverb_scan_converging
|
|
16
|
-
elif s == '\\*':
|
|
17
|
-
return eval_adverb_scan_iterating
|
|
18
|
-
elif s == ':\\':
|
|
19
|
-
return eval_adverb_each_left
|
|
20
|
-
elif s == ':\'':
|
|
21
|
-
return eval_adverb_each_pair
|
|
22
|
-
elif s == ':/':
|
|
23
|
-
return eval_adverb_each_right
|
|
24
|
-
elif s == ':*':
|
|
25
|
-
return eval_dyad_adverb_iterate
|
|
26
|
-
elif s == ':~':
|
|
27
|
-
return (lambda f,a,b,k=klong: eval_adverb_while(k,f,a,b)) if arity == 2 else eval_adverb_converge
|
|
28
|
-
raise RuntimeError(f"unknown adverb: {s}")
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
def eval_adverb_converge(f, a, op):
|
|
6
|
+
def eval_adverb_converge(f, a, op, backend):
|
|
32
7
|
"""
|
|
33
8
|
f:~a [Converge]
|
|
34
9
|
|
|
@@ -54,10 +29,10 @@ def eval_adverb_converge(f, a, op):
|
|
|
54
29
|
def _e(p,q):
|
|
55
30
|
if not isinstance(p, type(q)):
|
|
56
31
|
return False
|
|
57
|
-
if is_number(p):
|
|
58
|
-
return np.isclose(p,q)
|
|
59
|
-
elif
|
|
60
|
-
return kg_equal(p,q)
|
|
32
|
+
if backend.is_number(p):
|
|
33
|
+
return backend.np.isclose(p,q)
|
|
34
|
+
elif backend.is_array(p):
|
|
35
|
+
return backend.kg_equal(p, q)
|
|
61
36
|
return p == q
|
|
62
37
|
x = f(a)
|
|
63
38
|
xx = f(x)
|
|
@@ -67,7 +42,7 @@ def eval_adverb_converge(f, a, op):
|
|
|
67
42
|
return x
|
|
68
43
|
|
|
69
44
|
|
|
70
|
-
def eval_adverb_each(f, a, op):
|
|
45
|
+
def eval_adverb_each(f, a, op, backend):
|
|
71
46
|
"""
|
|
72
47
|
|
|
73
48
|
f'a [Each]
|
|
@@ -92,17 +67,17 @@ def eval_adverb_each(f, a, op):
|
|
|
92
67
|
return a
|
|
93
68
|
has_str = False
|
|
94
69
|
r = []
|
|
95
|
-
for x in str_to_chr_arr(a):
|
|
70
|
+
for x in backend.str_to_chr_arr(a):
|
|
96
71
|
u = f(x)
|
|
97
72
|
has_str |= isinstance(u,str)
|
|
98
73
|
r.append(u)
|
|
99
|
-
return ''.join(r) if has_str else kg_asarray(r)
|
|
74
|
+
return ''.join(r) if has_str else backend.kg_asarray(r)
|
|
100
75
|
if is_iterable(a):
|
|
101
76
|
r = [f(x) for x in a]
|
|
102
|
-
return a if is_empty(a) else kg_asarray(r)
|
|
77
|
+
return a if is_empty(a) else backend.kg_asarray(r)
|
|
103
78
|
elif is_dict(a):
|
|
104
|
-
r = [f(kg_asarray(x)) for x in a.items()]
|
|
105
|
-
return kg_asarray(r)
|
|
79
|
+
r = [f(backend.kg_asarray(x)) for x in a.items()]
|
|
80
|
+
return backend.kg_asarray(r)
|
|
106
81
|
return f(a)
|
|
107
82
|
|
|
108
83
|
|
|
@@ -124,14 +99,14 @@ def eval_adverb_each2(f, a, b):
|
|
|
124
99
|
|
|
125
100
|
"""
|
|
126
101
|
if is_empty(a) or is_empty(b):
|
|
127
|
-
return
|
|
102
|
+
return bknp.asarray([]) if is_list(a) or is_list(b) else ""
|
|
128
103
|
if is_atom(a) and is_atom(b):
|
|
129
104
|
return f(a,b)
|
|
130
|
-
r =
|
|
105
|
+
r = bknp.asarray([f(x,y) for x,y in zip(a,b)])
|
|
131
106
|
return ''.join(r) if r.dtype == '<U1' else r
|
|
132
107
|
|
|
133
108
|
|
|
134
|
-
def eval_adverb_each_left(f, a, b):
|
|
109
|
+
def eval_adverb_each_left(f, a, b, backend):
|
|
135
110
|
"""
|
|
136
111
|
a f:\b [Each-Left]
|
|
137
112
|
a f:/b [Each-Right]
|
|
@@ -153,20 +128,19 @@ def eval_adverb_each_left(f, a, b):
|
|
|
153
128
|
Examples: 1,:\[2 3 4] --> [[1 2] [1 3] [1 4]]
|
|
154
129
|
1,:/[2 3 4] --> [[2 1] [3 1] [4 1]]
|
|
155
130
|
"""
|
|
156
|
-
b = str_to_chr_arr(b) if isinstance(b,str) else b
|
|
157
|
-
return
|
|
131
|
+
b = backend.str_to_chr_arr(b) if isinstance(b,str) else b
|
|
132
|
+
return backend.kg_asarray([f(a,x) for x in b])
|
|
158
133
|
|
|
159
134
|
|
|
160
|
-
def eval_adverb_each_right(f, a, b):
|
|
135
|
+
def eval_adverb_each_right(f, a, b, backend):
|
|
161
136
|
"""
|
|
162
137
|
see: eval_dyad_adverb_each_left
|
|
163
138
|
"""
|
|
164
|
-
b = str_to_chr_arr(b) if isinstance(b,str) else b
|
|
165
|
-
return
|
|
139
|
+
b = backend.str_to_chr_arr(b) if isinstance(b,str) else b
|
|
140
|
+
return backend.kg_asarray([f(x,a) for x in b])
|
|
166
141
|
|
|
167
142
|
|
|
168
|
-
|
|
169
|
-
def eval_adverb_each_pair(f, a, op):
|
|
143
|
+
def eval_adverb_each_pair(f, a, op, backend):
|
|
170
144
|
"""
|
|
171
145
|
|
|
172
146
|
f:'a [Each-Pair]
|
|
@@ -185,8 +159,8 @@ def eval_adverb_each_pair(f, a, op):
|
|
|
185
159
|
if is_atom(a) or (is_iterable(a) and len(a) == 1):
|
|
186
160
|
return a
|
|
187
161
|
j = isinstance(a, str)
|
|
188
|
-
a = str_to_chr_arr(a) if j else a
|
|
189
|
-
return kg_asarray([f(x,y) for x,y in zip(a[::],a[1::])])
|
|
162
|
+
a = backend.str_to_chr_arr(a) if j else a
|
|
163
|
+
return backend.kg_asarray([f(x,y) for x,y in zip(a[::],a[1::])])
|
|
190
164
|
|
|
191
165
|
|
|
192
166
|
def eval_dyad_adverb_iterate(f, a, b):
|
|
@@ -208,7 +182,7 @@ def eval_dyad_adverb_iterate(f, a, b):
|
|
|
208
182
|
return b
|
|
209
183
|
|
|
210
184
|
|
|
211
|
-
def eval_adverb_over(f, a, op):
|
|
185
|
+
def eval_adverb_over(f, a, op, backend):
|
|
212
186
|
"""
|
|
213
187
|
f/a [Over]
|
|
214
188
|
|
|
@@ -227,23 +201,23 @@ def eval_adverb_over(f, a, op):
|
|
|
227
201
|
return a
|
|
228
202
|
if len(a) == 1:
|
|
229
203
|
return a[0]
|
|
230
|
-
#
|
|
231
|
-
|
|
204
|
+
# Use backend's ufunc reduce when available for better performance
|
|
205
|
+
np_backend = backend.np
|
|
232
206
|
if isinstance(op, KGOp):
|
|
233
207
|
if safe_eq(op.a,'+'):
|
|
234
|
-
return
|
|
208
|
+
return np_backend.add.reduce(a)
|
|
235
209
|
elif safe_eq(op.a, '-'):
|
|
236
|
-
return
|
|
237
|
-
elif safe_eq(op.a, '*') and hasattr(
|
|
238
|
-
return
|
|
239
|
-
elif safe_eq(op.a, '%') and hasattr(
|
|
240
|
-
return
|
|
210
|
+
return np_backend.subtract.reduce(a)
|
|
211
|
+
elif safe_eq(op.a, '*') and hasattr(np_backend.multiply,'reduce'):
|
|
212
|
+
return np_backend.multiply.reduce(a)
|
|
213
|
+
elif safe_eq(op.a, '%') and hasattr(np_backend.divide,'reduce'):
|
|
214
|
+
return np_backend.divide.reduce(a)
|
|
241
215
|
elif safe_eq(op.a, '&') and a.ndim == 1:
|
|
242
|
-
return
|
|
216
|
+
return np_backend.min(a)
|
|
243
217
|
elif safe_eq(op.a, '|') and a.ndim == 1:
|
|
244
|
-
return
|
|
245
|
-
elif safe_eq(op.a, ',') and
|
|
246
|
-
return a if a.ndim == 1 else
|
|
218
|
+
return np_backend.max(a)
|
|
219
|
+
elif safe_eq(op.a, ',') and np_backend.isarray(a) and a.dtype != 'O':
|
|
220
|
+
return a if a.ndim == 1 else np_backend.concatenate(a, axis=0)
|
|
247
221
|
return functools.reduce(f, a)
|
|
248
222
|
|
|
249
223
|
|
|
@@ -280,7 +254,7 @@ def eval_adverb_over_neutral(f, a, b):
|
|
|
280
254
|
return functools.reduce(f,b[1:],f(a,b[0]))
|
|
281
255
|
|
|
282
256
|
|
|
283
|
-
def eval_adverb_scan_over_neutral(f, a, b):
|
|
257
|
+
def eval_adverb_scan_over_neutral(f, a, b, backend):
|
|
284
258
|
"""
|
|
285
259
|
|
|
286
260
|
f\a [Scan-Over]
|
|
@@ -309,31 +283,33 @@ def eval_adverb_scan_over_neutral(f, a, b):
|
|
|
309
283
|
b = [b]
|
|
310
284
|
b = [f(a,b[0]), *b[1:]]
|
|
311
285
|
r = list(itertools.accumulate(b,f))
|
|
312
|
-
q = kg_asarray(r)
|
|
286
|
+
q = backend.kg_asarray(r)
|
|
313
287
|
r = [a, *q]
|
|
314
|
-
return kg_asarray(r)
|
|
288
|
+
return backend.kg_asarray(r)
|
|
315
289
|
|
|
316
290
|
|
|
317
|
-
def eval_adverb_scan_over(f, a, op):
|
|
291
|
+
def eval_adverb_scan_over(f, a, op, backend):
|
|
318
292
|
"""
|
|
319
293
|
see eval_adverb_scan_over_neutral
|
|
320
294
|
"""
|
|
321
295
|
if is_atom(a):
|
|
322
296
|
return a
|
|
323
|
-
#
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
297
|
+
# Use backend's ufunc accumulate when available for better performance
|
|
298
|
+
np_backend = backend.np
|
|
299
|
+
if isinstance(op, KGOp):
|
|
300
|
+
if safe_eq(op.a, '+') and hasattr(np_backend.add, 'accumulate'):
|
|
301
|
+
return np_backend.add.accumulate(a)
|
|
302
|
+
elif safe_eq(op.a, '-') and hasattr(np_backend.subtract, 'accumulate'):
|
|
303
|
+
return np_backend.subtract.accumulate(a)
|
|
304
|
+
elif safe_eq(op.a, '*') and hasattr(np_backend.multiply, 'accumulate'):
|
|
305
|
+
return np_backend.multiply.accumulate(a)
|
|
306
|
+
elif safe_eq(op.a, '%') and hasattr(np_backend.divide, 'accumulate'):
|
|
307
|
+
return np_backend.divide.accumulate(a)
|
|
332
308
|
r = list(itertools.accumulate(a, f))
|
|
333
|
-
return kg_asarray(r)
|
|
309
|
+
return backend.kg_asarray(r)
|
|
334
310
|
|
|
335
311
|
|
|
336
|
-
def eval_adverb_scan_converging(f, a, op):
|
|
312
|
+
def eval_adverb_scan_converging(f, a, op, backend):
|
|
337
313
|
"""
|
|
338
314
|
|
|
339
315
|
f\~a [Scan-Converging]
|
|
@@ -356,15 +332,15 @@ def eval_adverb_scan_converging(f, a, op):
|
|
|
356
332
|
x = a
|
|
357
333
|
xx = f(a)
|
|
358
334
|
r = [a, xx]
|
|
359
|
-
while not kg_equal(x,xx):
|
|
335
|
+
while not backend.kg_equal(x, xx):
|
|
360
336
|
x = xx
|
|
361
337
|
xx = f(x)
|
|
362
338
|
r.append(xx)
|
|
363
339
|
r.pop()
|
|
364
|
-
return kg_asarray(r)
|
|
340
|
+
return backend.kg_asarray(r)
|
|
365
341
|
|
|
366
342
|
|
|
367
|
-
def eval_adverb_scan_while(klong, f, a, b):
|
|
343
|
+
def eval_adverb_scan_while(klong, f, a, b, backend):
|
|
368
344
|
"""
|
|
369
345
|
|
|
370
346
|
a f\~b [Scan-While]
|
|
@@ -389,10 +365,10 @@ def eval_adverb_scan_while(klong, f, a, b):
|
|
|
389
365
|
b = f(b)
|
|
390
366
|
r.append(b)
|
|
391
367
|
r.pop()
|
|
392
|
-
return kg_asarray(r)
|
|
368
|
+
return backend.kg_asarray(r)
|
|
393
369
|
|
|
394
370
|
|
|
395
|
-
def eval_adverb_scan_iterating(f, a, b):
|
|
371
|
+
def eval_adverb_scan_iterating(f, a, b, backend):
|
|
396
372
|
"""
|
|
397
373
|
|
|
398
374
|
a f\*b [Scan-Iterating]
|
|
@@ -410,7 +386,7 @@ def eval_adverb_scan_iterating(f, a, b):
|
|
|
410
386
|
b = f(b)
|
|
411
387
|
r.append(b)
|
|
412
388
|
a = a - 1
|
|
413
|
-
return kg_asarray(r)
|
|
389
|
+
return backend.kg_asarray(r)
|
|
414
390
|
|
|
415
391
|
|
|
416
392
|
def eval_adverb_while(klong, f, a, b):
|
|
@@ -429,3 +405,29 @@ def eval_adverb_while(klong, f, a, b):
|
|
|
429
405
|
while klong.eval(KGCall(a, b, arity=1)):
|
|
430
406
|
b = f(b)
|
|
431
407
|
return b
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
def get_adverb_fn(klong, s, arity):
|
|
411
|
+
backend = klong._backend
|
|
412
|
+
|
|
413
|
+
if s == "'":
|
|
414
|
+
return eval_adverb_each2 if arity == 2 else lambda f,a,op: eval_adverb_each(f,a,op,backend)
|
|
415
|
+
elif s == '/':
|
|
416
|
+
return eval_adverb_over_neutral if arity == 2 else lambda f,a,op: eval_adverb_over(f,a,op,backend)
|
|
417
|
+
elif s == '\\':
|
|
418
|
+
return (lambda f,a,b: eval_adverb_scan_over_neutral(f,a,b,backend)) if arity == 2 else lambda f,a,op: eval_adverb_scan_over(f,a,op,backend)
|
|
419
|
+
elif s == '\\~':
|
|
420
|
+
return (lambda f,a,b: eval_adverb_scan_while(klong,f,a,b,backend)) if arity == 2 else lambda f,a,op: eval_adverb_scan_converging(f,a,op,backend)
|
|
421
|
+
elif s == '\\*':
|
|
422
|
+
return lambda f,a,b: eval_adverb_scan_iterating(f,a,b,backend)
|
|
423
|
+
elif s == ':\\':
|
|
424
|
+
return lambda f,a,b: eval_adverb_each_left(f,a,b,backend)
|
|
425
|
+
elif s == ':\'':
|
|
426
|
+
return lambda f,a,op: eval_adverb_each_pair(f,a,op,backend)
|
|
427
|
+
elif s == ':/':
|
|
428
|
+
return lambda f,a,b: eval_adverb_each_right(f,a,b,backend)
|
|
429
|
+
elif s == ':*':
|
|
430
|
+
return eval_dyad_adverb_iterate
|
|
431
|
+
elif s == ':~':
|
|
432
|
+
return (lambda f,a,b: eval_adverb_while(klong,f,a,b)) if arity == 2 else lambda f,a,op: eval_adverb_converge(f,a,op,backend)
|
|
433
|
+
raise RuntimeError(f"unknown adverb: {s}")
|
klongpy/autograd.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from .core import KGLambda, KGCall, KGSym, KGFn
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class AutogradError(Exception):
|
|
6
|
+
"""Base class for autograd-related errors."""
|
|
7
|
+
pass
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class AutogradChainBrokenError(AutogradError):
|
|
11
|
+
"""Raised when the gradient computation chain is broken."""
|
|
12
|
+
|
|
13
|
+
def __init__(self, context, expected, actual, suggestion=None):
|
|
14
|
+
self.context = context
|
|
15
|
+
self.expected = expected
|
|
16
|
+
self.actual = actual
|
|
17
|
+
self.suggestion = suggestion
|
|
18
|
+
msg = f"Autograd chain broken at {context}: expected {expected}, got {actual}."
|
|
19
|
+
if suggestion:
|
|
20
|
+
msg += f" {suggestion}"
|
|
21
|
+
super().__init__(msg)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class NonScalarLossError(AutogradError):
|
|
25
|
+
"""Raised when the loss function returns a non-scalar value."""
|
|
26
|
+
|
|
27
|
+
def __init__(self, shape):
|
|
28
|
+
self.shape = shape
|
|
29
|
+
super().__init__(
|
|
30
|
+
f"Loss function must return a scalar, got shape {shape}. "
|
|
31
|
+
"Use sum (+/) or mean (%#) to reduce to a scalar."
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _get_float_dtype(backend):
|
|
36
|
+
"""Get the appropriate float dtype for the current backend."""
|
|
37
|
+
# MPS doesn't support float64
|
|
38
|
+
if hasattr(backend, 'supports_float64') and not backend.supports_float64():
|
|
39
|
+
return np.float32
|
|
40
|
+
return np.float64
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _scalar_value(x, backend):
|
|
44
|
+
"""Extract scalar value from various array/tensor types.
|
|
45
|
+
|
|
46
|
+
Raises:
|
|
47
|
+
NonScalarLossError: If x is not a scalar value.
|
|
48
|
+
"""
|
|
49
|
+
x = backend.to_numpy(x) if backend.is_backend_array(x) else x
|
|
50
|
+
if isinstance(x, np.ndarray):
|
|
51
|
+
if x.ndim == 0:
|
|
52
|
+
return float(x.item())
|
|
53
|
+
elif x.size == 1:
|
|
54
|
+
return float(x.flat[0])
|
|
55
|
+
else:
|
|
56
|
+
raise NonScalarLossError(tuple(x.shape))
|
|
57
|
+
return float(x)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _to_func_input(x, backend, require_grad=False):
|
|
61
|
+
"""Convert numpy array to appropriate input type for function call.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
x: Input array (numpy)
|
|
65
|
+
backend: Backend provider
|
|
66
|
+
require_grad: If True and backend supports autograd, create grad tensor.
|
|
67
|
+
For numeric gradient, this should be False.
|
|
68
|
+
"""
|
|
69
|
+
if require_grad and backend.supports_autograd():
|
|
70
|
+
return backend.create_grad_tensor(x)
|
|
71
|
+
return x
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _invoke_fn(klong, fn, args):
|
|
75
|
+
"""Invoke a Klong function with the given arguments.
|
|
76
|
+
|
|
77
|
+
Handles all function types uniformly:
|
|
78
|
+
- KGSym, KGLambda: wrap in KGCall with args
|
|
79
|
+
- KGFn, KGCall: extract inner function, wrap in KGCall with args
|
|
80
|
+
- callable: call directly with args
|
|
81
|
+
"""
|
|
82
|
+
if callable(fn) and not isinstance(fn, (KGSym, KGLambda, KGFn)):
|
|
83
|
+
return fn(*args)
|
|
84
|
+
inner = fn.a if isinstance(fn, KGFn) else fn
|
|
85
|
+
return klong.call(KGCall(inner, list(args), len(args)))
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def numeric_grad(func, x, backend, eps=None):
|
|
89
|
+
"""Compute numeric gradient of scalar-valued function."""
|
|
90
|
+
# Get appropriate float dtype
|
|
91
|
+
float_dtype = _get_float_dtype(backend)
|
|
92
|
+
|
|
93
|
+
# Use larger epsilon for float32 to maintain precision
|
|
94
|
+
if eps is None:
|
|
95
|
+
eps = 1e-4 if float_dtype == np.float32 else 1e-6
|
|
96
|
+
|
|
97
|
+
# Convert backend tensors to numpy for gradient computation
|
|
98
|
+
if backend.is_backend_array(x):
|
|
99
|
+
x = backend.to_numpy(x)
|
|
100
|
+
x = np.asarray(x, dtype=float_dtype)
|
|
101
|
+
|
|
102
|
+
grad = np.zeros_like(x, dtype=float_dtype)
|
|
103
|
+
it = np.nditer(x, flags=['multi_index'], op_flags=['readwrite'])
|
|
104
|
+
while not it.finished:
|
|
105
|
+
idx = it.multi_index
|
|
106
|
+
orig = float(x[idx])
|
|
107
|
+
x[idx] = orig + eps
|
|
108
|
+
f_pos = _scalar_value(func(_to_func_input(x.copy(), backend)), backend)
|
|
109
|
+
x[idx] = orig - eps
|
|
110
|
+
f_neg = _scalar_value(func(_to_func_input(x.copy(), backend)), backend)
|
|
111
|
+
grad[idx] = (f_pos - f_neg) / (2 * eps)
|
|
112
|
+
x[idx] = orig
|
|
113
|
+
it.iternext()
|
|
114
|
+
return grad
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def grad_of_fn(klong, fn, x):
|
|
118
|
+
"""
|
|
119
|
+
Return gradient of Klong or Python function ``fn`` at ``x``.
|
|
120
|
+
|
|
121
|
+
Uses PyTorch autograd when available (USE_TORCH=1), otherwise
|
|
122
|
+
falls back to numeric differentiation.
|
|
123
|
+
"""
|
|
124
|
+
backend = klong._backend
|
|
125
|
+
call_fn = lambda v: _invoke_fn(klong, fn, [v])
|
|
126
|
+
|
|
127
|
+
if backend.supports_autograd():
|
|
128
|
+
return backend.compute_autograd(call_fn, x)
|
|
129
|
+
else:
|
|
130
|
+
return numeric_grad(call_fn, x, backend)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def numeric_jacobian(func, x, backend, eps=None):
|
|
134
|
+
"""
|
|
135
|
+
Compute Jacobian matrix of func at point x using finite differences.
|
|
136
|
+
|
|
137
|
+
For f: R^n -> R^m, returns m x n matrix where J[i,j] = df_i/dx_j.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
func: Callable that takes an array and returns an array
|
|
141
|
+
x: Input point (array)
|
|
142
|
+
backend: Backend provider
|
|
143
|
+
eps: Step size for finite differences (default: 1e-6 or 1e-4 for float32)
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
Jacobian matrix as numpy array
|
|
147
|
+
"""
|
|
148
|
+
float_dtype = _get_float_dtype(backend)
|
|
149
|
+
if eps is None:
|
|
150
|
+
eps = 1e-4 if float_dtype == np.float32 else 1e-6
|
|
151
|
+
|
|
152
|
+
# Convert to numpy
|
|
153
|
+
if backend.is_backend_array(x):
|
|
154
|
+
x = backend.to_numpy(x)
|
|
155
|
+
x = np.asarray(x, dtype=float_dtype).flatten()
|
|
156
|
+
|
|
157
|
+
# Evaluate function at x to get output shape
|
|
158
|
+
f0 = func(_to_func_input(x.copy(), backend))
|
|
159
|
+
if backend.is_backend_array(f0):
|
|
160
|
+
f0 = backend.to_numpy(f0)
|
|
161
|
+
f0 = np.asarray(f0, dtype=float_dtype).flatten()
|
|
162
|
+
|
|
163
|
+
n = len(x) # Input dimension
|
|
164
|
+
m = len(f0) # Output dimension
|
|
165
|
+
jacobian = np.zeros((m, n), dtype=float_dtype)
|
|
166
|
+
|
|
167
|
+
for j in range(n):
|
|
168
|
+
x_plus = x.copy()
|
|
169
|
+
x_plus[j] += eps
|
|
170
|
+
x_minus = x.copy()
|
|
171
|
+
x_minus[j] -= eps
|
|
172
|
+
|
|
173
|
+
f_plus = func(_to_func_input(x_plus, backend))
|
|
174
|
+
f_minus = func(_to_func_input(x_minus, backend))
|
|
175
|
+
|
|
176
|
+
if backend.is_backend_array(f_plus):
|
|
177
|
+
f_plus = backend.to_numpy(f_plus)
|
|
178
|
+
if backend.is_backend_array(f_minus):
|
|
179
|
+
f_minus = backend.to_numpy(f_minus)
|
|
180
|
+
|
|
181
|
+
f_plus = np.asarray(f_plus, dtype=float_dtype).flatten()
|
|
182
|
+
f_minus = np.asarray(f_minus, dtype=float_dtype).flatten()
|
|
183
|
+
|
|
184
|
+
jacobian[:, j] = (f_plus - f_minus) / (2 * eps)
|
|
185
|
+
|
|
186
|
+
return jacobian
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def jacobian_of_fn(klong, fn, x):
|
|
190
|
+
"""
|
|
191
|
+
Compute Jacobian matrix of Klong function fn at point x.
|
|
192
|
+
|
|
193
|
+
For f: R^n -> R^m, returns m x n matrix where J[i,j] = df_i/dx_j.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
klong: KlongInterpreter instance
|
|
197
|
+
fn: Function (KGSym, KGLambda, KGFn, KGCall, or callable)
|
|
198
|
+
x: Input point
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
Jacobian matrix
|
|
202
|
+
"""
|
|
203
|
+
backend = klong._backend
|
|
204
|
+
call_fn = lambda v: _invoke_fn(klong, fn, [v])
|
|
205
|
+
|
|
206
|
+
if backend.supports_autograd():
|
|
207
|
+
try:
|
|
208
|
+
return backend.compute_jacobian(call_fn, x)
|
|
209
|
+
except Exception:
|
|
210
|
+
# Fall back to numeric if torch jacobian fails
|
|
211
|
+
return numeric_jacobian(call_fn, x, backend=backend)
|
|
212
|
+
else:
|
|
213
|
+
return numeric_jacobian(call_fn, x, backend=backend)
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def multi_jacobian_of_fn(klong, fn, param_syms):
|
|
217
|
+
"""
|
|
218
|
+
Compute Jacobians for multiple parameters in one call.
|
|
219
|
+
|
|
220
|
+
Args:
|
|
221
|
+
klong: KlongInterpreter instance
|
|
222
|
+
fn: Function (KGSym, KGLambda, KGFn, KGCall, or callable)
|
|
223
|
+
Should be a niladic function that references the parameters
|
|
224
|
+
param_syms: List of KGSym parameter symbols to differentiate with respect to
|
|
225
|
+
|
|
226
|
+
Returns:
|
|
227
|
+
List of Jacobian matrices, one per parameter
|
|
228
|
+
"""
|
|
229
|
+
backend = klong._backend
|
|
230
|
+
param_values = [klong[sym] for sym in param_syms]
|
|
231
|
+
call_fn = lambda: _invoke_fn(klong, fn, [])
|
|
232
|
+
|
|
233
|
+
jacobians = []
|
|
234
|
+
for sym, val in zip(param_syms, param_values):
|
|
235
|
+
original = klong[sym]
|
|
236
|
+
|
|
237
|
+
def single_param_fn(v, s=sym, orig=original):
|
|
238
|
+
"""Wrapper that sets param to v, calls fn, restores param."""
|
|
239
|
+
klong[s] = v
|
|
240
|
+
try:
|
|
241
|
+
return call_fn()
|
|
242
|
+
finally:
|
|
243
|
+
klong[s] = orig
|
|
244
|
+
|
|
245
|
+
if backend.supports_autograd():
|
|
246
|
+
try:
|
|
247
|
+
jac = backend.compute_jacobian(single_param_fn, val)
|
|
248
|
+
except Exception:
|
|
249
|
+
jac = numeric_jacobian(single_param_fn, val, backend=backend)
|
|
250
|
+
else:
|
|
251
|
+
jac = numeric_jacobian(single_param_fn, val, backend=backend)
|
|
252
|
+
|
|
253
|
+
# Restore original value after jacobian computation
|
|
254
|
+
klong[sym] = original
|
|
255
|
+
jacobians.append(jac)
|
|
256
|
+
|
|
257
|
+
return jacobians
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def multi_grad_of_fn(klong, fn, param_syms):
|
|
261
|
+
"""
|
|
262
|
+
Compute gradients for multiple parameters in one call.
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
klong: KlongInterpreter instance
|
|
266
|
+
fn: Loss function (KGSym, KGLambda, KGFn, KGCall, or callable)
|
|
267
|
+
Should be a niladic function that references the parameters
|
|
268
|
+
param_syms: List of KGSym parameter symbols to differentiate with respect to
|
|
269
|
+
|
|
270
|
+
Returns:
|
|
271
|
+
List of gradients, one per parameter
|
|
272
|
+
"""
|
|
273
|
+
backend = klong._backend
|
|
274
|
+
# Access context directly to avoid KGFnWrapper wrapping
|
|
275
|
+
param_values = [klong._context[sym] for sym in param_syms]
|
|
276
|
+
|
|
277
|
+
def call_fn_with_tensors(tensors):
|
|
278
|
+
"""Call the loss function with tensor values temporarily bound to symbols."""
|
|
279
|
+
originals = {sym: klong._context[sym] for sym in param_syms}
|
|
280
|
+
try:
|
|
281
|
+
for sym, tensor in zip(param_syms, tensors):
|
|
282
|
+
klong[sym] = tensor
|
|
283
|
+
return _invoke_fn(klong, fn, [])
|
|
284
|
+
finally:
|
|
285
|
+
for sym, orig in originals.items():
|
|
286
|
+
klong[sym] = orig
|
|
287
|
+
|
|
288
|
+
if backend.supports_autograd():
|
|
289
|
+
return backend.compute_multi_autograd(call_fn_with_tensors, param_values)
|
|
290
|
+
else:
|
|
291
|
+
# Fallback: compute numeric gradients one at a time
|
|
292
|
+
grads = []
|
|
293
|
+
for i, sym in enumerate(param_syms):
|
|
294
|
+
def single_param_fn(v, idx=i):
|
|
295
|
+
vals = list(param_values)
|
|
296
|
+
vals[idx] = v
|
|
297
|
+
return call_fn_with_tensors(vals)
|
|
298
|
+
grads.append(numeric_grad(single_param_fn, param_values[i], backend))
|
|
299
|
+
return grads
|