symbolicai 0.20.2__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. symai/__init__.py +96 -64
  2. symai/backend/base.py +93 -80
  3. symai/backend/engines/drawing/engine_bfl.py +12 -11
  4. symai/backend/engines/drawing/engine_gpt_image.py +108 -87
  5. symai/backend/engines/embedding/engine_llama_cpp.py +25 -28
  6. symai/backend/engines/embedding/engine_openai.py +3 -5
  7. symai/backend/engines/execute/engine_python.py +6 -5
  8. symai/backend/engines/files/engine_io.py +74 -67
  9. symai/backend/engines/imagecaptioning/engine_blip2.py +3 -3
  10. symai/backend/engines/imagecaptioning/engine_llavacpp_client.py +54 -38
  11. symai/backend/engines/index/engine_pinecone.py +23 -24
  12. symai/backend/engines/index/engine_vectordb.py +16 -14
  13. symai/backend/engines/lean/engine_lean4.py +38 -34
  14. symai/backend/engines/neurosymbolic/__init__.py +41 -13
  15. symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_chat.py +262 -182
  16. symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_reasoning.py +263 -191
  17. symai/backend/engines/neurosymbolic/engine_deepseekX_reasoning.py +53 -49
  18. symai/backend/engines/neurosymbolic/engine_google_geminiX_reasoning.py +212 -211
  19. symai/backend/engines/neurosymbolic/engine_groq.py +87 -63
  20. symai/backend/engines/neurosymbolic/engine_huggingface.py +21 -24
  21. symai/backend/engines/neurosymbolic/engine_llama_cpp.py +117 -48
  22. symai/backend/engines/neurosymbolic/engine_openai_gptX_chat.py +256 -229
  23. symai/backend/engines/neurosymbolic/engine_openai_gptX_reasoning.py +270 -150
  24. symai/backend/engines/ocr/engine_apilayer.py +6 -8
  25. symai/backend/engines/output/engine_stdout.py +1 -4
  26. symai/backend/engines/search/engine_openai.py +7 -7
  27. symai/backend/engines/search/engine_perplexity.py +5 -5
  28. symai/backend/engines/search/engine_serpapi.py +12 -14
  29. symai/backend/engines/speech_to_text/engine_local_whisper.py +20 -27
  30. symai/backend/engines/symbolic/engine_wolframalpha.py +3 -3
  31. symai/backend/engines/text_to_speech/engine_openai.py +5 -7
  32. symai/backend/engines/text_vision/engine_clip.py +7 -11
  33. symai/backend/engines/userinput/engine_console.py +3 -3
  34. symai/backend/engines/webscraping/engine_requests.py +81 -48
  35. symai/backend/mixin/__init__.py +13 -0
  36. symai/backend/mixin/anthropic.py +4 -2
  37. symai/backend/mixin/deepseek.py +2 -0
  38. symai/backend/mixin/google.py +2 -0
  39. symai/backend/mixin/openai.py +11 -3
  40. symai/backend/settings.py +83 -16
  41. symai/chat.py +101 -78
  42. symai/collect/__init__.py +7 -1
  43. symai/collect/dynamic.py +77 -69
  44. symai/collect/pipeline.py +35 -27
  45. symai/collect/stats.py +75 -63
  46. symai/components.py +198 -169
  47. symai/constraints.py +15 -12
  48. symai/core.py +698 -359
  49. symai/core_ext.py +32 -34
  50. symai/endpoints/api.py +80 -73
  51. symai/extended/.DS_Store +0 -0
  52. symai/extended/__init__.py +46 -12
  53. symai/extended/api_builder.py +11 -8
  54. symai/extended/arxiv_pdf_parser.py +13 -12
  55. symai/extended/bibtex_parser.py +2 -3
  56. symai/extended/conversation.py +101 -90
  57. symai/extended/document.py +17 -10
  58. symai/extended/file_merger.py +18 -13
  59. symai/extended/graph.py +18 -13
  60. symai/extended/html_style_template.py +2 -4
  61. symai/extended/interfaces/blip_2.py +1 -2
  62. symai/extended/interfaces/clip.py +1 -2
  63. symai/extended/interfaces/console.py +7 -1
  64. symai/extended/interfaces/dall_e.py +1 -1
  65. symai/extended/interfaces/flux.py +1 -1
  66. symai/extended/interfaces/gpt_image.py +1 -1
  67. symai/extended/interfaces/input.py +1 -1
  68. symai/extended/interfaces/llava.py +0 -1
  69. symai/extended/interfaces/naive_vectordb.py +7 -8
  70. symai/extended/interfaces/naive_webscraping.py +1 -1
  71. symai/extended/interfaces/ocr.py +1 -1
  72. symai/extended/interfaces/pinecone.py +6 -5
  73. symai/extended/interfaces/serpapi.py +1 -1
  74. symai/extended/interfaces/terminal.py +2 -3
  75. symai/extended/interfaces/tts.py +1 -1
  76. symai/extended/interfaces/whisper.py +1 -1
  77. symai/extended/interfaces/wolframalpha.py +1 -1
  78. symai/extended/metrics/__init__.py +11 -1
  79. symai/extended/metrics/similarity.py +11 -13
  80. symai/extended/os_command.py +17 -16
  81. symai/extended/packages/__init__.py +29 -3
  82. symai/extended/packages/symdev.py +19 -16
  83. symai/extended/packages/sympkg.py +12 -9
  84. symai/extended/packages/symrun.py +21 -19
  85. symai/extended/repo_cloner.py +11 -10
  86. symai/extended/seo_query_optimizer.py +1 -2
  87. symai/extended/solver.py +20 -23
  88. symai/extended/summarizer.py +4 -3
  89. symai/extended/taypan_interpreter.py +10 -12
  90. symai/extended/vectordb.py +99 -82
  91. symai/formatter/__init__.py +9 -1
  92. symai/formatter/formatter.py +12 -16
  93. symai/formatter/regex.py +62 -63
  94. symai/functional.py +176 -122
  95. symai/imports.py +136 -127
  96. symai/interfaces.py +56 -27
  97. symai/memory.py +14 -13
  98. symai/misc/console.py +49 -39
  99. symai/misc/loader.py +5 -3
  100. symai/models/__init__.py +17 -1
  101. symai/models/base.py +269 -181
  102. symai/models/errors.py +0 -1
  103. symai/ops/__init__.py +32 -22
  104. symai/ops/measures.py +11 -15
  105. symai/ops/primitives.py +348 -228
  106. symai/post_processors.py +32 -28
  107. symai/pre_processors.py +39 -41
  108. symai/processor.py +6 -4
  109. symai/prompts.py +59 -45
  110. symai/server/huggingface_server.py +23 -20
  111. symai/server/llama_cpp_server.py +7 -5
  112. symai/shell.py +3 -4
  113. symai/shellsv.py +499 -375
  114. symai/strategy.py +517 -287
  115. symai/symbol.py +111 -116
  116. symai/utils.py +42 -36
  117. {symbolicai-0.20.2.dist-info → symbolicai-1.0.0.dist-info}/METADATA +4 -2
  118. symbolicai-1.0.0.dist-info/RECORD +163 -0
  119. symbolicai-0.20.2.dist-info/RECORD +0 -162
  120. {symbolicai-0.20.2.dist-info → symbolicai-1.0.0.dist-info}/WHEEL +0 -0
  121. {symbolicai-0.20.2.dist-info → symbolicai-1.0.0.dist-info}/entry_points.txt +0 -0
  122. {symbolicai-0.20.2.dist-info → symbolicai-1.0.0.dist-info}/licenses/LICENSE +0 -0
  123. {symbolicai-0.20.2.dist-info → symbolicai-1.0.0.dist-info}/top_level.txt +0 -0
symai/ops/primitives.py CHANGED
@@ -1,19 +1,18 @@
1
1
  import ast
2
2
  import json
3
- import os
3
+ import numbers
4
4
  import pickle
5
5
  import uuid
6
- from typing import (TYPE_CHECKING, Any, Callable, Dict, Iterable, List,
7
- Optional, Tuple, Type, Union)
6
+ from collections.abc import Callable, Iterable
7
+ from pathlib import Path
8
+ from typing import TYPE_CHECKING, Any, Literal, Union
8
9
 
9
10
  import numpy as np
10
11
  import torch
11
- from pydantic import ValidationError
12
12
 
13
13
  from .. import core, core_ext
14
- from ..models import CustomConstraint, LengthConstraint, LLMDataModel
15
14
  from ..prompts import Prompt
16
- from ..utils import CustomUserWarning
15
+ from ..utils import UserMessage
17
16
  from .measures import calculate_frechet_distance, calculate_mmd
18
17
 
19
18
  if TYPE_CHECKING:
@@ -41,20 +40,21 @@ class Primitive:
41
40
 
42
41
 
43
42
  class OperatorPrimitives(Primitive):
44
- def __try_type_specific_func(self, other, func, op: str = None):
43
+ __hash__ = None
44
+
45
+ def __try_type_specific_func(self, other, func, op: str | None = None):
45
46
  if not isinstance(other, self._symbol_type):
