klongpy 0.7.0__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
klongpy/backends/base.py CHANGED
@@ -5,6 +5,7 @@ All backends must implement the BackendProvider interface to ensure
5
5
  consistent behavior across numpy, torch, and any future backends.
6
6
  """
7
7
  from abc import ABC, abstractmethod
8
+ import numpy as np
8
9
 
9
10
 
10
11
  def is_jagged_array(x):
@@ -142,14 +143,12 @@ class BackendProvider(ABC):
142
143
 
143
144
  def is_integer(self, x) -> bool:
144
145
  """Check if x is an integer type (scalar, numpy integer, or 0-dim integer tensor)."""
145
- import numpy as np
146
146
  if issubclass(type(x), (int, np.integer)):
147
147
  return True
148
148
  return self.is_scalar_integer(x)
149
149
 
150
150
  def is_float(self, x) -> bool:
151
151
  """Check if x is a float type (scalar, numpy float, int, or 0-dim float tensor)."""
152
- import numpy as np
153
152
  if issubclass(type(x), (float, np.floating, int)):
154
153
  return True
155
154
  return self.is_scalar_float(x) or self.is_scalar_integer(x)
@@ -180,7 +179,6 @@ class BackendProvider(ABC):
180
179
 
181
180
  Returns a truth value (0 or 1) suitable for Klong.
182
181
  """
183
- import numpy as np
184
182
  return np.asarray(x, dtype=object) == np.asarray(y, dtype=object)
185
183
 
186
184
  def detach_if_needed(self, x):
@@ -195,16 +193,21 @@ class BackendProvider(ABC):
195
193
  """
196
194
  Convert array to integer type.
197
195
  """
198
- import numpy as np
199
196
  return np.asarray(a, dtype=int) if self.is_array(a) else int(a)
200
197
 
198
+ def floor_to_int(self, a):
199
+ """
200
+ Floor a value and convert to integer.
201
+ """
202
+ result = np.floor(np.asarray(a, dtype=float))
203
+ return result.astype(int) if hasattr(result, 'astype') else int(result)
204
+
201
205
  def power(self, a, b):
202
206
  """
203
207
  Compute a^b, handling gradient tracking if applicable.
204
208
 
205
209
  Returns integer result if the result is a whole number.
206
210
  """
207
- import numpy as np
208
211
  r = np.power(float(a) if isinstance(a, (int, np.integer)) else a, b)
209
212
  return r
210
213
 
@@ -216,6 +219,10 @@ class BackendProvider(ABC):
216
219
  """Whether this backend supports automatic differentiation."""
217
220
  return False
218
221
 
222
+ def array_equal(self, a, b) -> bool:
223
+ """Backend-native exact equality for arrays/tensors."""
224
+ return bool(np.array_equal(a, b))
225
+
219
226
  def create_grad_tensor(self, x):
220
227
  """Create a tensor that tracks gradients. Raises if not supported."""
221
228
  raise NotImplementedError("This backend does not support autograd")
@@ -315,6 +322,148 @@ class BackendProvider(ABC):
315
322
  )
316
323
 
317
324
 
