klongpy 0.7.0__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
klongpy/dyads.py CHANGED
@@ -1,14 +1,9 @@
1
1
  from .core import *
2
2
  from .autograd import grad_of_fn, numeric_grad, jacobian_of_fn, multi_jacobian_of_fn, multi_grad_of_fn
3
- from .backend import (
4
- to_numpy, safe_equal, detach_if_needed, to_int_array, power as backend_power, has_gradient,
5
- kg_asarray, str_to_chr_arr, kg_equal, is_integer, is_float, get_dtype_kind, array_size
6
- )
7
- import sys
8
3
  import numpy
9
4
 
10
5
 
11
- def eval_dyad_add(a, b):
6
+ def eval_dyad_add(a, b, backend):
12
7
  """
13
8
 
14
9
  a+b [Plus]
@@ -23,10 +18,10 @@ def eval_dyad_add(a, b):
23
18
  1+0.3 --> 1.3
24
19
 
25
20
  """
26
- return np.add(a, b)
21
+ return backend.np.add(a, b)
27
22
 
28
23
 
29
- def eval_dyad_amend(a, b):
24
+ def eval_dyad_amend(a, b, backend):
30
25
  """
31
26
 
32
27
  a:=b [Amend]
@@ -52,13 +47,14 @@ def eval_dyad_amend(a, b):
52
47
  "abc":="def",3 --> "abcdef"
53
48
 
54
49
  """
55
- if not (isinstance(a, (str,list)) or np.isarray(a)):
50
+ np_backend = backend.np
51
+ if not (isinstance(a, (str,list)) or np_backend.isarray(a)):
56
52
  raise RuntimeError(f"a must be list or str: {a}")
57
53
  if len(b) <= 1:
58
54
  return a
59
55
  if isinstance(a, str):
60
- r = str_to_chr_arr(a)
61
- q = str_to_chr_arr(b[0])
56
+ r = backend.str_to_chr_arr(a)
57
+ q = backend.str_to_chr_arr(b[0])
62
58
  for i in b[1:]:
63
59
  try:
64
60
  r[i:i+len(q)] = q
@@ -67,29 +63,29 @@ def eval_dyad_amend(a, b):
67
63
  if i > len(r):
68
64
  RangeError(i)
69
65
  elif i == len(r):
70
- r = np.append(r, b[0])
66
+ r = numpy.append(r, b[0])
71
67
  else:
72
68
  r[i] = b[0]
73
69
  return "".join(["".join(x) for x in r])
74
- r = np.array(a) # clone
75
- if is_list(b[0]): # TOOD: use np.put if we can
70
+ r = np_backend.array(a) # clone
71
+ if is_list(b[0]): # TOOD: use bknp.put if we can
76
72
  r = r.tolist()
77
73
  for i in b[1:]:
78
74
  r[i] = b[0]
79
- r = kg_asarray(r)
75
+ r = backend.kg_asarray(r)
80
76
  else:
81
- np.put(r, np.asarray(b[1:],dtype=int), b[0])
77
+ numpy.put(r, numpy.asarray(b[1:],dtype=int), b[0])
82
78
  return r
83
79
 
84
80
 
85
81
  def _e_dyad_amend_in_depth(p, q, v):
86
- if np.isarray(q) and len(q) > 1:
82
+ if bknp.isarray(q) and len(q) > 1:
87
83
  r = _e_dyad_amend_in_depth(p[q[0]], q[1:] if len(q) > 2 else q[1], v)
88
- p = np.array(p, dtype=r.dtype)
84
+ p = bknp.array(p, dtype=r.dtype)
89
85
  p[q[0]] = r
90
86
  return p
91
87
  else:
92
- p = np.array(p, dtype=object) if isinstance(v, (str, KGSym)) else np.array(p)
88
+ p = bknp.array(p, dtype=object) if isinstance(v, (str, KGSym)) else bknp.array(p)
93
89
  p[q] = v
94
90
  return p
95
91
 
@@ -111,7 +107,7 @@ def eval_dyad_amend_in_depth(a, b):
111
107
  return _e_dyad_amend_in_depth(a, b[1:], b[0])
112
108
 
113
109
 
114
- def eval_dyad_cut(a, b):
110
+ def eval_dyad_cut(a, b, backend):
115
111
  """
116
112
 
117
113
  a:_b [Cut]
@@ -133,12 +129,12 @@ def eval_dyad_cut(a, b):
133
129
 
134
130
  """
135
131
  j = isinstance(b, str)
