gimlet-api 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gml/hf.py ADDED
@@ -0,0 +1,524 @@
1
+ # Copyright 2023- Gimlet Labs, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import glob
18
+ import tempfile
19
+ from collections.abc import Iterable
20
+ from pathlib import Path
21
+ from typing import Any, BinaryIO, Dict, List, Optional, TextIO, Tuple
22
+
23
+ import gml.proto.src.api.corepb.v1.model_exec_pb2 as modelexecpb
24
+ import torch
25
+ import transformers
26
+ from gml.asset_manager import AssetManager
27
+ from gml.model import GenerationConfig, Model, TorchModel
28
+ from gml.preprocessing import (
29
+ ImagePreprocessingStep,
30
+ ImageToFloatTensor,
31
+ LetterboxImage,
32
+ ResizeImage,
33
+ StandardizeTensor,
34
+ )
35
+ from gml.tensor import (
36
+ AttentionKeyValueCacheTensorSemantics,
37
+ BatchDimension,
38
+ BoundingBoxFormat,
39
+ DetectionNumCandidatesDimension,
40
+ DetectionOutputDimension,
41
+ ImageChannelDimension,
42
+ ImageHeightDimension,
43
+ ImageWidthDimension,
44
+ SegmentationMaskChannel,
45
+ TensorSemantics,
46
+ TokensDimension,
47
+ VocabLogitsDimension,
48
+ )
49
+ from transformers import (
50
+ BaseImageProcessor,
51
+ Pipeline,
52
+ PreTrainedModel,
53
+ PreTrainedTokenizer,
54
+ )
55
+
56
+
57
+ class HuggingFaceTokenizer(Model):
58
+ def __init__(self, tokenizer: PreTrainedTokenizer, name: Optional[str] = None):
59
+ if name is None:
60
+ name = tokenizer.name_or_path + ".tokenizer"
61
+ super().__init__(
62
+ name=name,
63
+ kind=modelexecpb.ModelInfo.MODEL_KIND_HUGGINGFACE_TOKENIZER,
64
+ storage_format=modelexecpb.ModelInfo.MODEL_STORAGE_FORMAT_OPAQUE,
65
+ input_tensor_semantics=[],
66
+ output_tensor_semantics=[],
67
+ )
68
+ self.tokenizer = tokenizer
69
+
70
+ def _collect_assets(
71
+ self, weight_manager: Optional[AssetManager] = None
72
+ ) -> Dict[str, TextIO | BinaryIO | Path]:
73
+ with tempfile.TemporaryDirectory() as tmpdir:
74
+ self.tokenizer.save_pretrained(tmpdir)
75
+ paths = [Path(f) for f in glob.glob(tmpdir + "/*")]
76
+ yield {p.name: p for p in paths}
77
+
78
+
79
+ class HuggingFaceGenerationConfig(GenerationConfig):
80
+
81
+ def __init__(self, model: PreTrainedModel):
82
+ config = model.generation_config
83
+ eos_tokens = config.eos_token_id
84
+ if eos_tokens is None:
85
+ eos_tokens = []
86
+ if not isinstance(eos_tokens, list):
87
+ eos_tokens = [eos_tokens]
88
+ super().__init__(eos_tokens)
89
+
90
+
91
+ def flatten(items):
92
+ flattened = []
93
+ if isinstance(items, torch.Tensor) or not isinstance(items, Iterable):
94
+ flattened.append(items)
95
+ else:
96
+ for x in items:
97
+ flattened.extend(flatten(x))
98
+ return flattened
99
+
100
+
101
+ class WrapWithFunctionalCache(torch.nn.Module):
102
+
103
+ def __init__(self, model: transformers.PreTrainedModel):
104
+ super().__init__()
105
+ self.model = model
106
+
107
+ def forward(self, input_ids, cache):
108
+ outputs = self.model(
109
+ input_ids=input_ids,
110
+ past_key_values=cache,
111
+ return_dict=True,
112
+ use_cache=True,
113
+ )
114
+
115
+ return outputs.logits, outputs.past_key_values
116
+
117
+
118
+ class HuggingFaceTextGenerationPipeline:
119
+ def __init__(
120
+ self,
121
+ pipeline: Pipeline,
122
+ name: Optional[str] = None,
123
+ tokenizer_name: Optional[str] = None,
124
+ dynamic_seqlen: bool = False,
125
+ ):
126
+ self.pipeline = pipeline
127
+ self.tokenizer_model = HuggingFaceTokenizer(pipeline.tokenizer, tokenizer_name)
128
+ self._cache_length_for_tracing = 32
129
+ if name is None:
130
+ name = pipeline.model.name_or_path
131
+
132
+ self.model = pipeline.model
133
+ self.model = self.model.to(torch.float16)
134
+ self.model = WrapWithFunctionalCache(pipeline.model)
135
+
136
+ self.language_model = TorchModel(
137
+ name,
138
+ torch_module=self.model,
139
+ **self._guess_model_spec(dynamic_seqlen),
140
+ )
141
+
142
+ def _initialize_key_value_cache(self):
143
+ cache = []
144
+ config = self.pipeline.model.config
145
+ head_dim = (
146
+ config.head_dim
147
+ if hasattr(config, "head_dim")
148
+ else config.hidden_size // config.num_attention_heads
149
+ )
150
+ num_key_value_heads = (
151
+ config.num_attention_heads
152
+ if config.num_key_value_heads is None
153
+ else config.num_key_value_heads
154
+ )
155
+ cache_shape = (1, num_key_value_heads, self._cache_length_for_tracing, head_dim)
156
+ for _ in range(config.num_hidden_layers):
157
+ cache.append(
158
+ [
159
+ torch.zeros(cache_shape).to(torch.float16),
160
+ torch.zeros(cache_shape).to(torch.float16),
161
+ ]
162
+ )
163
+ return cache
164
+
165
+ def _guess_model_spec(self, dynamic_seqlen: bool) -> Dict:
166
+ input_dict = self.pipeline.preprocess("this is a prompt! Test test test?")
167
+ if "input_ids" not in input_dict:
168
+ raise ValueError(
169
+ 'HuggingFaceTextGenerationPipeline expects preprocessed inputs to have an "input_ids" tensor'
170
+ )
171
+
172
+ inputs = []
173
+ input_tensor_semantics = []
174
+
175
+ # This currently assumes that all HF language models have inputs that are [B, NUM_TOKENS].
176
+ inputs.append(input_dict["input_ids"])
177
+ input_tensor_semantics.append(
178
+ TensorSemantics(
179
+ dimensions=[
180
+ BatchDimension(),
181
+ TokensDimension(),
182
+ ],
183
+ )
184
+ )
185
+
186
+ # Assume that the model supports a KeyValue cache.
187
+ cache_values = self._initialize_key_value_cache()
188
+ inputs.append(cache_values)
189
+ for _ in cache_values:
190
+ input_tensor_semantics.append(AttentionKeyValueCacheTensorSemantics())
191
+ input_tensor_semantics.append(AttentionKeyValueCacheTensorSemantics())
192
+
193
+ outputs = self.model(*inputs)
194
+
195
+ # Determine output semantics.
196
+ output_tensor_semantics = []
197
+ seqlen = inputs[0].shape[1]
198
+ found_logits = False
199
+ for tensor in flatten(outputs):
200
+ if not isinstance(tensor, torch.Tensor):
201
+ continue
202
+
203
+ if (
204
+ not found_logits
205
+ and len(tensor.shape) == 3
206
+ and tensor.shape[0] == 1
207
+ and tensor.shape[1] == seqlen
208
+ ):
209
+ # This should be the logits tensor.
210
+ output_tensor_semantics.append(
211
+ TensorSemantics(
212
+ dimensions=[
213
+ BatchDimension(),
214
+ TokensDimension(),
215
+ VocabLogitsDimension(),
216
+ ],
217
+ )
218
+ )
219
+ found_logits = True
220
+ else:
221
+ output_tensor_semantics.append(AttentionKeyValueCacheTensorSemantics())
222
+
223
+ dynamic_shapes = None
224
+ seqlen = torch.export.Dim("seqlen", min=2, max=9223372036854775096)
225
+
226
+ cache_length = torch.export.Dim("cache_length", min=2, max=9223372036854775096)
227
+ dynamic_shapes = [
228
+ {1: seqlen},
229
+ [[{2: cache_length}, {2: cache_length}] for _ in cache_values],
230
+ ]
231
+
232
+ return {
233
+ "example_inputs": inputs,
234
+ "dynamic_shapes": dynamic_shapes,
235
+ "input_tensor_semantics": input_tensor_semantics,
236
+ "output_tensor_semantics": output_tensor_semantics,
237
+ "generation_config": HuggingFaceGenerationConfig(self.pipeline.model),
238
+ }
239
+
240
+ def models(self) -> List[Model]:
241
+ return [self.tokenizer_model, self.language_model]
242
+
243
+
244
+ class HuggingFaceImageProcessor:
245
+
246
+ def __init__(
247
+ self,
248
+ model: PreTrainedModel,
249
+ processor: BaseImageProcessor,
250
+ ):
251
+ self.model = model
252
+ self.processor = processor
253
+
254
+ def input_spec(self) -> Dict[str, Any]:
255
+ target_size = None
256
+ image_preprocessing_steps = []
257
+ if (
258
+ hasattr(self.processor, "do_resize")
259
+ and self.processor.do_resize
260
+ and hasattr(self.processor, "size")
261
+ ):
262
+ target_size, preprocessing_step = self._convert_resize()
263
+ image_preprocessing_steps.append(preprocessing_step)
264
+
265
+ if (
266
+ hasattr(self.processor, "do_rescale")
267
+ and self.processor.do_rescale
268
+ and hasattr(self.processor, "rescale_factor")
269
+ ):
270
+ image_preprocessing_steps.append(
271
+ ImageToFloatTensor(
272
+ scale=True, scale_factor=self.processor.rescale_factor
273
+ )
274
+ )
275
+ else:
276
+ image_preprocessing_steps.append(ImageToFloatTensor(scale=False))
277
+
278
+ if hasattr(self.processor, "do_normalize") and self.processor.do_normalize:
279
+ image_preprocessing_steps.append(
280
+ StandardizeTensor(self.processor.image_mean, self.processor.image_std)
281
+ )
282
+
283
+ channels_first = True
284
+ if (
285
+ hasattr(self.processor, "input_data_format")
286
+ and self.processor.input_data_format == "channels_last"
287
+ ):
288
+ channels_first = False
289
+
290
+ # Assume RGB for now.
291
+ # TODO(james): figure out if this is specified anywhere in the huggingface pipeline.
292
+ channel_format = "rgb"
293
+
294
+ dimensions = [
295
+ BatchDimension(),
296
+ ]
297
+ input_shape = [1]
298
+ if channels_first:
299
+ dimensions.append(ImageChannelDimension(channel_format))
300
+ input_shape.append(3)
301
+ dimensions.append(ImageHeightDimension())
302
+ input_shape.append(target_size[0])
303
+ dimensions.append(ImageWidthDimension())
304
+ input_shape.append(target_size[1])
305
+ if not channels_first:
306
+ dimensions.append(ImageChannelDimension(channel_format))
307
+ input_shape.append(3)
308
+
309
+ example_input = torch.rand(input_shape)
310
+ input_tensor_semantics = [TensorSemantics(dimensions)]
311
+ return {
312
+ "example_inputs": [example_input],
313
+ "input_tensor_semantics": input_tensor_semantics,
314
+ "image_preprocessing_steps": image_preprocessing_steps,
315
+ }
316
+
317
+ def output_spec_segmentation(self) -> Dict[str, Any]:
318
+ if not hasattr(self.processor, "post_process_semantic_segmentation"):
319
+ raise NotImplementedError(
320
+ "only semantic segmentation is currently supported"
321
+ )
322
+ dimensions = [
323
+ BatchDimension(),
324
+ # TODO(james): verify all semantic segmentation in hugging face output a logits mask.
325
+ SegmentationMaskChannel("logits_mask"),
326
+ ImageHeightDimension(),
327
+ ImageWidthDimension(),
328
+ ]
329
+ output_tensor_semantics = [
330
+ TensorSemantics(dimensions),
331
+ ]
332
+ id_to_label = self.model.config.id2label
333
+ max_id = max(id_to_label)
334
+ labels = []
335
+ for i in range(max_id):
336
+ if i not in id_to_label:
337
+ labels.append("")
338
+ continue
339
+ labels.append(id_to_label[i])
340
+ return {
341
+ "output_tensor_semantics": output_tensor_semantics,
342
+ "class_labels": labels,
343
+ }
344
+
345
+ def output_spec_object_detection(self) -> Dict[str, Any]:
346
+ if not hasattr(self.processor, "post_process_object_detection"):
347
+ raise NotImplementedError(
348
+ "processor must have post_process_object_detection set"
349
+ )
350
+
351
+ id_to_label = self.model.config.id2label
352
+ max_id = max(id_to_label)
353
+ labels = []
354
+ for i in range(max_id):
355
+ if i not in id_to_label:
356
+ labels.append("")
357
+ continue
358
+ labels.append(id_to_label[i])
359
+ num_classes = max_id + 1
360
+
361
+ # TODO(james): verify assumptions made here apply broadly.
362
+ output_tensor_semantics = []
363
+ # We assume that ObjectDetectionWrapper is used to ensure that logits are the first tensor and boxes are the second.
364
+ logits_dimensions = [
365
+ BatchDimension(),
366
+ DetectionNumCandidatesDimension(is_nms=False),
367
+ DetectionOutputDimension(
368
+ scores_range=(0, num_classes),
369
+ scores_are_logits=True,
370
+ ),
371
+ ]
372
+ output_tensor_semantics.append(TensorSemantics(logits_dimensions))
373
+
374
+ box_dimensions = [
375
+ BatchDimension(),
376
+ DetectionNumCandidatesDimension(is_nms=False),
377
+ DetectionOutputDimension(
378
+ coordinates_start_index=0,
379
+ box_format=BoundingBoxFormat("cxcywh", is_normalized=True),
380
+ ),
381
+ ]
382
+ output_tensor_semantics.append(TensorSemantics(box_dimensions))
383
+ return {
384
+ "output_tensor_semantics": output_tensor_semantics,
385
+ "class_labels": labels,
386
+ }
387
+
388
+ def _convert_resize(self) -> Tuple[Tuple[int, int], ImagePreprocessingStep]:
389
+ size = self.processor.size
390
+ target_size = None
391
+ preprocess_step = None
392
+ if "height" in size and "width" in size:
393
+ target_size = [size["height"], size["width"]]
394
+ preprocess_step = ResizeImage()
395
+ elif (
396
+ "shortest_edge" in size
397
+ or "longest_edge" in size
398
+ or "max_height" in size
399
+ or "max_width" in size
400
+ ):
401
+ shortest_edge = size.get("shortest_edge")
402
+ longest_edge = size.get("longest_edge")
403
+ max_height = size.get("max_height")
404
+ max_width = size.get("max_width")
405
+
406
+ min_size = None
407
+ for edge_size in [shortest_edge, longest_edge, max_height, max_width]:
408
+ if not edge_size:
409
+ continue
410
+ if not min_size or edge_size < min_size:
411
+ min_size = edge_size
412
+
413
+ target_size = [min_size, min_size]
414
+ preprocess_step = LetterboxImage()
415
+ else:
416
+ raise ValueError(
417
+ "could not determine target size for resize from model config"
418
+ )
419
+ return target_size, preprocess_step
420
+
421
+
422
+ class HuggingFaceImageSegmentationPipeline:
423
+ def __init__(
424
+ self,
425
+ pipeline: Pipeline,
426
+ name: Optional[str] = None,
427
+ ):
428
+ self.pipeline = pipeline
429
+ if name is None:
430
+ name = pipeline.model.name_or_path
431
+
432
+ self.model = TorchModel(
433
+ name,
434
+ torch_module=self.pipeline.model,
435
+ **self._guess_model_spec(),
436
+ )
437
+
438
+ def _guess_model_spec(self) -> Dict:
439
+ if self.pipeline.image_processor is None:
440
+ raise ValueError(
441
+ "Could not determine image preprocessing for pipeline with image_processor=None"
442
+ )
443
+ if self.pipeline.tokenizer is not None:
444
+ raise NotImplementedError(
445
+ "HuggingFaceImageSegmentationPipeline does not yet support token inputs"
446
+ )
447
+
448
+ image_processor = HuggingFaceImageProcessor(
449
+ self.pipeline.model, self.pipeline.image_processor
450
+ )
451
+ spec = image_processor.input_spec()
452
+ spec.update(image_processor.output_spec_segmentation())
453
+ return spec
454
+
455
+ def models(self) -> List[Model]:
456
+ return [self.model]
457
+
458
+
459
+ class ObjectDetectionWrapper(torch.nn.Module):
460
+ def __init__(self, model: PreTrainedModel):
461
+ super().__init__()
462
+ self.model = model
463
+
464
+ def forward(self, *args, **kwargs):
465
+ outputs = self.model(*args, **kwargs)
466
+ return outputs.logits, outputs.pred_boxes
467
+
468
+
469
+ class HuggingFaceObjectDetectionPipeline:
470
+ def __init__(
471
+ self,
472
+ pipeline: Pipeline,
473
+ name: Optional[str] = None,
474
+ ):
475
+ self.pipeline = pipeline
476
+ if name is None:
477
+ name = pipeline.model.name_or_path
478
+
479
+ self.model = TorchModel(
480
+ name,
481
+ torch_module=ObjectDetectionWrapper(self.pipeline.model),
482
+ **self._guess_model_spec(),
483
+ )
484
+
485
+ def _guess_model_spec(self) -> Dict:
486
+ if self.pipeline.image_processor is None:
487
+ raise ValueError(
488
+ "Could not determine image preprocessing for pipeline with image_processor=None"
489
+ )
490
+ if self.pipeline.tokenizer is not None:
491
+ raise NotImplementedError(
492
+ "HuggingFaceObjectDetectionPipeline does not yet support token inputs"
493
+ )
494
+
495
+ image_processor = HuggingFaceImageProcessor(
496
+ self.pipeline.model, self.pipeline.image_processor
497
+ )
498
+ spec = image_processor.input_spec()
499
+ spec.update(image_processor.output_spec_object_detection())
500
+ return spec
501
+
502
+ def models(self) -> List[Model]:
503
+ return [self.model]
504
+
505
+
506
+ def import_huggingface_pipeline(pipeline: Pipeline, **kwargs) -> List[Model]:
507
+ if pipeline.framework != "pt":
508
+ raise ValueError(
509
+ "unimplemented: hugging face pipeline framework: {}".format(
510
+ pipeline.framework
511
+ )
512
+ )
513
+
514
+ if pipeline.task == "text-generation":
515
+ return HuggingFaceTextGenerationPipeline(pipeline, **kwargs).models()
516
+ elif pipeline.task == "image-segmentation":
517
+ return HuggingFaceImageSegmentationPipeline(pipeline, **kwargs).models()
518
+ elif pipeline.task == "object-detection":
519
+ return HuggingFaceObjectDetectionPipeline(pipeline, **kwargs).models()
520
+ raise ValueError(
521
+ "unimplemented: hugging face pipeline task: {} (supported tasks: [{}])".format(
522
+ pipeline.task, ["text-generation", "image-segmentation", "object-detection"]
523
+ )
524
+ )
gml/model.py CHANGED
@@ -16,18 +16,29 @@
16
16
  from __future__ import annotations