46
47
  other = self._to_type(other)
47
48
  # None shortcut
48
- if not self.__disable_none_shortcut__:
49
- if self.value is None or other.value is None:
50
- CustomUserWarning(f"unsupported {self._symbol_type.__class__} value operand type(s) for {op}: '{type(self.value)}' and '{type(other.value)}'", raise_with=TypeError)
49
+ if not self.__disable_none_shortcut__ and (self.value is None or other.value is None):
50
+ UserMessage(f"unsupported {self._symbol_type.__class__} value operand type(s) for {op}: '{type(self.value)}' and '{type(other.value)}'", raise_with=TypeError)
51
51
  # try type specific function
52
52
  try:
53
53
  # try type specific function
54
54
  value = func(self, other)
55
55
  if value is NotImplemented:
56
56
  operation = '' if op is None else op
57
- CustomUserWarning(f"unsupported {self._symbol_type.__class__} value operand type(s) for {operation}: '{type(self.value)}' and '{type(other.value)}'", raise_with=TypeError)
57
+ UserMessage(f"unsupported {self._symbol_type.__class__} value operand type(s) for {operation}: '{type(self.value)}' and '{type(other.value)}'", raise_with=TypeError)
58
58
  return value
59
59
  except Exception as ex:
60
60
  self._metadata._error = ex
@@ -66,7 +66,7 @@ class OperatorPrimitives(Primitive):
66
66
  This function raises an error if the neuro-symbolic engine is disabled.
67
67
  '''
68
68
  if self.__disable_nesy_engine__:
69
- CustomUserWarning(f"unsupported {self.__class__} value operand type(s) for {func.__name__}: '{type(self.value)}'", raise_with=TypeError)
69
+ UserMessage(f"unsupported {self.__class__} value operand type(s) for {func.__name__}: '{type(self.value)}'", raise_with=TypeError)
70
70
 
71
71
  def __bool__(self) -> bool:
72
72
  '''
@@ -80,7 +80,7 @@ class OperatorPrimitives(Primitive):
80
80
  if isinstance(self.value, bool):
81
81
  val = self.value
82
82
  elif self.value is not None:
83
- val = True if self.value else False
83
+ val = bool(self.value)
84
84
 
85
85
  return val
86
86
 
@@ -680,8 +680,8 @@ class OperatorPrimitives(Primitive):
680
680
  Returns:
681
681
  Symbol: A new symbol with the result of the OR operation.
682
682
  '''
683
- # exclude the evaluation for the Aggregator class
684
- from ..collect.stats import Aggregator
683
+ # Exclude the evaluation for the Aggregator class; keep import local to avoid ops.primitives <-> collect.stats cycle.
684
+ from ..collect.stats import Aggregator # noqa
685
685
  if isinstance(other, Aggregator):
686
686
  return NotImplemented
687
687
 
@@ -709,8 +709,8 @@ class OperatorPrimitives(Primitive):
709
709
  Returns:
710
710
  Symbol: A new Symbol object with the concatenated value.
711
711
  '''
712
- # exclude the evaluation for the Aggregator class
713
- from ..collect.stats import Aggregator
712
+ # Exclude the evaluation for the Aggregator class; keep import local to avoid ops.primitives <-> collect.stats cycle.
713
+ from ..collect.stats import Aggregator # noqa
714
714
  if isinstance(other, Aggregator):
715
715
  return NotImplemented
716
716
 
@@ -739,8 +739,8 @@ class OperatorPrimitives(Primitive):
739
739
  Returns:
740
740
  Symbol: A new Symbol object with the concatenated value.
741
741
  '''
742
- # exclude the evaluation for the Aggregator class
743
- from ..collect.stats import Aggregator
742
+ # Exclude the evaluation for the Aggregator class; keep import local to avoid ops.primitives <-> collect.stats cycle.
743
+ from ..collect.stats import Aggregator # noqa
744
744
  if isinstance(other, Aggregator):
745
745
  return NotImplemented
746
746
 
@@ -848,11 +848,12 @@ class OperatorPrimitives(Primitive):
848
848
  Returns:
849
849
  Symbol: A new Symbol object with the concatenated value.
850
850
  '''
851
- if isinstance(self.value, str) and isinstance(other, str) or \
852
- isinstance(self.value, str) and isinstance(other, self._symbol_type) and isinstance(other.value, str):
851
+ if (isinstance(self.value, str) and isinstance(other, str)) or \
852
+ (isinstance(self.value, str) and isinstance(other, self._symbol_type) and isinstance(other.value, str)):
853
853
  other = self._to_type(other)
854
854
  return self._to_type(f'{self.value}{other.value}')
855
- CustomUserWarning(f'This method is only supported for string concatenation! Got {type(self.value)} and {type(other)} instead.', raise_with=TypeError)
855
+ UserMessage(f'This method is only supported for string concatenation! Got {type(self.value)} and {type(other)} instead.', raise_with=TypeError)
856
+ return None
856
857
 
857
858
  def __rmatmul__(self, other: Any) -> 'Symbol':
858
859
  '''
@@ -917,7 +918,8 @@ class OperatorPrimitives(Primitive):
917
918
  if result is not None:
918
919
  return self._to_type(result)
919
920
 
920
- CustomUserWarning('Division operation not supported semantically! Might change in the future.', raise_with=NotImplementedError)
921
+ UserMessage('Division operation not supported semantically! Might change in the future.', raise_with=NotImplementedError)
922
+ return None
921
923
 
922
924
 
923
925
  def __itruediv__(self, other: Any) -> 'Symbol':
@@ -936,7 +938,8 @@ class OperatorPrimitives(Primitive):
936
938
  self._value = result
937
939
  return self
938
940
 
939
- CustomUserWarning('Division operation not supported semantically! Might change in the future.', raise_with=NotImplementedError)
941
+ UserMessage('Division operation not supported semantically! Might change in the future.', raise_with=NotImplementedError)
942
+ return None
940
943
 
941
944
 
942
945
  def __floordiv__(self, other: Any) -> 'Symbol':
@@ -954,7 +957,8 @@ class OperatorPrimitives(Primitive):
954
957
  if result is not None:
955
958
  return self._to_type(result)
956
959
 
957
- CustomUserWarning('Floor division operation not supported semantically! Might change in the future.', raise_with=NotImplementedError)
960
+ UserMessage('Floor division operation not supported semantically! Might change in the future.', raise_with=NotImplementedError)
961
+ return None
958
962
 
959
963
  def __rfloordiv__(self, other: Any) -> 'Symbol':
960
964
  '''
@@ -971,7 +975,8 @@ class OperatorPrimitives(Primitive):
971
975
  if result is not None:
972
976
  return self._to_type(result)
973
977
 
974
- CustomUserWarning('Floor division operation not supported semantically! Might change in the future.', raise_with=NotImplementedError)
978
+ UserMessage('Floor division operation not supported semantically! Might change in the future.', raise_with=NotImplementedError)
979
+ return None
975
980
 
976
981
  def __ifloordiv__(self, other: Any) -> 'Symbol':
977
982
  '''
@@ -989,7 +994,8 @@ class OperatorPrimitives(Primitive):
989
994
  self._value = result
990
995
  return self
991
996
 
992
- CustomUserWarning('Floor division operation not supported semantically! Might change in the future.', raise_with=NotImplementedError)
997
+ UserMessage('Floor division operation not supported semantically! Might change in the future.', raise_with=NotImplementedError)
998
+ return None
993
999
 
994
1000
  def __pow__(self, other: Any) -> 'Symbol':
995
1001
  '''
@@ -1006,7 +1012,8 @@ class OperatorPrimitives(Primitive):
1006
1012
  if result is not None:
1007
1013
  return self._to_type(result)
1008
1014
 
1009
- CustomUserWarning('Power operation not supported semantically! Might change in the future.', raise_with=NotImplementedError)
1015
+ UserMessage('Power operation not supported semantically! Might change in the future.', raise_with=NotImplementedError)
1016
+ return None
1010
1017
 
1011
1018
 
1012
1019
  def __rpow__(self, other: Any) -> 'Symbol':
@@ -1024,7 +1031,8 @@ class OperatorPrimitives(Primitive):
1024
1031
  if result is not None:
1025
1032
  return self._to_type(result)
1026
1033
 
1027
- CustomUserWarning('Power operation not supported semantically! Might change in the future.', raise_with=NotImplementedError)
1034
+ UserMessage('Power operation not supported semantically! Might change in the future.', raise_with=NotImplementedError)
1035
+ return None
1028
1036
 
1029
1037
  def __ipow__(self, other: Any) -> 'Symbol':
1030
1038
  '''
@@ -1042,7 +1050,8 @@ class OperatorPrimitives(Primitive):
1042
1050
  self._value = result
1043
1051
  return self
1044
1052
 
1045
- CustomUserWarning('Power operation not supported semantically! Might change in the future.', raise_with=NotImplementedError)
1053
+ UserMessage('Power operation not supported semantically! Might change in the future.', raise_with=NotImplementedError)
1054
+ return None
1046
1055
 
1047
1056
  def __mod__(self, other: Any) -> 'Symbol':
1048
1057
  '''