325
+ def kg_equal(self, a, b):
326
+ """Compare two values or arrays for equality, handling nested arrays and tensors."""
327
+ if a is b:
328
+ return True
329
+
330
+ # Backend-native comparison for backend arrays
331
+ if self.is_backend_array(a) and self.is_backend_array(b):
332
+ return self.array_equal(a, b)
333
+
334
+ # Fast path for numpy arrays (non-object)
335
+ if isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
336
+ if a.dtype != object and b.dtype != object:
337
+ return bool(np.array_equal(a, b))
338
+
339
+ # Convert backend arrays to numpy for mixed comparisons
340
+ if self.is_backend_array(a):
341
+ a = self.to_numpy(a)
342
+ if self.is_backend_array(b):
343
+ b = self.to_numpy(b)
344
+
345
+ # Fast path for numpy arrays (after any conversion)
346
+ if isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
347
+ if a.dtype != object and b.dtype != object:
348
+ return bool(np.array_equal(a, b))
349
+
350
+ # Normalize 0-d numpy arrays to scalars for mixed comparisons
351
+ if isinstance(a, np.ndarray) and a.ndim == 0:
352
+ a = a.item()
353
+ if isinstance(b, np.ndarray) and b.ndim == 0:
354
+ b = b.item()
355
+
356
+ # List/sequence comparison
357
+ a_is_seq = isinstance(a, (list, tuple)) or (isinstance(a, np.ndarray) and a.ndim > 0)
358
+ b_is_seq = isinstance(b, (list, tuple)) or (isinstance(b, np.ndarray) and b.ndim > 0)
359
+ if a_is_seq or b_is_seq:
360
+ if not (a_is_seq and b_is_seq):
361
+ return False
362
+ if isinstance(a, (list, tuple)) and isinstance(b, (list, tuple)):
363
+ def _is_int_scalar(x):
364
+ return isinstance(x, (int, bool, np.integer))
365
+ if len(a) == len(b) and len(a) >= 32 and all(_is_int_scalar(x) for x in a) and all(_is_int_scalar(y) for y in b):
366
+ return a == b
367
+ # Fast path for object numpy arrays when possible
368
+ if isinstance(a, np.ndarray) and isinstance(b, np.ndarray) and a.dtype == object and b.dtype == object:
369
+ if a.size >= 128:
370
+ try:
371
+ return bool(np.array_equal(a, b))
372
+ except Exception:
373
+ pass
374
+ if len(a) != len(b):
375
+ return False
376
+ return all(self.kg_equal(x, y) for x, y in zip(a, b))
377
+
378
+ # Numeric scalars: tolerant comparison
379
+ if self.is_number(a) and self.is_number(b):
380
+ result = np.isclose(a, b)
381
+ if hasattr(result, 'item'):
382
+ return bool(result.item())
383
+ return bool(result)
384
+
385
+ # Fallback: direct equality
386
+ result = a == b
387
+ if hasattr(result, 'all'):
388
+ return bool(result.all())
389
+ if hasattr(result, 'item'):
390
+ return bool(result.item())
391
+ return bool(result)
392
+
393
+ def vec_fn(self, a, f):
394
+ """
395
+ Apply function f to array a, with support for nested object arrays.
396
+ """
397
+ if self.np.isarray(a) and a.dtype == 'O':
398
+ result = [self.vec_fn(x, f) if self._is_list(x) else f(x) for x in a]
399
+ return np.asarray(result, dtype=object)
400
+ return f(a)
401
+
402
+ def vec_fn2(self, a, b, f):
403
+ """
404
+ Apply function f to elements of a and b, handling nested structures.
405
+ """
406
+ if self.np.isarray(a):
407
+ if a.dtype == 'O':
408
+ if self.np.isarray(b):
409
+ assert len(a) == len(b)
410
+ return self.kg_asarray([self.vec_fn2(x, y, f) for x, y in zip(a, b)])
411
+ else:
412
+ return self.kg_asarray([self.vec_fn2(x, b, f) for x in a])
413
+ elif self.np.isarray(b) and b.dtype == 'O':
414
+ assert len(a) == len(b)
415
+ return self.kg_asarray([self.vec_fn2(x, y, f) for x, y in zip(a, b)])
416
+ elif self.np.isarray(b) and b.dtype == 'O':
417
+ return self.kg_asarray([self.vec_fn2(a, x, f) for x in b])
418
+ return f(a, b)
419
+
420
+ def rec_fn(self, a, f):
421
+ """
422
+ Recursively apply function f to all elements of a nested structure.
423
+ """
424
+ return self.kg_asarray([self.rec_fn(x, f) for x in a]) if self._is_list(a) else f(a)
425
+
426
+ def _is_list(self, x):
427
+ """Check if x is a list-like structure (array or list, non-empty)."""
428
+ if isinstance(x, np.ndarray):
429
+ return x.size > 0
430
+ if isinstance(x, (list, tuple)):
431
+ return len(x) > 0
432
+ return False
433
+
434
+ @property
435
+ def device(self):
436
+ """Return the current device for this backend (e.g., 'cpu', 'cuda:0', 'mps')."""
437
+ return 'cpu'
438
+
439
+ def list_devices(self):
440
+ """
441
+ List available devices for this backend.
442
+
443
+ Returns:
444
+ list: List of available device names (e.g., ['cpu'], ['cpu', 'cuda:0', 'mps'])
445
+ """
446
+ return ['cpu']
447
+
448
+ def get_info(self):
449
+ """
450
+ Get comprehensive information about this backend.
451
+
452
+ Returns:
453
+ dict: Dictionary with backend name, current device, available devices,
454
+ and feature support flags.
455
+ """
456
+ return {
457
+ 'name': self.name,
458
+ 'device': self.device,
459
+ 'devices': self.list_devices(),
460
+ 'supports_float64': self.supports_float64(),
461
+ 'supports_strings': self.supports_strings(),
462
+ 'supports_object_dtype': self.supports_object_dtype(),
463
+ 'supports_autograd': self.supports_autograd(),
464
+ }
465
+
466
+
318
467
  class UnsupportedDtypeError(Exception):
