gimlet-api 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gml/hf.py ADDED
@@ -0,0 +1,521 @@
1
+ # Copyright 2023- Gimlet Labs, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import glob
18
+ import tempfile
19
+ from collections.abc import Iterable
20
+ from pathlib import Path
21
+ from typing import Any, BinaryIO, Dict, List, Optional, TextIO, Tuple
22
+
23
+ import gml.proto.src.api.corepb.v1.model_exec_pb2 as modelexecpb
24
+ import torch
25
+ import transformers
26
+ from gml.model import GenerationConfig, Model, TorchModel
27
+ from gml.preprocessing import (
28
+ ImagePreprocessingStep,
29
+ ImageToFloatTensor,
30
+ LetterboxImage,
31
+ ResizeImage,
32
+ StandardizeTensor,
33
+ )
34
+ from gml.tensor import (
35
+ AttentionKeyValueCacheTensorSemantics,
36
+ BatchDimension,
37
+ BoundingBoxFormat,
38
+ DetectionNumCandidatesDimension,
39
+ DetectionOutputDimension,
40
+ ImageChannelDimension,
41
+ ImageHeightDimension,
42
+ ImageWidthDimension,
43
+ SegmentationMaskChannel,
44
+ TensorSemantics,
45
+ TokensDimension,
46
+ VocabLogitsDimension,
47
+ )
48
+ from transformers import (
49
+ BaseImageProcessor,
50
+ Pipeline,
51
+ PreTrainedModel,
52
+ PreTrainedTokenizer,
53
+ )
54
+
55
+
56
+ class HuggingFaceTokenizer(Model):
57
+ def __init__(self, tokenizer: PreTrainedTokenizer, name: Optional[str] = None):
58
+ if name is None:
59
+ name = tokenizer.name_or_path + ".tokenizer"
60
+ super().__init__(
61
+ name=name,
62
+ kind=modelexecpb.ModelInfo.MODEL_KIND_HUGGINGFACE_TOKENIZER,
63
+ storage_format=modelexecpb.ModelInfo.MODEL_STORAGE_FORMAT_OPAQUE,
64
+ input_tensor_semantics=[],
65
+ output_tensor_semantics=[],
66
+ )
67
+ self.tokenizer = tokenizer
68
+
69
+ def _collect_assets(self) -> Dict[str, TextIO | BinaryIO | Path]:
70
+ with tempfile.TemporaryDirectory() as tmpdir:
71
+ self.tokenizer.save_pretrained(tmpdir)
72
+ paths = [Path(f) for f in glob.glob(tmpdir + "/*")]
73
+ yield {p.name: p for p in paths}
74
+
75
+
76
+ class HuggingFaceGenerationConfig(GenerationConfig):
77
+
78
+ def __init__(self, model: PreTrainedModel):
79
+ config = model.generation_config
80
+ eos_tokens = config.eos_token_id
81
+ if eos_tokens is None:
82
+ eos_tokens = []
83
+ if not isinstance(eos_tokens, list):
84
+ eos_tokens = [eos_tokens]
85
+ super().__init__(eos_tokens)
86
+
87
+
88
+ def flatten(items):
89
+ flattened = []
90
+ if isinstance(items, torch.Tensor) or not isinstance(items, Iterable):
91
+ flattened.append(items)
92
+ else:
93
+ for x in items:
94
+ flattened.extend(flatten(x))
95
+ return flattened
96
+
97
+
98
+ class WrapWithFunctionalCache(torch.nn.Module):
99
+
100
+ def __init__(self, model: transformers.PreTrainedModel):
101
+ super().__init__()
102
+ self.model = model
103
+
104
+ def forward(self, input_ids, cache):
105
+ outputs = self.model(
106
+ input_ids=input_ids,
107
+ past_key_values=cache,
108
+ return_dict=True,
109
+ use_cache=True,
110
+ )
111
+
112
+ return outputs.logits, outputs.past_key_values
113
+
114
+
115
+ class HuggingFaceTextGenerationPipeline:
116
+ def __init__(
117
+ self,
118
+ pipeline: Pipeline,
119
+ name: Optional[str] = None,
120
+ tokenizer_name: Optional[str] = None,
121
+ dynamic_seqlen: bool = False,
122
+ ):
123
+ self.pipeline = pipeline
124
+ self.tokenizer_model = HuggingFaceTokenizer(pipeline.tokenizer, tokenizer_name)
125
+ self._cache_length_for_tracing = 32
126
+ if name is None:
127
+ name = pipeline.model.name_or_path
128
+
129
+ self.model = pipeline.model
130
+ self.model = self.model.to(torch.float16)
131
+ self.model = WrapWithFunctionalCache(pipeline.model)
132
+
133
+ self.language_model = TorchModel(
134
+ name,
135
+ torch_module=self.model,
136
+ **self._guess_model_spec(dynamic_seqlen),
137
+ )
138
+
139
+ def _initialize_key_value_cache(self):
140
+ cache = []
141
+ config = self.pipeline.model.config
142
+ head_dim = (
143
+ config.head_dim
144
+ if hasattr(config, "head_dim")
145
+ else config.hidden_size // config.num_attention_heads
146
+ )
147
+ num_key_value_heads = (
148
+ config.num_attention_heads
149
+ if config.num_key_value_heads is None
150
+ else config.num_key_value_heads
151
+ )
152
+ cache_shape = (1, num_key_value_heads, self._cache_length_for_tracing, head_dim)
153
+ for _ in range(config.num_hidden_layers):
154
+ cache.append(
155
+ [
156
+ torch.zeros(cache_shape).to(torch.float16),
157
+ torch.zeros(cache_shape).to(torch.float16),
158
+ ]
159
+ )
160
+ return cache
161
+
162
+ def _guess_model_spec(self, dynamic_seqlen: bool) -> Dict:
163
+ input_dict = self.pipeline.preprocess("this is a prompt! Test test test?")
164
+ if "input_ids" not in input_dict:
165
+ raise ValueError(
166
+ 'HuggingFaceTextGenerationPipeline expects preprocessed inputs to have an "input_ids" tensor'
167
+ )
168
+
169
+ inputs = []
170
+ input_tensor_semantics = []
171
+
172
+ # This currently assumes that all HF language models have inputs that are [B, NUM_TOKENS].
173
+ inputs.append(input_dict["input_ids"])
174
+ input_tensor_semantics.append(
175
+ TensorSemantics(
176
+ dimensions=[
177
+ BatchDimension(),
178
+ TokensDimension(),
179
+ ],
180
+ )
181
+ )
182
+
183
+ # Assume that the model supports a KeyValue cache.
184
+ cache_values = self._initialize_key_value_cache()
185
+ inputs.append(cache_values)
186
+ for _ in cache_values:
187
+ input_tensor_semantics.append(AttentionKeyValueCacheTensorSemantics())
188
+ input_tensor_semantics.append(AttentionKeyValueCacheTensorSemantics())
189
+
190
+ outputs = self.model(*inputs)
191
+
192
+ # Determine output semantics.
193
+ output_tensor_semantics = []
194
+ seqlen = inputs[0].shape[1]
195
+ found_logits = False
196
+ for tensor in flatten(outputs):
197
+ if not isinstance(tensor, torch.Tensor):
198
+ continue
199
+
200
+ if (
201
+ not found_logits
202
+ and len(tensor.shape) == 3
203
+ and tensor.shape[0] == 1
204
+ and tensor.shape[1] == seqlen
205
+ ):
206
+ # This should be the logits tensor.
207
+ output_tensor_semantics.append(
208
+ TensorSemantics(
209
+ dimensions=[
210
+ BatchDimension(),
211
+ TokensDimension(),
212
+ VocabLogitsDimension(),
213
+ ],
214
+ )
215
+ )
216
+ found_logits = True
217
+ else:
218
+ output_tensor_semantics.append(AttentionKeyValueCacheTensorSemantics())
219
+
220
+ dynamic_shapes = None
221
+ seqlen = torch.export.Dim("seqlen", min=2, max=9223372036854775096)
222
+
223
+ cache_length = torch.export.Dim("cache_length", min=2, max=9223372036854775096)
224
+ dynamic_shapes = [
225
+ {1: seqlen},
226
+ [[{2: cache_length}, {2: cache_length}] for _ in cache_values],
227
+ ]
228
+
229
+ return {
230
+ "example_inputs": inputs,
231
+ "dynamic_shapes": dynamic_shapes,
232
+ "input_tensor_semantics": input_tensor_semantics,
233
+ "output_tensor_semantics": output_tensor_semantics,
234
+ "generation_config": HuggingFaceGenerationConfig(self.pipeline.model),
235
+ }
236
+
237
+ def models(self) -> List[Model]:
238
+ return [self.tokenizer_model, self.language_model]
239
+
240
+
241
+ class HuggingFaceImageProcessor:
242
+
243
+ def __init__(
244
+ self,
245
+ model: PreTrainedModel,
246
+ processor: BaseImageProcessor,
247
+ ):
248
+ self.model = model
249
+ self.processor = processor
250
+
251
+ def input_spec(self) -> Dict[str, Any]:
252
+ target_size = None
253
+ image_preprocessing_steps = []
254
+ if (
255
+ hasattr(self.processor, "do_resize")
256
+ and self.processor.do_resize
257
+ and hasattr(self.processor, "size")
258
+ ):
259
+ target_size, preprocessing_step = self._convert_resize()
260
+ image_preprocessing_steps.append(preprocessing_step)
261
+
262
+ if (
263
+ hasattr(self.processor, "do_rescale")
264
+ and self.processor.do_rescale
265
+ and hasattr(self.processor, "rescale_factor")
266
+ ):
267
+ image_preprocessing_steps.append(
268
+ ImageToFloatTensor(
269
+ scale=True, scale_factor=self.processor.rescale_factor
270
+ )
271
+ )
272
+ else:
273
+ image_preprocessing_steps.append(ImageToFloatTensor(scale=False))
274
+
275
+ if hasattr(self.processor, "do_normalize") and self.processor.do_normalize:
276
+ image_preprocessing_steps.append(
277
+ StandardizeTensor(self.processor.image_mean, self.processor.image_std)
278
+ )
279
+
280
+ channels_first = True
281
+ if (
282
+ hasattr(self.processor, "input_data_format")
283
+ and self.processor.input_data_format == "channels_last"
284
+ ):
285
+ channels_first = False
286
+
287
+ # Assume RGB for now.
288
+ # TODO(james): figure out if this is specified anywhere in the huggingface pipeline.
289
+ channel_format = "rgb"
290
+
291
+ dimensions = [
292
+ BatchDimension(),
293
+ ]
294
+ input_shape = [1]
295
+ if channels_first:
296
+ dimensions.append(ImageChannelDimension(channel_format))
297
+ input_shape.append(3)
298
+ dimensions.append(ImageHeightDimension())
299
+ input_shape.append(target_size[0])
300
+ dimensions.append(ImageWidthDimension())
301
+ input_shape.append(target_size[1])
302
+ if not channels_first:
303
+ dimensions.append(ImageChannelDimension(channel_format))
304
+ input_shape.append(3)
305
+
306
+ example_input = torch.rand(input_shape)
307
+ input_tensor_semantics = [TensorSemantics(dimensions)]
308
+ return {
309
+ "example_inputs": [example_input],
310
+ "input_tensor_semantics": input_tensor_semantics,
311
+ "image_preprocessing_steps": image_preprocessing_steps,
312
+ }
313
+
314
+ def output_spec_segmentation(self) -> Dict[str, Any]:
315
+ if not hasattr(self.processor, "post_process_semantic_segmentation"):
316
+ raise NotImplementedError(
317
+ "only semantic segmentation is currently supported"
318
+ )
319
+ dimensions = [
320
+ BatchDimension(),
321
+ # TODO(james): verify all semantic segmentation in hugging face output a logits mask.
322
+ SegmentationMaskChannel("logits_mask"),
323
+ ImageHeightDimension(),
324
+ ImageWidthDimension(),
325
+ ]
326
+ output_tensor_semantics = [
327
+ TensorSemantics(dimensions),
328
+ ]
329
+ id_to_label = self.model.config.id2label
330
+ max_id = max(id_to_label)
331
+ labels = []
332
+ for i in range(max_id):
333
+ if i not in id_to_label:
334
+ labels.append("")
335
+ continue
336
+ labels.append(id_to_label[i])
337
+ return {
338
+ "output_tensor_semantics": output_tensor_semantics,
339
+ "class_labels": labels,
340
+ }
341
+
342
+ def output_spec_object_detection(self) -> Dict[str, Any]:
343
+ if not hasattr(self.processor, "post_process_object_detection"):
344
+ raise NotImplementedError(
345
+ "only semantic segmentation is currently supported"
346
+ )
347
+
348
+ id_to_label = self.model.config.id2label
349
+ max_id = max(id_to_label)
350
+ labels = []
351
+ for i in range(max_id):
352
+ if i not in id_to_label:
353
+ labels.append("")
354
+ continue
355
+ labels.append(id_to_label[i])
356
+ num_classes = max_id + 1
357
+
358
+ # TODO(james): verify assumptions made here apply broadly.
359
+ output_tensor_semantics = []
360
+ # We assume that ObjectDetectionWrapper is used to ensure that logits are the first tensor and boxes are the second.
361
+ logits_dimensions = [
362
+ BatchDimension(),
363
+ DetectionNumCandidatesDimension(is_nms=False),
364
+ DetectionOutputDimension(
365
+ scores_range=(0, num_classes),
366
+ scores_are_logits=True,
367
+ ),
368
+ ]
369
+ output_tensor_semantics.append(TensorSemantics(logits_dimensions))
370
+
371
+ box_dimensions = [
372
+ BatchDimension(),
373
+ DetectionNumCandidatesDimension(is_nms=False),
374
+ DetectionOutputDimension(
375
+ coordinates_start_index=0,
376
+ box_format=BoundingBoxFormat("cxcywh", is_normalized=True),
377
+ ),
378
+ ]
379
+ output_tensor_semantics.append(TensorSemantics(box_dimensions))
380
+ return {
381
+ "output_tensor_semantics": output_tensor_semantics,
382
+ "class_labels": labels,
383
+ }
384
+
385
+ def _convert_resize(self) -> Tuple[Tuple[int, int], ImagePreprocessingStep]:
386
+ size = self.processor.size
387
+ target_size = None
388
+ preprocess_step = None
389
+ if "height" in size and "width" in size:
390
+ target_size = [size["height"], size["width"]]
391
+ preprocess_step = ResizeImage()
392
+ elif (
393
+ "shortest_edge" in size
394
+ or "longest_edge" in size
395
+ or "max_height" in size
396
+ or "max_width" in size
397
+ ):
398
+ shortest_edge = size.get("shortest_edge")
399
+ longest_edge = size.get("longest_edge")
400
+ max_height = size.get("max_height")
401
+ max_width = size.get("max_width")
402
+
403
+ min_size = None
404
+ for edge_size in [shortest_edge, longest_edge, max_height, max_width]:
405
+ if not edge_size:
406
+ continue
407
+ if not min_size or edge_size < min_size:
408
+ min_size = edge_size
409
+
410
+ target_size = [min_size, min_size]
411
+ preprocess_step = LetterboxImage()
412
+ else:
413
+ raise ValueError(
414
+ "could not determine target size for resize from model config"
415
+ )
416
+ return target_size, preprocess_step
417
+
418
+
419
+ class HuggingFaceImageSegmentationPipeline:
420
+ def __init__(
421
+ self,
422
+ pipeline: Pipeline,
423
+ name: Optional[str] = None,
424
+ ):
425
+ self.pipeline = pipeline
426
+ if name is None:
427
+ name = pipeline.model.name_or_path
428
+
429
+ self.model = TorchModel(
430
+ name,
431
+ torch_module=self.pipeline.model,
432
+ **self._guess_model_spec(),
433
+ )
434
+
435
+ def _guess_model_spec(self) -> Dict:
436
+ if self.pipeline.image_processor is None:
437
+ raise ValueError(
438
+ "Could not determine image preprocessing for pipeline with image_processor=None"
439
+ )
440
+ if self.pipeline.tokenizer is not None:
441
+ raise NotImplementedError(
442
+ "HuggingFaceImageSegmentationPipeline does not yet support token inputs"
443
+ )
444
+
445
+ image_processor = HuggingFaceImageProcessor(
446
+ self.pipeline.model, self.pipeline.image_processor
447
+ )
448
+ spec = image_processor.input_spec()
449
+ spec.update(image_processor.output_spec_segmentation())
450
+ return spec
451
+
452
+ def models(self) -> List[Model]:
453
+ return [self.model]
454
+
455
+
456
+ class ObjectDetectionWrapper(torch.nn.Module):
457
+ def __init__(self, model: PreTrainedModel):
458
+ super().__init__()
459
+ self.model = model
460
+
461
+ def forward(self, *args, **kwargs):
462
+ outputs = self.model(*args, **kwargs)
463
+ return outputs.logits, outputs.pred_boxes
464
+
465
+
466
+ class HuggingFaceObjectDetectionPipeline:
467
+ def __init__(
468
+ self,
469
+ pipeline: Pipeline,
470
+ name: Optional[str] = None,
471
+ ):
472
+ self.pipeline = pipeline
473
+ if name is None:
474
+ name = pipeline.model.name_or_path
475
+
476
+ self.model = TorchModel(
477
+ name,
478
+ torch_module=ObjectDetectionWrapper(self.pipeline.model),
479
+ **self._guess_model_spec(),
480
+ )
481
+
482
+ def _guess_model_spec(self) -> Dict:
483
+ if self.pipeline.image_processor is None:
484
+ raise ValueError(
485
+ "Could not determine image preprocessing for pipeline with image_processor=None"
486
+ )
487
+ if self.pipeline.tokenizer is not None:
488
+ raise NotImplementedError(
489
+ "HuggingFaceObjectDetectionPipeline does not yet support token inputs"
490
+ )
491
+
492
+ image_processor = HuggingFaceImageProcessor(
493
+ self.pipeline.model, self.pipeline.image_processor
494
+ )
495
+ spec = image_processor.input_spec()
496
+ spec.update(image_processor.output_spec_object_detection())
497
+ return spec
498
+
499
+ def models(self) -> List[Model]:
500
+ return [self.model]
501
+
502
+
503
+ def import_huggingface_pipeline(pipeline: Pipeline, **kwargs) -> List[Model]:
504
+ if pipeline.framework != "pt":
505
+ raise ValueError(
506
+ "unimplemented: hugging face pipeline framework: {}".format(
507
+ pipeline.framework
508
+ )
509
+ )
510
+
511
+ if pipeline.task == "text-generation":
512
+ return HuggingFaceTextGenerationPipeline(pipeline, **kwargs).models()
513
+ elif pipeline.task == "image-segmentation":
514
+ return HuggingFaceImageSegmentationPipeline(pipeline, **kwargs).models()
515
+ elif pipeline.task == "object-detection":
516
+ return HuggingFaceObjectDetectionPipeline(pipeline, **kwargs).models()
517
+ raise ValueError(
518
+ "unimplemented: hugging face pipeline task: {} (supported tasks: [{}])".format(
519
+ pipeline.task, ["text-generation", "image-segmentation", "object-detection"]
520
+ )
521
+ )
gml/model.py CHANGED
@@ -13,20 +13,31 @@
13
13
  # limitations under the License.
14
14
  #
15
15
  # SPDX-License-Identifier: Apache-2.0
16
+ from __future__ import annotations
16
17
 
17
18
  import abc
19
+ import contextlib
18
20
  import io
19
21
  from pathlib import Path
20
- from typing import BinaryIO, Dict, List, Literal, Optional, TextIO, Tuple
22
+ from typing import BinaryIO, Dict, List, Literal, Optional, Sequence, TextIO, Tuple
21
23
 
22
24
  import gml.proto.src.api.corepb.v1.model_exec_pb2 as modelexecpb
23
25
  import torch
24
- import torch_mlir
25
- from gml.compile import to_torch_mlir
26
+ from gml.compile import to_torch_mlir, torch_mlir_output_kind
26
27
  from gml.preprocessing import ImagePreprocessingStep
27
28
  from gml.tensor import TensorSemantics
28
29
 
29
30
 
31
+ class GenerationConfig:
32
+ def __init__(self, eos_token_ids: List[int]):
33
+ self.eos_token_ids = eos_token_ids
34
+
35
+ def to_proto(self) -> modelexecpb.GenerationConfig:
36
+ return modelexecpb.GenerationConfig(
37
+ eos_token_ids=self.eos_token_ids,
38
+ )
39
+
40
+
30
41
  class Model(abc.ABC):
31
42
  def __init__(
32
43
  self,
@@ -38,6 +49,7 @@ class Model(abc.ABC):
38
49
  class_labels: Optional[List[str]] = None,
39
50
  class_labels_file: Optional[Path] = None,
40
51
  image_preprocessing_steps: Optional[List[ImagePreprocessingStep]] = None,
52
+ generation_config: Optional[GenerationConfig] = None,
41
53
  ):
42
54
  self.name = name
43
55
  self.kind = kind
@@ -51,6 +63,7 @@ class Model(abc.ABC):
51
63
  self.input_tensor_semantics = input_tensor_semantics
52
64
  self.output_tensor_semantics = output_tensor_semantics
53
65
  self.image_preprocessing_steps = image_preprocessing_steps
66
+ self.generation_config = generation_config
54
67
 
55
68
  def to_proto(self) -> modelexecpb.ModelInfo:
56
69
  image_preprocessing_steps = None
@@ -58,6 +71,9 @@ class Model(abc.ABC):
58
71
  image_preprocessing_steps = [
59
72
  step.to_proto() for step in self.image_preprocessing_steps
60
73
  ]
74
+ generation_config = None
75
+ if self.generation_config:
76
+ generation_config = self.generation_config.to_proto()
61
77
  return modelexecpb.ModelInfo(
62
78
  name=self.name,
63
79
  kind=self.kind,
@@ -70,45 +86,61 @@ class Model(abc.ABC):
70
86
  output_tensor_semantics=[
71
87
  semantics.to_proto() for semantics in self.output_tensor_semantics
72
88
  ],
89
+ generation_config=generation_config,
73
90
  )
