gimlet-api 0.0.0.dev0__py3-none-any.whl → 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -11,6 +11,6 @@ Requires-Dist: protobuf
11
11
  Requires-Dist: grpcio
12
12
  Requires-Dist: torch
13
13
  Requires-Dist: torch_mlir_gml
14
- Version: 0.0.0.dev0
14
+ Version: 0.0.2
15
15
 
16
16
  UNKNOWN
@@ -35,7 +35,7 @@ gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2.py,sha256=2yD8ZKS5KM
35
35
  gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2_grpc.py,sha256=q1PugN3Jm_4v5hVWADJLCIkIEC2_beKEqEH4vb_SpH8,7396
36
36
  gml/proto/src/controlplane/model/mpb/v1/mpb_pb2.py,sha256=RVedXkNYu2iF5OHiXoYyRw9AGRCUWG7qNyY-5QY71Go,3762
37
37
  gml/proto/src/controlplane/model/mpb/v1/mpb_pb2_grpc.py,sha256=KSdb6V04qUHDsb1R2o3wixwTyZgrhwnPYobjnRgWX4I,4735
38
- gml/tensor.py,sha256=URYDHDwcbgxdGXhhvKrKelEU_x5UvPdKonWBm8YHEwY,6119
39
- gimlet_api-0.0.0.dev0.dist-info/WHEEL,sha256=sobxWSyDDkdg_rinUth-jxhXHqoNqlmNMJY3aTZn2Us,91
40
- gimlet_api-0.0.0.dev0.dist-info/METADATA,sha256=n4OP-wBQzUZiPqXhY9UUW7t1WwZZaRTlClgcrweImvE,434
41
- gimlet_api-0.0.0.dev0.dist-info/RECORD,,
38
+ gml/tensor.py,sha256=veEDZGWRCJGa16gAabuCZwSS3jLXDXBk4xTH-v5C-Dw,7170
39
+ gimlet_api-0.0.2.dist-info/WHEEL,sha256=sobxWSyDDkdg_rinUth-jxhXHqoNqlmNMJY3aTZn2Us,91
40
+ gimlet_api-0.0.2.dist-info/METADATA,sha256=-XADMEHG4d3Gk1up3KH_9NhGK_2qtZG3dbC6NEONG0Q,429
41
+ gimlet_api-0.0.2.dist-info/RECORD,,
gml/tensor.py CHANGED
@@ -14,7 +14,7 @@
14
14
  # SPDX-License-Identifier: Proprietary
15
15
 
16
16
  import abc
17
- from typing import List, Optional, Tuple
17
+ from typing import List, Literal, Optional, Tuple
18
18
 
19
19
  import gml.proto.src.api.corepb.v1.model_exec_pb2 as modelexecpb
20
20
 
@@ -156,6 +156,33 @@ class DetectionOutputDimension(DimensionSemantics):
156
156
  )
157
157
 
158
158
 
159
+ def _segmentation_mask_kind_to_proto(kind: str):
160
+ match kind.lower():
161
+ case "bool_masks":
162
+ return (
163
+ modelexecpb.DimensionSemantics.SegmentationMaskParams.SEGMENTATION_MASK_KIND_BOOL
164
+ )
165
+ case "int_label_masks":
166
+ return (
167
+ modelexecpb.DimensionSemantics.SegmentationMaskParams.SEGMENTATION_MASK_KIND_CLASS_LABEL
168
+ )
169
+ case _:
170
+ raise ValueError("Invalid segmentation mask kind: {}".format(kind))
171
+
172
+
173
+ class SegmentationMaskChannel(DimensionSemantics):
174
+ def __init__(self, kind: Literal["bool_masks", "int_label_masks"]):
175
+ self.kind = _segmentation_mask_kind_to_proto(kind)
176
+
177
+ def to_proto(self) -> modelexecpb.DimensionSemantics:
178
+ return modelexecpb.DimensionSemantics(
179
+ kind=modelexecpb.DimensionSemantics.DIMENSION_SEMANTICS_KIND_SEGMENTATION_MASK_CHANNEL,
180
+ segmentation_mask_params=modelexecpb.DimensionSemantics.SegmentationMaskParams(
181
+ kind=self.kind,
182
+ ),
183
+ )
184
+
185
+
159
186
  class TensorSemantics:
160
187
  def __init__(self, dimensions: List[DimensionSemantics]):
161
188
  self.dimensions = dimensions