symbolicai 0.20.2__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- symai/__init__.py +96 -64
- symai/backend/base.py +93 -80
- symai/backend/engines/drawing/engine_bfl.py +12 -11
- symai/backend/engines/drawing/engine_gpt_image.py +108 -87
- symai/backend/engines/embedding/engine_llama_cpp.py +25 -28
- symai/backend/engines/embedding/engine_openai.py +3 -5
- symai/backend/engines/execute/engine_python.py +6 -5
- symai/backend/engines/files/engine_io.py +74 -67
- symai/backend/engines/imagecaptioning/engine_blip2.py +3 -3
- symai/backend/engines/imagecaptioning/engine_llavacpp_client.py +54 -38
- symai/backend/engines/index/engine_pinecone.py +23 -24
- symai/backend/engines/index/engine_vectordb.py +16 -14
- symai/backend/engines/lean/engine_lean4.py +38 -34
- symai/backend/engines/neurosymbolic/__init__.py +41 -13
- symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_chat.py +262 -182
- symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_reasoning.py +263 -191
- symai/backend/engines/neurosymbolic/engine_deepseekX_reasoning.py +53 -49
- symai/backend/engines/neurosymbolic/engine_google_geminiX_reasoning.py +212 -211
- symai/backend/engines/neurosymbolic/engine_groq.py +87 -63
- symai/backend/engines/neurosymbolic/engine_huggingface.py +21 -24
- symai/backend/engines/neurosymbolic/engine_llama_cpp.py +117 -48
- symai/backend/engines/neurosymbolic/engine_openai_gptX_chat.py +256 -229
- symai/backend/engines/neurosymbolic/engine_openai_gptX_reasoning.py +270 -150
- symai/backend/engines/ocr/engine_apilayer.py +6 -8
- symai/backend/engines/output/engine_stdout.py +1 -4
- symai/backend/engines/search/engine_openai.py +7 -7
- symai/backend/engines/search/engine_perplexity.py +5 -5
- symai/backend/engines/search/engine_serpapi.py +12 -14
- symai/backend/engines/speech_to_text/engine_local_whisper.py +20 -27
- symai/backend/engines/symbolic/engine_wolframalpha.py +3 -3
- symai/backend/engines/text_to_speech/engine_openai.py +5 -7
- symai/backend/engines/text_vision/engine_clip.py +7 -11
- symai/backend/engines/userinput/engine_console.py +3 -3
- symai/backend/engines/webscraping/engine_requests.py +81 -48
- symai/backend/mixin/__init__.py +13 -0
- symai/backend/mixin/anthropic.py +4 -2
- symai/backend/mixin/deepseek.py +2 -0
- symai/backend/mixin/google.py +2 -0
- symai/backend/mixin/openai.py +11 -3
- symai/backend/settings.py +83 -16
- symai/chat.py +101 -78
- symai/collect/__init__.py +7 -1
- symai/collect/dynamic.py +77 -69
- symai/collect/pipeline.py +35 -27
- symai/collect/stats.py +75 -63
- symai/components.py +198 -169
- symai/constraints.py +15 -12
- symai/core.py +698 -359
- symai/core_ext.py +32 -34
- symai/endpoints/api.py +80 -73
- symai/extended/.DS_Store +0 -0
- symai/extended/__init__.py +46 -12
- symai/extended/api_builder.py +11 -8
- symai/extended/arxiv_pdf_parser.py +13 -12
- symai/extended/bibtex_parser.py +2 -3
- symai/extended/conversation.py +101 -90
- symai/extended/document.py +17 -10
- symai/extended/file_merger.py +18 -13
- symai/extended/graph.py +18 -13
- symai/extended/html_style_template.py +2 -4
- symai/extended/interfaces/blip_2.py +1 -2
- symai/extended/interfaces/clip.py +1 -2
- symai/extended/interfaces/console.py +7 -1
- symai/extended/interfaces/dall_e.py +1 -1
- symai/extended/interfaces/flux.py +1 -1
- symai/extended/interfaces/gpt_image.py +1 -1
- symai/extended/interfaces/input.py +1 -1
- symai/extended/interfaces/llava.py +0 -1
- symai/extended/interfaces/naive_vectordb.py +7 -8
- symai/extended/interfaces/naive_webscraping.py +1 -1
- symai/extended/interfaces/ocr.py +1 -1
- symai/extended/interfaces/pinecone.py +6 -5
- symai/extended/interfaces/serpapi.py +1 -1
- symai/extended/interfaces/terminal.py +2 -3
- symai/extended/interfaces/tts.py +1 -1
- symai/extended/interfaces/whisper.py +1 -1
- symai/extended/interfaces/wolframalpha.py +1 -1
- symai/extended/metrics/__init__.py +11 -1
- symai/extended/metrics/similarity.py +11 -13
- symai/extended/os_command.py +17 -16
- symai/extended/packages/__init__.py +29 -3
- symai/extended/packages/symdev.py +19 -16
- symai/extended/packages/sympkg.py +12 -9
- symai/extended/packages/symrun.py +21 -19
- symai/extended/repo_cloner.py +11 -10
- symai/extended/seo_query_optimizer.py +1 -2
- symai/extended/solver.py +20 -23
- symai/extended/summarizer.py +4 -3
- symai/extended/taypan_interpreter.py +10 -12
- symai/extended/vectordb.py +99 -82
- symai/formatter/__init__.py +9 -1
- symai/formatter/formatter.py +12 -16
- symai/formatter/regex.py +62 -63
- symai/functional.py +176 -122
- symai/imports.py +136 -127
- symai/interfaces.py +56 -27
- symai/memory.py +14 -13
- symai/misc/console.py +49 -39
- symai/misc/loader.py +5 -3
- symai/models/__init__.py +17 -1
- symai/models/base.py +269 -181
- symai/models/errors.py +0 -1
- symai/ops/__init__.py +32 -22
- symai/ops/measures.py +11 -15
- symai/ops/primitives.py +348 -228
- symai/post_processors.py +32 -28
- symai/pre_processors.py +39 -41
- symai/processor.py +6 -4
- symai/prompts.py +59 -45
- symai/server/huggingface_server.py +23 -20
- symai/server/llama_cpp_server.py +7 -5
- symai/shell.py +3 -4
- symai/shellsv.py +499 -375
- symai/strategy.py +517 -287
- symai/symbol.py +111 -116
- symai/utils.py +42 -36
- {symbolicai-0.20.2.dist-info → symbolicai-1.0.0.dist-info}/METADATA +4 -2
- symbolicai-1.0.0.dist-info/RECORD +163 -0
- symbolicai-0.20.2.dist-info/RECORD +0 -162
- {symbolicai-0.20.2.dist-info → symbolicai-1.0.0.dist-info}/WHEEL +0 -0
- {symbolicai-0.20.2.dist-info → symbolicai-1.0.0.dist-info}/entry_points.txt +0 -0
- {symbolicai-0.20.2.dist-info → symbolicai-1.0.0.dist-info}/licenses/LICENSE +0 -0
- {symbolicai-0.20.2.dist-info → symbolicai-1.0.0.dist-info}/top_level.txt +0 -0
symai/ops/primitives.py
CHANGED
|
@@ -1,19 +1,18 @@
|
|
|
1
1
|
import ast
|
|
2
2
|
import json
|
|
3
|
-
import
|
|
3
|
+
import numbers
|
|
4
4
|
import pickle
|
|
5
5
|
import uuid
|
|
6
|
-
from
|
|
7
|
-
|
|
6
|
+
from collections.abc import Callable, Iterable
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Literal, Union
|
|
8
9
|
|
|
9
10
|
import numpy as np
|
|
10
11
|
import torch
|
|
11
|
-
from pydantic import ValidationError
|
|
12
12
|
|
|
13
13
|
from .. import core, core_ext
|
|
14
|
-
from ..models import CustomConstraint, LengthConstraint, LLMDataModel
|
|
15
14
|
from ..prompts import Prompt
|
|
16
|
-
from ..utils import
|
|
15
|
+
from ..utils import UserMessage
|
|
17
16
|
from .measures import calculate_frechet_distance, calculate_mmd
|
|
18
17
|
|
|
19
18
|
if TYPE_CHECKING:
|
|
@@ -41,20 +40,21 @@ class Primitive:
|
|
|
41
40
|
|
|
42
41
|
|
|
43
42
|
class OperatorPrimitives(Primitive):
|
|
44
|
-
|
|
43
|
+
__hash__ = None
|
|
44
|
+
|
|
45
|
+
def __try_type_specific_func(self, other, func, op: str | None = None):
|
|
45
46
|
if not isinstance(other, self._symbol_type):
|
|
46
47
|
other = self._to_type(other)
|
|
47
48
|
# None shortcut
|
|
48
|
-
if not self.__disable_none_shortcut__:
|
|
49
|
-
|
|
50
|
-
CustomUserWarning(f"unsupported {self._symbol_type.__class__} value operand type(s) for {op}: '{type(self.value)}' and '{type(other.value)}'", raise_with=TypeError)
|
|
49
|
+
if not self.__disable_none_shortcut__ and (self.value is None or other.value is None):
|
|
50
|
+
UserMessage(f"unsupported {self._symbol_type.__class__} value operand type(s) for {op}: '{type(self.value)}' and '{type(other.value)}'", raise_with=TypeError)
|
|
51
51
|
# try type specific function
|
|
52
52
|
try:
|
|
53
53
|
# try type specific function
|
|
54
54
|
value = func(self, other)
|
|
55
55
|
if value is NotImplemented:
|
|
56
56
|
operation = '' if op is None else op
|
|
57
|
-
|
|
57
|
+
UserMessage(f"unsupported {self._symbol_type.__class__} value operand type(s) for {operation}: '{type(self.value)}' and '{type(other.value)}'", raise_with=TypeError)
|
|
58
58
|
return value
|
|
59
59
|
except Exception as ex:
|
|
60
60
|
self._metadata._error = ex
|
|
@@ -66,7 +66,7 @@ class OperatorPrimitives(Primitive):
|
|
|
66
66
|
This function raises an error if the neuro-symbolic engine is disabled.
|
|
67
67
|
'''
|
|
68
68
|
if self.__disable_nesy_engine__:
|
|
69
|
-
|
|
69
|
+
UserMessage(f"unsupported {self.__class__} value operand type(s) for {func.__name__}: '{type(self.value)}'", raise_with=TypeError)
|
|
70
70
|
|
|
71
71
|
def __bool__(self) -> bool:
|
|
72
72
|
'''
|
|
@@ -80,7 +80,7 @@ class OperatorPrimitives(Primitive):
|
|
|
80
80
|
if isinstance(self.value, bool):
|
|
81
81
|
val = self.value
|
|
82
82
|
elif self.value is not None:
|
|
83
|
-
val =
|
|
83
|
+
val = bool(self.value)
|
|
84
84
|
|
|
85
85
|
return val
|
|
86
86
|
|
|
@@ -680,8 +680,8 @@ class OperatorPrimitives(Primitive):
|
|
|
680
680
|
Returns:
|
|
681
681
|
Symbol: A new symbol with the result of the OR operation.
|
|
682
682
|
'''
|
|
683
|
-
#
|
|
684
|
-
from ..collect.stats import Aggregator
|
|
683
|
+
# Exclude the evaluation for the Aggregator class; keep import local to avoid ops.primitives <-> collect.stats cycle.
|
|
684
|
+
from ..collect.stats import Aggregator # noqa
|
|
685
685
|
if isinstance(other, Aggregator):
|
|
686
686
|
return NotImplemented
|
|
687
687
|
|
|
@@ -709,8 +709,8 @@ class OperatorPrimitives(Primitive):
|
|
|
709
709
|
Returns:
|
|
710
710
|
Symbol: A new Symbol object with the concatenated value.
|
|
711
711
|
'''
|
|
712
|
-
#
|
|
713
|
-
from ..collect.stats import Aggregator
|
|
712
|
+
# Exclude the evaluation for the Aggregator class; keep import local to avoid ops.primitives <-> collect.stats cycle.
|
|
713
|
+
from ..collect.stats import Aggregator # noqa
|
|
714
714
|
if isinstance(other, Aggregator):
|
|
715
715
|
return NotImplemented
|
|
716
716
|
|
|
@@ -739,8 +739,8 @@ class OperatorPrimitives(Primitive):
|
|
|
739
739
|
Returns:
|
|
740
740
|
Symbol: A new Symbol object with the concatenated value.
|
|
741
741
|
'''
|
|
742
|
-
#
|
|
743
|
-
from ..collect.stats import Aggregator
|
|
742
|
+
# Exclude the evaluation for the Aggregator class; keep import local to avoid ops.primitives <-> collect.stats cycle.
|
|
743
|
+
from ..collect.stats import Aggregator # noqa
|
|
744
744
|
if isinstance(other, Aggregator):
|
|
745
745
|
return NotImplemented
|
|
746
746
|
|
|
@@ -848,11 +848,12 @@ class OperatorPrimitives(Primitive):
|
|
|
848
848
|
Returns:
|
|
849
849
|
Symbol: A new Symbol object with the concatenated value.
|
|
850
850
|
'''
|
|
851
|
-
if isinstance(self.value, str) and isinstance(other, str) or \
|
|
852
|
-
isinstance(self.value, str) and isinstance(other, self._symbol_type) and isinstance(other.value, str):
|
|
851
|
+
if (isinstance(self.value, str) and isinstance(other, str)) or \
|
|
852
|
+
(isinstance(self.value, str) and isinstance(other, self._symbol_type) and isinstance(other.value, str)):
|
|
853
853
|
other = self._to_type(other)
|
|
854
854
|
return self._to_type(f'{self.value}{other.value}')
|
|
855
|
-
|
|
855
|
+
UserMessage(f'This method is only supported for string concatenation! Got {type(self.value)} and {type(other)} instead.', raise_with=TypeError)
|
|
856
|
+
return None
|
|
856
857
|
|
|
857
858
|
def __rmatmul__(self, other: Any) -> 'Symbol':
|
|
858
859
|
'''
|
|
@@ -917,7 +918,8 @@ class OperatorPrimitives(Primitive):
|
|
|
917
918
|
if result is not None:
|
|
918
919
|
return self._to_type(result)
|
|
919
920
|
|
|
920
|
-
|
|
921
|
+
UserMessage('Division operation not supported semantically! Might change in the future.', raise_with=NotImplementedError)
|
|
922
|
+
return None
|
|
921
923
|
|
|
922
924
|
|
|
923
925
|
def __itruediv__(self, other: Any) -> 'Symbol':
|
|
@@ -936,7 +938,8 @@ class OperatorPrimitives(Primitive):
|
|
|
936
938
|
self._value = result
|
|
937
939
|
return self
|
|
938
940
|
|
|
939
|
-
|
|
941
|
+
UserMessage('Division operation not supported semantically! Might change in the future.', raise_with=NotImplementedError)
|
|
942
|
+
return None
|
|
940
943
|
|
|
941
944
|
|
|
942
945
|
def __floordiv__(self, other: Any) -> 'Symbol':
|
|
@@ -954,7 +957,8 @@ class OperatorPrimitives(Primitive):
|
|
|
954
957
|
if result is not None:
|
|
955
958
|
return self._to_type(result)
|
|
956
959
|
|
|
957
|
-
|
|
960
|
+
UserMessage('Floor division operation not supported semantically! Might change in the future.', raise_with=NotImplementedError)
|
|
961
|
+
return None
|
|
958
962
|
|
|
959
963
|
def __rfloordiv__(self, other: Any) -> 'Symbol':
|
|
960
964
|
'''
|
|
@@ -971,7 +975,8 @@ class OperatorPrimitives(Primitive):
|
|
|
971
975
|
if result is not None:
|
|
972
976
|
return self._to_type(result)
|
|
973
977
|
|
|
974
|
-
|
|
978
|
+
UserMessage('Floor division operation not supported semantically! Might change in the future.', raise_with=NotImplementedError)
|
|
979
|
+
return None
|
|
975
980
|
|
|
976
981
|
def __ifloordiv__(self, other: Any) -> 'Symbol':
|
|
977
982
|
'''
|
|
@@ -989,7 +994,8 @@ class OperatorPrimitives(Primitive):
|
|
|
989
994
|
self._value = result
|
|
990
995
|
return self
|
|
991
996
|
|
|
992
|
-
|
|
997
|
+
UserMessage('Floor division operation not supported semantically! Might change in the future.', raise_with=NotImplementedError)
|
|
998
|
+
return None
|
|
993
999
|
|
|
994
1000
|
def __pow__(self, other: Any) -> 'Symbol':
|
|
995
1001
|
'''
|
|
@@ -1006,7 +1012,8 @@ class OperatorPrimitives(Primitive):
|
|
|
1006
1012
|
if result is not None:
|
|
1007
1013
|
return self._to_type(result)
|
|
1008
1014
|
|
|
1009
|
-
|
|
1015
|
+
UserMessage('Power operation not supported semantically! Might change in the future.', raise_with=NotImplementedError)
|
|
1016
|
+
return None
|
|
1010
1017
|
|
|
1011
1018
|
|
|
1012
1019
|
def __rpow__(self, other: Any) -> 'Symbol':
|
|
@@ -1024,7 +1031,8 @@ class OperatorPrimitives(Primitive):
|
|
|
1024
1031
|
if result is not None:
|
|
1025
1032
|
return self._to_type(result)
|
|
1026
1033
|
|
|
1027
|
-
|
|
1034
|
+
UserMessage('Power operation not supported semantically! Might change in the future.', raise_with=NotImplementedError)
|
|
1035
|
+
return None
|
|
1028
1036
|
|
|
1029
1037
|
def __ipow__(self, other: Any) -> 'Symbol':
|
|
1030
1038
|
'''
|
|
@@ -1042,7 +1050,8 @@ class OperatorPrimitives(Primitive):
|
|
|
1042
1050
|
self._value = result
|
|
1043
1051
|
return self
|
|
1044
1052
|
|
|
1045
|
-
|
|
1053
|
+
UserMessage('Power operation not supported semantically! Might change in the future.', raise_with=NotImplementedError)
|
|
1054
|
+
return None
|
|
1046
1055
|
|
|
1047
1056
|
def __mod__(self, other: Any) -> 'Symbol':
|
|
1048
1057
|
'''
|
|
@@ -1059,7 +1068,8 @@ class OperatorPrimitives(Primitive):
|
|
|
1059
1068
|
if result is not None:
|
|
1060
1069
|
return self._to_type(result)
|
|
1061
1070
|
|
|
1062
|
-
|
|
1071
|
+
UserMessage('Modulo operation not supported semantically! Might change in the future.', raise_with=NotImplementedError)
|
|
1072
|
+
return None
|
|
1063
1073
|
|
|
1064
1074
|
def __rmod__(self, other: Any) -> 'Symbol':
|
|
1065
1075
|
'''
|
|
@@ -1076,7 +1086,9 @@ class OperatorPrimitives(Primitive):
|
|
|
1076
1086
|
if result is not None:
|
|
1077
1087
|
return self._to_type(result)
|
|
1078
1088
|
|
|
1079
|
-
|
|
1089
|
+
msg = 'Modulo operation not supported! Might change in the future.'
|
|
1090
|
+
UserMessage(msg)
|
|
1091
|
+
raise NotImplementedError(msg) from self._metadata._error
|
|
1080
1092
|
|
|
1081
1093
|
def __imod__(self, other: Any) -> 'Symbol':
|
|
1082
1094
|
'''
|
|
@@ -1094,7 +1106,8 @@ class OperatorPrimitives(Primitive):
|
|
|
1094
1106
|
self._value = result
|
|
1095
1107
|
return self
|
|
1096
1108
|
|
|
1097
|
-
|
|
1109
|
+
UserMessage('Modulo operation not supported semantically! Might change in the future.', raise_with=NotImplementedError)
|
|
1110
|
+
return None
|
|
1098
1111
|
|
|
1099
1112
|
def __mul__(self, other: Any) -> 'Symbol':
|
|
1100
1113
|
'''
|
|
@@ -1111,7 +1124,8 @@ class OperatorPrimitives(Primitive):
|
|
|
1111
1124
|
if result is not None:
|
|
1112
1125
|
return self._to_type(result)
|
|
1113
1126
|
|
|
1114
|
-
|
|
1127
|
+
UserMessage('Multiply operation not supported semantically! Might change in the future.', raise_with=NotImplementedError)
|
|
1128
|
+
return None
|
|
1115
1129
|
|
|
1116
1130
|
def __rmul__(self, other: Any) -> 'Symbol':
|
|
1117
1131
|
'''
|
|
@@ -1128,7 +1142,8 @@ class OperatorPrimitives(Primitive):
|
|
|
1128
1142
|
if result is not None:
|
|
1129
1143
|
return self._to_type(result)
|
|
1130
1144
|
|
|
1131
|
-
|
|
1145
|
+
UserMessage('Multiply operation not supported semantically! Might change in the future.', raise_with=NotImplementedError)
|
|
1146
|
+
return None
|
|
1132
1147
|
|
|
1133
1148
|
def __imul__(self, other: Any) -> 'Symbol':
|
|
1134
1149
|
'''
|
|
@@ -1146,7 +1161,8 @@ class OperatorPrimitives(Primitive):
|
|
|
1146
1161
|
self._value = result
|
|
1147
1162
|
return self
|
|
1148
1163
|
|
|
1149
|
-
|
|
1164
|
+
UserMessage('Multiply operation not supported semantically! Might change in the future.', raise_with=NotImplementedError)
|
|
1165
|
+
return None
|
|
1150
1166
|
|
|
1151
1167
|
|
|
1152
1168
|
class CastingPrimitives(Primitive):
|
|
@@ -1165,14 +1181,14 @@ class CastingPrimitives(Primitive):
|
|
|
1165
1181
|
@property
|
|
1166
1182
|
def sem(self) -> "Symbol":
|
|
1167
1183
|
"""
|
|
1168
|
-
Return a semantic view of this Symbol.
|
|
1184
|
+
Return a semantic view of this Symbol.
|
|
1169
1185
|
(Useful after calling `.syn` in a chain.)
|
|
1170
1186
|
"""
|
|
1171
1187
|
if getattr(self, "__semantic__", False):
|
|
1172
1188
|
return self
|
|
1173
1189
|
return self._to_type(self.value, semantic=True)
|
|
1174
1190
|
|
|
1175
|
-
def cast(self, as_type:
|
|
1191
|
+
def cast(self, as_type: type) -> Any:
|
|
1176
1192
|
'''
|
|
1177
1193
|
Cast the Symbol's value to a specific type.
|
|
1178
1194
|
|
|
@@ -1184,7 +1200,7 @@ class CastingPrimitives(Primitive):
|
|
|
1184
1200
|
'''
|
|
1185
1201
|
return as_type(self.value)
|
|
1186
1202
|
|
|
1187
|
-
def to(self, as_type:
|
|
1203
|
+
def to(self, as_type: type) -> Any:
|
|
1188
1204
|
'''
|
|
1189
1205
|
Cast the Symbol's value to a specific type.
|
|
1190
1206
|
|
|
@@ -1250,7 +1266,7 @@ class IterationPrimitives(Primitive):
|
|
|
1250
1266
|
This mixin contains functions that perform iteration operations on symbols or symbol values.
|
|
1251
1267
|
The functions in this mixin are bound to the 'neurosymbolic' engine for evaluation.
|
|
1252
1268
|
'''
|
|
1253
|
-
def __getitem__(self, key:
|
|
1269
|
+
def __getitem__(self, key: str | int | slice) -> 'Symbol':
|
|
1254
1270
|
'''
|
|
1255
1271
|
Get the item of the Symbol value with the specified key or index.
|
|
1256
1272
|
If the Symbol value is a list, tuple, or numpy array, the key can be an integer or slice.
|
|
@@ -1270,7 +1286,7 @@ class IterationPrimitives(Primitive):
|
|
|
1270
1286
|
try:
|
|
1271
1287
|
return self.value[key]
|
|
1272
1288
|
except Exception:
|
|
1273
|
-
|
|
1289
|
+
UserMessage(f'Key {key} not found in {self.value}', raise_with=Exception)
|
|
1274
1290
|
|
|
1275
1291
|
@core.getitem()
|
|
1276
1292
|
def _func(_, index: str):
|
|
@@ -1278,7 +1294,7 @@ class IterationPrimitives(Primitive):
|
|
|
1278
1294
|
|
|
1279
1295
|
return self._to_type(_func(self, key))
|
|
1280
1296
|
|
|
1281
|
-
def __setitem__(self, key:
|
|
1297
|
+
def __setitem__(self, key: str | int | slice, value: Any) -> None:
|
|
1282
1298
|
'''
|
|
1283
1299
|
Set the item of the Symbol value with the specified key or index to the given value.
|
|
1284
1300
|
If the Symbol value is a list, the key can be an integer or slice.
|
|
@@ -1292,17 +1308,18 @@ class IterationPrimitives(Primitive):
|
|
|
1292
1308
|
Raises:
|
|
1293
1309
|
KeyError: If the key or index is not found in the Symbol value.
|
|
1294
1310
|
'''
|
|
1295
|
-
|
|
1311
|
+
# Local import avoids ops.primitives -> post_processors -> symbol -> ops circular load.
|
|
1312
|
+
from ..post_processors import ASTPostProcessor # noqa
|
|
1296
1313
|
|
|
1297
1314
|
if not isinstance(self.value, (str, dict, list)):
|
|
1298
|
-
|
|
1315
|
+
UserMessage(f'Setting item is not supported for {type(self.value)}. Supported types are str, dict, and list.', raise_with=TypeError)
|
|
1299
1316
|
|
|
1300
1317
|
if not self.__semantic__:
|
|
1301
1318
|
try:
|
|
1302
1319
|
self.value[key] = value
|
|
1303
1320
|
return
|
|
1304
1321
|
except Exception:
|
|
1305
|
-
|
|
1322
|
+
UserMessage(f'Key {key} not found in {self.value}', raise_with=Exception)
|
|
1306
1323
|
|
|
1307
1324
|
@core.setitem()
|
|
1308
1325
|
def _func(_, index: str, value: str):
|
|
@@ -1314,7 +1331,7 @@ class IterationPrimitives(Primitive):
|
|
|
1314
1331
|
except Exception:
|
|
1315
1332
|
self._value = result # It was a string, or something failed (because } wasn't close, etc)
|
|
1316
1333
|
|
|
1317
|
-
def __delitem__(self, key:
|
|
1334
|
+
def __delitem__(self, key: str | int) -> None:
|
|
1318
1335
|
'''
|
|
1319
1336
|
Delete the item of the Symbol value with the specified key or index.
|
|
1320
1337
|
If the Symbol value is a dictionary, the key can be a string or an integer.
|
|
@@ -1326,17 +1343,18 @@ class IterationPrimitives(Primitive):
|
|
|
1326
1343
|
Raises:
|
|
1327
1344
|
KeyError: If the key or index is not found in the Symbol value.
|
|
1328
1345
|
'''
|
|
1329
|
-
|
|
1346
|
+
# Local import avoids ops.primitives -> post_processors -> symbol -> ops circular load.
|
|
1347
|
+
from ..post_processors import ASTPostProcessor # noqa
|
|
1330
1348
|
|
|
1331
1349
|
if not isinstance(self.value, (str, dict, list)):
|
|
1332
|
-
|
|
1350
|
+
UserMessage(f'Setting item is not supported for {type(self.value)}. Supported types are str, dict, and list.', raise_with=TypeError)
|
|
1333
1351
|
|
|
1334
1352
|
if not self.__semantic__:
|
|
1335
1353
|
try:
|
|
1336
1354
|
del self.value[key]
|
|
1337
1355
|
return
|
|
1338
1356
|
except Exception:
|
|
1339
|
-
|
|
1357
|
+
UserMessage(f'Key {key} not found in {self.value}', raise_with=Exception)
|
|
1340
1358
|
|
|
1341
1359
|
@core.delitem()
|
|
1342
1360
|
def _func(_, index: str):
|
|
@@ -1427,7 +1445,7 @@ class StringHelperPrimitives(Primitive):
|
|
|
1427
1445
|
'''
|
|
1428
1446
|
This mixin contains functions that provide additional help for symbols or their values.
|
|
1429
1447
|
'''
|
|
1430
|
-
def split(self, delimiter: str, **
|
|
1448
|
+
def split(self, delimiter: str, **_kwargs) -> 'Symbol':
|
|
1431
1449
|
'''
|
|
1432
1450
|
Splits the symbol value by a specified delimiter.
|
|
1433
1451
|
Uses the core.split decorator to create a _func method that splits the symbol value by the specified delimiter.
|
|
@@ -1442,7 +1460,7 @@ class StringHelperPrimitives(Primitive):
|
|
|
1442
1460
|
assert isinstance(self.value, str), f'self.value must be a string, got {type(self.value)}'
|
|
1443
1461
|
return self._to_type([*self.value.split(delimiter)])
|
|
1444
1462
|
|
|
1445
|
-
def join(self, delimiter: str = ' ', **
|
|
1463
|
+
def join(self, delimiter: str = ' ', **_kwargs) -> 'Symbol':
|
|
1446
1464
|
'''
|
|
1447
1465
|
Joins the symbol value with a specified delimiter.
|
|
1448
1466
|
|
|
@@ -1456,7 +1474,7 @@ class StringHelperPrimitives(Primitive):
|
|
|
1456
1474
|
assert isinstance(self.value, Iterable), f'value must be an iterable, got {type(self.value)}'
|
|
1457
1475
|
return self._to_type(delimiter.join(self.value))
|
|
1458
1476
|
|
|
1459
|
-
def startswith(self, prefix: str, **
|
|
1477
|
+
def startswith(self, prefix: str, **_kwargs) -> bool:
|
|
1460
1478
|
'''
|
|
1461
1479
|
Checks if the symbol value starts with a specified prefix.
|
|
1462
1480
|
Uses the core.startswith decorator to create a _func method that checks if the symbol value starts with the specified prefix.
|
|
@@ -1479,7 +1497,7 @@ class StringHelperPrimitives(Primitive):
|
|
|
1479
1497
|
|
|
1480
1498
|
return _func(self, prefix)
|
|
1481
1499
|
|
|
1482
|
-
def endswith(self, suffix: str, **
|
|
1500
|
+
def endswith(self, suffix: str, **_kwargs) -> bool:
|
|
1483
1501
|
'''
|
|
1484
1502
|
Checks if the symbol value ends with a specified suffix.
|
|
1485
1503
|
Uses the core.endswith decorator to create a _func method that checks if the symbol value ends with the specified suffix.
|
|
@@ -1574,7 +1592,7 @@ class ExpressionHandlingPrimitives(Primitive):
|
|
|
1574
1592
|
if not hasattr(self, '_accumulated_results'):
|
|
1575
1593
|
self._accumulated_results = []
|
|
1576
1594
|
|
|
1577
|
-
def get_results(self) ->
|
|
1595
|
+
def get_results(self) -> list['Symbol']:
|
|
1578
1596
|
'''
|
|
1579
1597
|
Retrieves accumulated results from previous interpretations.
|
|
1580
1598
|
|
|
@@ -1589,7 +1607,7 @@ class ExpressionHandlingPrimitives(Primitive):
|
|
|
1589
1607
|
self.init_results()
|
|
1590
1608
|
self._accumulated_results = []
|
|
1591
1609
|
|
|
1592
|
-
def interpret(self, prompt:
|
|
1610
|
+
def interpret(self, prompt: str | None = "Evaluate the symbolic expressions and return only the result:\n", accumulate: bool = False, **kwargs) -> 'Symbol':
|
|
1593
1611
|
'''
|
|
1594
1612
|
Evaluates simple symbolic expressions.
|
|
1595
1613
|
Uses the core.expression decorator to create a _func method that evaluates the given expression.
|
|
@@ -1639,7 +1657,7 @@ class DataHandlingPrimitives(Primitive):
|
|
|
1639
1657
|
|
|
1640
1658
|
return self._to_type(_func(self))
|
|
1641
1659
|
|
|
1642
|
-
def summarize(self, context:
|
|
1660
|
+
def summarize(self, context: str | None = None, **kwargs) -> 'Symbol':
|
|
1643
1661
|
'''
|
|
1644
1662
|
Summarizes the symbol value.
|
|
1645
1663
|
Uses the core.summarize decorator with an optional context to create a _func method that summarizes the symbol value.
|
|
@@ -1670,7 +1688,7 @@ class DataHandlingPrimitives(Primitive):
|
|
|
1670
1688
|
|
|
1671
1689
|
return self._to_type(_func(self))
|
|
1672
1690
|
|
|
1673
|
-
def filter(self, criteria: str, include:
|
|
1691
|
+
def filter(self, criteria: str, include: bool | None = False, **kwargs) -> 'Symbol':
|
|
1674
1692
|
'''
|
|
1675
1693
|
Filters the symbol value based on a specified criteria.
|
|
1676
1694
|
Uses the core.filtering decorator with the provided criteria and include flag to create a _func method to filter the symbol value.
|
|
@@ -1707,7 +1725,7 @@ class DataHandlingPrimitives(Primitive):
|
|
|
1707
1725
|
try:
|
|
1708
1726
|
iter(self.value)
|
|
1709
1727
|
except TypeError:
|
|
1710
|
-
|
|
1728
|
+
UserMessage('Map can only be applied to iterable objects', raise_with=AssertionError)
|
|
1711
1729
|
|
|
1712
1730
|
@core.map(instruction=instruction, **kwargs)
|
|
1713
1731
|
def _func(_):
|
|
@@ -1807,7 +1825,7 @@ class UniquenessPrimitives(Primitive):
|
|
|
1807
1825
|
This mixin includes functions that work with unique aspects of symbol values, like extracting unique information or composing new unique symbols.
|
|
1808
1826
|
Future functionalities might include finding duplicate information, defining levels of uniqueness, etc.
|
|
1809
1827
|
'''
|
|
1810
|
-
def unique(self, keys:
|
|
1828
|
+
def unique(self, keys: list[str] | None = None, **kwargs) -> 'Symbol':
|
|
1811
1829
|
'''
|
|
1812
1830
|
Extracts unique information from the symbol value, using provided keys.
|
|
1813
1831
|
Uses the core.unique decorator with a list of keys to create a _func method that extracts unique data from the symbol value.
|
|
@@ -1818,6 +1836,8 @@ class UniquenessPrimitives(Primitive):
|
|
|
1818
1836
|
Returns:
|
|
1819
1837
|
Symbol: A new symbol with the unique information.
|
|
1820
1838
|
'''
|
|
1839
|
+
if keys is None:
|
|
1840
|
+
keys = []
|
|
1821
1841
|
@core.unique(keys=keys, **kwargs)
|
|
1822
1842
|
def _func(_) -> str:
|
|
1823
1843
|
pass
|
|
@@ -1844,7 +1864,7 @@ class PatternMatchingPrimitives(Primitive):
|
|
|
1844
1864
|
This mixin houses functions that deal with ranking symbols, extracting details based on patterns, and correcting symbols.
|
|
1845
1865
|
It will house future functionalities that involve sorting, complex pattern detections, advanced correction techniques etc.
|
|
1846
1866
|
'''
|
|
1847
|
-
def rank(self, measure:
|
|
1867
|
+
def rank(self, measure: str | None = 'alphanumeric', order: str | None = 'desc', **kwargs) -> 'Symbol':
|
|
1848
1868
|
'''
|
|
1849
1869
|
Ranks the symbol value based on a measure and a sort order.
|
|
1850
1870
|
Uses the core.rank decorator with the specified measure and order to create a _func method that ranks the symbol value.
|
|
@@ -1899,7 +1919,7 @@ class PatternMatchingPrimitives(Primitive):
|
|
|
1899
1919
|
|
|
1900
1920
|
return self._to_type(_func(self))
|
|
1901
1921
|
|
|
1902
|
-
def translate(self, language:
|
|
1922
|
+
def translate(self, language: str | None = 'English', **kwargs) -> 'Symbol':
|
|
1903
1923
|
'''
|
|
1904
1924
|
Translates the symbol value to the specified language.
|
|
1905
1925
|
Uses the @core.translate decorator to translate the symbol's value to the specified language.
|
|
@@ -1917,7 +1937,7 @@ class PatternMatchingPrimitives(Primitive):
|
|
|
1917
1937
|
|
|
1918
1938
|
return self._to_type(_func(self))
|
|
1919
1939
|
|
|
1920
|
-
def choice(self, cases:
|
|
1940
|
+
def choice(self, cases: list[str], default: str, **kwargs) -> 'Symbol':
|
|
1921
1941
|
'''
|
|
1922
1942
|
Chooses one of the cases based on the symbol value.
|
|
1923
1943
|
Uses the @core.case decorator, selects one of the cases based on the symbol's value.
|
|
@@ -1942,7 +1962,7 @@ class QueryHandlingPrimitives(Primitive):
|
|
|
1942
1962
|
This mixin helps in transforming, preparing, and executing queries, and it is designed to be extendable as new ways of handling queries are developed.
|
|
1943
1963
|
Future methods could potentially include query optimization, enhanced query formatting, multi-level query execution, query error handling, etc.
|
|
1944
1964
|
'''
|
|
1945
|
-
def query(self, context: str, prompt:
|
|
1965
|
+
def query(self, context: str, prompt: str | None = None, examples: list[Prompt] | None = None, **kwargs) -> 'Symbol':
|
|
1946
1966
|
'''
|
|
1947
1967
|
Queries the symbol value based on a specified context.
|
|
1948
1968
|
Uses the @core.query decorator, queries based on the context, prompt, and examples.
|
|
@@ -2004,7 +2024,7 @@ class ExecutionControlPrimitives(Primitive):
|
|
|
2004
2024
|
This mixin represents the core methods for dealing with symbol execution.
|
|
2005
2025
|
Possible future methods could potentially include async execution, pipeline chaining, execution profiling, improved error handling, version management, embedding more complex execution control structures etc.
|
|
2006
2026
|
'''
|
|
2007
|
-
def analyze(self, exception: Exception, query:
|
|
2027
|
+
def analyze(self, exception: Exception, query: str | None = '', **kwargs) -> 'Symbol':
|
|
2008
2028
|
'''Uses the @core.analyze decorator, analyzes an exception and returns a symbol.
|
|
2009
2029
|
|
|
2010
2030
|
Args:
|
|
@@ -2120,7 +2140,7 @@ class ExecutionControlPrimitives(Primitive):
|
|
|
2120
2140
|
|
|
2121
2141
|
return self._to_type(_func(self))
|
|
2122
2142
|
|
|
2123
|
-
def stream(self, expr: 'Expression', token_ratio:
|
|
2143
|
+
def stream(self, expr: 'Expression', token_ratio: float | None = 0.6, **kwargs) -> 'Symbol':
|
|
2124
2144
|
'''
|
|
2125
2145
|
Streams the Symbol's value through an Expression object.
|
|
2126
2146
|
This method divides the Symbol's value into chunks and processes each chunk through the given Expression object.
|
|
@@ -2156,7 +2176,7 @@ class ExecutionControlPrimitives(Primitive):
|
|
|
2156
2176
|
else:
|
|
2157
2177
|
yield expr(self, **kwargs)
|
|
2158
2178
|
|
|
2159
|
-
def ftry(self, expr: 'Expression', retries:
|
|
2179
|
+
def ftry(self, expr: 'Expression', retries: int | None = 1, **kwargs) -> 'Symbol':
|
|
2160
2180
|
# TODO: find a way to pass on the constraints and behavior from the self.expr to the corrected code
|
|
2161
2181
|
'''
|
|
2162
2182
|
Tries to evaluate a Symbol using a given Expression.
|
|
@@ -2198,33 +2218,32 @@ class ExecutionControlPrimitives(Primitive):
|
|
|
2198
2218
|
retry_cnt += 1
|
|
2199
2219
|
if retry_cnt > retries:
|
|
2200
2220
|
raise e
|
|
2221
|
+
# analyze the error
|
|
2222
|
+
payload = f'[ORIGINAL_USER_PROMPT]\n{prompt["prompt_instruction"]}\n\n' if 'prompt_instruction' in prompt else ''
|
|
2223
|
+
payload = payload + f'[ORIGINAL_USER_DATA]\n{code}\n\n[ORIGINAL_GENERATED_OUTPUT]\n{prompt["out_msg"]}'
|
|
2224
|
+
probe = sym.analyze(query="What is the issue in this expression?", payload=payload, exception=e)
|
|
2225
|
+
# attempt to correct the error
|
|
2226
|
+
payload = f'[ORIGINAL_USER_PROMPT]\n{prompt["prompt_instruction"]}\n\n' if 'prompt_instruction' in prompt else ''
|
|
2227
|
+
payload = payload + f'[ANALYSIS]\n{probe}\n\n'
|
|
2228
|
+
context = f'Try to correct the error of the original user request based on the analysis above: \n [GENERATED_OUTPUT]\n{prompt["out_msg"]}\n\n'
|
|
2229
|
+
constraints = expr.constraints if hasattr(expr, 'constraints') else []
|
|
2230
|
+
|
|
2231
|
+
if hasattr(expr, 'post_processor'):
|
|
2232
|
+
post_processor = expr.post_processor
|
|
2233
|
+
sym = code.correct(
|
|
2234
|
+
context=context,
|
|
2235
|
+
exception=e,
|
|
2236
|
+
payload=payload,
|
|
2237
|
+
constraints=constraints,
|
|
2238
|
+
post_processor=post_processor
|
|
2239
|
+
)
|
|
2201
2240
|
else:
|
|
2202
|
-
|
|
2203
|
-
|
|
2204
|
-
|
|
2205
|
-
|
|
2206
|
-
|
|
2207
|
-
|
|
2208
|
-
payload = payload + f'[ANALYSIS]\n{probe}\n\n'
|
|
2209
|
-
context = f'Try to correct the error of the original user request based on the analysis above: \n [GENERATED_OUTPUT]\n{prompt["out_msg"]}\n\n'
|
|
2210
|
-
constraints = expr.constraints if hasattr(expr, 'constraints') else []
|
|
2211
|
-
|
|
2212
|
-
if hasattr(expr, 'post_processor'):
|
|
2213
|
-
post_processor = expr.post_processor
|
|
2214
|
-
sym = code.correct(
|
|
2215
|
-
context=context,
|
|
2216
|
-
exception=e,
|
|
2217
|
-
payload=payload,
|
|
2218
|
-
constraints=constraints,
|
|
2219
|
-
post_processor=post_processor
|
|
2220
|
-
)
|
|
2221
|
-
else:
|
|
2222
|
-
sym = code.correct(
|
|
2223
|
-
context=context,
|
|
2224
|
-
exception=e,
|
|
2225
|
-
payload=payload,
|
|
2226
|
-
constraints=constraints
|
|
2227
|
-
)
|
|
2241
|
+
sym = code.correct(
|
|
2242
|
+
context=context,
|
|
2243
|
+
exception=e,
|
|
2244
|
+
payload=payload,
|
|
2245
|
+
constraints=constraints
|
|
2246
|
+
)
|
|
2228
2247
|
|
|
2229
2248
|
|
|
2230
2249
|
class DictHandlingPrimitives(Primitive):
|
|
@@ -2257,7 +2276,7 @@ class TemplateStylingPrimitives(Primitive):
|
|
|
2257
2276
|
This mixin includes functionalities for stylizing symbols and applying templates.
|
|
2258
2277
|
Future functionalities might include a variety of new stylizing methods, application of more complex templates, etc.
|
|
2259
2278
|
'''
|
|
2260
|
-
def template(that, template: str, placeholder:
|
|
2279
|
+
def template(that, template: str, placeholder: str | None = '{{placeholder}}', **_kwargs) -> 'Symbol':
|
|
2261
2280
|
'''
|
|
2262
2281
|
Applies a template to the Symbol.
|
|
2263
2282
|
This method uses the @core.template decorator to apply the given template and placeholder to the Symbol.
|
|
@@ -2277,7 +2296,7 @@ class TemplateStylingPrimitives(Primitive):
|
|
|
2277
2296
|
|
|
2278
2297
|
return _func(that)
|
|
2279
2298
|
|
|
2280
|
-
def style(self, description: str, libraries:
|
|
2299
|
+
def style(self, description: str, libraries: list | None = None, **kwargs) -> 'Symbol':
|
|
2281
2300
|
'''
|
|
2282
2301
|
Applies a style to the Symbol.
|
|
2283
2302
|
This method uses the @core.style decorator to apply the given style description, libraries, and placeholder to the Symbol.
|
|
@@ -2291,6 +2310,8 @@ class TemplateStylingPrimitives(Primitive):
|
|
|
2291
2310
|
Returns:
|
|
2292
2311
|
Symbol: A Symbol object with the style applied.
|
|
2293
2312
|
'''
|
|
2313
|
+
if libraries is None:
|
|
2314
|
+
libraries = []
|
|
2294
2315
|
@core.style(description=description, libraries=libraries, **kwargs)
|
|
2295
2316
|
def _func(_):
|
|
2296
2317
|
pass
|
|
@@ -2362,16 +2383,17 @@ class EmbeddingPrimitives(Primitive):
|
|
|
2362
2383
|
'''
|
|
2363
2384
|
# if the embedding is not yet computed, compute it
|
|
2364
2385
|
if self._metadata.embedding is None:
|
|
2365
|
-
if (
|
|
2386
|
+
if (isinstance(self.value, (list, tuple)) and all(isinstance(x, (int, float, bool)) for x in self.value)) \
|
|
2366
2387
|
or isinstance(self.value, np.ndarray):
|
|
2367
|
-
if isinstance(self.value, list
|
|
2388
|
+
if isinstance(self.value, (list, tuple)):
|
|
2368
2389
|
assert len(self.value) > 0, 'Cannot compute embedding of empty list'
|
|
2369
|
-
|
|
2390
|
+
symbol_type = self._symbol_type
|
|
2391
|
+
if isinstance(self.value[0], symbol_type):
|
|
2370
2392
|
# convert each element to numpy array
|
|
2371
2393
|
self._metadata.embedding = np.asarray([x.embedding for x in self.value])
|
|
2372
2394
|
elif isinstance(self.value[0], str):
|
|
2373
2395
|
# embed each string
|
|
2374
|
-
self._metadata.embedding = np.asarray([
|
|
2396
|
+
self._metadata.embedding = np.asarray([symbol_type(x).embedding for x in self.value])
|
|
2375
2397
|
else:
|
|
2376
2398
|
# convert to numpy array
|
|
2377
2399
|
self._metadata.embedding = np.asarray(self.value)
|
|
@@ -2390,23 +2412,195 @@ class EmbeddingPrimitives(Primitive):
|
|
|
2390
2412
|
|
|
2391
2413
|
def _ensure_numpy_format(self, x, cast=False):
|
|
2392
2414
|
# if it is a Symbol, get its value
|
|
2393
|
-
if not isinstance(x, np.ndarray
|
|
2415
|
+
if not isinstance(x, (np.ndarray, torch.Tensor, list)):
|
|
2394
2416
|
if not isinstance(x, self._symbol_type): #@NOTE: enforce Symbol to avoid circular import
|
|
2395
2417
|
if not cast:
|
|
2396
|
-
|
|
2418
|
+
msg = f'Cannot compute similarity with type {type(x)}'
|
|
2419
|
+
UserMessage(msg)
|
|
2420
|
+
raise TypeError(msg)
|
|
2397
2421
|
x = self._symbol_type(x)
|
|
2398
2422
|
# evaluate the Symbol as an embedding
|
|
2399
2423
|
x = x.embedding
|
|
2400
2424
|
# if it is a list, convert it to numpy
|
|
2401
|
-
if isinstance(x, list
|
|
2425
|
+
if isinstance(x, (list, tuple)):
|
|
2402
2426
|
assert len(x) > 0, 'Cannot compute similarity with empty list'
|
|
2403
2427
|
x = np.asarray(x)
|
|
2404
2428
|
# if it is a tensor, convert it to numpy
|
|
2405
2429
|
elif isinstance(x, torch.Tensor):
|
|
2406
2430
|
x = x.detach().cpu().numpy()
|
|
2407
|
-
|
|
2431
|
+
else:
|
|
2432
|
+
x = np.asarray(x)
|
|
2408
2433
|
|
|
2409
|
-
|
|
2434
|
+
x = np.squeeze(x)
|
|
2435
|
+
if x.ndim == 0:
|
|
2436
|
+
x = x[None]
|
|
2437
|
+
|
|
2438
|
+
return x[:, None]
|
|
2439
|
+
|
|
2440
|
+
def _prepare_embedding_operand(self, operand):
|
|
2441
|
+
if isinstance(operand, (list, tuple)):
|
|
2442
|
+
if self._is_numeric_sequence(operand):
|
|
2443
|
+
return self._ensure_numpy_format(operand, cast=True)
|
|
2444
|
+
formatted = [
|
|
2445
|
+
self._ensure_numpy_format(item, cast=True) for item in operand
|
|
2446
|
+
]
|
|
2447
|
+
return np.concatenate(formatted, axis=1)
|
|
2448
|
+
return self._ensure_numpy_format(operand, cast=True)
|
|
2449
|
+
|
|
2450
|
+
def _is_numeric_sequence(self, operand: Iterable):
|
|
2451
|
+
for item in operand:
|
|
2452
|
+
if isinstance(item, (list, tuple, np.ndarray, torch.Tensor, self._symbol_type)):
|
|
2453
|
+
return False
|
|
2454
|
+
if isinstance(item, (numbers.Real, np.generic)):
|
|
2455
|
+
continue
|
|
2456
|
+
return False
|
|
2457
|
+
return True
|
|
2458
|
+
|
|
2459
|
+
def _get_similarity_handler(self, metric, eps, kwargs):
|
|
2460
|
+
def _cosine_similarity(lhs, rhs):
|
|
2461
|
+
return lhs.T@rhs / (np.sqrt(lhs.T@lhs) * np.sqrt(rhs.T@rhs) + eps)
|
|
2462
|
+
|
|
2463
|
+
def _angular_cosine_similarity(lhs, rhs):
|
|
2464
|
+
c = kwargs.get('c', 1)
|
|
2465
|
+
return 1 - (c * np.arccos(lhs.T@rhs / (np.sqrt(lhs.T@lhs) * np.sqrt(rhs.T@rhs) + eps)) / np.pi)
|
|
2466
|
+
|
|
2467
|
+
def _product_similarity(lhs, rhs):
|
|
2468
|
+
return lhs.T@rhs
|
|
2469
|
+
|
|
2470
|
+
def _manhattan_similarity(lhs, rhs):
|
|
2471
|
+
return np.abs(lhs - rhs).sum(axis=0, keepdims=True)
|
|
2472
|
+
|
|
2473
|
+
def _euclidean_similarity(lhs, rhs):
|
|
2474
|
+
return np.sqrt(np.sum((lhs - rhs)**2, axis=0, keepdims=True))
|
|
2475
|
+
|
|
2476
|
+
def _minkowski_similarity(lhs, rhs):
|
|
2477
|
+
p = kwargs.get('p', 3)
|
|
2478
|
+
return np.sum(np.abs(lhs - rhs)**p, axis=0, keepdims=True)**(1/p)
|
|
2479
|
+
|
|
2480
|
+
def _jaccard_similarity(lhs, rhs):
|
|
2481
|
+
intersection = np.minimum(lhs, rhs)
|
|
2482
|
+
union = np.maximum(lhs, rhs)
|
|
2483
|
+
return np.sum(intersection, axis=0, keepdims=True) / (np.sum(union, axis=0, keepdims=True) + eps)
|
|
2484
|
+
|
|
2485
|
+
metric_handlers = {
|
|
2486
|
+
'cosine': _cosine_similarity,
|
|
2487
|
+
'angular-cosine': _angular_cosine_similarity,
|
|
2488
|
+
'product': _product_similarity,
|
|
2489
|
+
'manhattan': _manhattan_similarity,
|
|
2490
|
+
'euclidean': _euclidean_similarity,
|
|
2491
|
+
'minkowski': _minkowski_similarity,
|
|
2492
|
+
'jaccard': _jaccard_similarity,
|
|
2493
|
+
}
|
|
2494
|
+
|
|
2495
|
+
handler = metric_handlers.get(metric)
|
|
2496
|
+
if handler is None:
|
|
2497
|
+
msg = (
|
|
2498
|
+
f"Similarity metric {metric} not implemented. Available metrics: "
|
|
2499
|
+
"'cosine', 'angular-cosine', 'product', 'manhattan', 'euclidean', 'minkowski', 'jaccard'"
|
|
2500
|
+
)
|
|
2501
|
+
UserMessage(msg)
|
|
2502
|
+
raise NotImplementedError(msg)
|
|
2503
|
+
return handler
|
|
2504
|
+
|
|
2505
|
+
def _get_kernel_handler(self, kernel):
|
|
2506
|
+
kernel_handlers = {
|
|
2507
|
+
'gaussian': self._kernel_gaussian,
|
|
2508
|
+
'rbf': self._kernel_rbf,
|
|
2509
|
+
'laplacian': self._kernel_laplacian,
|
|
2510
|
+
'polynomial': self._kernel_polynomial,
|
|
2511
|
+
'sigmoid': self._kernel_sigmoid,
|
|
2512
|
+
'linear': self._kernel_linear,
|
|
2513
|
+
'cauchy': self._kernel_cauchy,
|
|
2514
|
+
't-distribution': self._kernel_t_distribution,
|
|
2515
|
+
'inverse-multiquadric': self._kernel_inverse_multiquadric,
|
|
2516
|
+
'cosine': self._kernel_cosine,
|
|
2517
|
+
'angular-cosine': self._kernel_angular_cosine,
|
|
2518
|
+
'frechet': self._kernel_frechet,
|
|
2519
|
+
'mmd': self._kernel_mmd,
|
|
2520
|
+
}
|
|
2521
|
+
|
|
2522
|
+
handler = kernel_handlers.get(kernel)
|
|
2523
|
+
if handler is None:
|
|
2524
|
+
msg = "Kernel function {kernel} not implemented. Available functions: 'gaussian'"
|
|
2525
|
+
UserMessage(msg.format(kernel=kernel))
|
|
2526
|
+
raise NotImplementedError(msg.format(kernel=kernel))
|
|
2527
|
+
return handler
|
|
2528
|
+
|
|
2529
|
+
def _kernel_gaussian(self, lhs, rhs, _eps, kwargs):
|
|
2530
|
+
gamma = kwargs.get('gamma', 1)
|
|
2531
|
+
return np.exp(-gamma * np.sum((lhs - rhs)**2, axis=0))
|
|
2532
|
+
|
|
2533
|
+
def _kernel_rbf(self, lhs, rhs, _eps, kwargs):
|
|
2534
|
+
bandwidth = kwargs.get('bandwidth')
|
|
2535
|
+
gamma = kwargs.get('gamma', 1)
|
|
2536
|
+
distance_sq = np.sum((lhs - rhs)**2, axis=0)
|
|
2537
|
+
if bandwidth is not None:
|
|
2538
|
+
val = 0
|
|
2539
|
+
for a in bandwidth:
|
|
2540
|
+
gamma = 1.0 / (2 * a)
|
|
2541
|
+
val += np.exp(-gamma * distance_sq)
|
|
2542
|
+
return val
|
|
2543
|
+
return np.exp(-gamma * distance_sq)
|
|
2544
|
+
|
|
2545
|
+
def _kernel_laplacian(self, lhs, rhs, _eps, kwargs):
|
|
2546
|
+
gamma = kwargs.get('gamma', 1)
|
|
2547
|
+
return np.exp(-gamma * np.sum(np.abs(lhs - rhs), axis=0))
|
|
2548
|
+
|
|
2549
|
+
def _kernel_polynomial(self, lhs, rhs, _eps, kwargs):
|
|
2550
|
+
gamma = kwargs.get('gamma', 1)
|
|
2551
|
+
degree = kwargs.get('degree', 3)
|
|
2552
|
+
coef = kwargs.get('coef', 1)
|
|
2553
|
+
return (gamma * np.sum((lhs * rhs), axis=0) + coef)**degree
|
|
2554
|
+
|
|
2555
|
+
def _kernel_sigmoid(self, lhs, rhs, _eps, kwargs):
|
|
2556
|
+
gamma = kwargs.get('gamma', 1)
|
|
2557
|
+
coef = kwargs.get('coef', 1)
|
|
2558
|
+
return np.tanh(gamma * np.sum((lhs * rhs), axis=0) + coef)
|
|
2559
|
+
|
|
2560
|
+
def _kernel_linear(self, lhs, rhs, _eps, _kwargs):
|
|
2561
|
+
return np.sum((lhs * rhs), axis=0)
|
|
2562
|
+
|
|
2563
|
+
def _kernel_cauchy(self, lhs, rhs, _eps, kwargs):
|
|
2564
|
+
gamma = kwargs.get('gamma', 1)
|
|
2565
|
+
return 1 / (1 + np.sum((lhs - rhs)**2, axis=0) / gamma)
|
|
2566
|
+
|
|
2567
|
+
def _kernel_t_distribution(self, lhs, rhs, _eps, kwargs):
|
|
2568
|
+
gamma = kwargs.get('gamma', 1)
|
|
2569
|
+
degree = kwargs.get('degree', 1)
|
|
2570
|
+
return 1 / (1 + (np.sum((lhs - rhs)**2, axis=0) / (gamma * degree))**(degree + 1) / 2)
|
|
2571
|
+
|
|
2572
|
+
def _kernel_inverse_multiquadric(self, lhs, rhs, _eps, kwargs):
|
|
2573
|
+
gamma = kwargs.get('gamma', 1)
|
|
2574
|
+
return 1 / np.sqrt(np.sum((lhs - rhs)**2, axis=0) / gamma**2 + 1)
|
|
2575
|
+
|
|
2576
|
+
def _kernel_cosine(self, lhs, rhs, eps, _kwargs):
|
|
2577
|
+
numerator = np.sum(lhs * rhs, axis=0)
|
|
2578
|
+
denominator = np.sqrt(np.sum(lhs**2, axis=0)) * np.sqrt(np.sum(rhs**2, axis=0)) + eps
|
|
2579
|
+
return 1 - (numerator / denominator)
|
|
2580
|
+
|
|
2581
|
+
def _kernel_angular_cosine(self, lhs, rhs, eps, kwargs):
|
|
2582
|
+
c = kwargs.get('c', 1)
|
|
2583
|
+
numerator = np.sum(lhs * rhs, axis=0)
|
|
2584
|
+
denominator = np.sqrt(np.sum(lhs**2, axis=0)) * np.sqrt(np.sum(rhs**2, axis=0)) + eps
|
|
2585
|
+
return c * np.arccos(numerator / denominator) / np.pi
|
|
2586
|
+
|
|
2587
|
+
def _kernel_frechet(self, lhs, rhs, eps, kwargs):
|
|
2588
|
+
sigma1 = kwargs.get('sigma1')
|
|
2589
|
+
sigma2 = kwargs.get('sigma2')
|
|
2590
|
+
assert sigma1 is not None and sigma2 is not None, 'Frechet distance requires covariance matrices for both inputs'
|
|
2591
|
+
return calculate_frechet_distance(lhs.T, sigma1, rhs.T, sigma2, eps)
|
|
2592
|
+
|
|
2593
|
+
def _kernel_mmd(self, lhs, rhs, eps, _kwargs):
|
|
2594
|
+
return calculate_mmd(lhs.T, rhs.T, eps=eps)
|
|
2595
|
+
|
|
2596
|
+
def similarity(
|
|
2597
|
+
self,
|
|
2598
|
+
other: Union['Symbol', list, np.ndarray, torch.Tensor],
|
|
2599
|
+
metric: Literal['cosine', 'angular-cosine', 'product', 'manhattan', 'euclidean', 'minkowski', 'jaccard'] = 'cosine',
|
|
2600
|
+
eps: float = 1e-8,
|
|
2601
|
+
normalize: Callable | None = None,
|
|
2602
|
+
**kwargs,
|
|
2603
|
+
) -> float:
|
|
2410
2604
|
'''
|
|
2411
2605
|
Calculates the similarity between two Symbol objects using a specified metric.
|
|
2412
2606
|
This method compares the values of two Symbol objects and calculates their similarity according to the specified metric.
|
|
@@ -2427,45 +2621,29 @@ class EmbeddingPrimitives(Primitive):
|
|
|
2427
2621
|
NotImplementedError: If the given metric is not supported.
|
|
2428
2622
|
'''
|
|
2429
2623
|
v = self._ensure_numpy_format(self)
|
|
2430
|
-
|
|
2431
|
-
|
|
2432
|
-
|
|
2433
|
-
o.append(self._ensure_numpy_format(other[i], cast=True))
|
|
2434
|
-
o = np.concatenate(o, axis=1)
|
|
2435
|
-
else:
|
|
2436
|
-
o = self._ensure_numpy_format(other, cast=True)
|
|
2437
|
-
|
|
2438
|
-
if metric == 'cosine':
|
|
2439
|
-
val = v.T@o / (np.sqrt(v.T@v) * np.sqrt(o.T@o) + eps)
|
|
2440
|
-
elif metric == 'angular-cosine':
|
|
2441
|
-
c = kwargs.get('c', 1)
|
|
2442
|
-
val = 1 - (c * np.arccos((v.T@o / (np.sqrt(v.T@v) * np.sqrt(o.T@o) + eps))) / np.pi)
|
|
2443
|
-
elif metric == 'product':
|
|
2444
|
-
val = v.T@o
|
|
2445
|
-
elif metric == 'manhattan':
|
|
2446
|
-
val = np.abs(v - o).sum(axis=0, keepdims=True)
|
|
2447
|
-
elif metric == 'euclidean':
|
|
2448
|
-
val = np.sqrt(np.sum((v - o)**2, axis=0, keepdims=True))
|
|
2449
|
-
elif metric == 'minkowski':
|
|
2450
|
-
p = kwargs.get('p', 3)
|
|
2451
|
-
val = np.sum(np.abs(v - o)**p, axis=0, keepdims=True)**(1/p)
|
|
2452
|
-
elif metric == 'jaccard':
|
|
2453
|
-
intersection = np.minimum(v, o)
|
|
2454
|
-
union = np.maximum(v, o)
|
|
2455
|
-
val = np.sum(intersection, axis=0, keepdims=True) / (np.sum(union, axis=0, keepdims=True) + eps)
|
|
2456
|
-
else:
|
|
2457
|
-
raise NotImplementedError(f"Similarity metric {metric} not implemented. Available metrics: 'cosine', 'angular-cosine', 'product', 'manhattan', 'euclidean', 'minkowski', 'jaccard'")
|
|
2624
|
+
o = self._prepare_embedding_operand(other)
|
|
2625
|
+
handler = self._get_similarity_handler(metric, eps, kwargs)
|
|
2626
|
+
val = handler(v, o)
|
|
2458
2627
|
|
|
2459
2628
|
# get the similarity value(s)
|
|
2460
2629
|
shape = val.shape
|
|
2461
|
-
if len(shape) >= 2 and min(shape) > 1:
|
|
2462
|
-
|
|
2463
|
-
|
|
2464
|
-
|
|
2630
|
+
if len(shape) >= 2 and min(shape) > 1:
|
|
2631
|
+
val = val.diagonal()
|
|
2632
|
+
elif len(shape) < 1 or shape[0] <= 1:
|
|
2633
|
+
val = val.item()
|
|
2634
|
+
if normalize is not None:
|
|
2635
|
+
val = normalize(val)
|
|
2465
2636
|
|
|
2466
2637
|
return val
|
|
2467
2638
|
|
|
2468
|
-
def distance(
|
|
2639
|
+
def distance(
|
|
2640
|
+
self,
|
|
2641
|
+
other: Union['Symbol', list, np.ndarray, torch.Tensor],
|
|
2642
|
+
kernel: Literal['gaussian', 'rbf', 'laplacian', 'polynomial', 'sigmoid', 'linear', 'cauchy', 't-distribution', 'inverse-multiquadric', 'cosine', 'angular-cosine', 'frechet', 'mmd'] = 'gaussian',
|
|
2643
|
+
eps: float = 1e-8,
|
|
2644
|
+
normalize: Callable | None = None,
|
|
2645
|
+
**kwargs,
|
|
2646
|
+
) -> float:
|
|
2469
2647
|
'''
|
|
2470
2648
|
Calculates the kernel between two Symbol objects.
|
|
2471
2649
|
|
|
@@ -2483,82 +2661,18 @@ class EmbeddingPrimitives(Primitive):
|
|
|
2483
2661
|
NotImplementedError: If the given kernel is not supported.
|
|
2484
2662
|
'''
|
|
2485
2663
|
v = self._ensure_numpy_format(self)
|
|
2486
|
-
|
|
2487
|
-
|
|
2488
|
-
|
|
2489
|
-
o.append(self._ensure_numpy_format(other[i], cast=True))
|
|
2490
|
-
o = np.concatenate(o, axis=1)
|
|
2491
|
-
else:
|
|
2492
|
-
o = self._ensure_numpy_format(other, cast=True)
|
|
2493
|
-
|
|
2494
|
-
# compute the kernel value
|
|
2495
|
-
if kernel == 'gaussian':
|
|
2496
|
-
gamma = kwargs.get('gamma', 1)
|
|
2497
|
-
val = np.exp(-gamma * np.sum((v - o)**2, axis=0))
|
|
2498
|
-
elif kernel == 'rbf':
|
|
2499
|
-
# vectors are expected to be normalized
|
|
2500
|
-
bandwidth = kwargs.get('bandwidth', None)
|
|
2501
|
-
gamma = kwargs.get('gamma', 1)
|
|
2502
|
-
d = np.sum((v - o)**2, axis=0)
|
|
2503
|
-
if bandwidth is not None:
|
|
2504
|
-
val = 0
|
|
2505
|
-
for a in bandwidth:
|
|
2506
|
-
gamma = 1.0 / (2 * a)
|
|
2507
|
-
val += np.exp(-gamma * d)
|
|
2508
|
-
else:
|
|
2509
|
-
# if no bandwidth is given, default to the gaussian kernel
|
|
2510
|
-
val = np.exp(-gamma * d)
|
|
2511
|
-
elif kernel == 'laplacian':
|
|
2512
|
-
gamma = kwargs.get('gamma', 1)
|
|
2513
|
-
val = np.exp(-gamma * np.sum(np.abs(v - o), axis=0))
|
|
2514
|
-
elif kernel == 'polynomial':
|
|
2515
|
-
gamma = kwargs.get('gamma', 1)
|
|
2516
|
-
degree = kwargs.get('degree', 3)
|
|
2517
|
-
coef = kwargs.get('coef', 1)
|
|
2518
|
-
val = (gamma * np.sum((v * o), axis=0) + coef)**degree
|
|
2519
|
-
elif kernel == 'sigmoid':
|
|
2520
|
-
gamma = kwargs.get('gamma', 1)
|
|
2521
|
-
coef = kwargs.get('coef', 1)
|
|
2522
|
-
val = np.tanh(gamma * np.sum((v * o), axis=0) + coef)
|
|
2523
|
-
elif kernel == 'linear':
|
|
2524
|
-
val = np.sum((v * o), axis=0)
|
|
2525
|
-
elif kernel == 'cauchy':
|
|
2526
|
-
gamma = kwargs.get('gamma', 1)
|
|
2527
|
-
val = 1 / (1 + np.sum((v - o)**2, axis=0) / gamma)
|
|
2528
|
-
elif kernel == 't-distribution':
|
|
2529
|
-
gamma = kwargs.get('gamma', 1)
|
|
2530
|
-
degree = kwargs.get('degree', 1)
|
|
2531
|
-
val = 1 / (1 + (np.sum((v - o)**2, axis=0) / (gamma * degree))**(degree + 1) / 2)
|
|
2532
|
-
elif kernel == 'inverse-multiquadric':
|
|
2533
|
-
gamma = kwargs.get('gamma', 1)
|
|
2534
|
-
val = 1 / np.sqrt(np.sum((v - o)**2, axis=0) / gamma**2 + 1)
|
|
2535
|
-
elif kernel == 'cosine':
|
|
2536
|
-
val = 1 - (np.sum(v * o, axis=0) / (np.sqrt(np.sum(v**2, axis=0)) * np.sqrt(np.sum(o**2, axis=0)) + eps))
|
|
2537
|
-
elif kernel == 'angular-cosine':
|
|
2538
|
-
c = kwargs.get('c', 1)
|
|
2539
|
-
val = c * np.arccos((np.sum(v * o, axis=0) / (np.sqrt(np.sum(v**2, axis=0)) * np.sqrt(np.sum(o**2, axis=0)) + eps))) / np.pi
|
|
2540
|
-
elif kernel == 'frechet':
|
|
2541
|
-
sigma1 = kwargs.get('sigma1', None)
|
|
2542
|
-
sigma2 = kwargs.get('sigma2', None)
|
|
2543
|
-
assert sigma1 is not None and sigma2 is not None, 'Frechet distance requires covariance matrices for both inputs'
|
|
2544
|
-
v = v.T
|
|
2545
|
-
o = o.T
|
|
2546
|
-
val = calculate_frechet_distance(v, sigma1, o, sigma2, eps)
|
|
2547
|
-
elif kernel == 'mmd':
|
|
2548
|
-
v = v.T
|
|
2549
|
-
o = o.T
|
|
2550
|
-
val = calculate_mmd(v, o, eps=eps)
|
|
2551
|
-
else:
|
|
2552
|
-
raise NotImplementedError(f"Kernel function {kernel} not implemented. Available functions: 'gaussian'")
|
|
2664
|
+
o = self._prepare_embedding_operand(other)
|
|
2665
|
+
handler = self._get_kernel_handler(kernel)
|
|
2666
|
+
val = handler(v, o, eps, kwargs)
|
|
2553
2667
|
|
|
2554
2668
|
# get the kernel value(s)
|
|
2555
2669
|
shape = val.shape
|
|
2556
|
-
if len(shape) >= 1 and shape[0] > 1
|
|
2557
|
-
|
|
2558
|
-
|
|
2670
|
+
val = val if len(shape) >= 1 and shape[0] > 1 else val.item()
|
|
2671
|
+
if normalize is not None:
|
|
2672
|
+
val = normalize(val)
|
|
2559
2673
|
return val
|
|
2560
2674
|
|
|
2561
|
-
def zip(self, **kwargs) ->
|
|
2675
|
+
def zip(self, **kwargs) -> list[tuple[str, list, dict]]:
|
|
2562
2676
|
'''
|
|
2563
2677
|
Zips the Symbol's value with its embeddings and a query containing the value.
|
|
2564
2678
|
This method zips the Symbol's value along with its embeddings and a query containing the value.
|
|
@@ -2577,19 +2691,21 @@ class EmbeddingPrimitives(Primitive):
|
|
|
2577
2691
|
elif isinstance(self.value, list):
|
|
2578
2692
|
pass
|
|
2579
2693
|
else:
|
|
2580
|
-
|
|
2694
|
+
msg = f'Expected id to be a string, got {type(self.value)}'
|
|
2695
|
+
UserMessage(msg)
|
|
2696
|
+
raise ValueError(msg)
|
|
2581
2697
|
|
|
2582
2698
|
embeds = self.embed(**kwargs).value
|
|
2583
2699
|
idx = [str(uuid.uuid4()) for _ in range(len(self.value))]
|
|
2584
2700
|
query = [{'text': str(self.value[i])} for i in range(len(self.value))]
|
|
2585
2701
|
|
|
2586
2702
|
# convert embeds to list if it is a tensor or numpy array
|
|
2587
|
-
if
|
|
2703
|
+
if isinstance(embeds, np.ndarray):
|
|
2588
2704
|
embeds = embeds.tolist()
|
|
2589
|
-
elif
|
|
2705
|
+
elif isinstance(embeds, torch.Tensor):
|
|
2590
2706
|
embeds = embeds.cpu().numpy().tolist()
|
|
2591
2707
|
|
|
2592
|
-
return list(zip(idx, embeds, query))
|
|
2708
|
+
return list(zip(idx, embeds, query, strict=False))
|
|
2593
2709
|
|
|
2594
2710
|
|
|
2595
2711
|
class IOHandlingPrimitives(Primitive):
|
|
@@ -2630,7 +2746,7 @@ class IOHandlingPrimitives(Primitive):
|
|
|
2630
2746
|
return self.sym_return_type(self.value if condition else '') | res
|
|
2631
2747
|
return self._to_type(self.value if condition else '') | self._to_type(res)
|
|
2632
2748
|
|
|
2633
|
-
def open(self, path: str = None, **kwargs) -> 'Symbol':
|
|
2749
|
+
def open(self, path: str | None = None, **kwargs) -> 'Symbol':
|
|
2634
2750
|
'''
|
|
2635
2751
|
Open a file and store its content in an Expression object as a string.
|
|
2636
2752
|
|
|
@@ -2654,7 +2770,9 @@ class IOHandlingPrimitives(Primitive):
|
|
|
2654
2770
|
|
|
2655
2771
|
path = path if path is not None else self.value
|
|
2656
2772
|
if path is None:
|
|
2657
|
-
|
|
2773
|
+
msg = 'Path is not provided; either provide a path or set the value of the Symbol to the path'
|
|
2774
|
+
UserMessage(msg)
|
|
2775
|
+
raise ValueError(msg)
|
|
2658
2776
|
|
|
2659
2777
|
@core.opening(path=path, **kwargs)
|
|
2660
2778
|
def _func(_):
|
|
@@ -2726,7 +2844,7 @@ class PersistencePrimitives(Primitive):
|
|
|
2726
2844
|
|
|
2727
2845
|
return func_name
|
|
2728
2846
|
|
|
2729
|
-
def save(self, path: str, replace:
|
|
2847
|
+
def save(self, path: str, replace: bool | None = False, serialize: bool | None = True) -> None:
|
|
2730
2848
|
'''
|
|
2731
2849
|
Save the current Symbol to a file.
|
|
2732
2850
|
|
|
@@ -2738,22 +2856,24 @@ class PersistencePrimitives(Primitive):
|
|
|
2738
2856
|
Returns:
|
|
2739
2857
|
Symbol: The current Symbol.
|
|
2740
2858
|
'''
|
|
2741
|
-
file_path = path
|
|
2859
|
+
file_path = Path(path)
|
|
2742
2860
|
|
|
2743
2861
|
if not replace:
|
|
2744
2862
|
cnt = 0
|
|
2745
|
-
|
|
2746
|
-
|
|
2747
|
-
|
|
2863
|
+
candidate = file_path
|
|
2864
|
+
while candidate.exists():
|
|
2865
|
+
candidate = candidate.with_name(f'{file_path.stem}_{cnt}{file_path.suffix}')
|
|
2748
2866
|
cnt += 1
|
|
2867
|
+
file_path = candidate
|
|
2749
2868
|
|
|
2750
2869
|
if serialize:
|
|
2751
2870
|
# serialize the object via pickle instead of writing the string
|
|
2752
|
-
|
|
2753
|
-
|
|
2871
|
+
path_str = str(file_path)
|
|
2872
|
+
pickle_path = Path(path_str if path_str.endswith('.pkl') else f'{path_str}.pkl')
|
|
2873
|
+
with pickle_path.open('wb') as f:
|
|
2754
2874
|
pickle.dump(self, file=f)
|
|
2755
2875
|
else:
|
|
2756
|
-
with open(
|
|
2876
|
+
with file_path.open('w') as f:
|
|
2757
2877
|
f.write(str(self))
|
|
2758
2878
|
|
|
2759
2879
|
def load(self, path: str) -> Any:
|
|
@@ -2766,9 +2886,8 @@ class PersistencePrimitives(Primitive):
|
|
|
2766
2886
|
Returns:
|
|
2767
2887
|
Symbol: The loaded Symbol.
|
|
2768
2888
|
'''
|
|
2769
|
-
with open(
|
|
2770
|
-
|
|
2771
|
-
return obj
|
|
2889
|
+
with Path(path).open('rb') as f:
|
|
2890
|
+
return pickle.load(f)
|
|
2772
2891
|
|
|
2773
2892
|
|
|
2774
2893
|
class OutputHandlingPrimitives(Primitive):
|
|
@@ -2789,7 +2908,7 @@ class OutputHandlingPrimitives(Primitive):
|
|
|
2789
2908
|
Symbol: The resulting Symbol after the output operation.
|
|
2790
2909
|
'''
|
|
2791
2910
|
@core.output(**kwargs)
|
|
2792
|
-
def _func(_, *
|
|
2911
|
+
def _func(_, *_func_args, **_func_kwargs):
|
|
2793
2912
|
return self.value
|
|
2794
2913
|
|
|
2795
2914
|
return self._to_type(_func(self, self.value, *args))
|
|
@@ -2833,12 +2952,13 @@ class FineTuningPrimitives(Primitive):
|
|
|
2833
2952
|
# return tensor
|
|
2834
2953
|
return self._metadata.data
|
|
2835
2954
|
# if the data is a numpy array, convert it to tensor
|
|
2836
|
-
|
|
2955
|
+
if isinstance(self._metadata.data, np.ndarray):
|
|
2837
2956
|
# convert to tensor
|
|
2838
2957
|
self._metadata.data = torch.from_numpy(self._metadata.data)
|
|
2839
2958
|
return self._metadata.data
|
|
2840
|
-
|
|
2841
|
-
|
|
2959
|
+
msg = f'Expected data to be a tensor or numpy array, got {type(self._metadata.data)}'
|
|
2960
|
+
UserMessage(msg)
|
|
2961
|
+
raise TypeError(msg)
|
|
2842
2962
|
|
|
2843
2963
|
@data.setter
|
|
2844
2964
|
def data(self, data: torch.Tensor) -> None:
|