griptape-nodes 0.59.2__py3-none-any.whl → 0.60.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. griptape_nodes/common/macro_parser/__init__.py +28 -0
  2. griptape_nodes/common/macro_parser/core.py +230 -0
  3. griptape_nodes/common/macro_parser/exceptions.py +23 -0
  4. griptape_nodes/common/macro_parser/formats.py +170 -0
  5. griptape_nodes/common/macro_parser/matching.py +134 -0
  6. griptape_nodes/common/macro_parser/parsing.py +172 -0
  7. griptape_nodes/common/macro_parser/resolution.py +168 -0
  8. griptape_nodes/common/macro_parser/segments.py +42 -0
  9. griptape_nodes/exe_types/core_types.py +241 -4
  10. griptape_nodes/exe_types/node_types.py +7 -1
  11. griptape_nodes/exe_types/param_components/huggingface/__init__.py +1 -0
  12. griptape_nodes/exe_types/param_components/huggingface/huggingface_model_parameter.py +168 -0
  13. griptape_nodes/exe_types/param_components/huggingface/huggingface_repo_file_parameter.py +38 -0
  14. griptape_nodes/exe_types/param_components/huggingface/huggingface_repo_parameter.py +33 -0
  15. griptape_nodes/exe_types/param_components/huggingface/huggingface_utils.py +136 -0
  16. griptape_nodes/exe_types/param_components/log_parameter.py +136 -0
  17. griptape_nodes/exe_types/param_components/seed_parameter.py +59 -0
  18. griptape_nodes/exe_types/param_types/__init__.py +1 -0
  19. griptape_nodes/exe_types/param_types/parameter_bool.py +221 -0
  20. griptape_nodes/exe_types/param_types/parameter_float.py +179 -0
  21. griptape_nodes/exe_types/param_types/parameter_int.py +183 -0
  22. griptape_nodes/exe_types/param_types/parameter_number.py +380 -0
  23. griptape_nodes/exe_types/param_types/parameter_string.py +232 -0
  24. griptape_nodes/node_library/library_registry.py +2 -1
  25. griptape_nodes/retained_mode/events/app_events.py +21 -0
  26. griptape_nodes/retained_mode/events/os_events.py +142 -6
  27. griptape_nodes/retained_mode/events/parameter_events.py +2 -0
  28. griptape_nodes/retained_mode/griptape_nodes.py +14 -0
  29. griptape_nodes/retained_mode/managers/agent_manager.py +5 -3
  30. griptape_nodes/retained_mode/managers/arbitrary_code_exec_manager.py +19 -1
  31. griptape_nodes/retained_mode/managers/library_manager.py +27 -32
  32. griptape_nodes/retained_mode/managers/node_manager.py +14 -1
  33. griptape_nodes/retained_mode/managers/os_manager.py +403 -124
  34. griptape_nodes/retained_mode/managers/user_manager.py +120 -0
  35. griptape_nodes/retained_mode/managers/workflow_manager.py +44 -34
  36. griptape_nodes/traits/multi_options.py +26 -2
  37. griptape_nodes/utils/huggingface_utils.py +136 -0
  38. {griptape_nodes-0.59.2.dist-info → griptape_nodes-0.60.0.dist-info}/METADATA +1 -1
  39. {griptape_nodes-0.59.2.dist-info → griptape_nodes-0.60.0.dist-info}/RECORD +41 -18
  40. {griptape_nodes-0.59.2.dist-info → griptape_nodes-0.60.0.dist-info}/WHEEL +1 -1
  41. {griptape_nodes-0.59.2.dist-info → griptape_nodes-0.60.0.dist-info}/entry_points.txt +0 -0
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import logging
4
4
  import uuid
5
+ import warnings
5
6
  from abc import ABC, abstractmethod
6
7
  from collections.abc import Callable
7
8
  from copy import deepcopy
@@ -923,6 +924,9 @@ class ParameterBase(BaseNodeElement, ABC):
923
924
 
924
925
 
925
926
  class Parameter(BaseNodeElement, UIOptionsMixin):
927
+ # Maximum number of input types to show in tooltip before truncating
928
+ _MAX_TOOLTIP_INPUT_TYPES = 3
929
+
926
930
  # This is the list of types that the Parameter can accept, either externally or when internally treated as a property.
927
931
  # Today, we can accept multiple types for input, but only a single output type.
928
932
  tooltip: str | list[dict] # Default tooltip, can be string or list of dicts
@@ -956,11 +960,12 @@ class Parameter(BaseNodeElement, UIOptionsMixin):
956
960
  next: Parameter | None = None
957
961
  prev: Parameter | None = None
958
962
  parent_container_name: str | None = None
963
+ parent_element_name: str | None = None
959
964
 
960
- def __init__( # noqa: PLR0913,PLR0912
965
+ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915
961
966
  self,
962
967
  name: str,
963
- tooltip: str | list[dict],
968
+ tooltip: str | list[dict] | None = None,
964
969
  type: str | None = None, # noqa: A002
965
970
  input_types: list[str] | None = None,
966
971
  output_type: str | None = None,
@@ -974,12 +979,19 @@ class Parameter(BaseNodeElement, UIOptionsMixin):
974
979
  traits: set[Trait.__class__ | Trait] | None = None, # We are going to make these children.
975
980
  ui_options: dict | None = None,
976
981
  *,
982
+ hide: bool = False,
983
+ hide_label: bool = False,
984
+ hide_property: bool = False,
985
+ allow_input: bool = True,
986
+ allow_property: bool = True,
987
+ allow_output: bool = True,
977
988
  settable: bool = True,
978
989
  serializable: bool = True,
979
990
  user_defined: bool = False,
980
991
  element_id: str | None = None,
981
992
  element_type: str | None = None,
982
993
  parent_container_name: str | None = None,
994
+ parent_element_name: str | None = None,
983
995
  ):
984
996
  if not element_id:
985
997
  element_id = str(uuid.uuid4().hex)
@@ -987,6 +999,11 @@ class Parameter(BaseNodeElement, UIOptionsMixin):
987
999
  element_type = self.__class__.__name__
988
1000
  super().__init__(element_id=element_id, element_type=element_type)
989
1001
  self.name = name
1002
+
1003
+ # Generate default tooltip if none provided
1004
+ if not tooltip:
1005
+ tooltip = self._generate_default_tooltip(name, type, input_types, output_type)
1006
+
990
1007
  self.tooltip = tooltip
991
1008
  self.default_value = default_value
992
1009
  self.tooltip_as_input = tooltip_as_input
@@ -995,11 +1012,37 @@ class Parameter(BaseNodeElement, UIOptionsMixin):
995
1012
  self._settable = settable
996
1013
  self.serializable = serializable
997
1014
  self.user_defined = user_defined
1015
+
1016
+ # Process allowed_modes - use convenience parameters if allowed_modes not explicitly set
998
1017
  if allowed_modes is None:
999
- self._allowed_modes = {ParameterMode.INPUT, ParameterMode.OUTPUT, ParameterMode.PROPERTY}
1018
+ self._allowed_modes = set()
1019
+ if allow_input:
1020
+ self._allowed_modes.add(ParameterMode.INPUT)
1021
+ if allow_property:
1022
+ self._allowed_modes.add(ParameterMode.PROPERTY)
1023
+ if allow_output:
1024
+ self._allowed_modes.add(ParameterMode.OUTPUT)
1000
1025
  else:
1001
1026
  self._allowed_modes = allowed_modes
1002
1027
 
1028
+ # Warn if both allowed_modes and convenience parameters are set
1029
+ convenience_params_used = []
1030
+ if not allow_input:
1031
+ convenience_params_used.append("allow_input=False")
1032
+ if not allow_property:
1033
+ convenience_params_used.append("allow_property=False")
1034
+ if not allow_output:
1035
+ convenience_params_used.append("allow_output=False")
1036
+
1037
+ if convenience_params_used:
1038
+ warnings.warn(
1039
+ f"Parameter '{name}': Both 'allowed_modes' and convenience parameters "
1040
+ f"({', '.join(convenience_params_used)}) are set. Using 'allowed_modes' "
1041
+ f"and ignoring convenience parameters.",
1042
+ UserWarning,
1043
+ stacklevel=2,
1044
+ )
1045
+
1003
1046
  if converters is None:
1004
1047
  self._converters = []
1005
1048
  else:
@@ -1009,10 +1052,20 @@ class Parameter(BaseNodeElement, UIOptionsMixin):
1009
1052
  self._validators = []
1010
1053
  else:
1011
1054
  self._validators = validators
1055
+
1056
+ # Process common UI options from constructor parameters
1012
1057
  if ui_options is None:
1013
1058
  self._ui_options = {}
1014
1059
  else:
1015
- self._ui_options = ui_options
1060
+ self._ui_options = ui_options.copy()
1061
+
1062
+ # Add common UI options if they have truthy values
1063
+ if hide:
1064
+ self._ui_options["hide"] = hide
1065
+ if hide_label:
1066
+ self._ui_options["hide_label"] = hide_label
1067
+ if hide_property:
1068
+ self._ui_options["hide_property"] = hide_property
1016
1069
  if traits:
1017
1070
  for trait in traits:
1018
1071
  if not isinstance(trait, Trait):
@@ -1026,6 +1079,62 @@ class Parameter(BaseNodeElement, UIOptionsMixin):
1026
1079
  self.input_types = input_types
1027
1080
  self.output_type = output_type
1028
1081
  self.parent_container_name = parent_container_name
1082
+ self.parent_element_name = parent_element_name
1083
+
1084
+ def _generate_default_tooltip(
1085
+ self,
1086
+ name: str,
1087
+ type: str | None, # noqa: A002
1088
+ input_types: list[str] | None,
1089
+ output_type: str | None,
1090
+ ) -> str:
1091
+ """Generate a default tooltip describing the parameter type and usage.
1092
+
1093
+ Args:
1094
+ name: The parameter name
1095
+ type: The parameter type
1096
+ input_types: List of accepted input types
1097
+ output_type: The output type
1098
+
1099
+ Returns:
1100
+ A descriptive tooltip string
1101
+ """
1102
+ # Determine the primary type to describe
1103
+ primary_type = type
1104
+ if not primary_type and input_types:
1105
+ primary_type = input_types[0]
1106
+ if not primary_type and output_type:
1107
+ primary_type = output_type
1108
+ if not primary_type:
1109
+ primary_type = "any"
1110
+
1111
+ # Create a human-readable description
1112
+ type_descriptions = {
1113
+ "str": "text/string",
1114
+ "bool": "boolean (true/false)",
1115
+ "int": "integer number",
1116
+ "float": "decimal number",
1117
+ "any": "any type of data",
1118
+ "list": "list/array",
1119
+ "dict": "dictionary/object",
1120
+ "parametercontroltype": "control flow",
1121
+ }
1122
+
1123
+ type_desc = type_descriptions.get(primary_type.lower(), primary_type)
1124
+
1125
+ # Build the tooltip
1126
+ tooltip_parts = [f"Enter {type_desc} for {name}"]
1127
+
1128
+ # Add input type info if different from primary type
1129
+ if input_types and len(input_types) > 1:
1130
+ input_desc = ", ".join(
1131
+ type_descriptions.get(t.lower(), t) for t in input_types[: self._MAX_TOOLTIP_INPUT_TYPES]
1132
+ )
1133
+ if len(input_types) > self._MAX_TOOLTIP_INPUT_TYPES:
1134
+ input_desc += f" or {len(input_types) - self._MAX_TOOLTIP_INPUT_TYPES} other types"
1135
+ tooltip_parts.append(f"Accepts: {input_desc}")
1136
+
1137
+ return ". ".join(tooltip_parts) + "."
1029
1138
 
1030
1139
  def to_dict(self) -> dict[str, Any]:
1031
1140
  """Returns a nested dictionary representation of this node and its children."""
@@ -1055,6 +1164,8 @@ class Parameter(BaseNodeElement, UIOptionsMixin):
1055
1164
  our_dict["mode_allowed_property"] = allows_property
1056
1165
  our_dict["mode_allowed_output"] = allows_output
1057
1166
  our_dict["parent_container_name"] = self.parent_container_name
1167
+ our_dict["parent_element_name"] = self.parent_element_name
1168
+ our_dict["parent_group_name"] = self.parent_group_name
1058
1169
 
1059
1170
  return our_dict
1060
1171
 
@@ -1156,6 +1267,132 @@ class Parameter(BaseNodeElement, UIOptionsMixin):
1156
1267
  def ui_options(self, value: dict) -> None:
1157
1268
  self._ui_options = value
1158
1269
 
1270
+ @property
1271
+ def hide(self) -> bool:
1272
+ """Get whether the entire parameter is hidden in the UI.
1273
+
1274
+ Returns:
1275
+ True if the parameter should be hidden, False otherwise
1276
+ """
1277
+ return self.ui_options.get("hide", False)
1278
+
1279
+ @hide.setter
1280
+ @BaseNodeElement.emits_update_on_write
1281
+ def hide(self, value: bool) -> None:
1282
+ """Set whether to hide the entire parameter in the UI.
1283
+
1284
+ Args:
1285
+ value: True to hide the parameter, False to show it
1286
+ """
1287
+ self.update_ui_options_key("hide", value)
1288
+
1289
+ @property
1290
+ def hide_label(self) -> bool:
1291
+ """Get whether the parameter label is hidden in the UI.
1292
+
1293
+ Returns:
1294
+ True if the label should be hidden, False otherwise
1295
+ """
1296
+ return self.ui_options.get("hide_label", False)
1297
+
1298
+ @hide_label.setter
1299
+ @BaseNodeElement.emits_update_on_write
1300
+ def hide_label(self, value: bool) -> None:
1301
+ """Set whether to hide the parameter label in the UI.
1302
+
1303
+ Args:
1304
+ value: True to hide the label, False to show it
1305
+ """
1306
+ self.update_ui_options_key("hide_label", value)
1307
+
1308
+ @property
1309
+ def hide_property(self) -> bool:
1310
+ """Get whether the parameter is hidden in property mode.
1311
+
1312
+ Returns:
1313
+ True if the parameter should be hidden in property mode, False otherwise
1314
+ """
1315
+ return self.ui_options.get("hide_property", False)
1316
+
1317
+ @hide_property.setter
1318
+ @BaseNodeElement.emits_update_on_write
1319
+ def hide_property(self, value: bool) -> None:
1320
+ """Set whether to hide the parameter in property mode.
1321
+
1322
+ Args:
1323
+ value: True to hide in property mode, False to show it
1324
+ """
1325
+ self.update_ui_options_key("hide_property", value)
1326
+
1327
+ @property
1328
+ def allow_input(self) -> bool:
1329
+ """Get whether the parameter allows INPUT mode.
1330
+
1331
+ Returns:
1332
+ True if INPUT mode is allowed, False otherwise
1333
+ """
1334
+ return ParameterMode.INPUT in self.allowed_modes
1335
+
1336
+ @allow_input.setter
1337
+ def allow_input(self, value: bool) -> None:
1338
+ """Set whether to allow INPUT mode.
1339
+
1340
+ Args:
1341
+ value: True to allow INPUT mode, False to disallow it
1342
+ """
1343
+ current_modes = self.allowed_modes.copy()
1344
+ if value:
1345
+ current_modes.add(ParameterMode.INPUT)
1346
+ else:
1347
+ current_modes.discard(ParameterMode.INPUT)
1348
+ self.allowed_modes = current_modes
1349
+
1350
+ @property
1351
+ def allow_property(self) -> bool:
1352
+ """Get whether the parameter allows PROPERTY mode.
1353
+
1354
+ Returns:
1355
+ True if PROPERTY mode is allowed, False otherwise
1356
+ """
1357
+ return ParameterMode.PROPERTY in self.allowed_modes
1358
+
1359
+ @allow_property.setter
1360
+ def allow_property(self, value: bool) -> None:
1361
+ """Set whether to allow PROPERTY mode.
1362
+
1363
+ Args:
1364
+ value: True to allow PROPERTY mode, False to disallow it
1365
+ """
1366
+ current_modes = self.allowed_modes.copy()
1367
+ if value:
1368
+ current_modes.add(ParameterMode.PROPERTY)
1369
+ else:
1370
+ current_modes.discard(ParameterMode.PROPERTY)
1371
+ self.allowed_modes = current_modes
1372
+
1373
+ @property
1374
+ def allow_output(self) -> bool:
1375
+ """Get whether the parameter allows OUTPUT mode.
1376
+
1377
+ Returns:
1378
+ True if OUTPUT mode is allowed, False otherwise
1379
+ """
1380
+ return ParameterMode.OUTPUT in self.allowed_modes
1381
+
1382
+ @allow_output.setter
1383
+ def allow_output(self, value: bool) -> None:
1384
+ """Set whether to allow OUTPUT mode.
1385
+
1386
+ Args:
1387
+ value: True to allow OUTPUT mode, False to disallow it
1388
+ """
1389
+ current_modes = self.allowed_modes.copy()
1390
+ if value:
1391
+ current_modes.add(ParameterMode.OUTPUT)
1392
+ else:
1393
+ current_modes.discard(ParameterMode.OUTPUT)
1394
+ self.allowed_modes = current_modes
1395
+
1159
1396
  @property
1160
1397
  def input_types(self) -> list[str]:
1161
1398
  return self._custom_getter_for_property_input_types()
@@ -458,7 +458,13 @@ class BaseNode(ABC):
458
458
  if self.does_name_exist(param.name):
459
459
  msg = f"Cannot have duplicate names on parameters. Encountered two instances of '{param.name}'."
460
460
  raise ValueError(msg)
461
- self.add_node_element(param)
461
+ parameter_group = (
462
+ self.get_group_by_name_or_element_id(param.parent_element_name) if param.parent_element_name else None
463
+ )
464
+ if parameter_group is not None:
465
+ parameter_group.add_child(param)
466
+ else:
467
+ self.add_node_element(param)
462
468
  self._emit_parameter_lifecycle_event(param)
463
469
 
464
470
  def remove_parameter_element_by_name(self, element_name: str) -> None:
@@ -0,0 +1 @@
1
+ """Reusable HuggingFace parameters."""
@@ -0,0 +1,168 @@
1
+ import logging
2
+ import re
3
+ from abc import ABC, abstractmethod
4
+
5
+ from griptape_nodes.exe_types.core_types import Parameter, ParameterMessage, ParameterMode
6
+ from griptape_nodes.exe_types.node_types import BaseNode
7
+ from griptape_nodes.traits.options import Options
8
+
9
+ logger = logging.getLogger("griptape_nodes")
10
+
11
+
12
+ class HuggingFaceModelParameter(ABC):
13
+ @classmethod
14
+ def _repo_revision_to_key(cls, repo_revision: tuple[str, str]) -> str:
15
+ return f"{repo_revision[0]} ({repo_revision[1]})"
16
+
17
+ @classmethod
18
+ def _key_to_repo_revision(cls, key: str) -> tuple[str, str]:
19
+ # Check if key has hash format using regex
20
+ hash_pattern = r"^(.+) \(([a-f0-9]{40})\)$"
21
+ match = re.match(hash_pattern, key)
22
+ if match:
23
+ return match.group(1), match.group(2)
24
+
25
+ # Key is just the model name (no hash)
26
+ return key, ""
27
+
28
+ def __init__(self, node: BaseNode, parameter_name: str):
29
+ self._node = node
30
+ self._parameter_name = parameter_name
31
+ self._repo_revisions = []
32
+
33
+ def refresh_parameters(self) -> None:
34
+ parameter = self._node.get_parameter_by_name(self._parameter_name)
35
+ if parameter is None:
36
+ logger.debug(
37
+ "Parameter '%s' not found on node '%s'; cannot refresh choices.",
38
+ self._parameter_name,
39
+ self._node.name,
40
+ )
41
+ return
42
+
43
+ choices = self.get_choices()
44
+
45
+ current_value = self._node.get_parameter_value(self._parameter_name)
46
+ if current_value in choices:
47
+ default_value = current_value
48
+ else:
49
+ default_value = choices[0]
50
+
51
+ if parameter.find_elements_by_type(Options):
52
+ self._node._update_option_choices(self._parameter_name, choices, default_value)
53
+ else:
54
+ parameter.add_trait(Options(choices=choices))
55
+
56
+ def add_input_parameters(self) -> None:
57
+ choices = self.get_choices()
58
+
59
+ if not choices:
60
+ self._node.add_node_element(
61
+ ParameterMessage(
62
+ name=f"huggingface_repo_parameter_message_{self._parameter_name}",
63
+ title="Huggingface Model Download Required",
64
+ variant="warning",
65
+ value=self.get_help_message(),
66
+ button_link=f"#model-management?search={self.get_download_models()[0]}",
67
+ button_text="Model Management",
68
+ button_icon="hard-drive",
69
+ )
70
+ )
71
+ return
72
+
73
+ self._node.add_parameter(
74
+ Parameter(
75
+ name=self._parameter_name,
76
+ default_value=choices[0] if choices else None,
77
+ input_types=["str"],
78
+ type="str",
79
+ ui_options={"display_name": self._parameter_name, "show_search": True},
80
+ traits={
81
+ Options(
82
+ choices=choices,
83
+ )
84
+ },
85
+ tooltip=self._parameter_name,
86
+ allowed_modes={ParameterMode.PROPERTY},
87
+ )
88
+ )
89
+
90
+ def remove_input_parameters(self) -> None:
91
+ self._node.remove_parameter_element_by_name(self._parameter_name)
92
+ self._node.remove_parameter_element_by_name(f"huggingface_repo_parameter_message_{self._parameter_name}")
93
+
94
+ def get_choices(self) -> list[str]:
95
+ # Ensure the latest repo revisions are fetched
96
+ self._repo_revisions = self.fetch_repo_revisions()
97
+ # Count occurrences of each model name
98
+ model_counts = {}
99
+ for repo_id, _ in self.list_repo_revisions():
100
+ model_counts[repo_id] = model_counts.get(repo_id, 0) + 1
101
+
102
+ # Generate choices with hash only when there are duplicates
103
+ choices = []
104
+ for repo_revision in self.list_repo_revisions():
105
+ repo_id, _ = repo_revision
106
+ if model_counts[repo_id] > 1:
107
+ # Multiple versions exist, show hash for disambiguation
108
+ choices.append(self._repo_revision_to_key(repo_revision))
109
+ else:
110
+ # Only one version, show just the model name
111
+ choices.append(repo_id)
112
+ logger.debug("Available choices for parameter '%s': %s", self._parameter_name, choices)
113
+ return choices
114
+
115
+ def validate_before_node_run(self) -> list[Exception] | None:
116
+ self.refresh_parameters()
117
+ try:
118
+ self.get_repo_revision()
119
+ except Exception as e:
120
+ return [e]
121
+
122
+ return None
123
+
124
+ def list_repo_revisions(self) -> list[tuple[str, str]]:
125
+ return self._repo_revisions
126
+
127
+ def get_repo_revision(self) -> tuple[str, str]:
128
+ value = self._node.get_parameter_value(self._parameter_name)
129
+ if value is None:
130
+ msg = "Model download required!"
131
+ raise RuntimeError(msg)
132
+
133
+ # Parse the value using _key_to_repo_revision
134
+ repo_id, revision = self._key_to_repo_revision(value)
135
+
136
+ # If revision is empty (just model name), find it in our stored list
137
+ if not revision:
138
+ for stored_repo_id, stored_revision in self._repo_revisions:
139
+ if stored_repo_id == repo_id:
140
+ logger.debug("Using revision '%s' for model '%s'", stored_revision, repo_id)
141
+ return stored_repo_id, stored_revision
142
+ # If not found, raise an error
143
+ msg = f"Model '{repo_id}' not found in available models!"
144
+ raise RuntimeError(msg)
145
+
146
+ # If revision was provided, return it directly
147
+ return repo_id, revision
148
+
149
+ def get_help_message(self) -> str:
150
+ download_models = "\n".join([f" {model}" for model in self.get_download_models()])
151
+
152
+ return (
153
+ "Model download required to continue.\n\n"
154
+ "To download models:\n\n"
155
+ "1. Navigate to Settings -> Model Management\n\n"
156
+ "2. Search for the model(s) you need and click the download button:\n"
157
+ f"{download_models}\n\n"
158
+ "After completing these steps, a dropdown menu with available models will appear."
159
+ )
160
+
161
+ @abstractmethod
162
+ def fetch_repo_revisions(self) -> list[tuple[str, str]]: ...
163
+
164
+ @abstractmethod
165
+ def get_download_commands(self) -> list[str]: ...
166
+
167
+ @abstractmethod
168
+ def get_download_models(self) -> list[str]: ...
@@ -0,0 +1,38 @@
1
+ import logging
2
+
3
+ from griptape_nodes.exe_types.node_types import BaseNode
4
+ from griptape_nodes.exe_types.param_components.huggingface.huggingface_model_parameter import HuggingFaceModelParameter
5
+ from griptape_nodes.exe_types.param_components.huggingface.huggingface_utils import (
6
+ list_repo_revisions_with_file_in_cache,
7
+ )
8
+
9
+ logger = logging.getLogger("griptape_nodes")
10
+
11
+
12
+ class HuggingFaceRepoFileParameter(HuggingFaceModelParameter):
13
+ def __init__(self, node: BaseNode, repo_files: list[tuple[str, str]], parameter_name: str = "model"):
14
+ super().__init__(node, parameter_name)
15
+ self._repo_files = repo_files
16
+ self.refresh_parameters()
17
+
18
+ def fetch_repo_revisions(self) -> list[tuple[str, str]]:
19
+ return [
20
+ repo_revision
21
+ for (repo, file) in self._repo_files
22
+ for repo_revision in list_repo_revisions_with_file_in_cache(repo, file)
23
+ ]
24
+
25
+ def get_download_commands(self) -> list[str]:
26
+ return [f'huggingface-cli download "{repo}" "{file}"' for (repo, file) in self._repo_files]
27
+
28
+ def get_download_models(self) -> list[str]:
29
+ """Returns a list of model names that should be downloaded."""
30
+ return [repo for (repo, file) in self._repo_files]
31
+
32
+ def get_repo_filename(self) -> str:
33
+ repo_id, _ = self.get_repo_revision()
34
+ for repo, file in self._repo_files:
35
+ if repo == repo_id:
36
+ return file
37
+ msg = f"File not found for repository {repo_id}"
38
+ raise ValueError(msg)
@@ -0,0 +1,33 @@
1
+ import logging
2
+
3
+ from griptape_nodes.exe_types.node_types import BaseNode
4
+ from griptape_nodes.exe_types.param_components.huggingface.huggingface_model_parameter import HuggingFaceModelParameter
5
+ from griptape_nodes.exe_types.param_components.huggingface.huggingface_utils import (
6
+ list_all_repo_revisions_in_cache,
7
+ list_repo_revisions_in_cache,
8
+ )
9
+
10
+ logger = logging.getLogger("griptape_nodes")
11
+
12
+
13
+ class HuggingFaceRepoParameter(HuggingFaceModelParameter):
14
+ def __init__(
15
+ self, node: BaseNode, repo_ids: list[str], parameter_name: str = "model", *, list_all_models: bool = False
16
+ ):
17
+ super().__init__(node, parameter_name)
18
+ self._repo_ids = repo_ids
19
+ self._list_all_models = list_all_models
20
+ self.refresh_parameters()
21
+
22
+ def fetch_repo_revisions(self) -> list[tuple[str, str]]:
23
+ if self._list_all_models:
24
+ all_revisions = list_all_repo_revisions_in_cache()
25
+ return sorted(all_revisions, key=lambda x: x[0] not in self._repo_ids)
26
+ return [repo_revision for repo in self._repo_ids for repo_revision in list_repo_revisions_in_cache(repo)]
27
+
28
+ def get_download_commands(self) -> list[str]:
29
+ return [f'huggingface-cli download "{repo}"' for repo in self._repo_ids]
30
+
31
+ def get_download_models(self) -> list[str]:
32
+ """Returns a list of model names that should be downloaded."""
33
+ return self._repo_ids