136
- b = np.asarray(str_to_chr_arr(b) if j else b)
137
- a = a if np.isarray(a) else [a]
138
- r = np.array_split(b, a)
132
+ b = bknp.asarray(backend.str_to_chr_arr(b) if j else b)
133
+ a = a if bknp.isarray(a) else [a]
134
+ r = bknp.array_split(b, a)
139
135
  if len(b) == 0 and len(a) > 0:
140
136
  r = r[1:]
141
- return np.asarray(["".join(x) for x in r]) if j else kg_asarray(r)
137
+ return bknp.asarray(["".join(x) for x in r]) if j else backend.kg_asarray(r)
142
138
 
143
139
 
144
140
  def eval_dyad_at_index(klong, a, b):
@@ -172,24 +168,25 @@ def eval_dyad_at_index(klong, a, b):
172
168
 
173
169
  """
174
170
  if isinstance(a, (KGFn, KGSym)) or issubclass(type(a), KGLambda):
175
- b = [x for x in b] if np.isarray(b) else b
171
+ b = [x for x in b] if klong._backend.is_array(b) else b
176
172
  return klong.eval(KGCall(a, b, arity=1))
173
+ backend = klong._backend
177
174
  j = isinstance(a,str)
178
- a = str_to_chr_arr(a) if j else a
175
+ a = backend.str_to_chr_arr(a) if j else a
179
176
  if is_list(b):
180
177
  if is_empty(b):
181
- r = np.asarray([])
178
+ r = bknp.asarray([])
182
179
  else:
183
180
  # TODO: return None for missing keys? or raise?
184
- r = kg_asarray([a[x] for x in b])
185
- elif is_integer(b):
181
+ r = backend.kg_asarray([a[x] for x in b])
182
+ elif backend.is_integer(b):
186
183
  r = a[b]
187
184
  j = False
188
185
  else:
189
186
  r = a
190
187
  if j:
191
- if np.isarray(r) and r.ndim > 1:
192
- return np.asarray(["".join(x) for x in r], dtype=object)
188
+ if bknp.isarray(r) and r.ndim > 1:
189
+ return bknp.asarray(["".join(x) for x in r], dtype=object)
193
190
  return "".join(r)
194
191
  return r
195
192
 
@@ -214,7 +211,7 @@ def eval_dyad_define(klong, n, v):
214
211
  return v
215
212
 
216
213
 
217
- def eval_dyad_divide(a, b):
214
+ def eval_dyad_divide(a, b, backend):
218
215
  """
219
216
 
220
217
  a%b [Divide]
@@ -228,7 +225,11 @@ def eval_dyad_divide(a, b):
228
225
  10%8 --> 1.25
229
226
 
230
227
  """
231
- return np.divide(a, b)
228
+ if not is_list(a) and not is_list(b) and backend.is_number(b):
229
+ b_val = backend.scalar_to_python(b) if backend.is_backend_array(b) or (hasattr(b, 'ndim') and b.ndim == 0) else b
230
+ if b_val == 0:
231
+ return KLONG_UNDEFINED
232
+ return backend.np.divide(a, b)
232
233
 
233
234
 
234
235
  def eval_dyad_drop(a, b):
@@ -260,12 +261,12 @@ def eval_dyad_drop(a, b):
260
261
  return b[a:] if a >= 0 else b[:a]
261
262
 
262
263
 
263
- def _safe_equal(x, y):
264
- """Compare two values, handling torch tensors correctly."""
265
- return kg_truth(safe_equal(x, y))
264
+ def _safe_equal(x, y, backend):
265
+ """Compare two values, handling backend arrays correctly."""
266
+ return kg_truth(backend.safe_equal(x, y))
266
267
 
267
268
 
268
- def eval_dyad_equal(a, b):
269
+ def eval_dyad_equal(a, b, backend):
269
270
  """
270
271
 
271
272
  a=b [Equal]
@@ -290,7 +291,7 @@ def eval_dyad_equal(a, b):
290
291
  [1 2 3]=[1 4 3] --> [1 0 1]
291
292
 
292
293
  """
293
- return vec_fn2(a, b, _safe_equal)
294
+ return backend.vec_fn2(a, b, lambda x, y: _safe_equal(x, y, backend))
294
295
 
295
296
 
296
297
  def finditer(s, sub):
@@ -303,7 +304,7 @@ def finditer(s, sub):
303
304
  i += 1
304
305
 
305
306
 
306
- def eval_dyad_find(a, b):
307
+ def eval_dyad_find(a, b, backend):
307
308
  """
308
309
 
309
310
  a?b [Find]
@@ -332,44 +333,44 @@ def eval_dyad_find(a, b):
332
333
 
