gimlet-api 0.0.8__py3-none-any.whl → 0.0.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gml/hf.py CHANGED
@@ -24,8 +24,10 @@ from typing import Any, BinaryIO, Dict, List, Optional, TextIO, Tuple
24
24
 
25
25
  import torch
26
26
  import transformers
27
+ from rich.progress import Console
27
28
  from transformers import (
28
29
  BaseImageProcessor,
30
+ DynamicCache,
29
31
  Pipeline,
30
32
  PreTrainedModel,
31
33
  PreTrainedTokenizer,
@@ -49,6 +51,7 @@ from gml.tensor import (
49
51
  DetectionNumCandidatesDimension,
50
52
  DetectionOutputDimension,
51
53
  DimensionSemantics,
54
+ EmbeddingDimension,
52
55
  ImageChannelDimension,
53
56
  ImageHeightDimension,
54
57
  ImageWidthDimension,
@@ -60,6 +63,11 @@ from gml.tensor import (
60
63
 
61
64
  FALLBACK_RESIZE_SIZE = 512
62
65
 
66
+ # Set dynamic dimension max size to less than the int64 max, leaving leeway for the size to be ~4x by the model.
67
+ MAX_DYNAMIC_VAL = 2**61
68
+
69
+ console = Console()
70
+
63
71
 
64
72
  class HuggingFaceTokenizer(Model):
65
73
  def __init__(self, tokenizer: PreTrainedTokenizer, name: Optional[str] = None):
@@ -105,7 +113,6 @@ def flatten(items):
105
113
 
106
114
 
107
115
  class WrapWithFunctionalCache(torch.nn.Module):
108
-
109
116
  def __init__(self, model: transformers.PreTrainedModel):
110
117
  super().__init__()
111
118
  self.model = model
@@ -128,6 +135,8 @@ class HuggingFaceTextGenerationPipeline:
128
135
  name: Optional[str] = None,
129
136
  tokenizer_name: Optional[str] = None,
130
137
  dynamic_seqlen: bool = False,
138
+ dynamic_batch: bool = False,
139
+ export_predispatch: bool = False,
131
140
  ):
132
141
  self.pipeline = pipeline
133
142
  self.tokenizer_model = HuggingFaceTokenizer(pipeline.tokenizer, tokenizer_name)
@@ -139,13 +148,20 @@ class HuggingFaceTextGenerationPipeline:
139
148
  self.model = self.model.to(torch.float16)
140
149
  self.model = WrapWithFunctionalCache(pipeline.model)
141
150
 
151
+ self.dynamic_batch = dynamic_batch
152
+ self.batch_size = 1
153
+ if self.dynamic_batch:
154
+ # dynamic tracing fails for dimensions of size 1.
155
+ self.batch_size = 2
156
+
142
157
  self.language_model = TorchModel(
143
158
  name,
144
159
  torch_module=self.model,
160
+ export_predispatch=export_predispatch,
145
161
  **self._guess_model_spec(dynamic_seqlen),
146
162
  )
147
163
 
148
- def _initialize_key_value_cache(self):
164
+ def _initialize_key_value_cache(self) -> DynamicCache:
149
165
  cache = []
150
166
  config = self.pipeline.model.config
151
167
  head_dim = (
@@ -158,7 +174,12 @@ class HuggingFaceTextGenerationPipeline:
158
174
  if config.num_key_value_heads is None
159
175
  else config.num_key_value_heads
160
176
  )
161
- cache_shape = (1, num_key_value_heads, self._cache_length_for_tracing, head_dim)
177
+ cache_shape = (
178
+ self.batch_size,
179
+ num_key_value_heads,
180
+ self._cache_length_for_tracing,
181
+ head_dim,
182
+ )
162
183
  for _ in range(config.num_hidden_layers):
163
184
  cache.append(
164
185
  [
@@ -166,7 +187,67 @@ class HuggingFaceTextGenerationPipeline:
166
187
  torch.zeros(cache_shape).to(torch.float16),
167
188
  ]
168
189
  )
169
- return cache
190
+ return DynamicCache.from_legacy_cache(cache)
191
+
192
+ def _parse_transformer_config(
193
+ self, model: transformers.PreTrainedModel
194
+ ) -> modelexecpb.TransformerConfig:
195
+ # Only non-default rope config set the rope_scaling parameter
196
+ attention_head_size = getattr(
197
+ model.config,
198
+ "attention_head_size",
199
+ model.config.hidden_size // model.config.num_attention_heads,
200
+ )
201
+ partial_rotary_factor = getattr(model.config, "partial_rotary_factor", 1.0)
202
+ rotary_embedding_dim = getattr(
203
+ model.config,
204
+ "rotary_dim",
205
+ int(attention_head_size * partial_rotary_factor),
206
+ )
207
+ if (
208
+ hasattr(model.config, "rope_scaling")
209
+ and model.config.rope_scaling is not None
210
+ ):
211
+ rope_scaling = model.config.rope_scaling
212
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None))
213
+ if not rope_type == "llama3":
214
+ raise NotImplementedError(
215
+ "rope scaling type {} is not supported".format(rope_type)
216
+ )
217
+ # LLAMA 3 example config: https://huggingface.co/meta-llama/Llama-3.2-1B/blob/main/config.json
218
+ llama3_config = modelexecpb.Llama3RopeConfig()
219
+ llama3_config.theta = model.config.rope_theta
220
+ llama3_config.rotary_embedding_dim = rotary_embedding_dim
221
+ llama3_config.max_position_embeddings = model.config.max_position_embeddings
222
+
223
+ llama3_config.factor = rope_scaling["factor"]
224
+ llama3_config.high_freq_factor = rope_scaling["high_freq_factor"]
225
+ llama3_config.low_freq_factor = rope_scaling["low_freq_factor"]
226
+ llama3_config.original_max_position_embeddings = rope_scaling[
227
+ "original_max_position_embeddings"
228
+ ]
229
+ return modelexecpb.TransformerConfig(
230
+ position_embedding_config=modelexecpb.PositionEmbeddingConfig(
231
+ kind=modelexecpb.PositionEmbeddingKind.POSITION_EMBEDDING_KIND_ROPE_LLAMA3,
232
+ llama3_rope_config=llama3_config,
233
+ ),
234
+ )
235
+ # Default rope configs:
236
+ # 1. Llama-2: https://huggingface.co/NousResearch/Llama-2-7b-hf/blob/main/config.json
237
+ # 2. Qwen2.5: https://huggingface.co/Qwen/Qwen2.5-14B-Instruct-1M/blob/main/config.json
238
+ # 3. Mixtral: https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/config.json
239
+ default_rope_config = modelexecpb.DefaultRopeConfig()
240
+ default_rope_config.theta = model.config.rope_theta
241
+ default_rope_config.max_position_embeddings = (
242
+ model.config.max_position_embeddings
243
+ )
244
+ default_rope_config.rotary_embedding_dim = rotary_embedding_dim
245
+ return modelexecpb.TransformerConfig(
246
+ position_embedding_config=modelexecpb.PositionEmbeddingConfig(
247
+ kind=modelexecpb.PositionEmbeddingKind.POSITION_EMBEDDING_KIND_ROPE_DEFAULT,
248
+ default_rope_config=default_rope_config,
249
+ ),
250
+ )
170
251
 
