optimum-rbln 0.7.2rc1__py3-none-any.whl → 0.7.3a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. optimum/rbln/__version__.py +9 -4
  2. optimum/rbln/diffusers/modeling_diffusers.py +18 -12
  3. optimum/rbln/modeling.py +1 -1
  4. optimum/rbln/modeling_base.py +15 -3
  5. optimum/rbln/ops/__init__.py +6 -2
  6. optimum/rbln/ops/attn.py +95 -7
  7. optimum/rbln/ops/flash_attn.py +43 -6
  8. optimum/rbln/transformers/modeling_generic.py +3 -3
  9. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -1
  10. optimum/rbln/transformers/models/bart/modeling_bart.py +1 -1
  11. optimum/rbln/transformers/models/bert/modeling_bert.py +1 -1
  12. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +186 -78
  13. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +55 -17
  14. optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -3
  15. optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -3
  16. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +3 -3
  17. optimum/rbln/transformers/models/midm/midm_architecture.py +3 -3
  18. optimum/rbln/transformers/models/phi/phi_architecture.py +2 -2
  19. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
  20. optimum/rbln/transformers/models/t5/modeling_t5.py +1 -1
  21. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +1 -1
  22. optimum/rbln/utils/import_utils.py +7 -0
  23. {optimum_rbln-0.7.2rc1.dist-info → optimum_rbln-0.7.3a0.dist-info}/METADATA +1 -1
  24. {optimum_rbln-0.7.2rc1.dist-info → optimum_rbln-0.7.3a0.dist-info}/RECORD +26 -26
  25. {optimum_rbln-0.7.2rc1.dist-info → optimum_rbln-0.7.3a0.dist-info}/WHEEL +0 -0
  26. {optimum_rbln-0.7.2rc1.dist-info → optimum_rbln-0.7.3a0.dist-info}/licenses/LICENSE +0 -0
@@ -1,8 +1,13 @@
1
- # file generated by setuptools_scm
1
+ # file generated by setuptools-scm
2
2
  # don't change, don't track in version control
3
+
4
+ __all__ = ["__version__", "__version_tuple__", "version", "version_tuple"]
5
+
3
6
  TYPE_CHECKING = False
4
7
  if TYPE_CHECKING:
5
- from typing import Tuple, Union
8
+ from typing import Tuple
9
+ from typing import Union
10
+
6
11
  VERSION_TUPLE = Tuple[Union[int, str], ...]
7
12
  else:
8
13
  VERSION_TUPLE = object
@@ -12,5 +17,5 @@ __version__: str
12
17
  __version_tuple__: VERSION_TUPLE
13
18
  version_tuple: VERSION_TUPLE
14
19
 
15
- __version__ = version = '0.7.2rc1'
16
- __version_tuple__ = version_tuple = (0, 7, 2)
20
+ __version__ = version = '0.7.3a0'
21
+ __version_tuple__ = version_tuple = (0, 7, 3)
@@ -71,13 +71,11 @@ class RBLNDiffusionMixin:
71
71
  _prefix = {}
72
72
 
73
73
  @classmethod
74
- @property
75
- def img2img_pipeline(cls):
74
+ def is_img2img_pipeline(cls):
76
75
  return "Img2Img" in cls.__name__
77
76
 
78
77
  @classmethod
79
- @property
80
- def inpaint_pipeline(cls):
78
+ def is_inpaint_pipeline(cls):
81
79
  return "Inpaint" in cls.__name__
82
80
 
83
81
  @classmethod
@@ -100,8 +98,8 @@ class RBLNDiffusionMixin:
100
98
  submodule_config.update({k: v for k, v in pipe_global_config.items() if k not in submodule_config})
101
99
  submodule_config.update(
102
100
  {
103
- "img2img_pipeline": cls.img2img_pipeline,
104
- "inpaint_pipeline": cls.inpaint_pipeline,
101
+ "img2img_pipeline": cls.is_img2img_pipeline(),
102
+ "inpaint_pipeline": cls.is_inpaint_pipeline(),
105
103
  }
106
104
  )
107
105
  submodule_config = submodule_cls.update_rbln_config_using_pipe(model, submodule_config)
@@ -112,6 +110,11 @@ class RBLNDiffusionMixin:
112
110
  submodule_cls: RBLNModel = getattr(importlib.import_module("optimum.rbln"), f"{submodule_class_name}")
113
111
  prefix = cls._prefix.get(submodule_name, "")
114
112
  connected_submodules = cls._connected_classes.get(submodule_name)._submodules
113
+ pipe_global_config = {k: v for k, v in submodule_config.items() if k not in connected_submodules}
114
+ submodule_config = {k: v for k, v in submodule_config.items() if k in connected_submodules}
115
+ for key in submodule_config.keys():
116
+ submodule_config[key].update(pipe_global_config)
117
+
115
118
  for connected_submodule_name in connected_submodules:
116
119
  connected_submodule_config = rbln_config.pop(prefix + connected_submodule_name, {})
117
120
  if connected_submodule_name in submodule_config:
@@ -119,14 +122,17 @@ class RBLNDiffusionMixin:
119
122
  else:
120
123
  submodule_config[connected_submodule_name] = connected_submodule_config
121
124
 
122
- submodules = copy.deepcopy(cls._submodules)
123
- submodules += [prefix + connected_submodule_name for connected_submodule_name in connected_submodules]
125
+ pipe_global_config = {
126
+ k: v for k, v in rbln_config.items() if k != submodule_class_name and not isinstance(v, dict)
127
+ }
124
128
 
125
- pipe_global_config = {k: v for k, v in rbln_config.items() if k not in submodules}
126
129
  for connected_submodule_name in connected_submodules:
127
- submodule_config[connected_submodule_name].update(
128
- {k: v for k, v in pipe_global_config.items() if k not in submodule_config}
129
- )
130
+ for k, v in pipe_global_config.items():
131
+ if "guidance_scale" in k:
132
+ if prefix + "guidance_scale" == k:
133
+ submodule_config[connected_submodule_name]["guidance_scale"] = v
134
+ else:
135
+ submodule_config[connected_submodule_name][k] = v
130
136
  rbln_config[submodule_name] = submodule_config
131
137
  else:
132
138
  raise ValueError(f"submodule {submodule_name} isn't supported")
optimum/rbln/modeling.py CHANGED
@@ -196,7 +196,7 @@ class RBLNModel(RBLNBaseModel):
196
196
  **kwargs,
197
197
  ) -> "PreTrainedModel":
198
198
  kwargs = cls.update_kwargs(kwargs)
199
- return cls.hf_class.from_pretrained(
199
+ return cls.get_hf_class().from_pretrained(
200
200
  model_id,
201
201
  subfolder=subfolder,
202
202
  revision=revision,
@@ -389,8 +389,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
389
389
  return rbln_config
390
390
 
391
391
  @classmethod
392
- @property
393
- def hf_class(cls):
392
+ def get_hf_class(cls):
394
393
  """
395
394
  Lazily loads and caches the corresponding Hugging Face model class.
396
395
  Removes 'RBLN' prefix from the class name to get the original class name
@@ -416,7 +415,20 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
416
415
  return self.forward(*args, **kwargs)
417
416
 
418
417
  def __repr__(self):
419
- return repr(self.model) + repr(self.rbln_submodules)
418
+ has_submodules = len(self.rbln_submodules) > 0
419
+ repr_str: str = f"<{self.__class__.__name__}>\n"
420
+ repr_str += f"- Total {len(self.model)} Runtimes"
421
+ repr_str += f" and {len(self.rbln_submodules)} Submodules\n" if has_submodules else "\n"
422
+ repr_str += "[Runtimes]\n"
423
+ repr_str += "\n".join([repr(model) for model in self.model])
424
+ repr_str += "\n"
425
+
426
+ if has_submodules > 0:
427
+ for i, submodule in enumerate(self.rbln_submodules):
428
+ repr_str += f"[Submodules {i} : {self._rbln_submodules[i]['name']}]\n"
429
+ repr_str += repr(submodule) + "\n"
430
+
431
+ return repr_str
420
432
 
421
433
  def __post_init__(self, **kwargs):
422
434
  pass
@@ -12,6 +12,10 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .attn import register_rbln_custom_attention, register_rbln_custom_attention_add_softmax
16
- from .flash_attn import register_rbln_custom_flash_attention
15
+ from .attn import (
16
+ register_rbln_custom_attention_add_softmax,
17
+ register_rbln_custom_causal_masked_attention,
18
+ register_rbln_custom_masked_attention,
19
+ )
20
+ from .flash_attn import register_rbln_custom_flash_causal_masked_attention, register_rbln_custom_flash_masked_attention
17
21
  from .kv_cache_update import register_rbln_custom_cache_update
optimum/rbln/ops/attn.py CHANGED
@@ -25,13 +25,13 @@ else:
25
25
 
26
26
 
27
27
  @lru_cache
28
- def register_rbln_custom_attention():
28
+ def register_rbln_custom_masked_attention():
29
29
  torch.library.define(
30
- "rbln_custom_ops::attn_decode",
30
+ "rbln_custom_ops::masked_attn_decode",
31
31
  "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d) -> Tensor[]",
32
32
  )
33
33
 
34
- @torch.library.impl("rbln_custom_ops::attn_decode", "cpu")
34
+ @torch.library.impl("rbln_custom_ops::masked_attn_decode", "cpu")
35
35
  def attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, scale):
36
36
  """Defines the computation pattern for fused attention with KV cache updates.
37
37
 
@@ -66,7 +66,7 @@ def register_rbln_custom_attention():
66
66
  torch.empty(*vcache.shape, device=vcache.device),
67
67
  )
68
68
 
69
- @register_fake("rbln_custom_ops::attn_decode")
69
+ @register_fake("rbln_custom_ops::masked_attn_decode")
70
70
  def attn_decode_abstract(q, k, v, m, kcache, vcache, seq, partition):
71
71
  return (
72
72
  q,
@@ -75,11 +75,11 @@ def register_rbln_custom_attention():
75
75
  )
76
76
 
77
77
  torch.library.define(
78
- "rbln_custom_ops::attn_prefill",
78
+ "rbln_custom_ops::masked_attn_prefill",
79
79
  "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e) -> Tensor[]",
80
80
  )
81
81
 
82
- @torch.library.impl("rbln_custom_ops::attn_prefill", "cpu")
82
+ @torch.library.impl("rbln_custom_ops::masked_attn_prefill", "cpu")
83
83
  def attn_prefill_cpu(q, k, v, mask, kcache, vcache, batch, seq, scale):
84
84
  """Defines the computation pattern for prefill phase attention with KV cache updates.
85
85
 
@@ -109,11 +109,99 @@ def register_rbln_custom_attention():
109
109
  """
110
110
  return q, kcache, vcache
111
111
 
112
- @register_fake("rbln_custom_ops::attn_prefill")
112
+ @register_fake("rbln_custom_ops::masked_attn_prefill")
113
113
  def attn_prefill_abstract(q, k, v, m, kcache, vcache, batch, seq, partition):
114
114
  return q, kcache, vcache
115
115
 
116
116
 
117
+ @lru_cache
118
+ def register_rbln_custom_causal_masked_attention():
119
+ torch.library.define(
120
+ "rbln_custom_ops::causal_masked_attn_decode",
121
+ "(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d) -> Tensor[]",
122
+ )
123
+
124
+ @torch.library.impl("rbln_custom_ops::causal_masked_attn_decode", "cpu")
125
+ def attn_decode_cpu(q, k, v, kcache, vcache, seq, scale):
126
+ """Defines the computation pattern for fused attention with KV cache updates.
127
+
128
+ IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
129
+ a single optimized NPU operation. It is NOT meant for CPU execution.
130
+
131
+ Pattern components that compiler fuses into a single op:
132
+ 1. KV cache updates with new key/value states
133
+ 2. Scaled dot-product attention computation
134
+ 3. Causal masked softmax operation
135
+ 4. Final attention output computation
136
+
137
+ Expected tensor shapes:
138
+ - q: [batch=1, n_heads, n_groups, 1, head_dim] - Query states for single token
139
+ - k: [batch=1, n_heads, 1, 1, head_dim] - Key states for current input
140
+ - v: [batch=1, n_heads, 1, 1, head_dim] - Value states for current input
141
+ - kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
142
+ - vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
143
+ - seq: [1] - Current sequence position
144
+ - scale: [] - Attention scale factor
145
+
146
+ Returns:
147
+ Tuple[Tensor, Tensor, Tensor]:
148
+ - attn_output: [batch=1, n_heads, n_groups, 1, head_dim] - Attention output
149
+ - kcache: Same shape as input kcache, batch=1 - Placeholder for compiler
150
+ - vcache: Same shape as input vcache, batch=1 - Placeholder for compiler
151
+ """
152
+ return (
153
+ q,
154
+ torch.empty(*kcache.shape, device=kcache.device),
155
+ torch.empty(*vcache.shape, device=vcache.device),
156
+ )
157
+
158
+ @register_fake("rbln_custom_ops::causal_masked_attn_decode")
159
+ def attn_decode_abstract(q, k, v, kcache, vcache, seq, partition):
160
+ return (
161
+ q,
162
+ torch.empty(*kcache.shape, device=kcache.device),
163
+ torch.empty(*vcache.shape, device=vcache.device),
164
+ )
165
+
166
+ torch.library.define(
167
+ "rbln_custom_ops::causal_masked_attn_prefill",
168
+ "(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e) -> Tensor[]",
169
+ )
170
+
171
+ @torch.library.impl("rbln_custom_ops::causal_masked_attn_prefill", "cpu")
172
+ def attn_prefill_cpu(q, k, v, kcache, vcache, batch, seq, scale):
173
+ """Defines the computation pattern for prefill phase attention with KV cache updates.
174
+
175
+ IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
176
+ a single optimized NPU operation. It is NOT meant for CPU execution.
177
+
178
+ Key differences from decode pattern:
179
+ - Handles prefill phase with multiple input tokens
180
+ - Takes explicit batch index for continuous batching
181
+
182
+ Expected tensor shapes:
183
+ - q: [batch=1, n_heads, n_groups, seq_len, head_dim] - Query states for multiple tokens
184
+ - k: [batch=1, n_heads, 1, seq_len, head_dim] - Key states for current input
185
+ - v: [batch=1, n_heads, 1, seq_len, head_dim] - Value states for current input
186
+ - kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
187
+ - vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
188
+ - batch: [1] - Batch index for cache access
189
+ - seq: [1] - Starting sequence position
190
+ - scale: [] - Attention scale factor
191
+
192
+ Returns:
193
+ Tuple[Tensor, Tensor, Tensor]:
194
+ - attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
195
+ - empty_kcache: Same shape as input kcache - Placeholder for compiler
196
+ - empty_vcache: Same shape as input vcache - Placeholder for compiler
197
+ """
198
+ return q, kcache, vcache
199
+
200
+ @register_fake("rbln_custom_ops::causal_masked_attn_prefill")
201
+ def attn_prefill_abstract(q, k, v, kcache, vcache, batch, seq, partition):
202
+ return q, kcache, vcache
203
+
204
+
117
205
  @lru_cache
118
206
  def register_rbln_custom_attention_add_softmax():
119
207
  torch.library.define(
@@ -25,13 +25,13 @@ else:
25
25
 
26
26
 
27
27
  @lru_cache
28
- def register_rbln_custom_flash_attention():
28
+ def register_rbln_custom_flash_masked_attention():
29
29
  torch.library.define(
30
- "rbln_custom_ops::flash_attn_decode",
30
+ "rbln_custom_ops::flash_masked_attn_decode",
31
31
  "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, int e) -> Tensor[]",
32
32
  )
33
33
 
34
- @torch.library.impl("rbln_custom_ops::flash_attn_decode", "cpu")
34
+ @torch.library.impl("rbln_custom_ops::flash_masked_attn_decode", "cpu")
35
35
  def flash_attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, scale, partition):
36
36
  return (
37
37
  q,
@@ -39,7 +39,7 @@ def register_rbln_custom_flash_attention():
39
39
  torch.empty(*vcache.shape, device=vcache.device),
40
40
  )
41
41
 
42
- @register_fake("rbln_custom_ops::flash_attn_decode")
42
+ @register_fake("rbln_custom_ops::flash_masked_attn_decode")
43
43
  def flash_attn_decode_abstract(q, k, v, m, kcache, vcache, seq, scale, partition):
44
44
  return (
45
45
  q,
@@ -48,7 +48,7 @@ def register_rbln_custom_flash_attention():
48
48
  )
49
49
 
50
50
  torch.library.define(
51
- "rbln_custom_ops::flash_attn_prefill",
51
+ "rbln_custom_ops::flash_masked_attn_prefill",
52
52
  "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor[]",
53
53
  )
54
54
 
@@ -56,6 +56,43 @@ def register_rbln_custom_flash_attention():
56
56
  def flash_attn_prefill_cpu(q, k, v, mask, kcache, vcache, batch, seq, scale, partition):
57
57
  return q, kcache, vcache
58
58
 
59
- @register_fake("rbln_custom_ops::flash_attn_prefill")
59
+ @register_fake("rbln_custom_ops::flash_masked_attn_prefill")
60
60
  def flash_attn_prefill_abstract(q, k, v, m, kcache, vcache, batch, seq, scale, partition):
61
61
  return q, kcache, vcache
62
+
63
+
64
+ @lru_cache
65
+ def register_rbln_custom_flash_causal_masked_attention():
66
+ torch.library.define(
67
+ "rbln_custom_ops::flash_causal_masked_attn_decode",
68
+ "(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, int e) -> Tensor[]",
69
+ )
70
+
71
+ @torch.library.impl("rbln_custom_ops::flash_causal_masked_attn_decode", "cpu")
72
+ def flash_attn_decode_cpu(q, k, v, kcache, vcache, seq, scale, partition):
73
+ return (
74
+ q,
75
+ torch.empty(*kcache.shape, device=kcache.device),
76
+ torch.empty(*vcache.shape, device=vcache.device),
77
+ )
78
+
79
+ @register_fake("rbln_custom_ops::flash_causal_masked_attn_decode")
80
+ def flash_attn_decode_abstract(q, k, v, kcache, vcache, seq, scale, partition):
81
+ return (
82
+ q,
83
+ torch.empty(*kcache.shape, device=kcache.device),
84
+ torch.empty(*vcache.shape, device=vcache.device),
85
+ )
86
+
87
+ torch.library.define(
88
+ "rbln_custom_ops::flash_causal_masked_attn_prefill",
89
+ "(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor[]",
90
+ )
91
+
92
+ @torch.library.impl("rbln_custom_ops::flash_causal_masked_attn_prefill", "cpu")
93
+ def flash_attn_prefill_cpu(q, k, v, kcache, vcache, batch, seq, scale, partition):
94
+ return q, kcache, vcache
95
+
96
+ @register_fake("rbln_custom_ops::flash_causal_masked_attn_prefill")
97
+ def flash_attn_prefill_abstract(q, k, v, kcache, vcache, batch, seq, scale, partition):
98
+ return q, kcache, vcache
@@ -73,7 +73,7 @@ class RBLNModelForQuestionAnswering(RBLNModel):
73
73
  if rbln_batch_size is None:
74
74
  rbln_batch_size = 1
75
75
 
76
- signature_params = inspect.signature(cls.hf_class.forward).parameters.keys()
76
+ signature_params = inspect.signature(cls.get_hf_class().forward).parameters.keys()
77
77
 
78
78
  if rbln_model_input_names is None:
79
79
  for tokenizer in preprocessors:
@@ -289,7 +289,7 @@ class RBLNModelForSequenceClassification(RBLNModel):
289
289
  if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
290
290
  raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
291
291
 
292
- signature_params = inspect.signature(cls.hf_class.forward).parameters.keys()
292
+ signature_params = inspect.signature(cls.get_hf_class().forward).parameters.keys()
293
293
 
294
294
  if rbln_model_input_names is None:
295
295
  for tokenizer in preprocessors:
@@ -362,7 +362,7 @@ class RBLNModelForMaskedLM(RBLNModel):
362
362
  if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
363
363
  raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
364
364
 
365
- signature_params = inspect.signature(cls.hf_class.forward).parameters.keys()
365
+ signature_params = inspect.signature(cls.get_hf_class().forward).parameters.keys()
366
366
 
367
367
  if rbln_model_input_names is None:
368
368
  for tokenizer in preprocessors:
@@ -142,7 +142,7 @@ class BartSelfAttention(Seq2SeqSelfAttention):
142
142
  self.num_heads = self._original_mod.num_heads
143
143
  self.head_dim = self._original_mod.embed_dim // self._original_mod.num_heads
144
144
  self.scaling = self.head_dim**-0.5
145
- self.attn_decode = torch.ops.rbln_custom_ops.attn_decode
145
+ self.attn_decode = torch.ops.rbln_custom_ops.masked_attn_decode
146
146
 
147
147
  def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
148
148
  query_states = self.q_proj(hidden_states) * self.scaling
@@ -58,7 +58,7 @@ class RBLNBartModel(RBLNModel):
58
58
  if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
59
59
  raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
60
60
 
61
- signature_params = inspect.signature(cls.hf_class.forward).parameters.keys()
61
+ signature_params = inspect.signature(cls.get_hf_class().forward).parameters.keys()
62
62
 
63
63
  if rbln_model_input_names is None:
64
64
  for tokenizer in preprocessors:
@@ -56,7 +56,7 @@ class RBLNBertModel(RBLNModel):
56
56
  if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
57
57
  raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
58
58
 
59
- signature_params = inspect.signature(cls.hf_class.forward).parameters.keys()
59
+ signature_params = inspect.signature(cls.get_hf_class().forward).parameters.keys()
60
60
 
61
61
  if rbln_model_input_names is None:
62
62
  for tokenizer in preprocessors: