griptape-nodes 0.59.2__py3-none-any.whl → 0.60.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- griptape_nodes/common/macro_parser/__init__.py +28 -0
- griptape_nodes/common/macro_parser/core.py +230 -0
- griptape_nodes/common/macro_parser/exceptions.py +23 -0
- griptape_nodes/common/macro_parser/formats.py +170 -0
- griptape_nodes/common/macro_parser/matching.py +134 -0
- griptape_nodes/common/macro_parser/parsing.py +172 -0
- griptape_nodes/common/macro_parser/resolution.py +168 -0
- griptape_nodes/common/macro_parser/segments.py +42 -0
- griptape_nodes/exe_types/core_types.py +241 -4
- griptape_nodes/exe_types/node_types.py +7 -1
- griptape_nodes/exe_types/param_components/huggingface/__init__.py +1 -0
- griptape_nodes/exe_types/param_components/huggingface/huggingface_model_parameter.py +168 -0
- griptape_nodes/exe_types/param_components/huggingface/huggingface_repo_file_parameter.py +38 -0
- griptape_nodes/exe_types/param_components/huggingface/huggingface_repo_parameter.py +33 -0
- griptape_nodes/exe_types/param_components/huggingface/huggingface_utils.py +136 -0
- griptape_nodes/exe_types/param_components/log_parameter.py +136 -0
- griptape_nodes/exe_types/param_components/seed_parameter.py +59 -0
- griptape_nodes/exe_types/param_types/__init__.py +1 -0
- griptape_nodes/exe_types/param_types/parameter_bool.py +221 -0
- griptape_nodes/exe_types/param_types/parameter_float.py +179 -0
- griptape_nodes/exe_types/param_types/parameter_int.py +183 -0
- griptape_nodes/exe_types/param_types/parameter_number.py +380 -0
- griptape_nodes/exe_types/param_types/parameter_string.py +232 -0
- griptape_nodes/node_library/library_registry.py +2 -1
- griptape_nodes/retained_mode/events/app_events.py +21 -0
- griptape_nodes/retained_mode/events/os_events.py +142 -6
- griptape_nodes/retained_mode/events/parameter_events.py +2 -0
- griptape_nodes/retained_mode/griptape_nodes.py +14 -0
- griptape_nodes/retained_mode/managers/agent_manager.py +5 -3
- griptape_nodes/retained_mode/managers/arbitrary_code_exec_manager.py +19 -1
- griptape_nodes/retained_mode/managers/library_manager.py +27 -32
- griptape_nodes/retained_mode/managers/node_manager.py +14 -1
- griptape_nodes/retained_mode/managers/os_manager.py +403 -124
- griptape_nodes/retained_mode/managers/user_manager.py +120 -0
- griptape_nodes/retained_mode/managers/workflow_manager.py +44 -34
- griptape_nodes/traits/multi_options.py +26 -2
- griptape_nodes/utils/huggingface_utils.py +136 -0
- {griptape_nodes-0.59.2.dist-info → griptape_nodes-0.60.0.dist-info}/METADATA +1 -1
- {griptape_nodes-0.59.2.dist-info → griptape_nodes-0.60.0.dist-info}/RECORD +41 -18
- {griptape_nodes-0.59.2.dist-info → griptape_nodes-0.60.0.dist-info}/WHEEL +1 -1
- {griptape_nodes-0.59.2.dist-info → griptape_nodes-0.60.0.dist-info}/entry_points.txt +0 -0
|
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
4
|
import uuid
|
|
5
|
+
import warnings
|
|
5
6
|
from abc import ABC, abstractmethod
|
|
6
7
|
from collections.abc import Callable
|
|
7
8
|
from copy import deepcopy
|
|
@@ -923,6 +924,9 @@ class ParameterBase(BaseNodeElement, ABC):
|
|
|
923
924
|
|
|
924
925
|
|
|
925
926
|
class Parameter(BaseNodeElement, UIOptionsMixin):
|
|
927
|
+
# Maximum number of input types to show in tooltip before truncating
|
|
928
|
+
_MAX_TOOLTIP_INPUT_TYPES = 3
|
|
929
|
+
|
|
926
930
|
# This is the list of types that the Parameter can accept, either externally or when internally treated as a property.
|
|
927
931
|
# Today, we can accept multiple types for input, but only a single output type.
|
|
928
932
|
tooltip: str | list[dict] # Default tooltip, can be string or list of dicts
|
|
@@ -956,11 +960,12 @@ class Parameter(BaseNodeElement, UIOptionsMixin):
|
|
|
956
960
|
next: Parameter | None = None
|
|
957
961
|
prev: Parameter | None = None
|
|
958
962
|
parent_container_name: str | None = None
|
|
963
|
+
parent_element_name: str | None = None
|
|
959
964
|
|
|
960
|
-
def __init__( # noqa: PLR0913,
|
|
965
|
+
def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915
|
|
961
966
|
self,
|
|
962
967
|
name: str,
|
|
963
|
-
tooltip: str | list[dict],
|
|
968
|
+
tooltip: str | list[dict] | None = None,
|
|
964
969
|
type: str | None = None, # noqa: A002
|
|
965
970
|
input_types: list[str] | None = None,
|
|
966
971
|
output_type: str | None = None,
|
|
@@ -974,12 +979,19 @@ class Parameter(BaseNodeElement, UIOptionsMixin):
|
|
|
974
979
|
traits: set[Trait.__class__ | Trait] | None = None, # We are going to make these children.
|
|
975
980
|
ui_options: dict | None = None,
|
|
976
981
|
*,
|
|
982
|
+
hide: bool = False,
|
|
983
|
+
hide_label: bool = False,
|
|
984
|
+
hide_property: bool = False,
|
|
985
|
+
allow_input: bool = True,
|
|
986
|
+
allow_property: bool = True,
|
|
987
|
+
allow_output: bool = True,
|
|
977
988
|
settable: bool = True,
|
|
978
989
|
serializable: bool = True,
|
|
979
990
|
user_defined: bool = False,
|
|
980
991
|
element_id: str | None = None,
|
|
981
992
|
element_type: str | None = None,
|
|
982
993
|
parent_container_name: str | None = None,
|
|
994
|
+
parent_element_name: str | None = None,
|
|
983
995
|
):
|
|
984
996
|
if not element_id:
|
|
985
997
|
element_id = str(uuid.uuid4().hex)
|
|
@@ -987,6 +999,11 @@ class Parameter(BaseNodeElement, UIOptionsMixin):
|
|
|
987
999
|
element_type = self.__class__.__name__
|
|
988
1000
|
super().__init__(element_id=element_id, element_type=element_type)
|
|
989
1001
|
self.name = name
|
|
1002
|
+
|
|
1003
|
+
# Generate default tooltip if none provided
|
|
1004
|
+
if not tooltip:
|
|
1005
|
+
tooltip = self._generate_default_tooltip(name, type, input_types, output_type)
|
|
1006
|
+
|
|
990
1007
|
self.tooltip = tooltip
|
|
991
1008
|
self.default_value = default_value
|
|
992
1009
|
self.tooltip_as_input = tooltip_as_input
|
|
@@ -995,11 +1012,37 @@ class Parameter(BaseNodeElement, UIOptionsMixin):
|
|
|
995
1012
|
self._settable = settable
|
|
996
1013
|
self.serializable = serializable
|
|
997
1014
|
self.user_defined = user_defined
|
|
1015
|
+
|
|
1016
|
+
# Process allowed_modes - use convenience parameters if allowed_modes not explicitly set
|
|
998
1017
|
if allowed_modes is None:
|
|
999
|
-
self._allowed_modes =
|
|
1018
|
+
self._allowed_modes = set()
|
|
1019
|
+
if allow_input:
|
|
1020
|
+
self._allowed_modes.add(ParameterMode.INPUT)
|
|
1021
|
+
if allow_property:
|
|
1022
|
+
self._allowed_modes.add(ParameterMode.PROPERTY)
|
|
1023
|
+
if allow_output:
|
|
1024
|
+
self._allowed_modes.add(ParameterMode.OUTPUT)
|
|
1000
1025
|
else:
|
|
1001
1026
|
self._allowed_modes = allowed_modes
|
|
1002
1027
|
|
|
1028
|
+
# Warn if both allowed_modes and convenience parameters are set
|
|
1029
|
+
convenience_params_used = []
|
|
1030
|
+
if not allow_input:
|
|
1031
|
+
convenience_params_used.append("allow_input=False")
|
|
1032
|
+
if not allow_property:
|
|
1033
|
+
convenience_params_used.append("allow_property=False")
|
|
1034
|
+
if not allow_output:
|
|
1035
|
+
convenience_params_used.append("allow_output=False")
|
|
1036
|
+
|
|
1037
|
+
if convenience_params_used:
|
|
1038
|
+
warnings.warn(
|
|
1039
|
+
f"Parameter '{name}': Both 'allowed_modes' and convenience parameters "
|
|
1040
|
+
f"({', '.join(convenience_params_used)}) are set. Using 'allowed_modes' "
|
|
1041
|
+
f"and ignoring convenience parameters.",
|
|
1042
|
+
UserWarning,
|
|
1043
|
+
stacklevel=2,
|
|
1044
|
+
)
|
|
1045
|
+
|
|
1003
1046
|
if converters is None:
|
|
1004
1047
|
self._converters = []
|
|
1005
1048
|
else:
|
|
@@ -1009,10 +1052,20 @@ class Parameter(BaseNodeElement, UIOptionsMixin):
|
|
|
1009
1052
|
self._validators = []
|
|
1010
1053
|
else:
|
|
1011
1054
|
self._validators = validators
|
|
1055
|
+
|
|
1056
|
+
# Process common UI options from constructor parameters
|
|
1012
1057
|
if ui_options is None:
|
|
1013
1058
|
self._ui_options = {}
|
|
1014
1059
|
else:
|
|
1015
|
-
self._ui_options = ui_options
|
|
1060
|
+
self._ui_options = ui_options.copy()
|
|
1061
|
+
|
|
1062
|
+
# Add common UI options if they have truthy values
|
|
1063
|
+
if hide:
|
|
1064
|
+
self._ui_options["hide"] = hide
|
|
1065
|
+
if hide_label:
|
|
1066
|
+
self._ui_options["hide_label"] = hide_label
|
|
1067
|
+
if hide_property:
|
|
1068
|
+
self._ui_options["hide_property"] = hide_property
|
|
1016
1069
|
if traits:
|
|
1017
1070
|
for trait in traits:
|
|
1018
1071
|
if not isinstance(trait, Trait):
|
|
@@ -1026,6 +1079,62 @@ class Parameter(BaseNodeElement, UIOptionsMixin):
|
|
|
1026
1079
|
self.input_types = input_types
|
|
1027
1080
|
self.output_type = output_type
|
|
1028
1081
|
self.parent_container_name = parent_container_name
|
|
1082
|
+
self.parent_element_name = parent_element_name
|
|
1083
|
+
|
|
1084
|
+
def _generate_default_tooltip(
|
|
1085
|
+
self,
|
|
1086
|
+
name: str,
|
|
1087
|
+
type: str | None, # noqa: A002
|
|
1088
|
+
input_types: list[str] | None,
|
|
1089
|
+
output_type: str | None,
|
|
1090
|
+
) -> str:
|
|
1091
|
+
"""Generate a default tooltip describing the parameter type and usage.
|
|
1092
|
+
|
|
1093
|
+
Args:
|
|
1094
|
+
name: The parameter name
|
|
1095
|
+
type: The parameter type
|
|
1096
|
+
input_types: List of accepted input types
|
|
1097
|
+
output_type: The output type
|
|
1098
|
+
|
|
1099
|
+
Returns:
|
|
1100
|
+
A descriptive tooltip string
|
|
1101
|
+
"""
|
|
1102
|
+
# Determine the primary type to describe
|
|
1103
|
+
primary_type = type
|
|
1104
|
+
if not primary_type and input_types:
|
|
1105
|
+
primary_type = input_types[0]
|
|
1106
|
+
if not primary_type and output_type:
|
|
1107
|
+
primary_type = output_type
|
|
1108
|
+
if not primary_type:
|
|
1109
|
+
primary_type = "any"
|
|
1110
|
+
|
|
1111
|
+
# Create a human-readable description
|
|
1112
|
+
type_descriptions = {
|
|
1113
|
+
"str": "text/string",
|
|
1114
|
+
"bool": "boolean (true/false)",
|
|
1115
|
+
"int": "integer number",
|
|
1116
|
+
"float": "decimal number",
|
|
1117
|
+
"any": "any type of data",
|
|
1118
|
+
"list": "list/array",
|
|
1119
|
+
"dict": "dictionary/object",
|
|
1120
|
+
"parametercontroltype": "control flow",
|
|
1121
|
+
}
|
|
1122
|
+
|
|
1123
|
+
type_desc = type_descriptions.get(primary_type.lower(), primary_type)
|
|
1124
|
+
|
|
1125
|
+
# Build the tooltip
|
|
1126
|
+
tooltip_parts = [f"Enter {type_desc} for {name}"]
|
|
1127
|
+
|
|
1128
|
+
# Add input type info if different from primary type
|
|
1129
|
+
if input_types and len(input_types) > 1:
|
|
1130
|
+
input_desc = ", ".join(
|
|
1131
|
+
type_descriptions.get(t.lower(), t) for t in input_types[: self._MAX_TOOLTIP_INPUT_TYPES]
|
|
1132
|
+
)
|
|
1133
|
+
if len(input_types) > self._MAX_TOOLTIP_INPUT_TYPES:
|
|
1134
|
+
input_desc += f" or {len(input_types) - self._MAX_TOOLTIP_INPUT_TYPES} other types"
|
|
1135
|
+
tooltip_parts.append(f"Accepts: {input_desc}")
|
|
1136
|
+
|
|
1137
|
+
return ". ".join(tooltip_parts) + "."
|
|
1029
1138
|
|
|
1030
1139
|
def to_dict(self) -> dict[str, Any]:
|
|
1031
1140
|
"""Returns a nested dictionary representation of this node and its children."""
|
|
@@ -1055,6 +1164,8 @@ class Parameter(BaseNodeElement, UIOptionsMixin):
|
|
|
1055
1164
|
our_dict["mode_allowed_property"] = allows_property
|
|
1056
1165
|
our_dict["mode_allowed_output"] = allows_output
|
|
1057
1166
|
our_dict["parent_container_name"] = self.parent_container_name
|
|
1167
|
+
our_dict["parent_element_name"] = self.parent_element_name
|
|
1168
|
+
our_dict["parent_group_name"] = self.parent_group_name
|
|
1058
1169
|
|
|
1059
1170
|
return our_dict
|
|
1060
1171
|
|
|
@@ -1156,6 +1267,132 @@ class Parameter(BaseNodeElement, UIOptionsMixin):
|
|
|
1156
1267
|
def ui_options(self, value: dict) -> None:
|
|
1157
1268
|
self._ui_options = value
|
|
1158
1269
|
|
|
1270
|
+
@property
|
|
1271
|
+
def hide(self) -> bool:
|
|
1272
|
+
"""Get whether the entire parameter is hidden in the UI.
|
|
1273
|
+
|
|
1274
|
+
Returns:
|
|
1275
|
+
True if the parameter should be hidden, False otherwise
|
|
1276
|
+
"""
|
|
1277
|
+
return self.ui_options.get("hide", False)
|
|
1278
|
+
|
|
1279
|
+
@hide.setter
|
|
1280
|
+
@BaseNodeElement.emits_update_on_write
|
|
1281
|
+
def hide(self, value: bool) -> None:
|
|
1282
|
+
"""Set whether to hide the entire parameter in the UI.
|
|
1283
|
+
|
|
1284
|
+
Args:
|
|
1285
|
+
value: True to hide the parameter, False to show it
|
|
1286
|
+
"""
|
|
1287
|
+
self.update_ui_options_key("hide", value)
|
|
1288
|
+
|
|
1289
|
+
@property
|
|
1290
|
+
def hide_label(self) -> bool:
|
|
1291
|
+
"""Get whether the parameter label is hidden in the UI.
|
|
1292
|
+
|
|
1293
|
+
Returns:
|
|
1294
|
+
True if the label should be hidden, False otherwise
|
|
1295
|
+
"""
|
|
1296
|
+
return self.ui_options.get("hide_label", False)
|
|
1297
|
+
|
|
1298
|
+
@hide_label.setter
|
|
1299
|
+
@BaseNodeElement.emits_update_on_write
|
|
1300
|
+
def hide_label(self, value: bool) -> None:
|
|
1301
|
+
"""Set whether to hide the parameter label in the UI.
|
|
1302
|
+
|
|
1303
|
+
Args:
|
|
1304
|
+
value: True to hide the label, False to show it
|
|
1305
|
+
"""
|
|
1306
|
+
self.update_ui_options_key("hide_label", value)
|
|
1307
|
+
|
|
1308
|
+
@property
|
|
1309
|
+
def hide_property(self) -> bool:
|
|
1310
|
+
"""Get whether the parameter is hidden in property mode.
|
|
1311
|
+
|
|
1312
|
+
Returns:
|
|
1313
|
+
True if the parameter should be hidden in property mode, False otherwise
|
|
1314
|
+
"""
|
|
1315
|
+
return self.ui_options.get("hide_property", False)
|
|
1316
|
+
|
|
1317
|
+
@hide_property.setter
|
|
1318
|
+
@BaseNodeElement.emits_update_on_write
|
|
1319
|
+
def hide_property(self, value: bool) -> None:
|
|
1320
|
+
"""Set whether to hide the parameter in property mode.
|
|
1321
|
+
|
|
1322
|
+
Args:
|
|
1323
|
+
value: True to hide in property mode, False to show it
|
|
1324
|
+
"""
|
|
1325
|
+
self.update_ui_options_key("hide_property", value)
|
|
1326
|
+
|
|
1327
|
+
@property
|
|
1328
|
+
def allow_input(self) -> bool:
|
|
1329
|
+
"""Get whether the parameter allows INPUT mode.
|
|
1330
|
+
|
|
1331
|
+
Returns:
|
|
1332
|
+
True if INPUT mode is allowed, False otherwise
|
|
1333
|
+
"""
|
|
1334
|
+
return ParameterMode.INPUT in self.allowed_modes
|
|
1335
|
+
|
|
1336
|
+
@allow_input.setter
|
|
1337
|
+
def allow_input(self, value: bool) -> None:
|
|
1338
|
+
"""Set whether to allow INPUT mode.
|
|
1339
|
+
|
|
1340
|
+
Args:
|
|
1341
|
+
value: True to allow INPUT mode, False to disallow it
|
|
1342
|
+
"""
|
|
1343
|
+
current_modes = self.allowed_modes.copy()
|
|
1344
|
+
if value:
|
|
1345
|
+
current_modes.add(ParameterMode.INPUT)
|
|
1346
|
+
else:
|
|
1347
|
+
current_modes.discard(ParameterMode.INPUT)
|
|
1348
|
+
self.allowed_modes = current_modes
|
|
1349
|
+
|
|
1350
|
+
@property
|
|
1351
|
+
def allow_property(self) -> bool:
|
|
1352
|
+
"""Get whether the parameter allows PROPERTY mode.
|
|
1353
|
+
|
|
1354
|
+
Returns:
|
|
1355
|
+
True if PROPERTY mode is allowed, False otherwise
|
|
1356
|
+
"""
|
|
1357
|
+
return ParameterMode.PROPERTY in self.allowed_modes
|
|
1358
|
+
|
|
1359
|
+
@allow_property.setter
|
|
1360
|
+
def allow_property(self, value: bool) -> None:
|
|
1361
|
+
"""Set whether to allow PROPERTY mode.
|
|
1362
|
+
|
|
1363
|
+
Args:
|
|
1364
|
+
value: True to allow PROPERTY mode, False to disallow it
|
|
1365
|
+
"""
|
|
1366
|
+
current_modes = self.allowed_modes.copy()
|
|
1367
|
+
if value:
|
|
1368
|
+
current_modes.add(ParameterMode.PROPERTY)
|
|
1369
|
+
else:
|
|
1370
|
+
current_modes.discard(ParameterMode.PROPERTY)
|
|
1371
|
+
self.allowed_modes = current_modes
|
|
1372
|
+
|
|
1373
|
+
@property
|
|
1374
|
+
def allow_output(self) -> bool:
|
|
1375
|
+
"""Get whether the parameter allows OUTPUT mode.
|
|
1376
|
+
|
|
1377
|
+
Returns:
|
|
1378
|
+
True if OUTPUT mode is allowed, False otherwise
|
|
1379
|
+
"""
|
|
1380
|
+
return ParameterMode.OUTPUT in self.allowed_modes
|
|
1381
|
+
|
|
1382
|
+
@allow_output.setter
|
|
1383
|
+
def allow_output(self, value: bool) -> None:
|
|
1384
|
+
"""Set whether to allow OUTPUT mode.
|
|
1385
|
+
|
|
1386
|
+
Args:
|
|
1387
|
+
value: True to allow OUTPUT mode, False to disallow it
|
|
1388
|
+
"""
|
|
1389
|
+
current_modes = self.allowed_modes.copy()
|
|
1390
|
+
if value:
|
|
1391
|
+
current_modes.add(ParameterMode.OUTPUT)
|
|
1392
|
+
else:
|
|
1393
|
+
current_modes.discard(ParameterMode.OUTPUT)
|
|
1394
|
+
self.allowed_modes = current_modes
|
|
1395
|
+
|
|
1159
1396
|
@property
|
|
1160
1397
|
def input_types(self) -> list[str]:
|
|
1161
1398
|
return self._custom_getter_for_property_input_types()
|
|
@@ -458,7 +458,13 @@ class BaseNode(ABC):
|
|
|
458
458
|
if self.does_name_exist(param.name):
|
|
459
459
|
msg = f"Cannot have duplicate names on parameters. Encountered two instances of '{param.name}'."
|
|
460
460
|
raise ValueError(msg)
|
|
461
|
-
|
|
461
|
+
parameter_group = (
|
|
462
|
+
self.get_group_by_name_or_element_id(param.parent_element_name) if param.parent_element_name else None
|
|
463
|
+
)
|
|
464
|
+
if parameter_group is not None:
|
|
465
|
+
parameter_group.add_child(param)
|
|
466
|
+
else:
|
|
467
|
+
self.add_node_element(param)
|
|
462
468
|
self._emit_parameter_lifecycle_event(param)
|
|
463
469
|
|
|
464
470
|
def remove_parameter_element_by_name(self, element_name: str) -> None:
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Reusable HuggingFace parameters."""
|
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import re
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
|
|
5
|
+
from griptape_nodes.exe_types.core_types import Parameter, ParameterMessage, ParameterMode
|
|
6
|
+
from griptape_nodes.exe_types.node_types import BaseNode
|
|
7
|
+
from griptape_nodes.traits.options import Options
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger("griptape_nodes")
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class HuggingFaceModelParameter(ABC):
|
|
13
|
+
@classmethod
|
|
14
|
+
def _repo_revision_to_key(cls, repo_revision: tuple[str, str]) -> str:
|
|
15
|
+
return f"{repo_revision[0]} ({repo_revision[1]})"
|
|
16
|
+
|
|
17
|
+
@classmethod
|
|
18
|
+
def _key_to_repo_revision(cls, key: str) -> tuple[str, str]:
|
|
19
|
+
# Check if key has hash format using regex
|
|
20
|
+
hash_pattern = r"^(.+) \(([a-f0-9]{40})\)$"
|
|
21
|
+
match = re.match(hash_pattern, key)
|
|
22
|
+
if match:
|
|
23
|
+
return match.group(1), match.group(2)
|
|
24
|
+
|
|
25
|
+
# Key is just the model name (no hash)
|
|
26
|
+
return key, ""
|
|
27
|
+
|
|
28
|
+
def __init__(self, node: BaseNode, parameter_name: str):
|
|
29
|
+
self._node = node
|
|
30
|
+
self._parameter_name = parameter_name
|
|
31
|
+
self._repo_revisions = []
|
|
32
|
+
|
|
33
|
+
def refresh_parameters(self) -> None:
|
|
34
|
+
parameter = self._node.get_parameter_by_name(self._parameter_name)
|
|
35
|
+
if parameter is None:
|
|
36
|
+
logger.debug(
|
|
37
|
+
"Parameter '%s' not found on node '%s'; cannot refresh choices.",
|
|
38
|
+
self._parameter_name,
|
|
39
|
+
self._node.name,
|
|
40
|
+
)
|
|
41
|
+
return
|
|
42
|
+
|
|
43
|
+
choices = self.get_choices()
|
|
44
|
+
|
|
45
|
+
current_value = self._node.get_parameter_value(self._parameter_name)
|
|
46
|
+
if current_value in choices:
|
|
47
|
+
default_value = current_value
|
|
48
|
+
else:
|
|
49
|
+
default_value = choices[0]
|
|
50
|
+
|
|
51
|
+
if parameter.find_elements_by_type(Options):
|
|
52
|
+
self._node._update_option_choices(self._parameter_name, choices, default_value)
|
|
53
|
+
else:
|
|
54
|
+
parameter.add_trait(Options(choices=choices))
|
|
55
|
+
|
|
56
|
+
def add_input_parameters(self) -> None:
|
|
57
|
+
choices = self.get_choices()
|
|
58
|
+
|
|
59
|
+
if not choices:
|
|
60
|
+
self._node.add_node_element(
|
|
61
|
+
ParameterMessage(
|
|
62
|
+
name=f"huggingface_repo_parameter_message_{self._parameter_name}",
|
|
63
|
+
title="Huggingface Model Download Required",
|
|
64
|
+
variant="warning",
|
|
65
|
+
value=self.get_help_message(),
|
|
66
|
+
button_link=f"#model-management?search={self.get_download_models()[0]}",
|
|
67
|
+
button_text="Model Management",
|
|
68
|
+
button_icon="hard-drive",
|
|
69
|
+
)
|
|
70
|
+
)
|
|
71
|
+
return
|
|
72
|
+
|
|
73
|
+
self._node.add_parameter(
|
|
74
|
+
Parameter(
|
|
75
|
+
name=self._parameter_name,
|
|
76
|
+
default_value=choices[0] if choices else None,
|
|
77
|
+
input_types=["str"],
|
|
78
|
+
type="str",
|
|
79
|
+
ui_options={"display_name": self._parameter_name, "show_search": True},
|
|
80
|
+
traits={
|
|
81
|
+
Options(
|
|
82
|
+
choices=choices,
|
|
83
|
+
)
|
|
84
|
+
},
|
|
85
|
+
tooltip=self._parameter_name,
|
|
86
|
+
allowed_modes={ParameterMode.PROPERTY},
|
|
87
|
+
)
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
def remove_input_parameters(self) -> None:
|
|
91
|
+
self._node.remove_parameter_element_by_name(self._parameter_name)
|
|
92
|
+
self._node.remove_parameter_element_by_name(f"huggingface_repo_parameter_message_{self._parameter_name}")
|
|
93
|
+
|
|
94
|
+
def get_choices(self) -> list[str]:
|
|
95
|
+
# Ensure the latest repo revisions are fetched
|
|
96
|
+
self._repo_revisions = self.fetch_repo_revisions()
|
|
97
|
+
# Count occurrences of each model name
|
|
98
|
+
model_counts = {}
|
|
99
|
+
for repo_id, _ in self.list_repo_revisions():
|
|
100
|
+
model_counts[repo_id] = model_counts.get(repo_id, 0) + 1
|
|
101
|
+
|
|
102
|
+
# Generate choices with hash only when there are duplicates
|
|
103
|
+
choices = []
|
|
104
|
+
for repo_revision in self.list_repo_revisions():
|
|
105
|
+
repo_id, _ = repo_revision
|
|
106
|
+
if model_counts[repo_id] > 1:
|
|
107
|
+
# Multiple versions exist, show hash for disambiguation
|
|
108
|
+
choices.append(self._repo_revision_to_key(repo_revision))
|
|
109
|
+
else:
|
|
110
|
+
# Only one version, show just the model name
|
|
111
|
+
choices.append(repo_id)
|
|
112
|
+
logger.debug("Available choices for parameter '%s': %s", self._parameter_name, choices)
|
|
113
|
+
return choices
|
|
114
|
+
|
|
115
|
+
def validate_before_node_run(self) -> list[Exception] | None:
|
|
116
|
+
self.refresh_parameters()
|
|
117
|
+
try:
|
|
118
|
+
self.get_repo_revision()
|
|
119
|
+
except Exception as e:
|
|
120
|
+
return [e]
|
|
121
|
+
|
|
122
|
+
return None
|
|
123
|
+
|
|
124
|
+
def list_repo_revisions(self) -> list[tuple[str, str]]:
|
|
125
|
+
return self._repo_revisions
|
|
126
|
+
|
|
127
|
+
def get_repo_revision(self) -> tuple[str, str]:
|
|
128
|
+
value = self._node.get_parameter_value(self._parameter_name)
|
|
129
|
+
if value is None:
|
|
130
|
+
msg = "Model download required!"
|
|
131
|
+
raise RuntimeError(msg)
|
|
132
|
+
|
|
133
|
+
# Parse the value using _key_to_repo_revision
|
|
134
|
+
repo_id, revision = self._key_to_repo_revision(value)
|
|
135
|
+
|
|
136
|
+
# If revision is empty (just model name), find it in our stored list
|
|
137
|
+
if not revision:
|
|
138
|
+
for stored_repo_id, stored_revision in self._repo_revisions:
|
|
139
|
+
if stored_repo_id == repo_id:
|
|
140
|
+
logger.debug("Using revision '%s' for model '%s'", stored_revision, repo_id)
|
|
141
|
+
return stored_repo_id, stored_revision
|
|
142
|
+
# If not found, raise an error
|
|
143
|
+
msg = f"Model '{repo_id}' not found in available models!"
|
|
144
|
+
raise RuntimeError(msg)
|
|
145
|
+
|
|
146
|
+
# If revision was provided, return it directly
|
|
147
|
+
return repo_id, revision
|
|
148
|
+
|
|
149
|
+
def get_help_message(self) -> str:
|
|
150
|
+
download_models = "\n".join([f" {model}" for model in self.get_download_models()])
|
|
151
|
+
|
|
152
|
+
return (
|
|
153
|
+
"Model download required to continue.\n\n"
|
|
154
|
+
"To download models:\n\n"
|
|
155
|
+
"1. Navigate to Settings -> Model Management\n\n"
|
|
156
|
+
"2. Search for the model(s) you need and click the download button:\n"
|
|
157
|
+
f"{download_models}\n\n"
|
|
158
|
+
"After completing these steps, a dropdown menu with available models will appear."
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
@abstractmethod
|
|
162
|
+
def fetch_repo_revisions(self) -> list[tuple[str, str]]: ...
|
|
163
|
+
|
|
164
|
+
@abstractmethod
|
|
165
|
+
def get_download_commands(self) -> list[str]: ...
|
|
166
|
+
|
|
167
|
+
@abstractmethod
|
|
168
|
+
def get_download_models(self) -> list[str]: ...
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from griptape_nodes.exe_types.node_types import BaseNode
|
|
4
|
+
from griptape_nodes.exe_types.param_components.huggingface.huggingface_model_parameter import HuggingFaceModelParameter
|
|
5
|
+
from griptape_nodes.exe_types.param_components.huggingface.huggingface_utils import (
|
|
6
|
+
list_repo_revisions_with_file_in_cache,
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger("griptape_nodes")
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class HuggingFaceRepoFileParameter(HuggingFaceModelParameter):
|
|
13
|
+
def __init__(self, node: BaseNode, repo_files: list[tuple[str, str]], parameter_name: str = "model"):
|
|
14
|
+
super().__init__(node, parameter_name)
|
|
15
|
+
self._repo_files = repo_files
|
|
16
|
+
self.refresh_parameters()
|
|
17
|
+
|
|
18
|
+
def fetch_repo_revisions(self) -> list[tuple[str, str]]:
|
|
19
|
+
return [
|
|
20
|
+
repo_revision
|
|
21
|
+
for (repo, file) in self._repo_files
|
|
22
|
+
for repo_revision in list_repo_revisions_with_file_in_cache(repo, file)
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
def get_download_commands(self) -> list[str]:
|
|
26
|
+
return [f'huggingface-cli download "{repo}" "{file}"' for (repo, file) in self._repo_files]
|
|
27
|
+
|
|
28
|
+
def get_download_models(self) -> list[str]:
|
|
29
|
+
"""Returns a list of model names that should be downloaded."""
|
|
30
|
+
return [repo for (repo, file) in self._repo_files]
|
|
31
|
+
|
|
32
|
+
def get_repo_filename(self) -> str:
|
|
33
|
+
repo_id, _ = self.get_repo_revision()
|
|
34
|
+
for repo, file in self._repo_files:
|
|
35
|
+
if repo == repo_id:
|
|
36
|
+
return file
|
|
37
|
+
msg = f"File not found for repository {repo_id}"
|
|
38
|
+
raise ValueError(msg)
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from griptape_nodes.exe_types.node_types import BaseNode
|
|
4
|
+
from griptape_nodes.exe_types.param_components.huggingface.huggingface_model_parameter import HuggingFaceModelParameter
|
|
5
|
+
from griptape_nodes.exe_types.param_components.huggingface.huggingface_utils import (
|
|
6
|
+
list_all_repo_revisions_in_cache,
|
|
7
|
+
list_repo_revisions_in_cache,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger("griptape_nodes")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class HuggingFaceRepoParameter(HuggingFaceModelParameter):
|
|
14
|
+
def __init__(
|
|
15
|
+
self, node: BaseNode, repo_ids: list[str], parameter_name: str = "model", *, list_all_models: bool = False
|
|
16
|
+
):
|
|
17
|
+
super().__init__(node, parameter_name)
|
|
18
|
+
self._repo_ids = repo_ids
|
|
19
|
+
self._list_all_models = list_all_models
|
|
20
|
+
self.refresh_parameters()
|
|
21
|
+
|
|
22
|
+
def fetch_repo_revisions(self) -> list[tuple[str, str]]:
|
|
23
|
+
if self._list_all_models:
|
|
24
|
+
all_revisions = list_all_repo_revisions_in_cache()
|
|
25
|
+
return sorted(all_revisions, key=lambda x: x[0] not in self._repo_ids)
|
|
26
|
+
return [repo_revision for repo in self._repo_ids for repo_revision in list_repo_revisions_in_cache(repo)]
|
|
27
|
+
|
|
28
|
+
def get_download_commands(self) -> list[str]:
|
|
29
|
+
return [f'huggingface-cli download "{repo}"' for repo in self._repo_ids]
|
|
30
|
+
|
|
31
|
+
def get_download_models(self) -> list[str]:
|
|
32
|
+
"""Returns a list of model names that should be downloaded."""
|
|
33
|
+
return self._repo_ids
|