171
252
  def _guess_model_spec(self, dynamic_seqlen: bool) -> Dict:
172
253
  input_dict = self.pipeline.preprocess("this is a prompt! Test test test?")
@@ -179,7 +260,7 @@ class HuggingFaceTextGenerationPipeline:
179
260
  input_tensor_semantics = []
180
261
 
181
262
  # This currently assumes that all HF language models have inputs that are [B, NUM_TOKENS].
182
- inputs.append(input_dict["input_ids"])
263
+ inputs.append(torch.tile(input_dict["input_ids"], [self.batch_size, 1]))
183
264
  input_tensor_semantics.append(
184
265
  TensorSemantics(
185
266
  dimensions=[
@@ -192,7 +273,7 @@ class HuggingFaceTextGenerationPipeline:
192
273
  # Assume that the model supports a KeyValue cache.
193
274
  cache_values = self._initialize_key_value_cache()
194
275
  inputs.append(cache_values)
195
- for _ in cache_values:
276
+ for _ in range(len(cache_values)):
196
277
  input_tensor_semantics.append(AttentionKeyValueCacheTensorSemantics())
197
278
  input_tensor_semantics.append(AttentionKeyValueCacheTensorSemantics())
198
279
 
@@ -209,7 +290,7 @@ class HuggingFaceTextGenerationPipeline:
209
290
  if (
210
291
  not found_logits
211
292
  and len(tensor.shape) == 3
212
- and tensor.shape[0] == 1
293
+ and tensor.shape[0] == self.batch_size
213
294
  and tensor.shape[1] == seqlen
214
295
  ):
215
296
  # This should be the logits tensor.
@@ -226,14 +307,38 @@ class HuggingFaceTextGenerationPipeline:
226
307
  else:
227
308
  output_tensor_semantics.append(AttentionKeyValueCacheTensorSemantics())
228
309
 
310
+ if not found_logits:
311
+ raise ValueError(
312
+ "could not determine output logits tensor for text generation model"
313
+ )
314
+
315
+ num_experts_per_tok = (
316
+ 1
317
+ if not hasattr(self.pipeline.model.config, "num_experts_per_tok")
318
+ else self.pipeline.model.config.num_experts_per_tok
319
+ )
320
+
229
321
  dynamic_shapes = None
230
- seqlen = torch.export.Dim("seqlen", min=2, max=9223372036854775096)
322
+ # Set range to half of seqlen to account for # of tokens per expert.
323
+ # pytorch export creates a constraint on the number of possible tokens
324
+ # sent to each expert. That value is num_experts * seqlen. If we don't divide
325
+ # by number of experts, the tracing creates an integer value that exceeds the valid int64
326
+ # range and will throw a hard to decipher error message.
327
+ seqlen = torch.export.Dim(
328
+ "seqlen", min=2, max=MAX_DYNAMIC_VAL // num_experts_per_tok
329
+ )
231
330
 
232
- cache_length = torch.export.Dim("cache_length", min=2, max=9223372036854775096)
331
+ cache_length = torch.export.Dim("cache_length", min=2, max=MAX_DYNAMIC_VAL)
233
332
  dynamic_shapes = [
234
333
  {1: seqlen},
235
- [[{2: cache_length}, {2: cache_length}] for _ in cache_values],
334
+ [[{2: cache_length}, {2: cache_length}] for _ in range(len(cache_values))],
236
335
  ]
336
+ if self.dynamic_batch:
337
+ batch = torch.export.Dim("batch")
338
+ dynamic_shapes[0][0] = batch
339
+ for i in range(len(cache_values)):
340
+ dynamic_shapes[1][i][0][0] = batch
341
+ dynamic_shapes[1][i][1][0] = batch
237
342
 
238
343
  return {
239
344
  "example_inputs": inputs,
@@ -241,6 +346,7 @@ class HuggingFaceTextGenerationPipeline:
241
346
  "input_tensor_semantics": input_tensor_semantics,
242
347
  "output_tensor_semantics": output_tensor_semantics,
243
348
  "generation_config": HuggingFaceGenerationConfig(self.pipeline.model),
349
+ "transformer_config": self._parse_transformer_config(self.pipeline.model),
244
350
  }
245
351
 
246
352
  def models(self) -> List[Model]:
@@ -334,6 +440,14 @@ class HuggingFaceImageProcessor:
334
440
  raise NotImplementedError(
335
441
  "only semantic segmentation is currently supported"
336
442
  )
443
+ # TODO(philkuz): Support panoptic segmentation models. Multiple outputs come from panoptic segmentation models.
444
+ # We need to decide whether we should invest in converting the panoptic segmentation output to semantic segmentation
445
+ # format or if we should directly support panoptic segmentation output.
446
+ if hasattr(self.processor, "post_process_panoptic_segmentation"):
447
+ raise NotImplementedError(
448
+ "panoptic segmentation models are not supported yet"
449
+ )
450
+
337
451
  dimensions = [
338
452
  BatchDimension(),
339
453
  # TODO(james): verify all semantic segmentation in hugging face output a logits mask.
@@ -687,8 +801,8 @@ class HuggingFaceZeroShotObjectDetectionPipeline:
687
801
 
688
802
  spec["dynamic_shapes"].extend(
689
803
  [
690
- {0: "num_labels"},
691
- {0: "num_labels"},
804
+ {0: torch.export.Dim("num_labels", max=MAX_DYNAMIC_VAL)},
805
+ {0: torch.export.Dim("num_labels", max=MAX_DYNAMIC_VAL)},
692
806
  ]
693
807
  )
694
808
 
@@ -754,33 +868,121 @@ class HuggingFaceDepthEstimationPipeline:
754
868
  return [self.model]
755
869
 
756
870
 
757
- def import_huggingface_pipeline(pipeline: Pipeline, **kwargs) -> List[Model]:
758
- if pipeline.framework != "pt":
759
- raise ValueError(
760
- "unimplemented: hugging face pipeline framework: {}".format(
761
- pipeline.framework
871
+ class HuggingFaceFeatureExtractionPipeline:
872
+ def __init__(self, pipeline: Pipeline, name: Optional[str] = None):
873
+ self.pipeline = pipeline
874
+ if name is None:
875
+ name = pipeline.model.name_or_path
876
+
877
+ self.tokenizer_model = HuggingFaceTokenizer(self.pipeline.tokenizer)
878
+
879
+ self.model = TorchModel(
880
+ name=name,
881
+ torch_module=self.pipeline.model,
882
+ **self._guess_model_spec(),
883
+ )
884
+
885
+ def _guess_model_spec(self) -> Dict:
886
+ spec = {
887
+ "example_inputs": [],
888
+ "input_tensor_semantics": [],
889
+ "output_tensor_semantics": [],
890
+ "dynamic_shapes": [],
891
+ }
892
+
893
+ input_dict = self.pipeline.preprocess("this is a prompt! Test test test?")
894
+ if "input_ids" not in input_dict:
895
+ raise ValueError(
896
+ 'HuggingFaceFeatureExtractionPipeline expects preprocessed inputs to have an "input_ids" tensor'
762
897
  )
898
+
899
+ spec["example_inputs"].append(input_dict["input_ids"])
900
+ spec["input_tensor_semantics"].extend(
901
+ [
902
+ TensorSemantics(
903
+ dimensions=[
904
+ BatchDimension(),
905
+ TokensDimension(),
906
+ ]
907
+ ),
908
+ ]
763
909
  )
764
910
 
765
- if pipeline.task == "text-generation":
766
- return HuggingFaceTextGenerationPipeline(pipeline, **kwargs).models()
767
- elif pipeline.task == "image-segmentation":
768
- return HuggingFaceImageSegmentationPipeline(pipeline, **kwargs).models()
769
- elif pipeline.task == "object-detection":
770
- return HuggingFaceObjectDetectionPipeline(pipeline, **kwargs).models()
771
- elif pipeline.task == "zero-shot-object-detection":
772
- return HuggingFaceZeroShotObjectDetectionPipeline(pipeline, **kwargs).models()
773
- elif pipeline.task == "depth-estimation":
774
- return HuggingFaceDepthEstimationPipeline(pipeline, **kwargs).models()
775
- raise ValueError(
776
- "unimplemented: hugging face pipeline task: {} (supported tasks: [{}])".format(
777
- pipeline.task,
911
+ spec["output_tensor_semantics"].extend(
778
912
  [
779
- "text-generation",
780
- "image-segmentation",
781
- "object-detection",
782
- "zero-shot-object-detection",
783
- "depth-estimation",
784
- ],
785
- )
786
- )
913
+ TensorSemantics(
914
+ dimensions=[
915
+ BatchDimension(),
916
+ TokensDimension(),
917
+ EmbeddingDimension(),
918
+ ],
919
+ ),
920
+ TensorSemantics(
921
+ dimensions=[
922
+ BatchDimension(),
923
+ EmbeddingDimension(),
924
+ ],
925
+ ),
926
+ ]
927
+ )
928
+
929
+ max_seqlen = (
930
+ getattr(self.pipeline.model.config, "max_position_embeddings", 500) - 1
931
+ )
932
+ spec["dynamic_shapes"].extend(
933
+ [
934
+ {
935
+ 1: torch.export.Dim(
936
+ "seqlen",
937
+ max=max_seqlen,
938
+ )
939
+ },
940
+ ]
941
+ )
942
+ return spec
943
+
944
+ def models(self) -> List[Model]:
945
+ return [self.model, self.tokenizer_model]
946
+
947
+
948
+ def import_huggingface_pipeline(pipeline: Pipeline, **kwargs) -> List[Model]:
949
+ with console.status(
950
+ f'Importing HuggingFace pipeline: "{pipeline.model.name_or_path}"'
951
+ ):
952
+ if pipeline.framework != "pt":
953
+ raise ValueError(
954
+ "unimplemented: hugging face pipeline framework: {}".format(
955
+ pipeline.framework
956
+ )
957
+ )
958
+
959
+ if pipeline.task == "text-generation":
960
+ result = HuggingFaceTextGenerationPipeline(pipeline, **kwargs).models()
961
+ elif pipeline.task == "image-segmentation":
962
+ result = HuggingFaceImageSegmentationPipeline(pipeline, **kwargs).models()
963
+ elif pipeline.task == "object-detection":
964
+ result = HuggingFaceObjectDetectionPipeline(pipeline, **kwargs).models()
965
+ elif pipeline.task == "zero-shot-object-detection":
966
+ result = HuggingFaceZeroShotObjectDetectionPipeline(
967
+ pipeline, **kwargs
968
+ ).models()
969
+ elif pipeline.task == "depth-estimation":
970
+ result = HuggingFaceDepthEstimationPipeline(pipeline, **kwargs).models()
971
+ elif pipeline.task == "feature-extraction":
972
+ result = HuggingFaceFeatureExtractionPipeline(pipeline, **kwargs).models()
973
+ else:
974
+ raise ValueError(
975
+ "unimplemented: hugging face pipeline task: {} (supported tasks: [{}])".format(
976
+ pipeline.task,
977
+ [
978
+ "text-generation",
979
+ "image-segmentation",
980
+ "object-detection",
981
+ "zero-shot-object-detection",
982
+ "depth-estimation",
983
+ "feature-extraction",
984
+ ],
985
+ )
986
+ )
987
+ console.print(f'Imported HuggingFace pipeline: "{pipeline.model.name_or_path}".')
988
+ return result
gml/model.py CHANGED
@@ -18,6 +18,7 @@ from __future__ import annotations
18
18
  import abc
19
19
  import contextlib
20
20
  import io
21
+ from importlib.metadata import distributions
21
22
  from pathlib import Path
22
23
  from typing import BinaryIO, Dict, List, Literal, Optional, Sequence, TextIO, Tuple
23
24
 
@@ -29,6 +30,33 @@ from gml.compile import to_torch_mlir
29
30
  from gml.preprocessing import ImagePreprocessingStep
30
31
  from gml.tensor import TensorSemantics
31
32
 
33
+ MODELING_PACKAGES = set(
34
+ {
35
+ "accelerate",
36
+ "compressed-tensors",
37
+ "gimlet-api",
38
+ "mlir-gml",
39
+ "safetensors",
40
+ "safetensors-mlir",
41
+ "torch",
42
+ "torch-mlir-gml",
43
+ "torchvision",
44
+ "transformers",
45
+ "tokenizers",
46
+ }
47
+ )
48
+
49
+
50
+ def get_modeling_packages() -> Sequence[modelexecpb.PackageInfo]:
51
+ packages = []
52
+ for dist in filter(
53
+ lambda d: d.metadata["Name"] in MODELING_PACKAGES, distributions()
54
+ ):
55
+ packages.append(
56
+ modelexecpb.PackageInfo(name=dist.metadata["Name"], version=dist.version)
57
+ )
58
+ return sorted(packages, key=lambda p: p.name)
59
+
32
60
 
33
61
  class GenerationConfig:
34
62
  def __init__(self, eos_token_ids: List[int]):
@@ -52,6 +80,7 @@ class Model(abc.ABC):
52
80
  class_labels_file: Optional[Path] = None,
53
81
  image_preprocessing_steps: Optional[List[ImagePreprocessingStep]] = None,
54
82
  generation_config: Optional[GenerationConfig] = None,
83
+ transformer_config: Optional[modelexecpb.TransformerConfig] = None,
55
84
  ):
56
85
  self.name = name
57
86
  self.kind = kind
@@ -66,6 +95,7 @@ class Model(abc.ABC):
66
95
  self.output_tensor_semantics = output_tensor_semantics
67
96
  self.image_preprocessing_steps = image_preprocessing_steps
68
97
  self.generation_config = generation_config
98
+ self.transformer_config = transformer_config
69
99
 
70
100
  def to_proto(self) -> modelexecpb.ModelInfo:
71
101
  image_preprocessing_steps = None
@@ -89,6 +119,10 @@ class Model(abc.ABC):
89
119
  semantics.to_proto() for semantics in self.output_tensor_semantics
90
120
  ],
91
121
  generation_config=generation_config,
122
+ tracing_metadata=modelexecpb.TracingMetadata(
123
+ package_info=get_modeling_packages(),
124
+ ),
125
+ transformer_config=self.transformer_config,
92
126
  )
93
127
 
94
128
  @abc.abstractmethod
@@ -111,6 +145,7 @@ class TorchModel(Model):
111
145
  input_shapes: Optional[List[List[int]]] = None,
112
146
  input_dtypes: Optional[List[torch.dtype]] = None,
113
147
  dynamic_shapes: Optional[Sequence[Dict[int, str | "torch.export.Dim"]]] = None,
148
+ export_predispatch: bool = False,
114
149
  **kwargs,
115
150
  ):
116
151
  super().__init__(
@@ -133,6 +168,7 @@ class TorchModel(Model):
133
168
  torch.rand(shape, dtype=dtype)
134
169
  for shape, dtype in zip(self.input_shapes, self.input_dtypes)
135
170
  ]
171
+ self.export_predispatch = export_predispatch
136
172
 
137
173
  def _convert_to_torch_mlir(self, weight_manager: Optional[AssetManager] = None):
138
174
  return to_torch_mlir(
@@ -140,6 +176,7 @@ class TorchModel(Model):
140
176
  self.example_inputs,
141
177
  self.dynamic_shapes,
142
178
  weight_manager=weight_manager,
179
+ export_predispatch=self.export_predispatch,
143
180
  )
144
181
 
145
182
  def _collect_assets(
gml/pipelines.py CHANGED
@@ -173,6 +173,16 @@ nodes:
173
173
 
174
174
 
175
175
  class LiveChatPipeline(Pipeline):
176
+ def __init__(
177
+ self,
178
+ message_template: str = r'''"{% for message in messages %}{% if message['role'] == 'system' %}{{message['content']}}{% else %}{{'<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n' }}{% endif %}"''',
179
+ preset_system_prompt: str = r"'<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|>'",
180
+ add_generation_prompt: bool = True,
181
+ ):
182
+ self.add_generation_prompt = add_generation_prompt
183
+ self.message_template = message_template
184
+ self.preset_system_prompt = preset_system_prompt
185
+
176
186
  def to_yaml(self, models: List[Model], org_name: str) -> str:
177
187
  if len(models) != 2:
178
188
  raise ValueError(
@@ -195,6 +205,16 @@ nodes:
195
205
  kind: TextStreamSource
196
206
  outputs:
197
207
  - prompt
208
+ - name: query_template
209
+ kind: TemplateChatMessage
210
+ attributes:
211
+ add_generation_prompt: {self.add_generation_prompt}
212
+ message_template: {self.message_template}
213
+ preset_system_prompt: {self.preset_system_prompt}
214
+ inputs:
215
+ query: .text_source.prompt
216
+ outputs:
217
+ - chat_message
198
218
  - name: tokenize
199
219
  kind: Tokenize
200
220
  attributes:
@@ -203,7 +223,7 @@ nodes:
203
223
  name: {tokenizer.name}
204
224
  org: {org_name}
205
225
  inputs:
206
- text: .text_source.prompt
226
+ text: .query_template.chat_message
207
227
  outputs:
208
228
  - tokens
209
229
  - name: generate
gml/preprocessing.py CHANGED
@@ -28,7 +28,21 @@ class ImagePreprocessingStep(abc.ABC):
28
28
  pass
29
29
 
30
30
 
31
+ class ResizeImage(ImagePreprocessingStep):
32
+ """ResizeImage resizes the image to the target size without preserving aspect ratio."""
33
+
34
+ def to_proto(self) -> modelexecpb.ImagePreprocessingStep:
35
+ return modelexecpb.ImagePreprocessingStep(
36
+ kind=modelexecpb.ImagePreprocessingStep.IMAGE_PREPROCESSING_KIND_RESIZE,
37
+ resize_params=modelexecpb.ImagePreprocessingStep.ImageResizeParams(
38
+ kind=modelexecpb.ImagePreprocessingStep.ImageResizeParams.IMAGE_RESIZE_KIND_STRETCH,
39
+ ),
40
+ )
41
+
42
+
31
43
  class LetterboxImage(ImagePreprocessingStep):
44
+ """LetterboxImage resizes the image to the target size while preserving aspect ratio by introducting letterbox padding."""
45
+
32
46
  def to_proto(self) -> modelexecpb.ImagePreprocessingStep:
33
47
  return modelexecpb.ImagePreprocessingStep(
34
48
  kind=modelexecpb.ImagePreprocessingStep.IMAGE_PREPROCESSING_KIND_RESIZE,
@@ -38,14 +52,14 @@ class LetterboxImage(ImagePreprocessingStep):
38
52
  )
39
53
 
40
54
 
41
- class ResizeImage(ImagePreprocessingStep):
42
- """ResizeImage resizes the image to the target size without preserving aspect ratio."""
55
+ class CenterCropImage(ImagePreprocessingStep):
56
+ """CenterCropImage resizes the image to the target size while preserving aspect ratio by center cropping along one dimension."""
43
57
 
44
58
  def to_proto(self) -> modelexecpb.ImagePreprocessingStep:
45
59
  return modelexecpb.ImagePreprocessingStep(
46
60
  kind=modelexecpb.ImagePreprocessingStep.IMAGE_PREPROCESSING_KIND_RESIZE,
47
61
  resize_params=modelexecpb.ImagePreprocessingStep.ImageResizeParams(
48
- kind=modelexecpb.ImagePreprocessingStep.ImageResizeParams.IMAGE_RESIZE_KIND_STRETCH,
62
+ kind=modelexecpb.ImagePreprocessingStep.ImageResizeParams.IMAGE_RESIZE_KIND_CENTERCROP,
49
63
  ),
50
64
  )
51
65