@@ -1059,7 +1068,8 @@ class OperatorPrimitives(Primitive):
1059
1068
  if result is not None:
1060
1069
  return self._to_type(result)
1061
1070
 
1062
- CustomUserWarning('Modulo operation not supported semantically! Might change in the future.', raise_with=NotImplementedError)
1071
+ UserMessage('Modulo operation not supported semantically! Might change in the future.', raise_with=NotImplementedError)
1072
+ return None
1063
1073
 
1064
1074
  def __rmod__(self, other: Any) -> 'Symbol':
1065
1075
  '''
@@ -1076,7 +1086,9 @@ class OperatorPrimitives(Primitive):
1076
1086
  if result is not None:
1077
1087
  return self._to_type(result)
1078
1088
 
1079
- raise NotImplementedError('Modulo operation not supported! Might change in the future.') from self._metadata._error
1089
+ msg = 'Modulo operation not supported! Might change in the future.'
1090
+ UserMessage(msg)
1091
+ raise NotImplementedError(msg) from self._metadata._error
1080
1092
 
1081
1093
  def __imod__(self, other: Any) -> 'Symbol':
1082
1094
  '''
@@ -1094,7 +1106,8 @@ class OperatorPrimitives(Primitive):
1094
1106
  self._value = result
1095
1107
  return self
1096
1108
 
1097
- CustomUserWarning('Modulo operation not supported semantically! Might change in the future.', raise_with=NotImplementedError)
1109
+ UserMessage('Modulo operation not supported semantically! Might change in the future.', raise_with=NotImplementedError)
1110
+ return None
1098
1111
 
1099
1112
  def __mul__(self, other: Any) -> 'Symbol':
1100
1113
  '''
@@ -1111,7 +1124,8 @@ class OperatorPrimitives(Primitive):
1111
1124
  if result is not None:
1112
1125
  return self._to_type(result)
1113
1126
 
1114
- CustomUserWarning('Multiply operation not supported semantically! Might change in the future.', raise_with=NotImplementedError)
1127
+ UserMessage('Multiply operation not supported semantically! Might change in the future.', raise_with=NotImplementedError)
1128
+ return None
1115
1129
 
1116
1130
  def __rmul__(self, other: Any) -> 'Symbol':
1117
1131
  '''
@@ -1128,7 +1142,8 @@ class OperatorPrimitives(Primitive):
1128
1142
  if result is not None:
1129
1143
  return self._to_type(result)
1130
1144
 
1131
- CustomUserWarning('Multiply operation not supported semantically! Might change in the future.', raise_with=NotImplementedError)
1145
+ UserMessage('Multiply operation not supported semantically! Might change in the future.', raise_with=NotImplementedError)
1146
+ return None
1132
1147
 
1133
1148
  def __imul__(self, other: Any) -> 'Symbol':
1134
1149
  '''
@@ -1146,7 +1161,8 @@ class OperatorPrimitives(Primitive):
1146
1161
  self._value = result
1147
1162
  return self
1148
1163
 
1149
- CustomUserWarning('Multiply operation not supported semantically! Might change in the future.', raise_with=NotImplementedError)
1164
+ UserMessage('Multiply operation not supported semantically! Might change in the future.', raise_with=NotImplementedError)
1165
+ return None
1150
1166
 
1151
1167
 
1152
1168
  class CastingPrimitives(Primitive):
@@ -1165,14 +1181,14 @@ class CastingPrimitives(Primitive):
1165
1181
  @property
1166
1182
  def sem(self) -> "Symbol":
1167
1183
  """
1168
- Return a semantic view of this Symbol.
1184
+ Return a semantic view of this Symbol.
1169
1185
  (Useful after calling `.syn` in a chain.)
1170
1186
  """
1171
1187
  if getattr(self, "__semantic__", False):
1172
1188
  return self
1173
1189
  return self._to_type(self.value, semantic=True)
1174
1190
 
1175
- def cast(self, as_type: Type) -> Any:
1191
+ def cast(self, as_type: type) -> Any:
1176
1192
  '''
1177
1193
  Cast the Symbol's value to a specific type.
1178
1194
 
@@ -1184,7 +1200,7 @@ class CastingPrimitives(Primitive):
1184
1200
  '''
1185
1201
  return as_type(self.value)
1186
1202
 
1187
- def to(self, as_type: Type) -> Any:
1203
+ def to(self, as_type: type) -> Any:
1188
1204
  '''
1189
1205
  Cast the Symbol's value to a specific type.
1190
1206
 
@@ -1250,7 +1266,7 @@ class IterationPrimitives(Primitive):
1250
1266
  This mixin contains functions that perform iteration operations on symbols or symbol values.
1251
1267
  The functions in this mixin are bound to the 'neurosymbolic' engine for evaluation.
1252
1268
  '''
1253
- def __getitem__(self, key: Union[str, int, slice]) -> 'Symbol':
1269
+ def __getitem__(self, key: str | int | slice) -> 'Symbol':
1254
1270
  '''
1255
1271
  Get the item of the Symbol value with the specified key or index.
1256
1272
  If the Symbol value is a list, tuple, or numpy array, the key can be an integer or slice.
@@ -1270,7 +1286,7 @@ class IterationPrimitives(Primitive):
1270
1286
  try:
1271
1287
  return self.value[key]
1272
1288
  except Exception:
1273
- CustomUserWarning(f'Key {key} not found in {self.value}', raise_with=Exception)
1289
+ UserMessage(f'Key {key} not found in {self.value}', raise_with=Exception)
1274
1290
 
1275
1291
  @core.getitem()
1276
1292
  def _func(_, index: str):
@@ -1278,7 +1294,7 @@ class IterationPrimitives(Primitive):
1278
1294
 
1279
1295
  return self._to_type(_func(self, key))
1280
1296
 
1281
- def __setitem__(self, key: Union[str, int, slice], value: Any) -> None:
1297
+ def __setitem__(self, key: str | int | slice, value: Any) -> None:
1282
1298
  '''
1283
1299
  Set the item of the Symbol value with the specified key or index to the given value.
1284
1300
  If the Symbol value is a list, the key can be an integer or slice.
@@ -1292,17 +1308,18 @@ class IterationPrimitives(Primitive):
1292
1308
  Raises:
1293
1309
  KeyError: If the key or index is not found in the Symbol value.
1294
1310
  '''
1295
- from ..post_processors import ASTPostProcessor
1311
+ # Local import avoids ops.primitives -> post_processors -> symbol -> ops circular load.
1312
+ from ..post_processors import ASTPostProcessor # noqa
1296
1313
 
1297
1314
  if not isinstance(self.value, (str, dict, list)):
1298
- CustomUserWarning(f'Setting item is not supported for {type(self.value)}. Supported types are str, dict, and list.', raise_with=TypeError)
1315
+ UserMessage(f'Setting item is not supported for {type(self.value)}. Supported types are str, dict, and list.', raise_with=TypeError)
1299
1316
 
1300
1317
  if not self.__semantic__:
1301
1318
  try:
1302
1319
  self.value[key] = value
1303
1320
  return
1304
1321
  except Exception:
1305
- CustomUserWarning(f'Key {key} not found in {self.value}', raise_with=Exception)
1322
+ UserMessage(f'Key {key} not found in {self.value}', raise_with=Exception)
1306
1323
 
1307
1324
  @core.setitem()
1308
1325
  def _func(_, index: str, value: str):
@@ -1314,7 +1331,7 @@ class IterationPrimitives(Primitive):
1314
1331
  except Exception:
1315
1332
  self._value = result # It was a string, or something failed (because } wasn't close, etc)
1316
1333
 
1317
- def __delitem__(self, key: Union[str, int]) -> None:
1334
+ def __delitem__(self, key: str | int) -> None:
1318
1335
  '''
1319
1336
  Delete the item of the Symbol value with the specified key or index.
1320
1337
  If the Symbol value is a dictionary, the key can be a string or an integer.
@@ -1326,17 +1343,18 @@ class IterationPrimitives(Primitive):
1326
1343
  Raises:
1327
1344
  KeyError: If the key or index is not found in the Symbol value.
1328
1345
  '''
1329
- from ..post_processors import ASTPostProcessor
1346
+ # Local import avoids ops.primitives -> post_processors -> symbol -> ops circular load.
1347
+ from ..post_processors import ASTPostProcessor # noqa
1330
1348
 
1331
1349
  if not isinstance(self.value, (str, dict, list)):
1332
- CustomUserWarning(f'Setting item is not supported for {type(self.value)}. Supported types are str, dict, and list.', raise_with=TypeError)
1350
+ UserMessage(f'Setting item is not supported for {type(self.value)}. Supported types are str, dict, and list.', raise_with=TypeError)
1333
1351
 
1334
1352
  if not self.__semantic__:
1335
1353
  try:
1336
1354
  del self.value[key]
1337
1355
  return
1338
1356
  except Exception:
1339
- CustomUserWarning(f'Key {key} not found in {self.value}', raise_with=Exception)
1357
+ UserMessage(f'Key {key} not found in {self.value}', raise_with=Exception)
1340
1358
 
1341
1359
  @core.delitem()
1342
1360
  def _func(_, index: str):
@@ -1427,7 +1445,7 @@ class StringHelperPrimitives(Primitive):
1427
1445
  '''
1428
1446
  This mixin contains functions that provide additional help for symbols or their values.
1429
1447
  '''
1430
- def split(self, delimiter: str, **kwargs) -> 'Symbol':
1448
+ def split(self, delimiter: str, **_kwargs) -> 'Symbol':
1431
1449
  '''
1432
1450
  Splits the symbol value by a specified delimiter.
1433
1451
  Uses the core.split decorator to create a _func method that splits the symbol value by the specified delimiter.
@@ -1442,7 +1460,7 @@ class StringHelperPrimitives(Primitive):
1442
1460
  assert isinstance(self.value, str), f'self.value must be a string, got {type(self.value)}'
1443
1461
  return self._to_type([*self.value.split(delimiter)])
1444
1462
 
1445
- def join(self, delimiter: str = ' ', **kwargs) -> 'Symbol':
1463
+ def join(self, delimiter: str = ' ', **_kwargs) -> 'Symbol':
1446
1464
  '''
1447
1465
  Joins the symbol value with a specified delimiter.
1448
1466
 
@@ -1456,7 +1474,7 @@ class StringHelperPrimitives(Primitive):
1456
1474
  assert isinstance(self.value, Iterable), f'value must be an iterable, got {type(self.value)}'
1457
1475
  return self._to_type(delimiter.join(self.value))
1458
1476
 
1459
- def startswith(self, prefix: str, **kwargs) -> bool:
1477
+ def startswith(self, prefix: str, **_kwargs) -> bool:
1460
1478
  '''
1461
1479
  Checks if the symbol value starts with a specified prefix.
1462
1480
  Uses the core.startswith decorator to create a _func method that checks if the symbol value starts with the specified prefix.
@@ -1479,7 +1497,7 @@ class StringHelperPrimitives(Primitive):
1479
1497
 
1480
1498
  return _func(self, prefix)
1481
1499
 
1482
- def endswith(self, suffix: str, **kwargs) -> bool:
1500
+ def endswith(self, suffix: str, **_kwargs) -> bool:
1483
1501
  '''
1484
1502
  Checks if the symbol value ends with a specified suffix.
1485
1503
  Uses the core.endswith decorator to create a _func method that checks if the symbol value ends with the specified suffix.
@@ -1574,7 +1592,7 @@ class ExpressionHandlingPrimitives(Primitive):
1574
1592
  if not hasattr(self, '_accumulated_results'):
1575
1593
  self._accumulated_results = []
1576
1594
 
1577
- def get_results(self) -> List['Symbol']:
1595
+ def get_results(self) -> list['Symbol']:
1578
1596
  '''
1579
1597
  Retrieves accumulated results from previous interpretations.
1580
1598
 
@@ -1589,7 +1607,7 @@ class ExpressionHandlingPrimitives(Primitive):
1589
1607
  self.init_results()
1590
1608
  self._accumulated_results = []
1591
1609
 
1592
- def interpret(self, prompt: Optional[str] = "Evaluate the symbolic expressions and return only the result:\n", accumulate: bool = False, **kwargs) -> 'Symbol':
1610
+ def interpret(self, prompt: str | None = "Evaluate the symbolic expressions and return only the result:\n", accumulate: bool = False, **kwargs) -> 'Symbol':
1593
1611
  '''
1594
1612
  Evaluates simple symbolic expressions.
1595
1613
  Uses the core.expression decorator to create a _func method that evaluates the given expression.
@@ -1639,7 +1657,7 @@ class DataHandlingPrimitives(Primitive):
1639
1657
 
1640
1658
  return self._to_type(_func(self))
1641
1659
 
1642
- def summarize(self, context: Optional[str] = None, **kwargs) -> 'Symbol':
1660
+ def summarize(self, context: str | None = None, **kwargs) -> 'Symbol':
1643
1661
  '''
1644
1662
  Summarizes the symbol value.
1645
1663
  Uses the core.summarize decorator with an optional context to create a _func method that summarizes the symbol value.
@@ -1670,7 +1688,7 @@ class DataHandlingPrimitives(Primitive):
1670
1688
 
1671
1689
  return self._to_type(_func(self))
1672
1690
 
1673
- def filter(self, criteria: str, include: Optional[bool] = False, **kwargs) -> 'Symbol':
1691
+ def filter(self, criteria: str, include: bool | None = False, **kwargs) -> 'Symbol':
1674
1692
  '''
1675
1693
  Filters the symbol value based on a specified criteria.
1676
1694
  Uses the core.filtering decorator with the provided criteria and include flag to create a _func method to filter the symbol value.
@@ -1707,7 +1725,7 @@ class DataHandlingPrimitives(Primitive):
1707
1725
  try:
1708
1726
  iter(self.value)
1709
1727
  except TypeError:
1710
- CustomUserWarning('Map can only be applied to iterable objects', raise_with=AssertionError)
1728
+ UserMessage('Map can only be applied to iterable objects', raise_with=AssertionError)
1711
1729
 
1712
1730
  @core.map(instruction=instruction, **kwargs)
1713
1731
  def _func(_):
@@ -1807,7 +1825,7 @@ class UniquenessPrimitives(Primitive):
1807
1825
  This mixin includes functions that work with unique aspects of symbol values, like extracting unique information or composing new unique symbols.
1808
1826
  Future functionalities might include finding duplicate information, defining levels of uniqueness, etc.
1809
1827
  '''
1810
- def unique(self, keys: Optional[List[str]] = [], **kwargs) -> 'Symbol':
1828
+ def unique(self, keys: list[str] | None = None, **kwargs) -> 'Symbol':
1811
1829
  '''
1812
1830
  Extracts unique information from the symbol value, using provided keys.
1813
1831
  Uses the core.unique decorator with a list of keys to create a _func method that extracts unique data from the symbol value.
@@ -1818,6 +1836,8 @@ class UniquenessPrimitives(Primitive):
1818
1836
  Returns:
1819
1837
  Symbol: A new symbol with the unique information.
1820
1838
  '''
1839
+ if keys is None:
1840
+ keys = []
1821
1841
  @core.unique(keys=keys, **kwargs)
1822
1842
  def _func(_) -> str:
1823
1843
  pass
@@ -1844,7 +1864,7 @@ class PatternMatchingPrimitives(Primitive):
1844
1864
  This mixin houses functions that deal with ranking symbols, extracting details based on patterns, and correcting symbols.
1845
1865
  It will house future functionalities that involve sorting, complex pattern detections, advanced correction techniques etc.
1846
1866
  '''
1847
- def rank(self, measure: Optional[str] = 'alphanumeric', order: Optional[str] = 'desc', **kwargs) -> 'Symbol':
1867
+ def rank(self, measure: str | None = 'alphanumeric', order: str | None = 'desc', **kwargs) -> 'Symbol':
1848
1868
  '''
1849
1869
  Ranks the symbol value based on a measure and a sort order.
1850
1870
  Uses the core.rank decorator with the specified measure and order to create a _func method that ranks the symbol value.
@@ -1899,7 +1919,7 @@ class PatternMatchingPrimitives(Primitive):
1899
1919
 
1900
1920
  return self._to_type(_func(self))
1901
1921
 
1902
- def translate(self, language: Optional[str] = 'English', **kwargs) -> 'Symbol':
1922
+ def translate(self, language: str | None = 'English', **kwargs) -> 'Symbol':
1903
1923
  '''
1904
1924
  Translates the symbol value to the specified language.
1905
1925
  Uses the @core.translate decorator to translate the symbol's value to the specified language.
@@ -1917,7 +1937,7 @@ class PatternMatchingPrimitives(Primitive):
1917
1937
 
1918
1938
  return self._to_type(_func(self))
1919
1939
 
1920
- def choice(self, cases: List[str], default: str, **kwargs) -> 'Symbol':
1940
+ def choice(self, cases: list[str], default: str, **kwargs) -> 'Symbol':
1921
1941
  '''
1922
1942
  Chooses one of the cases based on the symbol value.
1923
1943
  Uses the @core.case decorator, selects one of the cases based on the symbol's value.
@@ -1942,7 +1962,7 @@ class QueryHandlingPrimitives(Primitive):
1942
1962
  This mixin helps in transforming, preparing, and executing queries, and it is designed to be extendable as new ways of handling queries are developed.
1943
1963
  Future methods could potentially include query optimization, enhanced query formatting, multi-level query execution, query error handling, etc.
1944
1964
  '''