17
17
 
18
18
  import abc
19
+ import contextlib
19
20
  import io
20
21
  from pathlib import Path
21
- from typing import BinaryIO, Dict, List, Literal, Optional, TextIO, Tuple
22
+ from typing import BinaryIO, Dict, List, Literal, Optional, Sequence, TextIO, Tuple
22
23
 
23
24
  import gml.proto.src.api.corepb.v1.model_exec_pb2 as modelexecpb
24
25
  import torch
25
- import torch_mlir
26
+ from gml.asset_manager import AssetManager, TempFileAssetManager
26
27
  from gml.compile import to_torch_mlir
27
28
  from gml.preprocessing import ImagePreprocessingStep
28
29
  from gml.tensor import TensorSemantics
29
30
 
30
31
 
32
+ class GenerationConfig:
33
+ def __init__(self, eos_token_ids: List[int]):
34
+ self.eos_token_ids = eos_token_ids
35
+
36
+ def to_proto(self) -> modelexecpb.GenerationConfig:
37
+ return modelexecpb.GenerationConfig(
38
+ eos_token_ids=self.eos_token_ids,
39
+ )
40
+
41
+
31
42
  class Model(abc.ABC):
32
43
  def __init__(
33
44
  self,
@@ -39,6 +50,7 @@ class Model(abc.ABC):
39
50
  class_labels: Optional[List[str]] = None,
40
51
  class_labels_file: Optional[Path] = None,
41
52
  image_preprocessing_steps: Optional[List[ImagePreprocessingStep]] = None,
53
+ generation_config: Optional[GenerationConfig] = None,
42
54
  ):
43
55
  self.name = name
44
56
  self.kind = kind
@@ -52,6 +64,7 @@ class Model(abc.ABC):
52
64
  self.input_tensor_semantics = input_tensor_semantics
53
65
  self.output_tensor_semantics = output_tensor_semantics
54
66
  self.image_preprocessing_steps = image_preprocessing_steps
67
+ self.generation_config = generation_config
55
68
 
56
69
  def to_proto(self) -> modelexecpb.ModelInfo:
57
70
  image_preprocessing_steps = None
@@ -59,6 +72,9 @@ class Model(abc.ABC):
59
72
  image_preprocessing_steps = [
60
73
  step.to_proto() for step in self.image_preprocessing_steps
61
74
  ]
75
+ generation_config = None
76
+ if self.generation_config:
77
+ generation_config = self.generation_config.to_proto()
62
78
  return modelexecpb.ModelInfo(
63
79
  name=self.name,
64
80
  kind=self.kind,
@@ -71,45 +87,73 @@ class Model(abc.ABC):
71
87
  output_tensor_semantics=[
72
88
  semantics.to_proto() for semantics in self.output_tensor_semantics
73
89
  ],
90
+ generation_config=generation_config,
74
91
  )
75
92
 
76
93
  @abc.abstractmethod
77
- def collect_assets(self) -> Dict[str, TextIO | BinaryIO | Path]:
94
+ def _collect_assets(
95
+ self, weight_manager: Optional[AssetManager] = None
96
+ ) -> Dict[str, TextIO | BinaryIO | Path]:
78
97
  pass
79
98
 
99
+ @contextlib.contextmanager
100
+ def collect_assets(self, weight_manager: Optional[AssetManager] = None):
101
+ yield from self._collect_assets(weight_manager)
102
+
80
103
 
81
104
  class TorchModel(Model):
82
105
  def __init__(
83
106
  self,
84
107
  name: str,
85
108
  torch_module: torch.nn.Module,
86
- input_shapes: List[List[int]],
87
- input_dtypes: List[torch.dtype],
109
+ example_inputs: Optional[List[torch.Tensor]] = None,
110
+ input_shapes: Optional[List[List[int]]] = None,
111
+ input_dtypes: Optional[List[torch.dtype]] = None,
112
+ dynamic_shapes: Optional[Sequence[Dict[int, str | "torch.export.Dim"]]] = None,
88
113
  **kwargs,
89
114
  ):
90
115
  super().__init__(
91
116
  name,
92
- modelexecpb.ModelInfo.MODEL_KIND_TORCHSCRIPT,
117
+ modelexecpb.ModelInfo.MODEL_KIND_TORCH,
93
118
  modelexecpb.ModelInfo.MODEL_STORAGE_FORMAT_MLIR_TEXT,
94
119
  **kwargs,
95
120
  )
96
121
  self.torch_module = torch_module