74
91
 
75
92
  @abc.abstractmethod
76
- def collect_assets(self) -> Dict[str, TextIO | BinaryIO | Path]:
93
+ def _collect_assets(self) -> Dict[str, TextIO | BinaryIO | Path]:
77
94
  pass
78
95
 
96
+ @contextlib.contextmanager
97
+ def collect_assets(self):
98
+ yield from self._collect_assets()
99
+
79
100
 
80
101
  class TorchModel(Model):
81
102
  def __init__(
82
103
  self,
83
104
  name: str,
84
105
  torch_module: torch.nn.Module,
85
- input_shapes: List[List[int]],
86
- input_dtypes: List[torch.dtype],
106
+ example_inputs: Optional[List[torch.Tensor]] = None,
107
+ input_shapes: Optional[List[List[int]]] = None,
108
+ input_dtypes: Optional[List[torch.dtype]] = None,
109
+ dynamic_shapes: Optional[Sequence[Dict[int, str | "torch.export.Dim"]]] = None,
87
110
  **kwargs,
88
111
  ):
89
112
  super().__init__(
90
113
  name,
91
- modelexecpb.ModelInfo.MODEL_KIND_TORCHSCRIPT,
114
+ torch_mlir_output_kind(),
92
115
  modelexecpb.ModelInfo.MODEL_STORAGE_FORMAT_MLIR_TEXT,
93
116
  **kwargs,
94
117
  )
95
118
  self.torch_module = torch_module
119
+ self.example_inputs = example_inputs
96
120
  self.input_shapes = input_shapes
97
121
  self.input_dtypes = input_dtypes
122
+ self.dynamic_shapes = dynamic_shapes
123
+ if self.example_inputs is None:
124
+ if self.input_shapes is None or self.input_dtypes is None:
125
+ raise ValueError(
126
+ "one of `example_inputs` or (`input_shapes` and `input_dtype`) must be provided to `TorchModel`"
127
+ )
128
+ self.example_inputs = [
129
+ torch.rand(shape, dtype=dtype)
130
+ for shape, dtype in zip(self.input_shapes, self.input_dtypes)
131
+ ]
98
132
 
99
133
  def _convert_to_torch_mlir(self):
100
134
  return to_torch_mlir(
101
- self.torch_module.to("cpu"),
102
- [
103
- torch_mlir.TensorPlaceholder(shape, dtype)
104
- for shape, dtype in zip(self.input_shapes, self.input_dtypes)
105
- ],
135
+ self.torch_module,
136
+ self.example_inputs,
137
+ self.dynamic_shapes,
106
138
  )
107
139
 
108
- def collect_assets(self) -> Dict[str, TextIO | BinaryIO | Path]:
140
+ def _collect_assets(self) -> Dict[str, TextIO | BinaryIO | Path]:
109
141
  compiled = self._convert_to_torch_mlir()
110
142
  file = io.BytesIO(str(compiled).encode("utf-8"))
111
- return {"": file}
143
+ yield {"": file}
112
144
 
113
145
 
114
146
  def _kind_str_to_kind_format_protos(
@@ -146,5 +178,5 @@ class ModelFromFiles(Model):
146
178
  super().__init__(name=name, kind=kind, storage_format=storage_format, **kwargs)
147
179
  self.files = files
148
180
 
149
- def collect_assets(self) -> Dict[str, TextIO | BinaryIO | Path]:
150
- return self.files
181
+ def _collect_assets(self) -> Dict[str, TextIO | BinaryIO | Path]:
182
+ yield self.files
gml/model_utils.py CHANGED
@@ -31,3 +31,5 @@ def prepare_ultralytics_yolo(model):
31
31
  for _, m in model.named_modules():
32
32
  if hasattr(m, "export"):
33
33
  m.export = True
34
+ # YOLOv8 requires setting `format` when `export = True`
35
+ m.format = "custom"