1945
- def query(self, context: str, prompt: Optional[str] = None, examples: Optional[List[Prompt]] = None, **kwargs) -> 'Symbol':
1965
+ def query(self, context: str, prompt: str | None = None, examples: list[Prompt] | None = None, **kwargs) -> 'Symbol':
1946
1966
  '''
1947
1967
  Queries the symbol value based on a specified context.
1948
1968
  Uses the @core.query decorator, queries based on the context, prompt, and examples.
@@ -2004,7 +2024,7 @@ class ExecutionControlPrimitives(Primitive):
2004
2024
  This mixin represents the core methods for dealing with symbol execution.
2005
2025
  Possible future methods could potentially include async execution, pipeline chaining, execution profiling, improved error handling, version management, embedding more complex execution control structures etc.
2006
2026
  '''
2007
- def analyze(self, exception: Exception, query: Optional[str] = '', **kwargs) -> 'Symbol':
2027
+ def analyze(self, exception: Exception, query: str | None = '', **kwargs) -> 'Symbol':
2008
2028
  '''Uses the @core.analyze decorator, analyzes an exception and returns a symbol.
2009
2029
 
2010
2030
  Args:
@@ -2120,7 +2140,7 @@ class ExecutionControlPrimitives(Primitive):
2120
2140
 
2121
2141
  return self._to_type(_func(self))
2122
2142
 
2123
- def stream(self, expr: 'Expression', token_ratio: Optional[float] = 0.6, **kwargs) -> 'Symbol':
2143
+ def stream(self, expr: 'Expression', token_ratio: float | None = 0.6, **kwargs) -> 'Symbol':
2124
2144
  '''
2125
2145
  Streams the Symbol's value through an Expression object.
2126
2146
  This method divides the Symbol's value into chunks and processes each chunk through the given Expression object.
@@ -2156,7 +2176,7 @@ class ExecutionControlPrimitives(Primitive):
2156
2176
  else:
2157
2177
  yield expr(self, **kwargs)
2158
2178
 
2159
- def ftry(self, expr: 'Expression', retries: Optional[int] = 1, **kwargs) -> 'Symbol':
2179
+ def ftry(self, expr: 'Expression', retries: int | None = 1, **kwargs) -> 'Symbol':
2160
2180
  # TODO: find a way to pass on the constraints and behavior from the self.expr to the corrected code
2161
2181
  '''
2162
2182
  Tries to evaluate a Symbol using a given Expression.
@@ -2198,33 +2218,32 @@ class ExecutionControlPrimitives(Primitive):
2198
2218
  retry_cnt += 1
2199
2219
  if retry_cnt > retries:
2200
2220
  raise e
2221
+ # analyze the error
2222
+ payload = f'[ORIGINAL_USER_PROMPT]\n{prompt["prompt_instruction"]}\n\n' if 'prompt_instruction' in prompt else ''
2223
+ payload = payload + f'[ORIGINAL_USER_DATA]\n{code}\n\n[ORIGINAL_GENERATED_OUTPUT]\n{prompt["out_msg"]}'
2224
+ probe = sym.analyze(query="What is the issue in this expression?", payload=payload, exception=e)
2225
+ # attempt to correct the error
2226
+ payload = f'[ORIGINAL_USER_PROMPT]\n{prompt["prompt_instruction"]}\n\n' if 'prompt_instruction' in prompt else ''
2227
+ payload = payload + f'[ANALYSIS]\n{probe}\n\n'
2228
+ context = f'Try to correct the error of the original user request based on the analysis above: \n [GENERATED_OUTPUT]\n{prompt["out_msg"]}\n\n'
2229
+ constraints = expr.constraints if hasattr(expr, 'constraints') else []
2230
+
2231
+ if hasattr(expr, 'post_processor'):
2232
+ post_processor = expr.post_processor
2233
+ sym = code.correct(
2234
+ context=context,
2235
+ exception=e,
2236
+ payload=payload,
2237
+ constraints=constraints,
2238
+ post_processor=post_processor
2239
+ )
2201
2240
  else:
2202
- # analyze the error
2203
- payload = f'[ORIGINAL_USER_PROMPT]\n{prompt["prompt_instruction"]}\n\n' if 'prompt_instruction' in prompt else ''
2204
- payload = payload + f'[ORIGINAL_USER_DATA]\n{code}\n\n[ORIGINAL_GENERATED_OUTPUT]\n{prompt["out_msg"]}'
2205
- probe = sym.analyze(query="What is the issue in this expression?", payload=payload, exception=e)
2206
- # attempt to correct the error
2207
- payload = f'[ORIGINAL_USER_PROMPT]\n{prompt["prompt_instruction"]}\n\n' if 'prompt_instruction' in prompt else ''
2208
- payload = payload + f'[ANALYSIS]\n{probe}\n\n'
2209
- context = f'Try to correct the error of the original user request based on the analysis above: \n [GENERATED_OUTPUT]\n{prompt["out_msg"]}\n\n'
2210
- constraints = expr.constraints if hasattr(expr, 'constraints') else []
2211
-
2212
- if hasattr(expr, 'post_processor'):
2213
- post_processor = expr.post_processor
2214
- sym = code.correct(
2215
- context=context,
2216
- exception=e,
2217
- payload=payload,
2218
- constraints=constraints,
2219
- post_processor=post_processor
2220
- )
2221
- else:
2222
- sym = code.correct(
2223
- context=context,
2224
- exception=e,
2225
- payload=payload,
2226
- constraints=constraints
2227
- )
2241
+ sym = code.correct(
2242
+ context=context,
2243
+ exception=e,
2244
+ payload=payload,
2245
+ constraints=constraints
2246
+ )
2228
2247
 
2229
2248
 
2230
2249
  class DictHandlingPrimitives(Primitive):
@@ -2257,7 +2276,7 @@ class TemplateStylingPrimitives(Primitive):
2257
2276
  This mixin includes functionalities for stylizing symbols and applying templates.
2258
2277
  Future functionalities might include a variety of new stylizing methods, application of more complex templates, etc.
2259
2278
  '''
2260
- def template(that, template: str, placeholder: Optional[str] = '{{placeholder}}', **kwargs) -> 'Symbol':
2279
+ def template(that, template: str, placeholder: str | None = '{{placeholder}}', **_kwargs) -> 'Symbol':
2261
2280
  '''
2262
2281
  Applies a template to the Symbol.
2263
2282
  This method uses the @core.template decorator to apply the given template and placeholder to the Symbol.
@@ -2277,7 +2296,7 @@ class TemplateStylingPrimitives(Primitive):
2277
2296
 
2278
2297
  return _func(that)
2279
2298
 
2280
- def style(self, description: str, libraries: Optional[List] = [], **kwargs) -> 'Symbol':
2299
+ def style(self, description: str, libraries: list | None = None, **kwargs) -> 'Symbol':
2281
2300
  '''
2282
2301
  Applies a style to the Symbol.
2283
2302
  This method uses the @core.style decorator to apply the given style description, libraries, and placeholder to the Symbol.
@@ -2291,6 +2310,8 @@ class TemplateStylingPrimitives(Primitive):
2291
2310
  Returns:
2292
2311
  Symbol: A Symbol object with the style applied.
2293
2312
  '''
2313
+ if libraries is None:
2314
+ libraries = []
2294
2315
  @core.style(description=description, libraries=libraries, **kwargs)
2295
2316
  def _func(_):
2296
2317
  pass
@@ -2362,16 +2383,17 @@ class EmbeddingPrimitives(Primitive):
2362
2383
  '''
2363
2384
  # if the embedding is not yet computed, compute it
2364
2385
  if self._metadata.embedding is None:
2365
- if ((isinstance(self.value, list) or isinstance(self.value, tuple)) and all([type(x) == int or type(x) == float or type(x) == bool for x in self.value])) \
2386
+ if (isinstance(self.value, (list, tuple)) and all(isinstance(x, (int, float, bool)) for x in self.value)) \
2366
2387
  or isinstance(self.value, np.ndarray):
2367
- if isinstance(self.value, list) or isinstance(self.value, tuple):
2388
+ if isinstance(self.value, (list, tuple)):
2368
2389
  assert len(self.value) > 0, 'Cannot compute embedding of empty list'
2369
- if isinstance(self.value[0], Symbol):
2390
+ symbol_type = self._symbol_type
2391
+ if isinstance(self.value[0], symbol_type):
2370
2392
  # convert each element to numpy array
2371
2393
  self._metadata.embedding = np.asarray([x.embedding for x in self.value])
2372
2394
  elif isinstance(self.value[0], str):
2373
2395
  # embed each string
2374
- self._metadata.embedding = np.asarray([Symbol(x).embedding for x in self.value])
2396
+ self._metadata.embedding = np.asarray([symbol_type(x).embedding for x in self.value])
2375
2397
  else:
2376
2398
  # convert to numpy array
2377
2399
  self._metadata.embedding = np.asarray(self.value)
@@ -2390,23 +2412,195 @@ class EmbeddingPrimitives(Primitive):
2390
2412
 
2391
2413
  def _ensure_numpy_format(self, x, cast=False):