122
+ self.example_inputs = example_inputs
97
123
  self.input_shapes = input_shapes
98
124
  self.input_dtypes = input_dtypes
125
+ self.dynamic_shapes = dynamic_shapes
126
+ if self.example_inputs is None:
127
+ if self.input_shapes is None or self.input_dtypes is None:
128
+ raise ValueError(
129
+ "one of `example_inputs` or (`input_shapes` and `input_dtype`) must be provided to `TorchModel`"
130
+ )
131
+ self.example_inputs = [
132
+ torch.rand(shape, dtype=dtype)
133
+ for shape, dtype in zip(self.input_shapes, self.input_dtypes)
134
+ ]
99
135
 
100
- def _convert_to_torch_mlir(self):
136
+ def _convert_to_torch_mlir(self, weight_manager: Optional[AssetManager] = None):
101
137
  return to_torch_mlir(
102
- self.torch_module.to("cpu"),
103
- [
104
- torch_mlir.TensorPlaceholder(shape, dtype)
105
- for shape, dtype in zip(self.input_shapes, self.input_dtypes)
106
- ],
138
+ self.torch_module,
139
+ self.example_inputs,
140
+ self.dynamic_shapes,
141
+ weight_manager=weight_manager,
107
142
  )
108
143
 
109
- def collect_assets(self) -> Dict[str, TextIO | BinaryIO | Path]:
110
- compiled = self._convert_to_torch_mlir()
111
- file = io.BytesIO(str(compiled).encode("utf-8"))
112
- return {"": file}
144
+ def _collect_assets(
145
+ self, weight_manager: Optional[AssetManager] = None
146
+ ) -> Dict[str, TextIO | BinaryIO | Path]:
147
+ if weight_manager is None:
148
+ # If the user does not provide a weight manager, use temp files.
149
+ weight_manager = TempFileAssetManager()
150
+
151
+ with weight_manager as weight_mgr:
152
+ compiled = self._convert_to_torch_mlir(weight_mgr)
153
+ file = io.BytesIO(str(compiled).encode("utf-8"))
154
+ assets = {"": file}
155
+ assets.update(weight_mgr.assets())
156
+ yield assets
113
157
 
114
158
 
115
159
  def _kind_str_to_kind_format_protos(
@@ -147,5 +191,7 @@ class ModelFromFiles(Model):
147
191
  super().__init__(name=name, kind=kind, storage_format=storage_format, **kwargs)
148
192
  self.files = files
149
193
 
150
- def collect_assets(self) -> Dict[str, TextIO | BinaryIO | Path]:
151
- return self.files
194
+ def _collect_assets(
195
+ self, weight_manager: Optional[AssetManager] = None
196
+ ) -> Dict[str, TextIO | BinaryIO | Path]:
197
+ yield self.files
gml/model_utils.py CHANGED
@@ -31,3 +31,5 @@ def prepare_ultralytics_yolo(model):
31
31
  for _, m in model.named_modules():
32
32
  if hasattr(m, "export"):
33
33
  m.export = True
34
+ # YOLOv8 requires setting `format` when `export = True`
35
+ m.format = "custom"