319
468
  """Raised when an operation requires a dtype not supported by the backend."""
320
469
  pass
@@ -22,7 +22,8 @@ class NumpyBackendProvider(BackendProvider):
22
22
  """NumPy-based backend provider."""
23
23
 
24
24
  def __init__(self, device=None):
25
- # device parameter is ignored for numpy backend (accepted for API consistency)
25
+ if device is not None:
26
+ raise ValueError("Backend 'numpy' does not support device selection")
26
27
  self._np = np
27
28
  np.seterr(divide='ignore')
28
29
  warnings.filterwarnings("error", category=NumpyVisibleDeprecationWarning)
@@ -0,0 +1,76 @@
1
+ """
2
+ Backend registry for KlongPy.
3
+
4
+ Owns backend registration, lookup, and lazy torch loading.
5
+ """
6
+ import importlib
7
+ import importlib.util
8
+
9
+ from .base import BackendProvider
10
+ from .numpy_backend import NumpyBackendProvider
11
+
12
+ # Registry of available backends
13
+ _BACKENDS = {}
14
+
15
+ # Default backend name
16
+ _DEFAULT_BACKEND = 'numpy'
17
+
18
+ _TORCH_AVAILABLE = importlib.util.find_spec("torch") is not None
19
+ _TORCH_BACKEND_LOADED = False
20
+ TorchBackendProvider = None
21
+
22
+
23
+ def register_backend(name: str, provider_class):
24
+ """Register a backend provider class."""
25
+ _BACKENDS[name] = provider_class
26
+
27
+
28
+ def _load_torch_backend():
29
+ global _TORCH_BACKEND_LOADED, TorchBackendProvider
30
+ if _TORCH_BACKEND_LOADED or not _TORCH_AVAILABLE:
31
+ return
32
+ _torch_backend = importlib.import_module("klongpy.backends.torch_backend")
33
+ TorchBackendProvider = _torch_backend.TorchBackendProvider
34
+ register_backend('torch', TorchBackendProvider)
35
+ _TORCH_BACKEND_LOADED = True
36
+
37
+
38
+ def get_backend(name: str = None, **kwargs) -> BackendProvider:
39
+ """
40
+ Get a backend provider instance.
41
+
42
+ Parameters
43
+ ----------
44
+ name : str, optional
45
+ Backend name ('numpy' or 'torch'). If None, uses default.
46
+ **kwargs
47
+ Additional arguments passed to the backend provider constructor.
48
+
49
+ Returns
50
+ -------
51
+ BackendProvider
52
+ The backend provider instance.
53
+ """
54
+ if name is None:
55
+ name = _DEFAULT_BACKEND
56
+
57
+ if name == 'torch':
58
+ _load_torch_backend()
59
+
60
+ if name not in _BACKENDS:
61
+ available = ', '.join(_BACKENDS.keys())
62
+ raise ValueError(f"Unknown backend: '{name}'. Available: {available}")
63
+
64
+ return _BACKENDS[name](**kwargs)
65
+
66
+
67
+ def list_backends():
68
+ """Return list of available backend names."""
69
+ backends = list(_BACKENDS.keys())
70
+ if _TORCH_AVAILABLE and 'torch' not in backends:
71
+ backends.append('torch')
72
+ return backends
73
+
74
+
75
+ # Register built-in backends
76
+ register_backend('numpy', NumpyBackendProvider)
@@ -7,8 +7,10 @@ It does not support object dtype or string operations.
7
7
  import math
8
8
  import numpy
9
9
  import torch
10
+ import torch.autograd.functional as torch_autograd_functional
10
11
 
11
12
  from .base import BackendProvider, UnsupportedDtypeError, is_jagged_array
13
+ from ..autograd import AutogradChainBrokenError, NonScalarLossError, _invoke_fn
12
14
 
13
15
  # numpy 2.x moved VisibleDeprecationWarning to numpy.exceptions
14
16
  from numpy.exceptions import VisibleDeprecationWarning as NumpyVisibleDeprecationWarning
@@ -208,26 +210,26 @@ class TorchBackend:
208
210
  return torch.from_numpy(a).to(self.device)
209
211
  # Check if input is a list/tuple of tensors - use stack to preserve gradients
210
212
  if isinstance(a, (list, tuple)) and len(a) > 0 and all(isinstance(x, torch.Tensor) for x in a):
211
- # torch.stack preserves requires_grad, torch.tensor does not
212
213
  result = torch.stack(a)
213
214
  if result.device != self.device:
214
215
  result = result.to(self.device)
215
- # Handle float64 on MPS
216
216
  if result.dtype == torch.float64 and self.device.type == 'mps':
217
217
  result = result.to(torch.float32)
218
218
  return result
219
- # Check if input contains any arrays/tensors mixed with non-arrays - these need object dtype
220
- if isinstance(a, (list, tuple)) and len(a) > 0:
221
- has_array = any(isinstance(x, (torch.Tensor, numpy.ndarray, list)) for x in a)
222
- has_scalar = any(isinstance(x, (int, float)) and not isinstance(x, bool) for x in a)
223
- if has_array and has_scalar:
224
- # Mixed array/scalar list - can't represent in torch without losing structure
219
+ # For all other lists/tuples, convert via numpy (faster than torch.tensor for nested/mixed data)
220
+ if isinstance(a, (list, tuple)):
221
+ arr = numpy.asarray(a)
222
+ if arr.dtype == object:
225
223
  raise TorchUnsupportedDtypeError(
226
- "PyTorch backend cannot convert mixed array/scalar lists without losing structure."
224
+ "PyTorch backend does not support object dtype arrays."
227
225
  )
226
+ # Convert float64 to float32 to match torch.tensor's default behavior
227
+ if arr.dtype == numpy.float64:
228
+ arr = arr.astype(numpy.float32)
229
+ return torch.from_numpy(arr).to(self.device)
230
+ # Scalar or other type
228
231
  try:
229
232
  t = torch.tensor(a, device=self.device)
230
- # Handle float64 on MPS
231
233
  if t.dtype == torch.float64 and self.device.type == 'mps':
232
234
  t = t.to(torch.float32)
233
235
  return t
@@ -509,6 +511,21 @@ class TorchBackendProvider(BackendProvider):
509
511
  """PyTorch-based backend provider."""
510
512
 
511
513
  def __init__(self, device=None):
514
+ if device is not None:
515
+ try:
516
+ torch_device = torch.device(device)
517
+ except Exception as exc:
518
+ raise ValueError(f"Invalid torch device '{device}': {exc}")
519
+ if torch_device.type == 'cuda':
520
+ if not torch.cuda.is_available():
521
+ raise ValueError(f"Torch device '{device}' is not available (cuda not available)")
522
+ if torch_device.index is not None and torch_device.index >= torch.cuda.device_count():
523
+ raise ValueError(f"Torch device '{device}' is not available (device index out of range)")
524
+ if torch_device.type == 'mps':
525
+ if not (hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()):
526
+ raise ValueError(f"Torch device '{device}' is not available (mps not available)")
527
+ if torch_device.type not in {'cpu', 'cuda', 'mps'}:
528
+ raise ValueError(f"Torch device type '{torch_device.type}' is not supported")
512
529
  self._torch_backend = TorchBackend(device)
513
530
  self._device = device
514
531
 
@@ -524,6 +541,17 @@ class TorchBackendProvider(BackendProvider):
524
541
  def device(self):
525
542
  return self._torch_backend.device
526
543
 
544
+ def list_devices(self):
545
+ """List available torch devices (cpu, cuda, mps)."""
546
+ devices = ['cpu']
547
+ if torch.cuda.is_available():
548
+ devices.append('cuda')
549
+ for i in range(torch.cuda.device_count()):
550
+ devices.append(f'cuda:{i}')
551
+ if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
552
+ devices.append('mps')
553
+ return devices
554
+
527
555
  def supports_object_dtype(self) -> bool:
528
556
  return False
529
557
 
@@ -608,6 +636,16 @@ class TorchBackendProvider(BackendProvider):
608
636
  # Default numpy comparison