2392
2414
  # if it is a Symbol, get its value
2393
- if not isinstance(x, np.ndarray) or not isinstance(x, torch.Tensor) or not isinstance(x, list):
2415
+ if not isinstance(x, (np.ndarray, torch.Tensor, list)):
2394
2416
  if not isinstance(x, self._symbol_type): #@NOTE: enforce Symbol to avoid circular import
2395
2417
  if not cast:
2396
- raise TypeError(f'Cannot compute similarity with type {type(x)}')
2418
+ msg = f'Cannot compute similarity with type {type(x)}'
2419
+ UserMessage(msg)
2420
+ raise TypeError(msg)
2397
2421
  x = self._symbol_type(x)
2398
2422
  # evaluate the Symbol as an embedding
2399
2423
  x = x.embedding
2400
2424
  # if it is a list, convert it to numpy
2401
- if isinstance(x, list) or isinstance(x, tuple):
2425
+ if isinstance(x, (list, tuple)):
2402
2426
  assert len(x) > 0, 'Cannot compute similarity with empty list'
2403
2427
  x = np.asarray(x)
2404
2428
  # if it is a tensor, convert it to numpy
2405
2429
  elif isinstance(x, torch.Tensor):
2406
2430
  x = x.detach().cpu().numpy()
2407
- return x.squeeze()[:, None]
2431
+ else:
2432
+ x = np.asarray(x)
2408
2433
 
2409
- def similarity(self, other: Union['Symbol', list, np.ndarray, torch.Tensor], metric: Union['cosine', 'angular-cosine', 'product', 'manhattan', 'euclidean', 'minkowski', 'jaccard'] = 'cosine', eps: float = 1e-8, normalize: Optional[Callable] = None, **kwargs) -> float:
2434
+ x = np.squeeze(x)
2435
+ if x.ndim == 0:
2436
+ x = x[None]
2437
+
2438
+ return x[:, None]
2439
+
2440
+ def _prepare_embedding_operand(self, operand):
2441
+ if isinstance(operand, (list, tuple)):
2442
+ if self._is_numeric_sequence(operand):
2443
+ return self._ensure_numpy_format(operand, cast=True)
2444
+ formatted = [
2445
+ self._ensure_numpy_format(item, cast=True) for item in operand
2446
+ ]
2447
+ return np.concatenate(formatted, axis=1)
2448
+ return self._ensure_numpy_format(operand, cast=True)
2449
+
2450
+ def _is_numeric_sequence(self, operand: Iterable):
2451
+ for item in operand:
2452
+ if isinstance(item, (list, tuple, np.ndarray, torch.Tensor, self._symbol_type)):
2453
+ return False
2454
+ if isinstance(item, (numbers.Real, np.generic)):
2455
+ continue
2456
+ return False
2457
+ return True
2458
+
2459
+ def _get_similarity_handler(self, metric, eps, kwargs):
2460
+ def _cosine_similarity(lhs, rhs):
2461
+ return lhs.T@rhs / (np.sqrt(lhs.T@lhs) * np.sqrt(rhs.T@rhs) + eps)
2462
+
2463
+ def _angular_cosine_similarity(lhs, rhs):
2464
+ c = kwargs.get('c', 1)
2465
+ return 1 - (c * np.arccos(lhs.T@rhs / (np.sqrt(lhs.T@lhs) * np.sqrt(rhs.T@rhs) + eps)) / np.pi)
2466
+
2467
+ def _product_similarity(lhs, rhs):
2468
+ return lhs.T@rhs
2469
+
2470
+ def _manhattan_similarity(lhs, rhs):
2471
+ return np.abs(lhs - rhs).sum(axis=0, keepdims=True)
2472
+
2473
+ def _euclidean_similarity(lhs, rhs):
2474
+ return np.sqrt(np.sum((lhs - rhs)**2, axis=0, keepdims=True))
2475
+
2476
+ def _minkowski_similarity(lhs, rhs):
2477
+ p = kwargs.get('p', 3)
2478
+ return np.sum(np.abs(lhs - rhs)**p, axis=0, keepdims=True)**(1/p)
2479
+
2480
+ def _jaccard_similarity(lhs, rhs):
2481
+ intersection = np.minimum(lhs, rhs)
2482
+ union = np.maximum(lhs, rhs)
2483
+ return np.sum(intersection, axis=0, keepdims=True) / (np.sum(union, axis=0, keepdims=True) + eps)
2484
+
2485
+ metric_handlers = {
2486
+ 'cosine': _cosine_similarity,
2487
+ 'angular-cosine': _angular_cosine_similarity,
2488
+ 'product': _product_similarity,
2489
+ 'manhattan': _manhattan_similarity,
2490
+ 'euclidean': _euclidean_similarity,
2491
+ 'minkowski': _minkowski_similarity,
2492
+ 'jaccard': _jaccard_similarity,
2493
+ }
2494
+
2495
+ handler = metric_handlers.get(metric)
2496
+ if handler is None:
2497
+ msg = (
2498
+ f"Similarity metric {metric} not implemented. Available metrics: "
2499
+ "'cosine', 'angular-cosine', 'product', 'manhattan', 'euclidean', 'minkowski', 'jaccard'"
2500
+ )
2501
+ UserMessage(msg)
2502
+ raise NotImplementedError(msg)
2503
+ return handler
2504
+
2505
+ def _get_kernel_handler(self, kernel):
2506
+ kernel_handlers = {
2507
+ 'gaussian': self._kernel_gaussian,
2508
+ 'rbf': self._kernel_rbf,
2509
+ 'laplacian': self._kernel_laplacian,
2510
+ 'polynomial': self._kernel_polynomial,
2511
+ 'sigmoid': self._kernel_sigmoid,
2512
+ 'linear': self._kernel_linear,
2513
+ 'cauchy': self._kernel_cauchy,
2514
+ 't-distribution': self._kernel_t_distribution,
2515
+ 'inverse-multiquadric': self._kernel_inverse_multiquadric,
2516
+ 'cosine': self._kernel_cosine,
2517
+ 'angular-cosine': self._kernel_angular_cosine,
2518
+ 'frechet': self._kernel_frechet,
2519
+ 'mmd': self._kernel_mmd,
2520
+ }
2521
+
2522
+ handler = kernel_handlers.get(kernel)
2523
+ if handler is None:
2524
+ msg = "Kernel function {kernel} not implemented. Available functions: 'gaussian'"
2525
+ UserMessage(msg.format(kernel=kernel))
2526
+ raise NotImplementedError(msg.format(kernel=kernel))
2527
+ return handler
2528
+
2529
+ def _kernel_gaussian(self, lhs, rhs, _eps, kwargs):
2530
+ gamma = kwargs.get('gamma', 1)
2531
+ return np.exp(-gamma * np.sum((lhs - rhs)**2, axis=0))
2532
+
2533
+ def _kernel_rbf(self, lhs, rhs, _eps, kwargs):
2534
+ bandwidth = kwargs.get('bandwidth')
2535
+ gamma = kwargs.get('gamma', 1)
2536
+ distance_sq = np.sum((lhs - rhs)**2, axis=0)
2537
+ if bandwidth is not None:
2538
+ val = 0
2539
+ for a in bandwidth:
2540
+ gamma = 1.0 / (2 * a)
2541
+ val += np.exp(-gamma * distance_sq)
2542
+ return val
2543
+ return np.exp(-gamma * distance_sq)
2544
+
2545
+ def _kernel_laplacian(self, lhs, rhs, _eps, kwargs):
2546
+ gamma = kwargs.get('gamma', 1)
2547
+ return np.exp(-gamma * np.sum(np.abs(lhs - rhs), axis=0))
2548
+
2549
+ def _kernel_polynomial(self, lhs, rhs, _eps, kwargs):
2550
+ gamma = kwargs.get('gamma', 1)
2551
+ degree = kwargs.get('degree', 3)
2552
+ coef = kwargs.get('coef', 1)
2553
+ return (gamma * np.sum((lhs * rhs), axis=0) + coef)**degree
2554
+
2555
+ def _kernel_sigmoid(self, lhs, rhs, _eps, kwargs):
2556
+ gamma = kwargs.get('gamma', 1)
2557
+ coef = kwargs.get('coef', 1)
2558
+ return np.tanh(gamma * np.sum((lhs * rhs), axis=0) + coef)
2559
+
2560
+ def _kernel_linear(self, lhs, rhs, _eps, _kwargs):
2561
+ return np.sum((lhs * rhs), axis=0)
2562
+
2563
+ def _kernel_cauchy(self, lhs, rhs, _eps, kwargs):
2564
+ gamma = kwargs.get('gamma', 1)
2565
+ return 1 / (1 + np.sum((lhs - rhs)**2, axis=0) / gamma)
2566
+
2567
+ def _kernel_t_distribution(self, lhs, rhs, _eps, kwargs):
2568
+ gamma = kwargs.get('gamma', 1)
2569
+ degree = kwargs.get('degree', 1)
2570
+ return 1 / (1 + (np.sum((lhs - rhs)**2, axis=0) / (gamma * degree))**(degree + 1) / 2)
2571
+
2572
+ def _kernel_inverse_multiquadric(self, lhs, rhs, _eps, kwargs):
2573
+ gamma = kwargs.get('gamma', 1)
2574
+ return 1 / np.sqrt(np.sum((lhs - rhs)**2, axis=0) / gamma**2 + 1)
2575
+
2576
+ def _kernel_cosine(self, lhs, rhs, eps, _kwargs):
2577
+ numerator = np.sum(lhs * rhs, axis=0)
2578
+ denominator = np.sqrt(np.sum(lhs**2, axis=0)) * np.sqrt(np.sum(rhs**2, axis=0)) + eps
2579
+ return 1 - (numerator / denominator)
2580
+
2581
+ def _kernel_angular_cosine(self, lhs, rhs, eps, kwargs):
2582
+ c = kwargs.get('c', 1)
2583
+ numerator = np.sum(lhs * rhs, axis=0)
2584
+ denominator = np.sqrt(np.sum(lhs**2, axis=0)) * np.sqrt(np.sum(rhs**2, axis=0)) + eps
2585
+ return c * np.arccos(numerator / denominator) / np.pi
2586
+
2587
+ def _kernel_frechet(self, lhs, rhs, eps, kwargs):
2588
+ sigma1 = kwargs.get('sigma1')
2589
+ sigma2 = kwargs.get('sigma2')
2590
+ assert sigma1 is not None and sigma2 is not None, 'Frechet distance requires covariance matrices for both inputs'
2591
+ return calculate_frechet_distance(lhs.T, sigma1, rhs.T, sigma2, eps)
2592
+
2593
+ def _kernel_mmd(self, lhs, rhs, eps, _kwargs):
2594
+ return calculate_mmd(lhs.T, rhs.T, eps=eps)
2595
+
2596
+ def similarity(
2597
+ self,
2598
+ other: Union['Symbol', list, np.ndarray, torch.Tensor],
2599
+ metric: Literal['cosine', 'angular-cosine', 'product', 'manhattan', 'euclidean', 'minkowski', 'jaccard'] = 'cosine',
2600
+ eps: float = 1e-8,
2601
+ normalize: Callable | None = None,
2602
+ **kwargs,
2603
+ ) -> float:
2410
2604
  '''
2411
2605
  Calculates the similarity between two Symbol objects using a specified metric.
2412
2606
  This method compares the values of two Symbol objects and calculates their similarity according to the specified metric.
@@ -2427,45 +2621,29 @@ class EmbeddingPrimitives(Primitive):
2427
2621
  NotImplementedError: If the given metric is not supported.
2428
2622
  '''
