klongpy 0.7.0__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- klongpy/__init__.py +0 -2
- klongpy/adverbs.py +84 -82
- klongpy/autograd.py +0 -9
- klongpy/backend.py +9 -142
- klongpy/backends/__init__.py +9 -77
- klongpy/backends/base.py +154 -5
- klongpy/backends/numpy_backend.py +2 -1
- klongpy/backends/registry.py +76 -0
- klongpy/backends/torch_backend.py +83 -31
- klongpy/cli.py +50 -7
- klongpy/core.py +113 -1094
- klongpy/db/sys_fn_db.py +3 -3
- klongpy/db/sys_fn_kvs.py +2 -4
- klongpy/dyads.py +203 -162
- klongpy/interpreter.py +32 -15
- klongpy/monads.py +99 -89
- klongpy/parser.py +328 -0
- klongpy/repl.py +2 -2
- klongpy/sys_fn.py +53 -15
- klongpy/sys_fn_ipc.py +4 -9
- klongpy/types.py +503 -0
- klongpy/writer.py +122 -0
- klongpy/ws/sys_fn_ws.py +5 -8
- {klongpy-0.7.0.dist-info → klongpy-0.7.1.dist-info}/METADATA +146 -95
- klongpy-0.7.1.dist-info/RECORD +52 -0
- klongpy-0.7.0.dist-info/RECORD +0 -48
- {klongpy-0.7.0.dist-info → klongpy-0.7.1.dist-info}/WHEEL +0 -0
- {klongpy-0.7.0.dist-info → klongpy-0.7.1.dist-info}/entry_points.txt +0 -0
- {klongpy-0.7.0.dist-info → klongpy-0.7.1.dist-info}/licenses/LICENSE +0 -0
- {klongpy-0.7.0.dist-info → klongpy-0.7.1.dist-info}/top_level.txt +0 -0
klongpy/__init__.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
from .interpreter import KlongInterpreter, KlongException
|
|
2
|
-
from .backend import TorchUnsupportedDtypeError
|
|
3
2
|
from .backends import (
|
|
4
3
|
get_backend,
|
|
5
4
|
register_backend,
|
|
@@ -11,7 +10,6 @@ from .backends import (
|
|
|
11
10
|
__all__ = [
|
|
12
11
|
"KlongInterpreter",
|
|
13
12
|
"KlongException",
|
|
14
|
-
"TorchUnsupportedDtypeError",
|
|
15
13
|
"UnsupportedDtypeError",
|
|
16
14
|
"get_backend",
|
|
17
15
|
"register_backend",
|
klongpy/adverbs.py
CHANGED
|
@@ -1,35 +1,9 @@
|
|
|
1
1
|
from .core import *
|
|
2
|
-
from .dyads import eval_dyad_add, eval_dyad_subtract, eval_dyad_multiply, eval_dyad_divide
|
|
3
|
-
from .backend import kg_asarray, is_number, str_to_chr_arr, kg_equal
|
|
4
2
|
import functools
|
|
5
3
|
import itertools
|
|
6
4
|
|
|
7
5
|
|
|
8
|
-
def
|
|
9
|
-
if s == "'":
|
|
10
|
-
return eval_adverb_each2 if arity == 2 else eval_adverb_each
|
|
11
|
-
elif s == '/':
|
|
12
|
-
return eval_adverb_over_neutral if arity == 2 else eval_adverb_over
|
|
13
|
-
elif s == '\\':
|
|
14
|
-
return eval_adverb_scan_over_neutral if arity == 2 else eval_adverb_scan_over
|
|
15
|
-
elif s == '\\~':
|
|
16
|
-
return (lambda f,a,b,k=klong: eval_adverb_scan_while(k,f,a,b)) if arity == 2 else eval_adverb_scan_converging
|
|
17
|
-
elif s == '\\*':
|
|
18
|
-
return eval_adverb_scan_iterating
|
|
19
|
-
elif s == ':\\':
|
|
20
|
-
return eval_adverb_each_left
|
|
21
|
-
elif s == ':\'':
|
|
22
|
-
return eval_adverb_each_pair
|
|
23
|
-
elif s == ':/':
|
|
24
|
-
return eval_adverb_each_right
|
|
25
|
-
elif s == ':*':
|
|
26
|
-
return eval_dyad_adverb_iterate
|
|
27
|
-
elif s == ':~':
|
|
28
|
-
return (lambda f,a,b,k=klong: eval_adverb_while(k,f,a,b)) if arity == 2 else eval_adverb_converge
|
|
29
|
-
raise RuntimeError(f"unknown adverb: {s}")
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
def eval_adverb_converge(f, a, op):
|
|
6
|
+
def eval_adverb_converge(f, a, op, backend):
|
|
33
7
|
"""
|
|
34
8
|
f:~a [Converge]
|
|
35
9
|
|
|
@@ -55,10 +29,10 @@ def eval_adverb_converge(f, a, op):
|
|
|
55
29
|
def _e(p,q):
|
|
56
30
|
if not isinstance(p, type(q)):
|
|
57
31
|
return False
|
|
58
|
-
if is_number(p):
|
|
59
|
-
return np.isclose(p,q)
|
|
60
|
-
elif
|
|
61
|
-
return kg_equal(p,q)
|
|
32
|
+
if backend.is_number(p):
|
|
33
|
+
return backend.np.isclose(p,q)
|
|
34
|
+
elif backend.is_array(p):
|
|
35
|
+
return backend.kg_equal(p, q)
|
|
62
36
|
return p == q
|
|
63
37
|
x = f(a)
|
|
64
38
|
xx = f(x)
|
|
@@ -68,7 +42,7 @@ def eval_adverb_converge(f, a, op):
|
|
|
68
42
|
return x
|
|
69
43
|
|
|
70
44
|
|
|
71
|
-
def eval_adverb_each(f, a, op):
|
|
45
|
+
def eval_adverb_each(f, a, op, backend):
|
|
72
46
|
"""
|
|
73
47
|
|
|
74
48
|
f'a [Each]
|
|
@@ -93,17 +67,17 @@ def eval_adverb_each(f, a, op):
|
|
|
93
67
|
return a
|
|
94
68
|
has_str = False
|
|
95
69
|
r = []
|
|
96
|
-
for x in str_to_chr_arr(a):
|
|
70
|
+
for x in backend.str_to_chr_arr(a):
|
|
97
71
|
u = f(x)
|
|
98
72
|
has_str |= isinstance(u,str)
|
|
99
73
|
r.append(u)
|
|
100
|
-
return ''.join(r) if has_str else kg_asarray(r)
|
|
74
|
+
return ''.join(r) if has_str else backend.kg_asarray(r)
|
|
101
75
|
if is_iterable(a):
|
|
102
76
|
r = [f(x) for x in a]
|
|
103
|
-
return a if is_empty(a) else kg_asarray(r)
|
|
77
|
+
return a if is_empty(a) else backend.kg_asarray(r)
|
|
104
78
|
elif is_dict(a):
|
|
105
|
-
r = [f(kg_asarray(x)) for x in a.items()]
|
|
106
|
-
return kg_asarray(r)
|
|
79
|
+
r = [f(backend.kg_asarray(x)) for x in a.items()]
|
|
80
|
+
return backend.kg_asarray(r)
|
|
107
81
|
return f(a)
|
|
108
82
|
|
|
109
83
|
|
|
@@ -125,14 +99,14 @@ def eval_adverb_each2(f, a, b):
|
|
|
125
99
|
|
|
126
100
|
"""
|
|
127
101
|
if is_empty(a) or is_empty(b):
|
|
128
|
-
return
|
|
102
|
+
return bknp.asarray([]) if is_list(a) or is_list(b) else ""
|
|
129
103
|
if is_atom(a) and is_atom(b):
|
|
130
104
|
return f(a,b)
|
|
131
|
-
r =
|
|
105
|
+
r = bknp.asarray([f(x,y) for x,y in zip(a,b)])
|
|
132
106
|
return ''.join(r) if r.dtype == '<U1' else r
|
|
133
107
|
|
|
134
108
|
|
|
135
|
-
def eval_adverb_each_left(f, a, b):
|
|
109
|
+
def eval_adverb_each_left(f, a, b, backend):
|
|
136
110
|
"""
|
|
137
111
|
a f:\b [Each-Left]
|
|
138
112
|
a f:/b [Each-Right]
|
|
@@ -154,20 +128,19 @@ def eval_adverb_each_left(f, a, b):
|
|
|
154
128
|
Examples: 1,:\[2 3 4] --> [[1 2] [1 3] [1 4]]
|
|
155
129
|
1,:/[2 3 4] --> [[2 1] [3 1] [4 1]]
|
|
156
130
|
"""
|
|
157
|
-
b = str_to_chr_arr(b) if isinstance(b,str) else b
|
|
158
|
-
return kg_asarray([f(a,x) for x in b])
|
|
131
|
+
b = backend.str_to_chr_arr(b) if isinstance(b,str) else b
|
|
132
|
+
return backend.kg_asarray([f(a,x) for x in b])
|
|
159
133
|
|
|
160
134
|
|
|
161
|
-
def eval_adverb_each_right(f, a, b):
|
|
135
|
+
def eval_adverb_each_right(f, a, b, backend):
|
|
162
136
|
"""
|
|
163
137
|
see: eval_dyad_adverb_each_left
|
|
164
138
|
"""
|
|
165
|
-
b = str_to_chr_arr(b) if isinstance(b,str) else b
|
|
166
|
-
return kg_asarray([f(x,a) for x in b])
|
|
139
|
+
b = backend.str_to_chr_arr(b) if isinstance(b,str) else b
|
|
140
|
+
return backend.kg_asarray([f(x,a) for x in b])
|
|
167
141
|
|
|
168
142
|
|
|
169
|
-
|
|
170
|
-
def eval_adverb_each_pair(f, a, op):
|
|
143
|
+
def eval_adverb_each_pair(f, a, op, backend):
|
|
171
144
|
"""
|
|
172
145
|
|
|
173
146
|
f:'a [Each-Pair]
|
|
@@ -186,8 +159,8 @@ def eval_adverb_each_pair(f, a, op):
|
|
|
186
159
|
if is_atom(a) or (is_iterable(a) and len(a) == 1):
|
|
187
160
|
return a
|
|
188
161
|
j = isinstance(a, str)
|
|
189
|
-
a = str_to_chr_arr(a) if j else a
|
|
190
|
-
return kg_asarray([f(x,y) for x,y in zip(a[::],a[1::])])
|
|
162
|
+
a = backend.str_to_chr_arr(a) if j else a
|
|
163
|
+
return backend.kg_asarray([f(x,y) for x,y in zip(a[::],a[1::])])
|
|
191
164
|
|
|
192
165
|
|
|
193
166
|
def eval_dyad_adverb_iterate(f, a, b):
|
|
@@ -209,7 +182,7 @@ def eval_dyad_adverb_iterate(f, a, b):
|
|
|
209
182
|
return b
|
|
210
183
|
|
|
211
184
|
|
|
212
|
-
def eval_adverb_over(f, a, op):
|
|
185
|
+
def eval_adverb_over(f, a, op, backend):
|
|
213
186
|
"""
|
|
214
187
|
f/a [Over]
|
|
215
188
|
|
|
@@ -228,22 +201,23 @@ def eval_adverb_over(f, a, op):
|
|
|
228
201
|
return a
|
|
229
202
|
if len(a) == 1:
|
|
230
203
|
return a[0]
|
|
231
|
-
# Use
|
|
204
|
+
# Use backend's ufunc reduce when available for better performance
|
|
205
|
+
np_backend = backend.np
|
|
232
206
|
if isinstance(op, KGOp):
|
|
233
207
|
if safe_eq(op.a,'+'):
|
|
234
|
-
return
|
|
208
|
+
return np_backend.add.reduce(a)
|
|
235
209
|
elif safe_eq(op.a, '-'):
|
|
236
|
-
return
|
|
237
|
-
elif safe_eq(op.a, '*') and hasattr(
|
|
238
|
-
return
|
|
239
|
-
elif safe_eq(op.a, '%') and hasattr(
|
|
240
|
-
return
|
|
210
|
+
return np_backend.subtract.reduce(a)
|
|
211
|
+
elif safe_eq(op.a, '*') and hasattr(np_backend.multiply,'reduce'):
|
|
212
|
+
return np_backend.multiply.reduce(a)
|
|
213
|
+
elif safe_eq(op.a, '%') and hasattr(np_backend.divide,'reduce'):
|
|
214
|
+
return np_backend.divide.reduce(a)
|
|
241
215
|
elif safe_eq(op.a, '&') and a.ndim == 1:
|
|
242
|
-
return
|
|
216
|
+
return np_backend.min(a)
|
|
243
217
|
elif safe_eq(op.a, '|') and a.ndim == 1:
|
|
244
|
-
return
|
|
245
|
-
elif safe_eq(op.a, ',') and
|
|
246
|
-
return a if a.ndim == 1 else
|
|
218
|
+
return np_backend.max(a)
|
|
219
|
+
elif safe_eq(op.a, ',') and np_backend.isarray(a) and a.dtype != 'O':
|
|
220
|
+
return a if a.ndim == 1 else np_backend.concatenate(a, axis=0)
|
|
247
221
|
return functools.reduce(f, a)
|
|
248
222
|
|
|
249
223
|
|
|
@@ -280,7 +254,7 @@ def eval_adverb_over_neutral(f, a, b):
|
|
|
280
254
|
return functools.reduce(f,b[1:],f(a,b[0]))
|
|
281
255
|
|
|
282
256
|
|
|
283
|
-
def eval_adverb_scan_over_neutral(f, a, b):
|
|
257
|
+
def eval_adverb_scan_over_neutral(f, a, b, backend):
|
|
284
258
|
"""
|
|
285
259
|
|
|
286
260
|
f\a [Scan-Over]
|
|
@@ -309,31 +283,33 @@ def eval_adverb_scan_over_neutral(f, a, b):
|
|
|
309
283
|
b = [b]
|
|
310
284
|
b = [f(a,b[0]), *b[1:]]
|
|
311
285
|
r = list(itertools.accumulate(b,f))
|
|
312
|
-
q = kg_asarray(r)
|
|
286
|
+
q = backend.kg_asarray(r)
|
|
313
287
|
r = [a, *q]
|
|
314
|
-
return kg_asarray(r)
|
|
288
|
+
return backend.kg_asarray(r)
|
|
315
289
|
|
|
316
290
|
|
|
317
|
-
def eval_adverb_scan_over(f, a, op):
|
|
291
|
+
def eval_adverb_scan_over(f, a, op, backend):
|
|
318
292
|
"""
|
|
319
293
|
see eval_adverb_scan_over_neutral
|
|
320
294
|
"""
|
|
321
295
|
if is_atom(a):
|
|
322
296
|
return a
|
|
323
|
-
# Use
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
297
|
+
# Use backend's ufunc accumulate when available for better performance
|
|
298
|
+
np_backend = backend.np
|
|
299
|
+
if isinstance(op, KGOp):
|
|
300
|
+
if safe_eq(op.a, '+') and hasattr(np_backend.add, 'accumulate'):
|
|
301
|
+
return np_backend.add.accumulate(a)
|
|
302
|
+
elif safe_eq(op.a, '-') and hasattr(np_backend.subtract, 'accumulate'):
|
|
303
|
+
return np_backend.subtract.accumulate(a)
|
|
304
|
+
elif safe_eq(op.a, '*') and hasattr(np_backend.multiply, 'accumulate'):
|
|
305
|
+
return np_backend.multiply.accumulate(a)
|
|
306
|
+
elif safe_eq(op.a, '%') and hasattr(np_backend.divide, 'accumulate'):
|
|
307
|
+
return np_backend.divide.accumulate(a)
|
|
332
308
|
r = list(itertools.accumulate(a, f))
|
|
333
|
-
return kg_asarray(r)
|
|
309
|
+
return backend.kg_asarray(r)
|
|
334
310
|
|
|
335
311
|
|
|
336
|
-
def eval_adverb_scan_converging(f, a, op):
|
|
312
|
+
def eval_adverb_scan_converging(f, a, op, backend):
|
|
337
313
|
"""
|
|
338
314
|
|
|
339
315
|
f\~a [Scan-Converging]
|
|
@@ -356,15 +332,15 @@ def eval_adverb_scan_converging(f, a, op):
|
|
|
356
332
|
x = a
|
|
357
333
|
xx = f(a)
|
|
358
334
|
r = [a, xx]
|
|
359
|
-
while not kg_equal(x,xx):
|
|
335
|
+
while not backend.kg_equal(x, xx):
|
|
360
336
|
x = xx
|
|
361
337
|
xx = f(x)
|
|
362
338
|
r.append(xx)
|
|
363
339
|
r.pop()
|
|
364
|
-
return kg_asarray(r)
|
|
340
|
+
return backend.kg_asarray(r)
|
|
365
341
|
|
|
366
342
|
|
|
367
|
-
def eval_adverb_scan_while(klong, f, a, b):
|
|
343
|
+
def eval_adverb_scan_while(klong, f, a, b, backend):
|
|
368
344
|
"""
|
|
369
345
|
|
|
370
346
|
a f\~b [Scan-While]
|
|
@@ -389,10 +365,10 @@ def eval_adverb_scan_while(klong, f, a, b):
|
|
|
389
365
|
b = f(b)
|
|
390
366
|
r.append(b)
|
|
391
367
|
r.pop()
|
|
392
|
-
return kg_asarray(r)
|
|
368
|
+
return backend.kg_asarray(r)
|
|
393
369
|
|
|
394
370
|
|
|
395
|
-
def eval_adverb_scan_iterating(f, a, b):
|
|
371
|
+
def eval_adverb_scan_iterating(f, a, b, backend):
|
|
396
372
|
"""
|
|
397
373
|
|
|
398
374
|
a f\*b [Scan-Iterating]
|
|
@@ -410,7 +386,7 @@ def eval_adverb_scan_iterating(f, a, b):
|
|
|
410
386
|
b = f(b)
|
|
411
387
|
r.append(b)
|
|
412
388
|
a = a - 1
|
|
413
|
-
return kg_asarray(r)
|
|
389
|
+
return backend.kg_asarray(r)
|
|
414
390
|
|
|
415
391
|
|
|
416
392
|
def eval_adverb_while(klong, f, a, b):
|
|
@@ -429,3 +405,29 @@ def eval_adverb_while(klong, f, a, b):
|
|
|
429
405
|
while klong.eval(KGCall(a, b, arity=1)):
|
|
430
406
|
b = f(b)
|
|
431
407
|
return b
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
def get_adverb_fn(klong, s, arity):
|
|
411
|
+
backend = klong._backend
|
|
412
|
+
|
|
413
|
+
if s == "'":
|
|
414
|
+
return eval_adverb_each2 if arity == 2 else lambda f,a,op: eval_adverb_each(f,a,op,backend)
|
|
415
|
+
elif s == '/':
|
|
416
|
+
return eval_adverb_over_neutral if arity == 2 else lambda f,a,op: eval_adverb_over(f,a,op,backend)
|
|
417
|
+
elif s == '\\':
|
|
418
|
+
return (lambda f,a,b: eval_adverb_scan_over_neutral(f,a,b,backend)) if arity == 2 else lambda f,a,op: eval_adverb_scan_over(f,a,op,backend)
|
|
419
|
+
elif s == '\\~':
|
|
420
|
+
return (lambda f,a,b: eval_adverb_scan_while(klong,f,a,b,backend)) if arity == 2 else lambda f,a,op: eval_adverb_scan_converging(f,a,op,backend)
|
|
421
|
+
elif s == '\\*':
|
|
422
|
+
return lambda f,a,b: eval_adverb_scan_iterating(f,a,b,backend)
|
|
423
|
+
elif s == ':\\':
|
|
424
|
+
return lambda f,a,b: eval_adverb_each_left(f,a,b,backend)
|
|
425
|
+
elif s == ':\'':
|
|
426
|
+
return lambda f,a,op: eval_adverb_each_pair(f,a,op,backend)
|
|
427
|
+
elif s == ':/':
|
|
428
|
+
return lambda f,a,b: eval_adverb_each_right(f,a,b,backend)
|
|
429
|
+
elif s == ':*':
|
|
430
|
+
return eval_dyad_adverb_iterate
|
|
431
|
+
elif s == ':~':
|
|
432
|
+
return (lambda f,a,b: eval_adverb_while(klong,f,a,b)) if arity == 2 else lambda f,a,op: eval_adverb_converge(f,a,op,backend)
|
|
433
|
+
raise RuntimeError(f"unknown adverb: {s}")
|
klongpy/autograd.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import numpy as np
|
|
2
2
|
from .core import KGLambda, KGCall, KGSym, KGFn
|
|
3
|
-
from .backend import get_default_backend
|
|
4
3
|
|
|
5
4
|
|
|
6
5
|
class AutogradError(Exception):
|
|
@@ -131,14 +130,6 @@ def grad_of_fn(klong, fn, x):
|
|
|
131
130
|
return numeric_grad(call_fn, x, backend)
|
|
132
131
|
|
|
133
132
|
|
|
134
|
-
def torch_autograd(func, x):
|
|
135
|
-
"""Compute gradient using PyTorch autograd (requires torch backend)."""
|
|
136
|
-
backend = get_default_backend()
|
|
137
|
-
if not backend.supports_autograd():
|
|
138
|
-
raise RuntimeError("PyTorch autograd requires torch backend (USE_TORCH=1)")
|
|
139
|
-
return backend.compute_autograd(func, x)
|
|
140
|
-
|
|
141
|
-
|
|
142
133
|
def numeric_jacobian(func, x, backend, eps=None):
|
|
143
134
|
"""
|
|
144
135
|
Compute Jacobian matrix of func at point x using finite differences.
|
klongpy/backend.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
|
1
1
|
"""
|
|
2
|
-
Backend
|
|
2
|
+
Backend module for KlongPy.
|
|
3
3
|
|
|
4
|
-
|
|
5
|
-
New code should use the backends package directly:
|
|
4
|
+
Prefer using the backends package directly:
|
|
6
5
|
|
|
7
6
|
from klongpy.backends import get_backend, BackendProvider
|
|
8
7
|
|
|
@@ -10,162 +9,30 @@ For per-interpreter backends, use:
|
|
|
10
9
|
|
|
11
10
|
klong = KlongInterpreter(backend='torch')
|
|
12
11
|
"""
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
from .backends import (
|
|
16
|
-
get_backend,
|
|
17
|
-
register_backend,
|
|
18
|
-
list_backends,
|
|
12
|
+
from .backends.base import (
|
|
19
13
|
BackendProvider,
|
|
20
14
|
UnsupportedDtypeError,
|
|
21
|
-
TorchUnsupportedDtypeError,
|
|
22
|
-
NumpyBackendProvider,
|
|
23
|
-
TorchBackendProvider,
|
|
24
|
-
KGChar,
|
|
25
15
|
is_jagged_array,
|
|
26
16
|
is_supported_type,
|
|
27
17
|
)
|
|
18
|
+
from .backends.numpy_backend import KGChar, NumpyBackendProvider
|
|
19
|
+
from .backends.registry import get_backend, list_backends, register_backend, TorchBackendProvider
|
|
28
20
|
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
_default_backend = get_backend()
|
|
33
|
-
|
|
34
|
-
# Backward compatibility: expose np and use_torch at module level
|
|
35
|
-
np = _default_backend.np
|
|
36
|
-
use_torch = _default_backend.name == 'torch'
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
def get_default_backend():
|
|
40
|
-
"""Get the default backend provider."""
|
|
41
|
-
return _default_backend
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
def to_numpy(x):
|
|
45
|
-
"""Convert tensor/array to numpy, handling device transfers and 0-dim arrays."""
|
|
46
|
-
result = _default_backend.to_numpy(x)
|
|
47
|
-
if isinstance(result, real_np.ndarray) and result.ndim == 0:
|
|
48
|
-
return result.item()
|
|
49
|
-
return result
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
def to_display(x):
|
|
53
|
-
"""Convert value to display-friendly format."""
|
|
54
|
-
return _default_backend.to_display(x)
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
def array_size(a):
|
|
58
|
-
"""
|
|
59
|
-
Get the total number of elements in an array/tensor.
|
|
60
|
-
|
|
61
|
-
Works with both numpy arrays and torch tensors.
|
|
62
|
-
"""
|
|
63
|
-
return _default_backend.array_size(a)
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
def safe_equal(x, y):
|
|
67
|
-
"""Compare two values for equality, handling backend-specific array types."""
|
|
68
|
-
return _default_backend.safe_equal(x, y)
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
def detach_if_needed(x):
|
|
72
|
-
"""Detach array from computation graph if needed."""
|
|
73
|
-
return _default_backend.detach_if_needed(x)
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
def to_int_array(a):
|
|
77
|
-
"""Convert array to integer type."""
|
|
78
|
-
return _default_backend.to_int_array(a)
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
def power(a, b):
|
|
82
|
-
"""Compute a^b, handling gradient tracking if applicable."""
|
|
83
|
-
return _default_backend.power(a, b)
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
def has_gradient(x):
|
|
87
|
-
"""Check if x is tracking gradients (for autograd)."""
|
|
88
|
-
return _default_backend.has_gradient(x)
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
def kg_asarray(a):
|
|
92
|
-
"""Convert input to array using the default backend's kg_asarray method."""
|
|
93
|
-
return _default_backend.kg_asarray(a)
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
def is_integer(x):
|
|
97
|
-
"""Check if x is an integer type using the default backend."""
|
|
98
|
-
return _default_backend.is_integer(x)
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
def is_float(x):
|
|
102
|
-
"""Check if x is a float type using the default backend."""
|
|
103
|
-
return _default_backend.is_float(x)
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
def is_number(a):
|
|
107
|
-
"""Check if a is a number (integer or float) using the default backend."""
|
|
108
|
-
return _default_backend.is_number(a)
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
def get_dtype_kind(arr):
|
|
112
|
-
"""Get the dtype 'kind' character for an array using the default backend."""
|
|
113
|
-
return _default_backend.get_dtype_kind(arr)
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
def str_to_chr_arr(s):
|
|
117
|
-
"""Convert string to character array using the default backend."""
|
|
118
|
-
return _default_backend.str_to_chr_arr(s)
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
def kg_argsort(a, descending=False):
|
|
122
|
-
"""Argsort array using the default backend."""
|
|
123
|
-
from .core import kg_argsort as core_kg_argsort
|
|
124
|
-
return core_kg_argsort(a, _default_backend, descending=descending)
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
def vec_fn(a, f):
|
|
128
|
-
"""Apply a function f to an array a, with support for nested arrays."""
|
|
129
|
-
from .core import vec_fn as core_vec_fn
|
|
130
|
-
return core_vec_fn(a, f, _default_backend)
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
def kg_equal(a, b):
|
|
134
|
-
"""Compare two values or arrays for equality using the default backend."""
|
|
135
|
-
from .core import kg_equal as core_kg_equal
|
|
136
|
-
return core_kg_equal(a, b, _default_backend)
|
|
137
|
-
|
|
21
|
+
_default_np_backend = get_backend('numpy')
|
|
22
|
+
np = _default_np_backend.np
|
|
23
|
+
bknp = np
|
|
138
24
|
|
|
139
25
|
__all__ = [
|
|
140
26
|
'np',
|
|
141
|
-
'
|
|
27
|
+
'bknp',
|
|
142
28
|
'get_backend',
|
|
143
|
-
'get_default_backend',
|
|
144
29
|
'register_backend',
|
|
145
30
|
'list_backends',
|
|
146
31
|
'BackendProvider',
|
|
147
32
|
'UnsupportedDtypeError',
|
|
148
|
-
'TorchUnsupportedDtypeError',
|
|
149
33
|
'NumpyBackendProvider',
|
|
150
34
|
'TorchBackendProvider',
|
|
151
35
|
'KGChar',
|
|
152
36
|
'is_supported_type',
|
|
153
37
|
'is_jagged_array',
|
|
154
|
-
'to_numpy',
|
|
155
|
-
'to_display',
|
|
156
|
-
'array_size',
|
|
157
|
-
'safe_equal',
|
|
158
|
-
'detach_if_needed',
|
|
159
|
-
'to_int_array',
|
|
160
|
-
'power',
|
|
161
|
-
'has_gradient',
|
|
162
|
-
'kg_asarray',
|
|
163
|
-
'is_integer',
|
|
164
|
-
'is_float',
|
|
165
|
-
'is_number',
|
|
166
|
-
'get_dtype_kind',
|
|
167
|
-
'str_to_chr_arr',
|
|
168
|
-
'kg_argsort',
|
|
169
|
-
'vec_fn',
|
|
170
|
-
'kg_equal',
|
|
171
38
|
]
|
klongpy/backends/__init__.py
CHANGED
|
@@ -1,94 +1,26 @@
|
|
|
1
1
|
"""
|
|
2
|
-
Backend
|
|
2
|
+
Backend public API for KlongPy.
|
|
3
3
|
|
|
4
|
-
|
|
5
|
-
The default backend is 'numpy'.
|
|
4
|
+
Re-exports registry helpers and core backend types.
|
|
6
5
|
"""
|
|
7
|
-
import
|
|
8
|
-
|
|
9
|
-
|
|
6
|
+
from .base import (
|
|
7
|
+
BackendProvider,
|
|
8
|
+
UnsupportedDtypeError,
|
|
9
|
+
is_jagged_array,
|
|
10
|
+
is_supported_type,
|
|
11
|
+
)
|
|
10
12
|
from .numpy_backend import NumpyBackendProvider, KGChar
|
|
11
|
-
|
|
12
|
-
# Registry of available backends
|
|
13
|
-
_BACKENDS = {}
|
|
14
|
-
|
|
15
|
-
# Default backend name
|
|
16
|
-
_DEFAULT_BACKEND = 'numpy'
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
def register_backend(name: str, provider_class):
|
|
20
|
-
"""Register a backend provider class."""
|
|
21
|
-
_BACKENDS[name] = provider_class
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
def get_backend(name: str = None, **kwargs) -> BackendProvider:
|
|
25
|
-
"""
|
|
26
|
-
Get a backend provider instance.
|
|
27
|
-
|
|
28
|
-
Parameters
|
|
29
|
-
----------
|
|
30
|
-
name : str, optional
|
|
31
|
-
Backend name ('numpy' or 'torch'). If None, uses default.
|
|
32
|
-
**kwargs
|
|
33
|
-
Additional arguments passed to the backend provider constructor.
|
|
34
|
-
|
|
35
|
-
Returns
|
|
36
|
-
-------
|
|
37
|
-
BackendProvider
|
|
38
|
-
The backend provider instance.
|
|
39
|
-
"""
|
|
40
|
-
if name is None:
|
|
41
|
-
# Check environment variable for default
|
|
42
|
-
env_backend = os.environ.get('KLONGPY_BACKEND', '').lower()
|
|
43
|
-
if env_backend == 'torch' or os.environ.get('USE_TORCH') == '1':
|
|
44
|
-
name = 'torch'
|
|
45
|
-
else:
|
|
46
|
-
name = _DEFAULT_BACKEND
|
|
47
|
-
|
|
48
|
-
if name not in _BACKENDS:
|
|
49
|
-
available = ', '.join(_BACKENDS.keys())
|
|
50
|
-
raise ValueError(f"Unknown backend: '{name}'. Available: {available}")
|
|
51
|
-
|
|
52
|
-
return _BACKENDS[name](**kwargs)
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
def list_backends():
|
|
56
|
-
"""Return list of available backend names."""
|
|
57
|
-
return list(_BACKENDS.keys())
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
def set_default_backend(name: str):
|
|
61
|
-
"""Set the default backend name."""
|
|
62
|
-
global _DEFAULT_BACKEND
|
|
63
|
-
if name not in _BACKENDS:
|
|
64
|
-
raise ValueError(f"Unknown backend: '{name}'")
|
|
65
|
-
_DEFAULT_BACKEND = name
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
# Register built-in backends
|
|
69
|
-
register_backend('numpy', NumpyBackendProvider)
|
|
70
|
-
|
|
71
|
-
# Try to register torch backend if available
|
|
72
|
-
try:
|
|
73
|
-
from .torch_backend import TorchBackendProvider, TorchUnsupportedDtypeError
|
|
74
|
-
register_backend('torch', TorchBackendProvider)
|
|
75
|
-
except ImportError:
|
|
76
|
-
# Torch not available
|
|
77
|
-
TorchBackendProvider = None
|
|
78
|
-
TorchUnsupportedDtypeError = UnsupportedDtypeError
|
|
79
|
-
|
|
13
|
+
from .registry import get_backend, list_backends, register_backend, TorchBackendProvider
|
|
80
14
|
|
|
81
15
|
__all__ = [
|
|
82
16
|
'BackendProvider',
|
|
83
17
|
'UnsupportedDtypeError',
|
|
84
|
-
'TorchUnsupportedDtypeError',
|
|
85
18
|
'NumpyBackendProvider',
|
|
86
19
|
'TorchBackendProvider',
|
|
87
20
|
'KGChar',
|
|
88
21
|
'get_backend',
|
|
89
22
|
'register_backend',
|
|
90
23
|
'list_backends',
|
|
91
|
-
'set_default_backend',
|
|
92
24
|
'is_jagged_array',
|
|
93
25
|
'is_supported_type',
|
|
94
26
|
]
|