609
637
  return numpy.asarray(x, dtype=object) == numpy.asarray(y, dtype=object)
610
638
 
639
+ def array_equal(self, a, b) -> bool:
640
+ """Backend-native exact equality for torch tensors."""
641
+ if not isinstance(a, torch.Tensor) or not isinstance(b, torch.Tensor):
642
+ return False
643
+ try:
644
+ return bool(torch.equal(a, b))
645
+ except RuntimeError:
646
+ # Fall back to CPU comparison if devices mismatch
647
+ return bool(torch.equal(a.cpu(), b.cpu()))
648
+
611
649
  def detach_if_needed(self, x):
612
650
  """Detach tensor if it requires grad, to allow type conversions."""
613
651
  if isinstance(x, torch.Tensor) and x.requires_grad:
@@ -620,14 +658,28 @@ class TorchBackendProvider(BackendProvider):
620
658
  return a.to(int)
621
659
  return numpy.asarray(a, dtype=int) if isinstance(a, numpy.ndarray) else int(a)
622
660
 
661
+ def floor_to_int(self, a):
662
+ """Floor a value and convert to integer."""
663
+ if not isinstance(a, torch.Tensor):
664
+ a = self.kg_asarray(a)
665
+ return torch.floor(a.float()).to(int)
666
+
623
667
  def power(self, a, b):
624
668
  """Compute a^b, handling gradient tracking for torch tensors."""
625
- # Use torch.pow for tensors to maintain gradients when possible
626
669
  if isinstance(a, torch.Tensor):
670
+ # Handle negative exponents - torch doesn't support int^negative
671
+ if isinstance(b, torch.Tensor) and b.dtype in (torch.int8, torch.int16, torch.int32, torch.int64) and (b < 0).any():
672
+ base = a.float() if a.dtype in (torch.int8, torch.int16, torch.int32, torch.int64) else a
673
+ result = base.pow(b.abs()).float()
674
+ return torch.where(b < 0, 1.0 / result, result)
675
+ b_val = b.item() if isinstance(b, torch.Tensor) and b.ndim == 0 else b
676
+ if isinstance(b_val, (int, numpy.integer)) and b_val < 0:
677
+ a = a.float()
627
678
  return a.pow(b)
628
679
  # For numpy arrays or scalars
629
680
  a_val = float(a) if isinstance(a, (int, numpy.integer)) else a
630
- return numpy.power(a_val, b)
681
+ b_val = b.item() if isinstance(b, torch.Tensor) and b.ndim == 0 else (b.cpu().numpy() if isinstance(b, torch.Tensor) else b)
682
+ return numpy.power(a_val, b_val)
631
683
 
632
684
  def has_gradient(self, x) -> bool:
633
685
  """Check if x is tracking gradients."""
@@ -648,8 +700,6 @@ class TorchBackendProvider(BackendProvider):
648
700
 
649
701
  def compute_autograd(self, func, x):
650
702
  """Compute gradient using PyTorch automatic differentiation."""
651
- from ..autograd import AutogradChainBrokenError, NonScalarLossError
652
-
653
703
  x_tensor = self.create_grad_tensor(x)
654
704
 
655
705
  # Compute the function value
@@ -700,8 +750,6 @@ class TorchBackendProvider(BackendProvider):
700
750
  Returns:
701
751
  List of gradients, one per parameter
702
752
  """
703
- from ..autograd import AutogradChainBrokenError, NonScalarLossError
704
-
705
753
  # Create grad tensors for all parameters
706
754
  grad_tensors = [self.create_grad_tensor(p) for p in params]
707
755
 
@@ -744,12 +792,10 @@ class TorchBackendProvider(BackendProvider):
744
792
  Returns:
745
793
  Jacobian matrix J where J[i,j] = df_i/dx_j
746
794
  """
747
- import torch.autograd.functional as F
748
-
749
795
  x_tensor = self.create_grad_tensor(x)
750
796
 
751
797
  # torch.autograd.functional.jacobian expects func(inputs) -> outputs
752
- jacobian = F.jacobian(func, x_tensor)
798
+ jacobian = torch_autograd_functional.jacobian(func, x_tensor)
753
799
 
754
800
  return jacobian
755
801
 
@@ -812,7 +858,14 @@ class TorchBackendProvider(BackendProvider):
812
858
  _ = compiled_fn(example_input)
813
859
 
814
860
  if output_path is None:
815
- return compiled_fn
861
+ # Wrap with Klong-convention parameter name (x) so that
862
+ # KGLambda introspection binds the argument correctly when
863
+ # the compiled function is stored via ::
864
+ def klong_compiled(x):
865
+ if not isinstance(x, torch.Tensor):
866
+ x = self.create_grad_tensor(x)
867
+ return compiled_fn(x)
868
+ return klong_compiled
816
869
 
817
870
  # Export the function graph for inspection
818
871
  try:
@@ -920,11 +973,14 @@ class TorchBackendProvider(BackendProvider):
920
973
  Returns:
921
974
  1 if gradients are correct, raises error otherwise
922
975
  """
923
- from ..autograd import _invoke_fn
976
+ # Gradcheck requires float64 which is only supported on CPU
977
+ if self.device.type != 'cpu':
978
+ raise RuntimeError(
979
+ f".gradcheck() requires CPU device, got '{self.device.type}'. "
980
+ "Run with: kgpy --backend torch --device cpu"
981
+ )
924
982
 
925
- # Determine dtype based on device support
926
- use_float32 = self.device.type == 'mps' # MPS doesn't support float64
927
- dtype = torch.float32 if use_float32 else torch.float64
983
+ dtype = torch.float64
928
984
 
929
985
  # Wrap the Klong function
930
986
  def wrapped_fn(v):
@@ -934,19 +990,15 @@ class TorchBackendProvider(BackendProvider):
934
990
  result = result.sum()
935
991
  return result
936
992
 
937
- # Convert inputs to tensor on CPU for gradcheck (avoids MPS float64 issues)
993
+ # Convert inputs to tensor on CPU for gradcheck
938
994
  if isinstance(inputs, (list, tuple)) and not isinstance(inputs[0], torch.Tensor):
939
995
  tensor_inputs = torch.tensor(inputs, dtype=dtype, device='cpu', requires_grad=True)
940
996
  elif not isinstance(inputs, torch.Tensor):
941
997
  tensor_inputs = torch.tensor([inputs], dtype=dtype, device='cpu', requires_grad=True)
942
998
  else:
943
- tensor_inputs = inputs.to(dtype=dtype, device='cpu').requires_grad_(True)
999
+ tensor_inputs = inputs.detach().cpu().to(dtype=dtype).requires_grad_(True)
944
1000
 
945
- # Run gradcheck with adjusted tolerances for float32
946
- if use_float32:
947
- result = self.gradcheck(wrapped_fn, (tensor_inputs,), eps=1e-4, atol=1e-3, rtol=1e-2)
948
- else:
949
- result = self.gradcheck(wrapped_fn, (tensor_inputs,))
1001
+ result = self.gradcheck(wrapped_fn, (tensor_inputs,))
950
1002
 
951
1003
  return 1 if result else 0
952
1004
 
klongpy/cli.py CHANGED
@@ -7,6 +7,8 @@ See https://t3x.org/klong/klong-ref.txt.html for additional details.
7
7
 
8
8
  import argparse
9
9
  import asyncio
10
+ import importlib
11
+ import importlib.util
10
12
  import importlib.metadata
11
13
  import os
12
14
  import sys
@@ -15,10 +17,13 @@ import timeit
15
17
 
16
18
  import colorama
17
19
 
18
- from klongpy import KlongInterpreter
20
+ from klongpy import KlongInterpreter, list_backends
19
21
  from klongpy.core import kg_write
20
22
  from klongpy.repl import cleanup_repl, create_repl
21
23
 
24
+ _READLINE_SPEC = importlib.util.find_spec("readline")
25
+ readline = importlib.import_module("readline") if _READLINE_SPEC else None
26
+
22
27
 
23
28
  def sys_cmd_shell(klong, cmd):
24
29
  """
@@ -191,11 +196,13 @@ async def repl_eval(klong, p, verbose=True):
191
196
  return r
192
197
 
193
198
 
194
- def show_repl_header(ipc_addr=None):
199
+ def show_repl_header(backend_name, device_name=None, ipc_addr=None):
195
200
  print()
196
201
  print(f"{colorama.Fore.GREEN}Welcome to KlongPy REPL v{importlib.metadata.distribution('klongpy').version}")
197
202
  print(f"{colorama.Fore.GREEN}Author: Brian Guarraci")
198
203
  print(f"{colorama.Fore.GREEN}Web: http://klongpy.org")
204
+ device_str = f" ({device_name})" if device_name else ""
205
+ print(f"{colorama.Fore.CYAN}Backend: {backend_name}{device_str}")
199
206
  print(f"{colorama.Fore.YELLOW}]h for help; Ctrl-D or ]q to quit")
200
207
  print()
201
208
  if ipc_addr:
@@ -239,7 +246,19 @@ class ConsoleInputHandler:
239
246
  print("\rbye!")
240
247
  break
241
248
  except KeyboardInterrupt:
242
- print(failure("\nkg: error: interrupted"))
249
+ buf = ""
250
+ if readline is not None:
251
+ try:
252
+ buf = readline.get_line_buffer()
253
+ except Exception:
254
+ buf = ""
255
+ if buf:
256
+ print()
257
+ continue
258
+ print("\rbye!")
259
+ if exit_state is not None:
260
+ exit_state["code"] = 130
261
+ break
243
262
  except Exception as e:
244
263
  print(failure(f"Error: {e.args}"))
245
264
  import traceback
@@ -256,6 +275,10 @@ async def run_in_klong(klong, s):
256
275
 
257
276
 
258
277
  def run_file(klong_loop, klong, fname, verbose=False):
278
+ # Add script directory to sys.path so .py/.pyf imports resolve sibling modules
279
+ script_dir = os.path.dirname(os.path.abspath(fname))
280
+ if script_dir not in sys.path:
281
+ sys.path.insert(0, script_dir)
259
282
  with open(fname, "r") as f:
260
283
  content = f.read()
261
284
  return run_in_loop(klong_loop, run_in_klong(klong, content))
@@ -281,21 +304,35 @@ def main():
281
304
  parser.add_argument('-t', '--test', help='test program from file')
282
305
  parser.add_argument('-v', '--verbose', help='enable verbose output', action="store_true")
283
306
  parser.add_argument('-d', '--debug', help='enable debug mode', action="store_true")
307
+ parser.add_argument('--backend', help='set array backend', type=str.lower, choices=list_backends())
308
+ parser.add_argument('--device', help='set backend device (torch only)', type=str)
284
309
  parser.add_argument('filename', nargs='?', help='filename to be run if no flags are specified')
285
310
 
286
311
  args = parser.parse_args(main_args[1:])
287
312
 
313
+ # Default to torch backend if available and not explicitly set
314
+ if args.backend is None:
315
+ available_backends = list_backends()
316
+ if 'torch' in available_backends:
317
+ args.backend = 'torch'
318
+
288
319
  if args.debug:
289
320
  print("args: ", args)
290
321
 
291
322
  if args.expr:
292
- klong = KlongInterpreter()
323
+ try:
324
+ klong = KlongInterpreter(backend=args.backend, device=args.device)
325
+ except ValueError as exc:
326
+ parser.error(str(exc))
293
327
  result = klong(args.expr)
294
328
  if result is not None:
295
329
  print(kg_write(result, klong._backend, display=False))
296
330
  return
297
331
 
298
- klong, loops = create_repl(debug=args.debug)
332
+ try:
333
+ klong, loops = create_repl(debug=args.debug, backend=args.backend, device=args.device)
334
+ except ValueError as exc:
335
+ parser.error(str(exc))
299
336
  io_loop, _, _, klong_loop, _, _ = loops
300
337
  shutdown_event = klong['.system']['closeEvent']
301
338
 
@@ -361,9 +398,15 @@ def main():
361
398
  if run_repl:
362
399
  exit_state = {"code": None}
363
400
  colorama.init(autoreset=True)
364
- show_repl_header(args.server)
401
+ backend_name = klong._backend.name
402
+ device_name = str(klong._backend.device) if hasattr(klong._backend, 'device') else None
403
+ show_repl_header(backend_name, device_name, args.server)
365
404
  console_loop.create_task(ConsoleInputHandler.input_producer(console_loop, klong_loop, klong, args.verbose, exit_state))
366
- console_loop.run_forever()
405
+ try:
406
+ console_loop.run_forever()
407
+ except KeyboardInterrupt:
408
+ exit_state["code"] = 130
409
+ console_loop.stop()
367
410
  console_loop.close()
368
411
  if exit_state["code"] is not None:
369
412
  exit_code = exit_state["code"]