333
334
  """
334
335
  if isinstance(a,str):
335
- return np.asarray(list(finditer(a,str(b))))
336
+ return bknp.asarray(list(finditer(a,str(b))))
336
337
  elif is_dict(a):
337
338
  v = a.get(b)
338
- return np.inf if v is None else v
339
+ return KLONG_UNDEFINED if v is None else v
339
340
  if is_list(b):
340
- return np.asarray([i for i,x in enumerate(a) if kg_equal(x,b)])
341
- return np.where(np.asarray(a) == b)[0]
341
+ return bknp.asarray([i for i,x in enumerate(a) if backend.kg_equal(x, b)])
342
+ return bknp.where(bknp.asarray(a) == b)[0]
342
343
 
343
344
 
344
- def __e_dyad_form(a, b):
345
+ def __e_dyad_form(a, b, backend):
345
346
  if isinstance(a,KGSym):
346
347
  if is_empty(b):
347
- return np.inf
348
+ return KLONG_UNDEFINED
348
349
  return KGSym(b[1:] if isinstance(b,str) and b.startswith(":") else b)
349
- if is_integer(a):
350
- if is_float(b) or is_empty(b) or ('.' in b and str_is_float(b)):
351
- return np.inf
350
+ if backend.is_integer(a):
351
+ if backend.is_float(b) or is_empty(b) or ('.' in b and str_is_float(b)):
352
+ return KLONG_UNDEFINED
352
353
  return int(b)
353
- if is_float(a):
354
+ if backend.is_float(a):
354
355
  if is_empty(b):
355
- return np.inf
356
+ return KLONG_UNDEFINED
356
357
  return float(b)
357
358
  if isinstance(a,KGChar):
358
359
  b = str(b)
359
360
  if len(b) != 1:
360
- return np.inf
361
+ return KLONG_UNDEFINED
361
362
  return KGChar(str(b)[0])
362
363
  return b
363
364
 
364
- def _e_dyad_form(a, b):
365
+ def _e_dyad_form(a, b, backend):
365
366
  """
366
367
  Unravel the broadcasting of a and b and apply __e_dyad_form
367
368
  """
368
- if np.isarray(a) and np.isarray(b):
369
- return np.asarray([vec_fn2(x,y,_e_dyad_form) for x,y in zip(a,b)])
370
- return __e_dyad_form(a,b)
369
+ if bknp.isarray(a) and bknp.isarray(b):
370
+ return bknp.asarray([backend.vec_fn2(x, y, lambda x, y: _e_dyad_form(x, y, backend)) for x,y in zip(a,b)])
371
+ return __e_dyad_form(a, b, backend)
371
372
 
372
- def eval_dyad_form(a, b):
373
+ def eval_dyad_form(a, b, backend):
373
374
  """
374
375
 
375
376
  a:$b [Form]
@@ -396,13 +397,17 @@ def eval_dyad_form(a, b):
396
397
  :x:$":symbol" --> :symbol
397
398
 
398
399
  """
399
- return vec_fn2(a, b, _e_dyad_form)
400
+ return backend.vec_fn2(a, b, lambda x, y: _e_dyad_form(x, y, backend))
400
401
 
401
402
 
402
- def __e_dyad_format2(a, b):
403
+ def __e_dyad_format2(a, b, backend):
404
+ if hasattr(a, 'ndim') and a.ndim == 0:
405
+ a = a.item()
406
+ if hasattr(b, 'ndim') and b.ndim == 0:
407
+ b = b.item()
403
408
  if safe_eq(int(a), 0):
404
409
  return str(b)
405
- if (is_float(b) and not isinstance(b,int)) and (is_float(a) and not isinstance(a,int)):
410
+ if (backend.is_float(b) and not isinstance(b,int)) and (backend.is_float(a) and not isinstance(a,int)):
406
411
  b = "{:Xf}".replace("X",str(a)).format(b)
407
412
  p = b.split('.')
408
413
  p[0] = p[0].rjust(int(a))
@@ -412,17 +417,17 @@ def __e_dyad_format2(a, b):
412
417
  r = str(b).ljust(abs(a)) if a >= 0 else str(b).rjust(abs(a))
413
418
  return r
414
419
 
415
- def _e_dyad_format2(a, b):
420
+ def _e_dyad_format2(a, b, backend):
416
421
  """
417
422
  Unravel the broadcasting of a and b and apply __e_dyad_format2