2429
2623
  v = self._ensure_numpy_format(self)
2430
- if isinstance(other, list) or isinstance(other, tuple):
2431
- o = []
2432
- for i in range(len(other)):
2433
- o.append(self._ensure_numpy_format(other[i], cast=True))
2434
- o = np.concatenate(o, axis=1)
2435
- else:
2436
- o = self._ensure_numpy_format(other, cast=True)
2437
-
2438
- if metric == 'cosine':
2439
- val = v.T@o / (np.sqrt(v.T@v) * np.sqrt(o.T@o) + eps)
2440
- elif metric == 'angular-cosine':
2441
- c = kwargs.get('c', 1)
2442
- val = 1 - (c * np.arccos((v.T@o / (np.sqrt(v.T@v) * np.sqrt(o.T@o) + eps))) / np.pi)
2443
- elif metric == 'product':
2444
- val = v.T@o
2445
- elif metric == 'manhattan':
2446
- val = np.abs(v - o).sum(axis=0, keepdims=True)
2447
- elif metric == 'euclidean':
2448
- val = np.sqrt(np.sum((v - o)**2, axis=0, keepdims=True))
2449
- elif metric == 'minkowski':
2450
- p = kwargs.get('p', 3)
2451
- val = np.sum(np.abs(v - o)**p, axis=0, keepdims=True)**(1/p)
2452
- elif metric == 'jaccard':
2453
- intersection = np.minimum(v, o)
2454
- union = np.maximum(v, o)
2455
- val = np.sum(intersection, axis=0, keepdims=True) / (np.sum(union, axis=0, keepdims=True) + eps)
2456
- else:
2457
- raise NotImplementedError(f"Similarity metric {metric} not implemented. Available metrics: 'cosine', 'angular-cosine', 'product', 'manhattan', 'euclidean', 'minkowski', 'jaccard'")
2624
+ o = self._prepare_embedding_operand(other)
2625
+ handler = self._get_similarity_handler(metric, eps, kwargs)
2626
+ val = handler(v, o)
2458
2627
 
2459
2628
  # get the similarity value(s)
2460
2629
  shape = val.shape
2461
- if len(shape) >= 2 and min(shape) > 1: val = val.diagonal()
2462
- elif len(shape) >= 1 and shape[0] > 1: val = val
2463
- else: val = val.item()
2464
- if normalize is not None: val = normalize(val)
2630
+ if len(shape) >= 2 and min(shape) > 1:
2631
+ val = val.diagonal()
2632
+ elif len(shape) < 1 or shape[0] <= 1:
2633
+ val = val.item()
2634
+ if normalize is not None:
2635
+ val = normalize(val)
2465
2636
 
2466
2637
  return val
2467
2638
 
2468
- def distance(self, other: Union['Symbol', list, np.ndarray, torch.Tensor], kernel: Union['gaussian', 'rbf', 'laplacian', 'polynomial', 'sigmoid', 'linear', 'cauchy', 't-distribution', 'inverse-multiquadric', 'cosine', 'angular-cosine', 'frechet', 'mmd'] = 'gaussian', eps: float = 1e-8, normalize: Optional[Callable] = None, **kwargs) -> float:
2639
+ def distance(
2640
+ self,
2641
+ other: Union['Symbol', list, np.ndarray, torch.Tensor],
2642
+ kernel: Literal['gaussian', 'rbf', 'laplacian', 'polynomial', 'sigmoid', 'linear', 'cauchy', 't-distribution', 'inverse-multiquadric', 'cosine', 'angular-cosine', 'frechet', 'mmd'] = 'gaussian',
2643
+ eps: float = 1e-8,
2644
+ normalize: Callable | None = None,
2645
+ **kwargs,
2646
+ ) -> float:
2469
2647
  '''
2470
2648
  Calculates the kernel between two Symbol objects.
2471
2649
 
@@ -2483,82 +2661,18 @@ class EmbeddingPrimitives(Primitive):
2483
2661
  NotImplementedError: If the given kernel is not supported.
2484
2662
  '''
2485
2663
  v = self._ensure_numpy_format(self)
2486
- if isinstance(other, list) or isinstance(other, tuple):
2487
- o = []
2488
- for i in range(len(other)):
2489
- o.append(self._ensure_numpy_format(other[i], cast=True))
2490
- o = np.concatenate(o, axis=1)
2491
- else:
2492
- o = self._ensure_numpy_format(other, cast=True)
2493
-
2494
- # compute the kernel value
2495
- if kernel == 'gaussian':
2496
- gamma = kwargs.get('gamma', 1)
2497
- val = np.exp(-gamma * np.sum((v - o)**2, axis=0))
2498
- elif kernel == 'rbf':
2499
- # vectors are expected to be normalized
2500
- bandwidth = kwargs.get('bandwidth', None)
2501
- gamma = kwargs.get('gamma', 1)
2502
- d = np.sum((v - o)**2, axis=0)
2503
- if bandwidth is not None:
2504
- val = 0
2505
- for a in bandwidth:
2506
- gamma = 1.0 / (2 * a)
2507
- val += np.exp(-gamma * d)
2508
- else:
2509
- # if no bandwidth is given, default to the gaussian kernel
2510
- val = np.exp(-gamma * d)
2511
- elif kernel == 'laplacian':
2512
- gamma = kwargs.get('gamma', 1)
2513
- val = np.exp(-gamma * np.sum(np.abs(v - o), axis=0))
2514
- elif kernel == 'polynomial':
2515
- gamma = kwargs.get('gamma', 1)
2516
- degree = kwargs.get('degree', 3)
2517
- coef = kwargs.get('coef', 1)
2518
- val = (gamma * np.sum((v * o), axis=0) + coef)**degree
2519
- elif kernel == 'sigmoid':
2520
- gamma = kwargs.get('gamma', 1)
2521
- coef = kwargs.get('coef', 1)
2522
- val = np.tanh(gamma * np.sum((v * o), axis=0) + coef)
2523
- elif kernel == 'linear':
2524
- val = np.sum((v * o), axis=0)
2525
- elif kernel == 'cauchy':
2526
- gamma = kwargs.get('gamma', 1)
2527
- val = 1 / (1 + np.sum((v - o)**2, axis=0) / gamma)
2528
- elif kernel == 't-distribution':
2529
- gamma = kwargs.get('gamma', 1)
2530
- degree = kwargs.get('degree', 1)
2531
- val = 1 / (1 + (np.sum((v - o)**2, axis=0) / (gamma * degree))**(degree + 1) / 2)
2532
- elif kernel == 'inverse-multiquadric':
2533
- gamma = kwargs.get('gamma', 1)
2534
- val = 1 / np.sqrt(np.sum((v - o)**2, axis=0) / gamma**2 + 1)
2535
- elif kernel == 'cosine':
2536
- val = 1 - (np.sum(v * o, axis=0) / (np.sqrt(np.sum(v**2, axis=0)) * np.sqrt(np.sum(o**2, axis=0)) + eps))
2537
- elif kernel == 'angular-cosine':
2538
- c = kwargs.get('c', 1)
2539
- val = c * np.arccos((np.sum(v * o, axis=0) / (np.sqrt(np.sum(v**2, axis=0)) * np.sqrt(np.sum(o**2, axis=0)) + eps))) / np.pi
2540
- elif kernel == 'frechet':
2541
- sigma1 = kwargs.get('sigma1', None)
2542
- sigma2 = kwargs.get('sigma2', None)
2543
- assert sigma1 is not None and sigma2 is not None, 'Frechet distance requires covariance matrices for both inputs'
2544
- v = v.T
2545
- o = o.T
2546
- val = calculate_frechet_distance(v, sigma1, o, sigma2, eps)
2547
- elif kernel == 'mmd':
2548
- v = v.T
2549
- o = o.T
2550
- val = calculate_mmd(v, o, eps=eps)
2551
- else:
2552
- raise NotImplementedError(f"Kernel function {kernel} not implemented. Available functions: 'gaussian'")
2664
+ o = self._prepare_embedding_operand(other)
2665
+ handler = self._get_kernel_handler(kernel)
2666
+ val = handler(v, o, eps, kwargs)
2553
2667
 