418
423
  """
419
424
  if is_list(a) and is_list(b):
420
- return kg_asarray([vec_fn2(x, y, _e_dyad_format2) for x, y in zip(to_list(a), to_list(b))])
421
- if np.isarray(a) and np.isarray(b):
422
- return np.asarray([vec_fn2(x, y, _e_dyad_format2) for x, y in zip(a, b)])
423
- return __e_dyad_format2(a, b)
425
+ return backend.kg_asarray([backend.vec_fn2(x, y, lambda x, y: _e_dyad_format2(x, y, backend)) for x, y in zip(to_list(a), to_list(b))])
426
+ if backend.np.isarray(a) and backend.np.isarray(b):
427
+ return backend.np.asarray([backend.vec_fn2(x, y, lambda x, y: _e_dyad_format2(x, y, backend)) for x, y in zip(a, b)])
428
+ return __e_dyad_format2(a, b, backend)
424
429
 
425
- def eval_dyad_format2(a, b):
430
+ def eval_dyad_format2(a, b, backend):
426
431
  """
427
432
 
428
433
  a$b [Format2]
@@ -447,7 +452,7 @@ def eval_dyad_format2(a, b):
447
452
  5.3$123.45 --> " 123.450"
448
453
 
449
454
  """
450
- return vec_fn2(a, b, _e_dyad_format2)
455
+ return backend.vec_fn2(a, b, lambda x, y: _e_dyad_format2(x, y, backend))
451
456
 
452
457
 
453
458
  def eval_dyad_index_in_depth(a, b):
@@ -467,15 +472,16 @@ def eval_dyad_index_in_depth(a, b):
467
472
  {y+x*x}:@[2 3] --> 7
468
473
 
469
474
  """
470
- return np.asarray(a)[tuple(b) if is_list(b) else b] if not is_empty(b) else b
475
+ return bknp.asarray(a)[tuple(b) if is_list(b) else b] if not is_empty(b) else b
471
476
 
472
477
 
473
- def _e_dyad_integer_divide(x, y):
474
- a = np.divide(x, y)
475
- a = kg_asarray(rec_fn(a, np.trunc)) if np.isarray(a) else a
476
- return to_int_array(a)
478
+ def _e_dyad_integer_divide(x, y, backend):
479
+ np_backend = backend.np
480
+ a = np_backend.divide(x, y)
481
+ a = backend.kg_asarray(backend.rec_fn(a, np_backend.trunc)) if np_backend.isarray(a) else a
482
+ return backend.to_int_array(a)
477
483
 
478
- def eval_dyad_integer_divide(a, b):
484
+ def eval_dyad_integer_divide(a, b, backend):
479
485
  """
480
486
 
481
487
  a:%b [Integer-Divide]
@@ -491,7 +497,11 @@ def eval_dyad_integer_divide(a, b):
491
497
  10:%8 --> 1
492
498
 
493
499
  """
494
- return vec_fn2(a, b, _e_dyad_integer_divide)
500
+ if not is_list(a) and not is_list(b) and backend.is_number(b):
501
+ b_val = backend.scalar_to_python(b) if backend.is_backend_array(b) or (hasattr(b, 'ndim') and b.ndim == 0) else b
502
+ if b_val == 0:
503
+ return KLONG_UNDEFINED
504
+ return backend.vec_fn2(a, b, lambda x, y: _e_dyad_integer_divide(x, y, backend))
495
505
 
496
506
 
497
507
  def _arr_to_list(a):
@@ -499,7 +509,7 @@ def _arr_to_list(a):
499
509
  return a if is_list(a) else [a]# if not is_list(a) else a
500
510
 
501
511
 
502
- def eval_dyad_join(a, b):
512
+ def eval_dyad_join(a, b, backend):
503
513
  """
504
514
 
505
515
  a,b [Join]
@@ -550,7 +560,7 @@ def eval_dyad_join(a, b):
550
560
  b[a[0]] = a[1]
551
561
  return b
552
562
 
553
- if np.isarray(a) and np.isarray(b):
563
+ if bknp.isarray(a) and bknp.isarray(b):
554
564
  # Only use fast path for 1D+ arrays (not 0D scalars)
555
565
  a_is_1d_plus = hasattr(a, 'ndim') and a.ndim >= 1
556
566
  b_is_1d_plus = hasattr(b, 'ndim') and b.ndim >= 1
@@ -558,24 +568,24 @@ def eval_dyad_join(a, b):
558
568
  if len(a) == 0:
559
569
  return b
560
570
  if len(a.shape) == len(b.shape) and a.shape[-1] == b.shape[-1]:
561
- return np.concatenate((a,b))
571
+ return bknp.concatenate((a,b))
562
572
 
563
573
  aa = _arr_to_list(a)
564
574
  bb = _arr_to_list(b)
565
575
 
566
576
  r = [*aa,*bb]
567
- nr = kg_asarray(r)
568
- # Check dtype kind for compatibility with both numpy and torch
569
- dtype_kind = get_dtype_kind(nr)
577
+ nr = backend.kg_asarray(r)
578
+ # Check dtype kind for compatibility across backends
579
+ dtype_kind = backend.get_dtype_kind(nr)
570
580
  if dtype_kind in ('i', 'f', 'u'):
571
581
  return nr
572
- # Use numpy directly for object arrays (torch backend doesn't support object dtype)
573
- # Convert any torch tensors to numpy first (needed for MPS tensors)
574
- r_numpy = [to_numpy(x) if np.isarray(x) else x for x in r]
582
+ # Use numpy directly for object arrays (backends without object dtype need this)
583
+ # Convert backend arrays to numpy first (needed for device-backed arrays)
584
+ r_numpy = [backend.to_numpy(x) if backend.is_array(x) else x for x in r]
575
585
  return numpy.asarray(r_numpy, dtype=object)
576
586
 
577
587
 
578
- def eval_dyad_less(a, b):
588
+ def eval_dyad_less(a, b, backend):
579
589
  """
580
590
 
581
591
  a<b [Less]
@@ -596,10 +606,10 @@ def eval_dyad_less(a, b):
596
606
  [1 2 3]<[1 4 3] --> [0 1 0]
597
607
 
598
608
  """
599
- return kg_truth(vec_fn2(a, b, lambda x,y: x < y if (isinstance(x,str) and isinstance(y,str)) else np.less(x,y)))
609
+ return kg_truth(backend.vec_fn2(a, b, lambda x,y: x < y if (isinstance(x,str) and isinstance(y,str)) else backend.np.less(x,y)))
600
610
 
601
611
 
602
- def eval_dyad_match(a,b):
612
+ def eval_dyad_match(a, b, backend):
603
613
  """
604
614
 
605
615
  a~b [Match]
@@ -633,10 +643,10 @@ def eval_dyad_match(a,b):
633
643
  [1 [2] 3]~[1 [4] 3] --> 0
634
644
 
635
645
  """
636
- return kg_truth(kg_equal(a,b))
646
+ return kg_truth(backend.kg_equal(a, b))
637
647
 
638
648
 
639
- def eval_dyad_maximum(a, b):
649
+ def eval_dyad_maximum(a, b, backend):
640
650
  """
641
651
 
642
652
  a|b [Max/Or]
@@ -660,10 +670,10 @@ def eval_dyad_maximum(a, b):
660
670
  1.0|1.1 --> 1.1
661
671
 
662
672
  """
663
- return np.maximum(a, b)
673
+ return backend.np.maximum(a, b)
664
674
 
665
675
 
666
- def eval_dyad_minimum(a, b):
676
+ def eval_dyad_minimum(a, b, backend):
667
677
  """
668
678
 
669
679
  a&b [Min/And]
@@ -687,10 +697,10 @@ def eval_dyad_minimum(a, b):
687
697
  1.0&1.1 --> 1.0
688
698
 
689
699
  """
690
- return np.minimum(a, b)
700
+ return backend.np.minimum(a, b)
691
701
 
692
702
 
693
- def eval_dyad_more(a, b):
703
+ def eval_dyad_more(a, b, backend):
694
704
  """
695
705
 
696
706
  a>b [More]
@@ -709,10 +719,10 @@ def eval_dyad_more(a, b):
709
719
  [1 4 3]>[1 2 3] --> [0 1 0]
710
720
 
711
721
  """
712
- return kg_truth(vec_fn2(a, b, lambda x,y: x > y if (isinstance(x,str) and isinstance(y,str)) else np.greater(x,y)))
722
+ return kg_truth(backend.vec_fn2(a, b, lambda x,y: x > y if (isinstance(x,str) and isinstance(y,str)) else backend.np.greater(x,y)))
713
723
 
714
724
 
715
- def eval_dyad_multiply(a, b):
725
+ def eval_dyad_multiply(a, b, backend):
716
726
  """
717
727
 
718
728
  a*b [Times]
@@ -726,19 +736,19 @@ def eval_dyad_multiply(a, b):
726
736
  0.3*7 --> 2.1
727
737
 
728
738
  """
729
- return np.multiply(a, b)
739
+ return backend.np.multiply(a, b)
730
740
 
731
741
 
732
- def _e_dyad_power(a, b):
742
+ def _e_dyad_power(a, b, backend):
733
743
  # Check if input requires grad - if so, preserve float for autograd
734
- input_has_grad = has_gradient(a)
735
- # Use backend power function which handles torch.pow for gradients
736
- r = backend_power(a, b)
744
+ input_has_grad = backend.has_gradient(a)
745
+ # Use backend power function which handles gradient-aware power
746
+ r = backend.power(a, b)
737
747
  # If input had gradients, keep result as float to preserve autograd
738
748
  if input_has_grad:
739
749
  return r
740
750
  # Check if result is integer using vectorized operations
741
- r_val = detach_if_needed(r)
751
+ r_val = backend.detach_if_needed(r)
742
752
  if is_list(r_val):
743
753
  # Vectorized check: trunc(r) == r for all elements
744
754
  trunc_r = numpy.trunc(r_val) if isinstance(r_val, numpy.ndarray) else r_val.trunc()
@@ -747,10 +757,10 @@ def _e_dyad_power(a, b):
747
757
  val = float(r_val) if hasattr(r_val, 'item') else r_val
748
758
  br = numpy.trunc(val) == val
749
759
  if br:
750
- return to_int_array(r)
760
+ return backend.to_int_array(r)
751
761
  return r
752
762
 
753
- def eval_dyad_power(a, b):
763
+ def eval_dyad_power(a, b, backend):
754
764
  """
755
765
 
756
766
  a^b [Power]
@@ -769,10 +779,10 @@ def eval_dyad_power(a, b):
769
779
  2^0.5 --> 1.41421356237309504
770
780
 
771
781
  """
772
- return vec_fn2(a, b, _e_dyad_power)
782
+ return backend.vec_fn2(a, b, lambda x, y: _e_dyad_power(x, y, backend))
773
783
 
774
784
 
775
- def eval_dyad_remainder(a, b):
785
+ def eval_dyad_remainder(a, b, backend):
776
786
  """
777
787
 
778
788
  a!b [Remainder]
@@ -790,10 +800,10 @@ def eval_dyad_remainder(a, b):
790
800
  -7!-5 --> -2
791
801
 
792
802
  """
793
- return np.fmod(a, b)
803
+ return backend.np.fmod(a, b)
794
804
 
795
805
 
796
- def eval_dyad_reshape(a, b):
806
+ def eval_dyad_reshape(a, b, backend):
797
807
  """
798
808
 
799
809
  a:^b [Reshape]
@@ -840,51 +850,52 @@ def eval_dyad_reshape(a, b):
840
850
  [2]:^[[1 2 3]] --> [[1 2 3] [1 2 3]]
841
851
 
842
852
  """
853
+ np_backend = backend.np
843
854
  j = isinstance(b, str)
844
- b = str_to_chr_arr(b) if j else b
845
- if np.isarray(a):
846
- if np.isarray(b):
847
- y = np.where(a < 0)[0]
855
+ b = backend.str_to_chr_arr(b) if j else b
856
+ if np_backend.isarray(a):
857
+ if np_backend.isarray(b):
858
+ y = np_backend.where(a < 0)[0]
848
859
  if len(y) > 0:
849
- a = np.copy(a)
850
- a[y] = array_size(b) // 2
851
- b_s = array_size(b)
852
- a_s = int(np.prod(a)) # Ensure it's a Python int for comparison
853
- # Convert shape to tuple of ints for torch compatibility
860
+ a = np_backend.copy(a)
861
+ a[y] = backend.array_size(b) // 2
862
+ b_s = backend.array_size(b)
863
+ a_s = int(np_backend.prod(a)) # Ensure it's a Python int for comparison
864
+ # Convert shape to tuple of ints for backend compatibility
854
865
  a_shape = tuple(int(x) for x in (a.tolist() if hasattr(a, 'tolist') else a))
855
866
  if a_s > b_s:
856
- b = np.tile(b.flatten(), (a_s // b_s))
857
- b = np.concatenate((b, b[:a_s - array_size(b)]))
858
- b_s = array_size(b)
867
+ b = np_backend.tile(b.flatten(), (a_s // b_s))
868
+ b = np_backend.concatenate((b, b[:a_s - backend.array_size(b)]))
869
+ b_s = backend.array_size(b)
859
870
  r = b.reshape(a_shape)
860
- r = np.asarray(["".join(x) for x in r]) if j else r
871
+ r = np_backend.asarray(["".join(x) for x in r]) if j else r
861
872
  j = False
862
873
  elif a_s == b_s:
863
874
  r = b.reshape(a_shape)
864
875
  else:
865
- r = np.resize(b, a_shape)
876
+ r = np_backend.resize(b, a_shape)
866
877
  else:
867
- r = np.full(a, b)
878
+ r = np_backend.full(a, b)
868
879
  else:
869
880
  if a == 0:
870
881
  r = b
871
- elif np.isarray(b):
882
+ elif np_backend.isarray(b):
872
883
  if a < b.shape[0]:
873
- r = np.resize(b, (a,))
884
+ r = np_backend.resize(b, (a,))
874
885
  else:
875
- ns = np.ones(len(b.shape),dtype=int)
886
+ ns = np_backend.ones(len(b.shape),dtype=int)
876
887
  ns[0] = a // b.shape[0]
877
- r = np.concatenate((np.tile(b,ns), b[:a - b.shape[0]*ns[0]]))
888
+ r = np_backend.concatenate((np_backend.tile(b,ns), b[:a - b.shape[0]*ns[0]]))
878
889
  else:
879
- r = np.full((a,), b)
890
+ r = np_backend.full((a,), b)
880
891
  if j:
881
- if np.isarray(r) and r.ndim > 1:
882
- return np.asarray(["".join(x) for x in r], dtype=object)
892
+ if np_backend.isarray(r) and r.ndim > 1:
893
+ return np_backend.asarray(["".join(x) for x in r], dtype=object)
883
894
  return "".join(r)
884
895
  return r
885
896
 
886
897
 
887
- def eval_dyad_rotate(a, b):
898
+ def eval_dyad_rotate(a, b, backend):
888
899
  """
889
900
 
890
901
  a:+b [Rotate]
@@ -911,12 +922,12 @@ def eval_dyad_rotate(a, b):
911
922
  if a == 0 or not is_iterable(b):
912
923
  return b
913
924
  j = isinstance(b, str)
914
- b = str_to_chr_arr(b) if j else b
915
- r = np.roll(b, a)
925
+ b = backend.str_to_chr_arr(b) if j else b
926
+ r = bknp.roll(b, a)
916
927
  return "".join(r) if j else r
917
928
 
918
929
 
919
- def eval_dyad_split(a, b):
930
+ def eval_dyad_split(a, b, backend):
920
931
  """
921
932
 
922
933
  a:#b [Split]
@@ -935,12 +946,12 @@ def eval_dyad_split(a, b):
935
946
 
936
947
  """
937
948
  if len(b) == 0:
938
- return np.asarray([])
949
+ return bknp.asarray([])
939
950
 
940
951
  j = isinstance(b, str)
941
- b = str_to_chr_arr(b) if j else b
952
+ b = backend.str_to_chr_arr(b) if j else b
942
953
 
943
- a = a if np.isarray(a) else [a]
954
+ a = a if bknp.isarray(a) else [a]
944
955
  if len(a) == 1:
945
956
  if a[0] >= len(b):
946
957
  r = [b]
@@ -948,7 +959,7 @@ def eval_dyad_split(a, b):
948
959
  k = len(b) // a[0]
949
960
  if (k*a[0]) < len(b):
950
961
  k += 1
951
- r = np.array_split(b, k)
962
+ r = bknp.array_split(b, k)
952
963
  else:
953
964
  p, q = 0, 0
954
965
  r = []
@@ -959,10 +970,10 @@ def eval_dyad_split(a, b):
959
970
  if p >= len(a):
960
971
  p = 0
961
972
 
962
- return np.asarray(["".join(x) for x in r],dtype=object) if j else kg_asarray(r)
973
+ return bknp.asarray(["".join(x) for x in r],dtype=object) if j else backend.kg_asarray(r)
963
974
 
964
975
 
965
- def eval_dyad_subtract(a, b):
976
+ def eval_dyad_subtract(a, b, backend):
966
977
  """
967
978
 
968
979
  a-b [Minus]
@@ -977,10 +988,10 @@ def eval_dyad_subtract(a, b):
977
988
  1-0.3 --> 0.7
978
989
 
979
990
  """
980
- return np.subtract(a, b)
991
+ return backend.np.subtract(a, b)
981
992
 
982
993
 
983
- def eval_dyad_take(a, b):
994
+ def eval_dyad_take(a, b, backend):
984
995
  """
985
996
 
986
997
  a#b [Take]
@@ -1003,16 +1014,18 @@ def eval_dyad_take(a, b):
1003
1014
  0#"" --> ""
1004
1015
 
1005
1016
  """
1017
+ np_backend = backend.np
1006
1018
  j = isinstance(b,str)
1007
- b = str_to_chr_arr(b) if j else np.asarray(b)
1008
- aa = int(np.abs(a)) if hasattr(np.abs(a), 'item') else np.abs(a) # Convert tensor to int
1009
- b_size = array_size(b)
1019
+ b = backend.str_to_chr_arr(b) if j else np_backend.asarray(b)
1020
+ abs_a = np_backend.abs(a)
1021
+ aa = int(abs_a) if hasattr(abs_a, 'item') else abs_a # Convert tensor to int
1022
+ b_size = backend.array_size(b)
1010
1023
  if b_size == 0:
1011
1024
  # Handle empty array/string case
1012
1025
  r = b
1013
1026
  elif aa > b_size:
1014
- b = np.tile(b, aa // len(b))
1015
- b = np.concatenate((b, b[:aa-array_size(b)]) if a > 0 else (b[-(aa-array_size(b)):], b))
1027
+ b = np_backend.tile(b, aa // len(b))
1028
+ b = np_backend.concatenate((b, b[:aa-backend.array_size(b)]) if a > 0 else (b[-(aa-backend.array_size(b)):], b))
1016
1029
  r = b[a:] if a < 0 else b[:a]
1017
1030
  else:
1018
1031
  r = b[a:] if a < 0 else b[:a]
@@ -1106,21 +1119,49 @@ def eval_dyad_autograd(klong, a, b):
1106
1119
 
1107
1120
 
1108
1121
  def create_dyad_functions(klong):
1109
- def _get_name(s):
1110
- s = s.strip()
1111
- i = s.index("a")
1112
- return s[i+1:i+s.index('b')]
1113
-
1114
- registry = {}
1115
-
1116
- m = sys.modules[__name__]
1117
- for name in filter(lambda n: n.startswith("eval_dyad_"), dir(m)):
1118
- fn = getattr(m,name)
1119
- name = _get_name(fn.__doc__)
1120
- if fn.__code__.co_argcount == 3:
1121
- fn = lambda x,y,f=fn,klong=klong: f(klong, x, y)
1122
- elif fn.__code__.co_argcount == 2 and 'klong' in fn.__code__.co_varnames:
1123
- fn = lambda x,f=fn,klong=klong: f(klong, x)
1124
- registry[name] = fn
1125
-
1126
- return registry
1122
+ backend = klong._backend
1123
+
1124
+ # Simple dyads that don't need backend or klong
1125
+ simple = {
1126
+ ':-': eval_dyad_amend_in_depth,
1127
+ '_': eval_dyad_drop,
1128
+ ':@': eval_dyad_index_in_depth,
1129
+ }
1130
+
1131
+ # Dyads needing backend
1132
+ backend_dyads = {
1133
+ '+': lambda a, b: eval_dyad_add(a, b, backend),
1134
+ '|': lambda a, b: eval_dyad_maximum(a, b, backend),
1135
+ '&': lambda a, b: eval_dyad_minimum(a, b, backend),
1136
+ '!': lambda a, b: eval_dyad_remainder(a, b, backend),
1137
+ '%': lambda a, b: eval_dyad_divide(a, b, backend),
1138
+ '*': lambda a, b: eval_dyad_multiply(a, b, backend),
1139
+ '-': lambda a, b: eval_dyad_subtract(a, b, backend),
1140
+ ':=': lambda a, b: eval_dyad_amend(a, b, backend),
1141
+ ':_': lambda a, b: eval_dyad_cut(a, b, backend),
1142
+ '=': lambda a, b: eval_dyad_equal(a, b, backend),
1143
+ '?': lambda a, b: eval_dyad_find(a, b, backend),
1144
+ ':$': lambda a, b: eval_dyad_form(a, b, backend),
1145
+ '$': lambda a, b: eval_dyad_format2(a, b, backend),
1146
+ ':%': lambda a, b: eval_dyad_integer_divide(a, b, backend),
1147
+ ',': lambda a, b: eval_dyad_join(a, b, backend),
1148
+ '<': lambda a, b: eval_dyad_less(a, b, backend),
1149
+ '~': lambda a, b: eval_dyad_match(a, b, backend),
1150
+ '>': lambda a, b: eval_dyad_more(a, b, backend),
1151
+ '^': lambda a, b: eval_dyad_power(a, b, backend),
1152
+ ':^': lambda a, b: eval_dyad_reshape(a, b, backend),
1153
+ ':+': lambda a, b: eval_dyad_rotate(a, b, backend),
1154
+ ':#': lambda a, b: eval_dyad_split(a, b, backend),
1155
+ '#': lambda a, b: eval_dyad_take(a, b, backend),
1156
+ }
1157
+
1158
+ # Dyads needing klong
1159
+ klong_dyads = {
1160
+ '@': lambda a, b: eval_dyad_at_index(klong, a, b),
1161
+ '::': lambda a, b: eval_dyad_define(klong, a, b),
1162
+ '∇': lambda a, b: eval_dyad_grad(klong, a, b),
1163
+ '∂': lambda a, b: eval_dyad_jacobian(klong, a, b),
1164
+ ':>': lambda a, b: eval_dyad_autograd(klong, a, b),
1165
+ }
1166
+
1167
+ return {**simple, **backend_dyads, **klong_dyads}