2554
2668
  # get the kernel value(s)
2555
2669
  shape = val.shape
2556
- if len(shape) >= 1 and shape[0] > 1: val = val
2557
- else: val = val.item()
2558
- if normalize is not None: val = normalize(val)
2670
+ val = val if len(shape) >= 1 and shape[0] > 1 else val.item()
2671
+ if normalize is not None:
2672
+ val = normalize(val)
2559
2673
  return val
2560
2674
 
2561
- def zip(self, **kwargs) -> List[Tuple[str, List, Dict]]:
2675
+ def zip(self, **kwargs) -> list[tuple[str, list, dict]]:
2562
2676
  '''
2563
2677
  Zips the Symbol's value with its embeddings and a query containing the value.
2564
2678
  This method zips the Symbol's value along with its embeddings and a query containing the value.
@@ -2577,19 +2691,21 @@ class EmbeddingPrimitives(Primitive):
2577
2691
  elif isinstance(self.value, list):
2578
2692
  pass
2579
2693
  else:
2580
- raise ValueError(f'Expected id to be a string, got {type(self.value)}')
2694
+ msg = f'Expected id to be a string, got {type(self.value)}'
2695
+ UserMessage(msg)
2696
+ raise ValueError(msg)
2581
2697
 
2582
2698
  embeds = self.embed(**kwargs).value
2583
2699
  idx = [str(uuid.uuid4()) for _ in range(len(self.value))]
2584
2700
  query = [{'text': str(self.value[i])} for i in range(len(self.value))]
2585
2701
 
2586
2702
  # convert embeds to list if it is a tensor or numpy array
2587
- if type(embeds) == np.ndarray:
2703
+ if isinstance(embeds, np.ndarray):
2588
2704
  embeds = embeds.tolist()
2589
- elif type(embeds) == torch.Tensor:
2705
+ elif isinstance(embeds, torch.Tensor):
2590
2706
  embeds = embeds.cpu().numpy().tolist()
2591
2707
 
2592
- return list(zip(idx, embeds, query))
2708
+ return list(zip(idx, embeds, query, strict=False))
2593
2709
 
2594
2710
 
2595
2711
  class IOHandlingPrimitives(Primitive):
@@ -2630,7 +2746,7 @@ class IOHandlingPrimitives(Primitive):
2630
2746
  return self.sym_return_type(self.value if condition else '') | res
2631
2747
  return self._to_type(self.value if condition else '') | self._to_type(res)
2632
2748
 
2633
- def open(self, path: str = None, **kwargs) -> 'Symbol':
2749
+ def open(self, path: str | None = None, **kwargs) -> 'Symbol':
2634
2750
  '''
2635
2751
  Open a file and store its content in an Expression object as a string.
2636
2752
 
@@ -2654,7 +2770,9 @@ class IOHandlingPrimitives(Primitive):
2654
2770
 
2655
2771
  path = path if path is not None else self.value
2656
2772
  if path is None:
2657
- raise ValueError('Path is not provided; either provide a path or set the value of the Symbol to the path')
2773
+ msg = 'Path is not provided; either provide a path or set the value of the Symbol to the path'
2774
+ UserMessage(msg)
2775
+ raise ValueError(msg)
2658
2776
 
2659
2777
  @core.opening(path=path, **kwargs)
2660
2778
  def _func(_):
@@ -2726,7 +2844,7 @@ class PersistencePrimitives(Primitive):
2726
2844
 
2727
2845
  return func_name
2728
2846
 
2729
- def save(self, path: str, replace: Optional[bool] = False, serialize: Optional[bool] = True) -> None:
2847
+ def save(self, path: str, replace: bool | None = False, serialize: bool | None = True) -> None:
2730
2848
  '''
2731
2849
  Save the current Symbol to a file.
2732
2850
 
@@ -2738,22 +2856,24 @@ class PersistencePrimitives(Primitive):
2738
2856
  Returns:
2739
2857
  Symbol: The current Symbol.
2740
2858
  '''
2741
- file_path = path
2859
+ file_path = Path(path)
2742
2860
 
2743
2861
  if not replace:
2744
2862
  cnt = 0
2745
- while os.path.exists(file_path):
2746
- filename, file_extension = os.path.splitext(path)
2747
- file_path = f'{filename}_{cnt}{file_extension}'
2863
+ candidate = file_path
2864
+ while candidate.exists():
2865
+ candidate = candidate.with_name(f'{file_path.stem}_{cnt}{file_path.suffix}')
2748
2866
  cnt += 1
2867
+ file_path = candidate
2749
2868
 
2750
2869
  if serialize:
2751
2870
  # serialize the object via pickle instead of writing the string
2752
- path_ = str(file_path) + '.pkl' if not str(file_path).endswith('.pkl') else str(file_path)
2753
- with open(path_, 'wb') as f:
2871
+ path_str = str(file_path)
2872
+ pickle_path = Path(path_str if path_str.endswith('.pkl') else f'{path_str}.pkl')
2873
+ with pickle_path.open('wb') as f:
2754
2874
  pickle.dump(self, file=f)
2755
2875
  else:
2756
- with open(str(file_path), 'w') as f:
2876
+ with file_path.open('w') as f:
2757
2877
  f.write(str(self))
2758
2878
 
2759
2879
  def load(self, path: str) -> Any:
@@ -2766,9 +2886,8 @@ class PersistencePrimitives(Primitive):
2766
2886
  Returns:
2767
2887
  Symbol: The loaded Symbol.
2768
2888
  '''
2769
- with open(path, 'rb') as f:
2770
- obj = pickle.load(f)
2771
- return obj
2889
+ with Path(path).open('rb') as f:
2890
+ return pickle.load(f)
2772
2891
 
2773
2892
 
2774
2893
  class OutputHandlingPrimitives(Primitive):
@@ -2789,7 +2908,7 @@ class OutputHandlingPrimitives(Primitive):
2789
2908
  Symbol: The resulting Symbol after the output operation.
2790
2909
  '''
2791
2910
  @core.output(**kwargs)
2792
- def _func(_, *func_args, **func_kwargs):
2911
+ def _func(_, *_func_args, **_func_kwargs):
2793
2912
  return self.value
2794
2913
 
2795
2914
  return self._to_type(_func(self, self.value, *args))
@@ -2833,12 +2952,13 @@ class FineTuningPrimitives(Primitive):
2833
2952
  # return tensor
2834
2953
  return self._metadata.data
2835
2954
  # if the data is a numpy array, convert it to tensor
2836
- elif isinstance(self._metadata.data, np.ndarray):
2955
+ if isinstance(self._metadata.data, np.ndarray):
2837
2956
  # convert to tensor
2838
2957
  self._metadata.data = torch.from_numpy(self._metadata.data)
2839
2958
  return self._metadata.data
2840
- else:
2841
- raise TypeError(f'Expected data to be a tensor or numpy array, got {type(self._metadata.data)}')
2959
+ msg = f'Expected data to be a tensor or numpy array, got {type(self._metadata.data)}'
2960
+ UserMessage(msg)
2961
+ raise TypeError(msg)
2842
2962
 
2843
2963
  @data.setter
2844
2964
  def data(self, data: torch.Tensor) -> None: