truefoundry 0.5.0rc5__py3-none-any.whl → 0.5.0rc6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truefoundry might be problematic. Click here for more details.

Files changed (52) hide show
  1. truefoundry/ml/autogen/client/__init__.py +0 -4
  2. truefoundry/ml/autogen/client/api/deprecated_api.py +340 -7
  3. truefoundry/ml/autogen/client/api/mlfoundry_artifacts_api.py +0 -322
  4. truefoundry/ml/autogen/client/api_client.py +8 -1
  5. truefoundry/ml/autogen/client/models/__init__.py +0 -4
  6. truefoundry/ml/autogen/client/models/add_features_to_model_version_request_dto.py +3 -17
  7. truefoundry/ml/autogen/client/models/agent.py +1 -1
  8. truefoundry/ml/autogen/client/models/agent_app.py +1 -1
  9. truefoundry/ml/autogen/client/models/agent_open_api_tool.py +1 -1
  10. truefoundry/ml/autogen/client/models/agent_open_api_tool_with_fqn.py +1 -1
  11. truefoundry/ml/autogen/client/models/agent_with_fqn.py +1 -1
  12. truefoundry/ml/autogen/client/models/artifact_version_manifest.py +1 -1
  13. truefoundry/ml/autogen/client/models/assistant_message.py +1 -1
  14. truefoundry/ml/autogen/client/models/blob_storage_reference.py +1 -1
  15. truefoundry/ml/autogen/client/models/chat_prompt.py +1 -1
  16. truefoundry/ml/autogen/client/models/external_artifact_source.py +1 -1
  17. truefoundry/ml/autogen/client/models/fast_ai_framework.py +1 -1
  18. truefoundry/ml/autogen/client/models/gluon_framework.py +1 -1
  19. truefoundry/ml/autogen/client/models/h2_o_framework.py +1 -1
  20. truefoundry/ml/autogen/client/models/image_content_part.py +1 -1
  21. truefoundry/ml/autogen/client/models/keras_framework.py +1 -1
  22. truefoundry/ml/autogen/client/models/light_gbm_framework.py +1 -1
  23. truefoundry/ml/autogen/client/models/model_version_dto.py +7 -8
  24. truefoundry/ml/autogen/client/models/model_version_manifest.py +1 -1
  25. truefoundry/ml/autogen/client/models/onnx_framework.py +1 -1
  26. truefoundry/ml/autogen/client/models/paddle_framework.py +1 -1
  27. truefoundry/ml/autogen/client/models/py_torch_framework.py +1 -1
  28. truefoundry/ml/autogen/client/models/sklearn_framework.py +1 -1
  29. truefoundry/ml/autogen/client/models/spa_cy_framework.py +1 -1
  30. truefoundry/ml/autogen/client/models/stats_models_framework.py +1 -1
  31. truefoundry/ml/autogen/client/models/system_message.py +1 -1
  32. truefoundry/ml/autogen/client/models/tensor_flow_framework.py +1 -1
  33. truefoundry/ml/autogen/client/models/text_content_part.py +1 -1
  34. truefoundry/ml/autogen/client/models/transformers_framework.py +1 -1
  35. truefoundry/ml/autogen/client/models/true_foundry_artifact_source.py +1 -1
  36. truefoundry/ml/autogen/client/models/update_model_version_request_dto.py +1 -13
  37. truefoundry/ml/autogen/client/models/user_message.py +1 -1
  38. truefoundry/ml/autogen/client/models/xg_boost_framework.py +1 -1
  39. truefoundry/ml/autogen/client_README.md +4 -8
  40. truefoundry/ml/autogen/models/__init__.py +4 -0
  41. truefoundry/ml/autogen/models/exceptions.py +30 -0
  42. truefoundry/ml/autogen/models/schema.py +1547 -0
  43. truefoundry/ml/autogen/models/signature.py +139 -0
  44. truefoundry/ml/autogen/models/utils.py +699 -0
  45. {truefoundry-0.5.0rc5.dist-info → truefoundry-0.5.0rc6.dist-info}/METADATA +1 -1
  46. {truefoundry-0.5.0rc5.dist-info → truefoundry-0.5.0rc6.dist-info}/RECORD +48 -47
  47. truefoundry/ml/autogen/client/models/feature_dto.py +0 -68
  48. truefoundry/ml/autogen/client/models/feature_value_type.py +0 -35
  49. truefoundry/ml/autogen/client/models/model_schema_dto.py +0 -85
  50. truefoundry/ml/autogen/client/models/prediction_type.py +0 -34
  51. {truefoundry-0.5.0rc5.dist-info → truefoundry-0.5.0rc6.dist-info}/WHEEL +0 -0
  52. {truefoundry-0.5.0rc5.dist-info → truefoundry-0.5.0rc6.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,1547 @@
1
+ import builtins
2
+ import datetime as dt
3
+ import importlib.util
4
+ import json
5
+ import string
6
+ from copy import deepcopy
7
+ from dataclasses import is_dataclass
8
+ from enum import Enum
9
+ from typing import (
10
+ Any,
11
+ Dict,
12
+ List,
13
+ Optional,
14
+ Tuple,
15
+ TypedDict,
16
+ Union,
17
+ get_args,
18
+ get_origin,
19
+ )
20
+
21
+ import numpy as np
22
+
23
+ from .exceptions import MlflowException
24
+
25
+ ARRAY_TYPE = "array"
26
+ OBJECT_TYPE = "object"
27
+ MAP_TYPE = "map"
28
+
29
+
30
+ class DataType(Enum):
31
+ """
32
+ MLflow data types.
33
+ """
34
+
35
+ def __new__(cls, value, numpy_type, spark_type, pandas_type=None, python_type=None):
36
+ res = object.__new__(cls)
37
+ res._value_ = value
38
+ res._numpy_type = numpy_type
39
+ res._spark_type = spark_type
40
+ res._pandas_type = pandas_type if pandas_type is not None else numpy_type
41
+ res._python_type = python_type if python_type is not None else numpy_type
42
+ return res
43
+
44
+ # NB: We only use pandas extension type for strings. There are also pandas extension types for
45
+ # integers and boolean values. We do not use them here for now as most downstream tools are
46
+ # most likely to use / expect native numpy types and would not be compatible with the extension
47
+ # types.
48
+ boolean = (1, np.dtype("bool"), "BooleanType", np.dtype("bool"), bool)
49
+ """Logical data (True, False) ."""
50
+ integer = (2, np.dtype("int32"), "IntegerType", np.dtype("int32"), int)
51
+ """32b signed integer numbers."""
52
+ long = (3, np.dtype("int64"), "LongType", np.dtype("int64"), int)
53
+ """64b signed integer numbers. """
54
+ float = (4, np.dtype("float32"), "FloatType", np.dtype("float32"), builtins.float)
55
+ """32b floating point numbers. """
56
+ double = (5, np.dtype("float64"), "DoubleType", np.dtype("float64"), builtins.float)
57
+ """64b floating point numbers. """
58
+ string = (6, np.dtype("str"), "StringType", object, str)
59
+ """Text data."""
60
+ binary = (7, np.dtype("bytes"), "BinaryType", object, bytes)
61
+ """Sequence of raw bytes."""
62
+ datetime = (
63
+ 8,
64
+ np.dtype("datetime64[ns]"),
65
+ "TimestampType",
66
+ np.dtype("datetime64[ns]"),
67
+ dt.date,
68
+ )
69
+ """64b datetime data."""
70
+
71
+ def __repr__(self):
72
+ return self.name
73
+
74
+ def to_numpy(self) -> np.dtype:
75
+ """Get equivalent numpy data type."""
76
+ return self._numpy_type
77
+
78
+ def to_pandas(self) -> np.dtype:
79
+ """Get equivalent pandas data type."""
80
+ return self._pandas_type
81
+
82
+ def to_python(self):
83
+ """Get equivalent python data type."""
84
+ return self._python_type
85
+
86
+ @classmethod
87
+ def is_boolean(cls, value):
88
+ return type(value) in DataType.boolean.get_all_types()
89
+
90
+ @classmethod
91
+ def is_integer(cls, value):
92
+ return type(value) in DataType.integer.get_all_types()
93
+
94
+ @classmethod
95
+ def is_long(cls, value):
96
+ return type(value) in DataType.long.get_all_types()
97
+
98
+ @classmethod
99
+ def is_float(cls, value):
100
+ return type(value) in DataType.float.get_all_types()
101
+
102
+ @classmethod
103
+ def is_double(cls, value):
104
+ return type(value) in DataType.double.get_all_types()
105
+
106
+ @classmethod
107
+ def is_string(cls, value):
108
+ return type(value) in DataType.string.get_all_types()
109
+
110
+ @classmethod
111
+ def is_binary(cls, value):
112
+ return type(value) in DataType.binary.get_all_types()
113
+
114
+ @classmethod
115
+ def is_datetime(cls, value):
116
+ return type(value) in DataType.datetime.get_all_types()
117
+
118
+ def get_all_types(self):
119
+ types = [self.to_numpy(), self.to_pandas(), self.to_python()]
120
+ if importlib.util.find_spec("pyspark") is not None:
121
+ types.append(self.to_spark())
122
+ if self.name == "datetime":
123
+ types.extend([np.datetime64, dt.datetime])
124
+ if self.name == "binary":
125
+ # This is to support identifying bytearrays as binary data
126
+ # for pandas DataFrame schema inference
127
+ types.extend([bytearray])
128
+ return types
129
+
130
+ @classmethod
131
+ def get_spark_types(cls):
132
+ return [dt.to_spark() for dt in cls._member_map_.values()]
133
+
134
+ @classmethod
135
+ def from_numpy_type(cls, np_type):
136
+ return next(
137
+ (v for v in cls._member_map_.values() if v.to_numpy() == np_type), None
138
+ )
139
+
140
+
141
+ class Property:
142
+ """
143
+ Specification used to represent a json-convertible object property.
144
+ """
145
+
146
+ def __init__(
147
+ self,
148
+ name: str,
149
+ dtype: Union[DataType, "Array", "Object", "Map", str],
150
+ required: bool = True,
151
+ ) -> None:
152
+ """
153
+ Args:
154
+ name: The name of the property
155
+ dtype: The data type of the property
156
+ required: Whether this property is required
157
+ """
158
+ if not isinstance(name, str):
159
+ raise MlflowException.invalid_parameter_value(
160
+ f"Expected name to be a string, got type {type(name).__name__}"
161
+ )
162
+ self._name = name
163
+ try:
164
+ self._dtype = DataType[dtype] if isinstance(dtype, str) else dtype
165
+ except KeyError:
166
+ raise MlflowException(
167
+ f"Unsupported type '{dtype}', expected instance of DataType, Array, Object, Map or "
168
+ f"one of {[t.name for t in DataType]}"
169
+ ) from None
170
+ if not isinstance(self.dtype, (DataType, Array, Object, Map)):
171
+ raise MlflowException(
172
+ "Expected mlflow.types.schema.Datatype, mlflow.types.schema.Array, "
173
+ "mlflow.types.schema.Object, mlflow.types.schema.Map or str for the 'dtype' "
174
+ f"argument, but got {self.dtype.__class__}"
175
+ ) from None
176
+ self._required = required
177
+
178
+ @property
179
+ def name(self) -> str:
180
+ """The property name."""
181
+ return self._name
182
+
183
+ @property
184
+ def dtype(self) -> Union[DataType, "Array", "Object", "Map"]:
185
+ """The property data type."""
186
+ return self._dtype
187
+
188
+ @property
189
+ def required(self) -> bool:
190
+ """Whether this property is required"""
191
+ return self._required
192
+
193
+ @required.setter
194
+ def required(self, value: bool) -> None:
195
+ self._required = value
196
+
197
+ def __eq__(self, other) -> bool:
198
+ if isinstance(other, Property):
199
+ return (
200
+ self.name == other.name
201
+ and self.dtype == other.dtype
202
+ and self.required == other.required
203
+ )
204
+ return False
205
+
206
+ def __lt__(self, other) -> bool:
207
+ return self.name < other.name
208
+
209
+ def __repr__(self) -> str:
210
+ required = "required" if self.required else "optional"
211
+ return f"{self.name}: {self.dtype!r} ({required})"
212
+
213
+ def to_dict(self):
214
+ d = (
215
+ {"type": self.dtype.name}
216
+ if isinstance(self.dtype, DataType)
217
+ else self.dtype.to_dict()
218
+ )
219
+ d["required"] = self.required
220
+ return {self.name: d}
221
+
222
+ @classmethod
223
+ def from_json_dict(cls, **kwargs):
224
+ """
225
+ Deserialize from a json loaded dictionary.
226
+ The dictionary is expected to contain only one key as `name`, and
227
+ the value should be a dictionary containing `type` and
228
+ optional `required` keys.
229
+ Example: {"property_name": {"type": "string", "required": True}}
230
+ """
231
+ if len(kwargs) != 1:
232
+ raise MlflowException(
233
+ f"Expected Property JSON to contain a single key as name, got {len(kwargs)} keys."
234
+ )
235
+ name, dic = kwargs.popitem()
236
+ if not {"type"} <= set(dic.keys()):
237
+ raise MlflowException(
238
+ f"Missing keys in Property `{name}`. Expected to find key `type`"
239
+ )
240
+ required = dic.pop("required", True)
241
+ dtype = dic["type"]
242
+ if dtype == ARRAY_TYPE:
243
+ return cls(name=name, dtype=Array.from_json_dict(**dic), required=required)
244
+ if dtype == OBJECT_TYPE:
245
+ return cls(name=name, dtype=Object.from_json_dict(**dic), required=required)
246
+ if dtype == MAP_TYPE:
247
+ return cls(name=name, dtype=Map.from_json_dict(**dic), required=required)
248
+ return cls(name=name, dtype=dtype, required=required)
249
+
250
+ def _merge(self, prop: "Property") -> "Property":
251
+ """
252
+ Check if current property is compatible with another property and return
253
+ the updated property.
254
+ When two properties have the same name, we need to check if their dtypes
255
+ are compatible or not.
256
+ An example of two compatible properties:
257
+
258
+ .. code-block:: python
259
+
260
+ prop1 = Property(
261
+ name="a",
262
+ dtype=Object(
263
+ properties=[Property(name="a", dtype=DataType.string, required=False)]
264
+ ),
265
+ )
266
+ prop2 = Property(
267
+ name="a",
268
+ dtype=Object(
269
+ properties=[
270
+ Property(name="a", dtype=DataType.string),
271
+ Property(name="b", dtype=DataType.double),
272
+ ]
273
+ ),
274
+ )
275
+ merged_prop = prop1._merge(prop2)
276
+ assert merged_prop == Property(
277
+ name="a",
278
+ dtype=Object(
279
+ properties=[
280
+ Property(name="a", dtype=DataType.string, required=False),
281
+ Property(name="b", dtype=DataType.double, required=False),
282
+ ]
283
+ ),
284
+ )
285
+
286
+ """
287
+ if not isinstance(prop, Property):
288
+ raise MlflowException(
289
+ f"Can't merge property with non-property type: {type(prop).__name__}"
290
+ )
291
+ if self.name != prop.name:
292
+ raise MlflowException("Can't merge properties with different names")
293
+ required = self.required and prop.required
294
+ if isinstance(self.dtype, DataType) and isinstance(prop.dtype, DataType):
295
+ if self.dtype == prop.dtype:
296
+ return Property(name=self.name, dtype=self.dtype, required=required)
297
+ raise MlflowException(
298
+ f"Properties are incompatible for {self.dtype} and {prop.dtype}"
299
+ )
300
+
301
+ if (
302
+ isinstance(self.dtype, (Array, Object, Map))
303
+ and self.dtype.__class__ is prop.dtype.__class__
304
+ ):
305
+ obj = self.dtype._merge(prop.dtype)
306
+ return Property(name=self.name, dtype=obj, required=required)
307
+
308
+ raise MlflowException("Properties are incompatible")
309
+
310
+
311
+ class Object:
312
+ """
313
+ Specification used to represent a json-convertible object.
314
+ """
315
+
316
+ def __init__(self, properties: List[Property]) -> None:
317
+ self._check_properties(properties)
318
+ # Sort by name to make sure the order is stable
319
+ self._properties = sorted(properties)
320
+
321
+ def _check_properties(self, properties):
322
+ if not isinstance(properties, list):
323
+ raise MlflowException.invalid_parameter_value(
324
+ f"Expected properties to be a list, got type {type(properties).__name__}"
325
+ )
326
+ if len(properties) == 0:
327
+ raise MlflowException.invalid_parameter_value(
328
+ "Creating Object with empty properties is not allowed."
329
+ )
330
+ if any(not isinstance(v, Property) for v in properties):
331
+ raise MlflowException.invalid_parameter_value(
332
+ "Expected values to be instance of Property"
333
+ )
334
+ # check duplicated property names
335
+ names = [prop.name for prop in properties]
336
+ duplicates = {name for name in names if names.count(name) > 1}
337
+ if len(duplicates) > 0:
338
+ raise MlflowException.invalid_parameter_value(
339
+ f"Found duplicated property names: {duplicates}"
340
+ )
341
+
342
+ @property
343
+ def properties(self) -> List[Property]:
344
+ """The list of object properties"""
345
+ return self._properties
346
+
347
+ @properties.setter
348
+ def properties(self, value: List[Property]) -> None:
349
+ self._check_properties(value)
350
+ self._properties = sorted(value)
351
+
352
+ def __eq__(self, other) -> bool:
353
+ if isinstance(other, Object):
354
+ return self.properties == other.properties
355
+ return False
356
+
357
+ def __repr__(self) -> str:
358
+ joined = ", ".join(map(repr, self.properties))
359
+ return "{" + joined + "}"
360
+
361
+ def to_dict(self):
362
+ properties = {
363
+ name: value
364
+ for prop in self.properties
365
+ for name, value in prop.to_dict().items()
366
+ }
367
+ return {
368
+ "type": OBJECT_TYPE,
369
+ "properties": properties,
370
+ }
371
+
372
+ @classmethod
373
+ def from_json_dict(cls, **kwargs):
374
+ """
375
+ Deserialize from a json loaded dictionary.
376
+ The dictionary is expected to contain `type` and
377
+ `properties` keys.
378
+ Example: {"type": "object", "properties": {"property_name": {"type": "string"}}}
379
+ """
380
+ if not {"properties", "type"} <= set(kwargs.keys()):
381
+ raise MlflowException(
382
+ "Missing keys in Object JSON. Expected to find keys `properties` and `type`"
383
+ )
384
+ if kwargs["type"] != OBJECT_TYPE:
385
+ raise MlflowException("Type mismatch, Object expects `object` as the type")
386
+ if not isinstance(kwargs["properties"], dict) or any(
387
+ not isinstance(prop, dict) for prop in kwargs["properties"].values()
388
+ ):
389
+ raise MlflowException(
390
+ "Expected properties to be a dictionary of Property JSON"
391
+ )
392
+ return cls(
393
+ [
394
+ Property.from_json_dict(**{name: prop})
395
+ for name, prop in kwargs["properties"].items()
396
+ ]
397
+ )
398
+
399
+ def _merge(self, obj: "Object") -> "Object":
400
+ """
401
+ Check if the current object is compatible with another object and return
402
+ the updated object.
403
+ When we infer the signature from a list of objects, it is possible
404
+ that one object has more properties than the other. In this case,
405
+ we should mark those optional properties as required=False.
406
+ For properties with the same name, we should check the compatibility
407
+ of two properties and update.
408
+ An example of two compatible objects:
409
+
410
+ .. code-block:: python
411
+
412
+ obj1 = Object(
413
+ properties=[
414
+ Property(name="a", dtype=DataType.string),
415
+ Property(name="b", dtype=DataType.double),
416
+ ]
417
+ )
418
+ obj2 = Object(
419
+ properties=[
420
+ Property(name="a", dtype=DataType.string),
421
+ Property(name="c", dtype=DataType.boolean),
422
+ ]
423
+ )
424
+ updated_obj = obj1._merge(obj2)
425
+ assert updated_obj == Object(
426
+ properties=[
427
+ Property(name="a", dtype=DataType.string),
428
+ Property(name="b", dtype=DataType.double, required=False),
429
+ Property(name="c", dtype=DataType.boolean, required=False),
430
+ ]
431
+ )
432
+
433
+ """
434
+ if not isinstance(obj, Object):
435
+ raise MlflowException(
436
+ f"Can't merge object with non-object type: {type(obj).__name__}"
437
+ )
438
+ if self == obj:
439
+ return deepcopy(self)
440
+ prop_dict1 = {prop.name: prop for prop in self.properties}
441
+ prop_dict2 = {prop.name: prop for prop in obj.properties}
442
+ updated_properties = []
443
+ # For each property in the first element, if it doesn't appear
444
+ # later, we update required=False
445
+ for k in prop_dict1.keys() - prop_dict2.keys():
446
+ updated_properties.append(
447
+ Property(name=k, dtype=prop_dict1[k].dtype, required=False)
448
+ )
449
+ # For common keys, property type should be the same
450
+ for k in prop_dict1.keys() & prop_dict2.keys():
451
+ updated_properties.append(prop_dict1[k]._merge(prop_dict2[k]))
452
+ # For each property appears in the second elements, if it doesn't
453
+ # exist, we update and set required=False
454
+ for k in prop_dict2.keys() - prop_dict1.keys():
455
+ updated_properties.append(
456
+ Property(name=k, dtype=prop_dict2[k].dtype, required=False)
457
+ )
458
+ return Object(properties=updated_properties)
459
+
460
+
461
+ class Array:
462
+ """
463
+ Specification used to represent a json-convertible array.
464
+ """
465
+
466
+ def __init__(
467
+ self,
468
+ dtype: Union["Array", "Map", DataType, Object, str],
469
+ ) -> None:
470
+ try:
471
+ self._dtype = DataType[dtype] if isinstance(dtype, str) else dtype
472
+ except KeyError:
473
+ raise MlflowException(
474
+ f"Unsupported type '{dtype}', expected instance of DataType, Array, Object, Map or "
475
+ f"one of {[t.name for t in DataType]}"
476
+ ) from None
477
+ if not isinstance(self.dtype, (Array, DataType, Object, Map)):
478
+ raise MlflowException(
479
+ "Expected mlflow.types.schema.Array, mlflow.types.schema.Datatype, "
480
+ "mlflow.types.schema.Object, mlflow.types.schema.Map or str for the "
481
+ f"'dtype' argument, but got '{self.dtype.__class__}'"
482
+ ) from None
483
+
484
+ @property
485
+ def dtype(self) -> Union["Array", DataType, Object]:
486
+ """The array data type."""
487
+ return self._dtype
488
+
489
+ def __eq__(self, other) -> bool:
490
+ if isinstance(other, Array):
491
+ return self.dtype == other.dtype
492
+ return False
493
+
494
+ def to_dict(self):
495
+ items = (
496
+ {"type": self.dtype.name}
497
+ if isinstance(self.dtype, DataType)
498
+ else self.dtype.to_dict()
499
+ )
500
+ return {"type": ARRAY_TYPE, "items": items}
501
+
502
+ @classmethod
503
+ def from_json_dict(cls, **kwargs):
504
+ """
505
+ Deserialize from a json loaded dictionary.
506
+ The dictionary is expected to contain `type` and
507
+ `items` keys.
508
+ Example: {"type": "array", "items": "string"}
509
+ """
510
+ if not {"items", "type"} <= set(kwargs.keys()):
511
+ raise MlflowException(
512
+ "Missing keys in Array JSON. Expected to find keys `items` and `type`"
513
+ )
514
+ if kwargs["type"] != ARRAY_TYPE:
515
+ raise MlflowException("Type mismatch, Array expects `array` as the type")
516
+ if not isinstance(kwargs["items"], dict):
517
+ raise MlflowException("Expected items to be a dictionary of Object JSON")
518
+ if not {"type"} <= set(kwargs["items"].keys()):
519
+ raise MlflowException(
520
+ "Missing keys in Array's items JSON. Expected to find key `type`"
521
+ )
522
+
523
+ if kwargs["items"]["type"] == OBJECT_TYPE:
524
+ item_type = Object.from_json_dict(**kwargs["items"])
525
+ elif kwargs["items"]["type"] == ARRAY_TYPE:
526
+ item_type = Array.from_json_dict(**kwargs["items"])
527
+ elif kwargs["items"]["type"] == MAP_TYPE:
528
+ item_type = Map.from_json_dict(**kwargs["items"])
529
+ else:
530
+ item_type = kwargs["items"]["type"]
531
+
532
+ return cls(dtype=item_type)
533
+
534
+ def __repr__(self) -> str:
535
+ return f"Array({self.dtype!r})"
536
+
537
+ def _merge(self, arr: "Array") -> "Array":
538
+ if not isinstance(arr, Array):
539
+ raise MlflowException(
540
+ f"Can't merge array with non-array type: {type(arr).__name__}"
541
+ )
542
+ if self == arr:
543
+ return deepcopy(self)
544
+ if isinstance(self.dtype, DataType):
545
+ if self.dtype == arr.dtype:
546
+ return Array(dtype=self.dtype)
547
+ raise MlflowException(
548
+ f"Array types are incompatible for {self} with dtype={self.dtype} and "
549
+ f"{arr} with dtype={arr.dtype}"
550
+ )
551
+
552
+ if (
553
+ isinstance(self.dtype, (Array, Object, Map))
554
+ and self.dtype.__class__ is arr.dtype.__class__
555
+ ):
556
+ return Array(dtype=self.dtype._merge(arr.dtype))
557
+
558
+ raise MlflowException(f"Array type {self!r} and {arr!r} are incompatible")
559
+
560
+
561
+ class Map:
562
+ """
563
+ Specification used to represent a json-convertible map with string type keys.
564
+ """
565
+
566
+ def __init__(self, value_type: Union["Array", "Map", DataType, Object, str]):
567
+ try:
568
+ self._value_type = (
569
+ DataType[value_type] if isinstance(value_type, str) else value_type
570
+ )
571
+ except KeyError:
572
+ raise MlflowException(
573
+ f"Unsupported value type '{value_type}', expected instance of DataType, Array, "
574
+ f"Object, Map or one of {[t.name for t in DataType]}"
575
+ ) from None
576
+ if not isinstance(self._value_type, (Array, Map, DataType, Object)):
577
+ raise MlflowException(
578
+ "Expected mlflow.types.schema.Array, mlflow.types.schema.Datatype, "
579
+ "mlflow.types.schema.Object, mlflow.types.schema.Map or str for "
580
+ f"the 'value_type' argument, but got '{self._value_type}'"
581
+ ) from None
582
+
583
+ @property
584
+ def value_type(self):
585
+ return self._value_type
586
+
587
+ def __repr__(self) -> str:
588
+ return f"Map(str -> {self._value_type})"
589
+
590
+ def __eq__(self, other) -> bool:
591
+ if isinstance(other, Map):
592
+ return self.value_type == other.value_type
593
+ return False
594
+
595
+ def to_dict(self):
596
+ values = (
597
+ {"type": self.value_type.name}
598
+ if isinstance(self.value_type, DataType)
599
+ else self.value_type.to_dict()
600
+ )
601
+ return {"type": MAP_TYPE, "values": values}
602
+
603
+ @classmethod
604
+ def from_json_dict(cls, **kwargs):
605
+ """
606
+ Deserialize from a json loaded dictionary.
607
+ The dictionary is expected to contain `type` and
608
+ `values` keys.
609
+ Example: {"type": "map", "values": "string"}
610
+ """
611
+ if not {"values", "type"} <= set(kwargs.keys()):
612
+ raise MlflowException(
613
+ "Missing keys in Array JSON. Expected to find keys `items` and `type`"
614
+ )
615
+ if kwargs["type"] != MAP_TYPE:
616
+ raise MlflowException("Type mismatch, Map expects `map` as the type")
617
+ if not isinstance(kwargs["values"], dict):
618
+ raise MlflowException("Expected values to be a dictionary of Object JSON")
619
+ if not {"type"} <= set(kwargs["values"].keys()):
620
+ raise MlflowException(
621
+ "Missing keys in Map's items JSON. Expected to find key `type`"
622
+ )
623
+ if kwargs["values"]["type"] == OBJECT_TYPE:
624
+ return cls(value_type=Object.from_json_dict(**kwargs["values"]))
625
+ if kwargs["values"]["type"] == ARRAY_TYPE:
626
+ return cls(value_type=Array.from_json_dict(**kwargs["values"]))
627
+ if kwargs["values"]["type"] == MAP_TYPE:
628
+ return cls(value_type=Map.from_json_dict(**kwargs["values"]))
629
+ return cls(value_type=kwargs["values"]["type"])
630
+
631
+ def _merge(self, map_type: "Map") -> "Map":
632
+ if not isinstance(map_type, Map):
633
+ raise MlflowException(
634
+ f"Can't merge map with non-map type: {type(map_type).__name__}"
635
+ )
636
+ if self == map_type:
637
+ return deepcopy(self)
638
+ if isinstance(self.value_type, DataType):
639
+ if self.value_type == map_type.value_type:
640
+ return Map(value_type=self.value_type)
641
+ raise MlflowException(
642
+ f"Map types are incompatible for {self} with value_type={self.value_type} and "
643
+ f"{map_type} with value_type={map_type.value_type}"
644
+ )
645
+
646
+ if (
647
+ isinstance(self.value_type, (Array, Object, Map))
648
+ and self.value_type.__class__ is map_type.value_type.__class__
649
+ ):
650
+ return Map(value_type=self.value_type._merge(map_type.value_type))
651
+
652
+ raise MlflowException(f"Map type {self!r} and {map_type!r} are incompatible")
653
+
654
+
655
+ class ColSpec:
656
+ """
657
+ Specification of name and type of a single column in a dataset.
658
+ """
659
+
660
+ def __init__(
661
+ self,
662
+ type: Union[DataType, Array, Object, Map, str],
663
+ name: Optional[str] = None,
664
+ required: bool = True,
665
+ ):
666
+ self._name = name
667
+
668
+ self._required = required
669
+ try:
670
+ self._type = DataType[type] if isinstance(type, str) else type
671
+ except KeyError:
672
+ raise MlflowException(
673
+ f"Unsupported type '{type}', expected instance of DataType or "
674
+ f"one of {[t.name for t in DataType]}"
675
+ ) from None
676
+ if not isinstance(self.type, (DataType, Array, Object, Map)):
677
+ raise TypeError(
678
+ "Expected mlflow.types.schema.Datatype, mlflow.types.schema.Array, "
679
+ "mlflow.types.schema.Object, mlflow.types.schema.Map or str for the 'type' "
680
+ f"argument, but got {self.type.__class__}"
681
+ ) from None
682
+
683
+ @property
684
+ def type(self) -> Union[DataType, Array, Object, Map]:
685
+ """The column data type."""
686
+ return self._type
687
+
688
+ @property
689
+ def name(self) -> Optional[str]:
690
+ """The column name or None if the columns is unnamed."""
691
+ return self._name
692
+
693
+ @name.setter
694
+ def name(self, value: bool) -> None:
695
+ self._name = value
696
+
697
+ @property
698
+ def required(self) -> bool:
699
+ """Whether this column is required."""
700
+ return self._required
701
+
702
+ def to_dict(self) -> Dict[str, Any]:
703
+ d = (
704
+ {"type": self.type.name}
705
+ if isinstance(self.type, DataType)
706
+ else self.type.to_dict()
707
+ )
708
+ if self.name is not None:
709
+ d["name"] = self.name
710
+ d["required"] = self.required
711
+ return d
712
+
713
+ def __eq__(self, other) -> bool:
714
+ if isinstance(other, ColSpec):
715
+ names_eq = (
716
+ self.name is None and other.name is None
717
+ ) or self.name == other.name
718
+ return (
719
+ names_eq and self.type == other.type and self.required == other.required
720
+ )
721
+ return False
722
+
723
+ def __repr__(self) -> str:
724
+ required = "required" if self.required else "optional"
725
+ if self.name is None:
726
+ return f"{self.type!r} ({required})"
727
+ return f"{self.name!r}: {self.type!r} ({required})"
728
+
729
+ @classmethod
730
+ def from_json_dict(cls, **kwargs):
731
+ """
732
+ Deserialize from a json loaded dictionary.
733
+ The dictionary is expected to contain `type` and
734
+ optional `name` and `required` keys.
735
+ """
736
+ if not {"type"} <= set(kwargs.keys()):
737
+ raise MlflowException(
738
+ "Missing keys in ColSpec JSON. Expected to find key `type`"
739
+ )
740
+ if kwargs["type"] not in [
741
+ ARRAY_TYPE,
742
+ OBJECT_TYPE,
743
+ MAP_TYPE,
744
+ ]:
745
+ return cls(**kwargs)
746
+ name = kwargs.pop("name", None)
747
+ required = kwargs.pop("required", None)
748
+ if kwargs["type"] == ARRAY_TYPE:
749
+ return cls(
750
+ name=name, type=Array.from_json_dict(**kwargs), required=required
751
+ )
752
+ if kwargs["type"] == OBJECT_TYPE:
753
+ return cls(
754
+ name=name,
755
+ type=Object.from_json_dict(**kwargs),
756
+ required=required,
757
+ )
758
+ if kwargs["type"] == MAP_TYPE:
759
+ return cls(name=name, type=Map.from_json_dict(**kwargs), required=required)
760
+
761
+
762
+ class TensorInfo:
763
+ """
764
+ Representation of the shape and type of a Tensor.
765
+ """
766
+
767
+ def __init__(self, dtype: np.dtype, shape: Union[tuple, list]):
768
+ if not isinstance(dtype, np.dtype):
769
+ raise TypeError(
770
+ f"Expected `dtype` to be instance of `{np.dtype}`, received `{ dtype.__class__}`"
771
+ )
772
+ # Throw if size information exists flexible numpy data types
773
+ if dtype.char in ["U", "S"] and not dtype.name.isalpha():
774
+ raise MlflowException(
775
+ "MLflow does not support size information in flexible numpy data types. Use"
776
+ f' np.dtype("{dtype.name.rstrip(string.digits)}") instead'
777
+ )
778
+
779
+ if not isinstance(shape, (tuple, list)):
780
+ raise TypeError(
781
+ "Expected `shape` to be instance of `{}` or `{}`, received `{}`".format(
782
+ tuple, list, shape.__class__
783
+ )
784
+ )
785
+ self._dtype = dtype
786
+ self._shape = tuple(shape)
787
+
788
+ @property
789
+ def dtype(self) -> np.dtype:
790
+ """
791
+ A unique character code for each of the 21 different numpy built-in types.
792
+ See https://numpy.org/devdocs/reference/generated/numpy.dtype.html#numpy.dtype for details.
793
+ """
794
+ return self._dtype
795
+
796
+ @property
797
+ def shape(self) -> tuple:
798
+ """The tensor shape"""
799
+ return self._shape
800
+
801
+ def to_dict(self) -> Dict[str, Any]:
802
+ return {"dtype": self._dtype.name, "shape": self._shape}
803
+
804
+ @classmethod
805
+ def from_json_dict(cls, **kwargs):
806
+ """
807
+ Deserialize from a json loaded dictionary.
808
+ The dictionary is expected to contain `dtype` and `shape` keys.
809
+ """
810
+ if not {"dtype", "shape"} <= set(kwargs.keys()):
811
+ raise MlflowException(
812
+ "Missing keys in TensorSpec JSON. Expected to find keys `dtype` and `shape`"
813
+ )
814
+ tensor_type = np.dtype(kwargs["dtype"])
815
+ tensor_shape = tuple(kwargs["shape"])
816
+ return cls(tensor_type, tensor_shape)
817
+
818
+ def __repr__(self) -> str:
819
+ return f"Tensor({self.dtype.name!r}, {self.shape!r})"
820
+
821
+
822
+ class TensorSpec:
823
+ """
824
+ Specification used to represent a dataset stored as a Tensor.
825
+ """
826
+
827
+ def __init__(
828
+ self,
829
+ type: np.dtype,
830
+ shape: Union[tuple, list],
831
+ name: Optional[str] = None,
832
+ ):
833
+ self._name = name
834
+ self._tensorInfo = TensorInfo(type, shape)
835
+
836
+ @property
837
+ def type(self) -> np.dtype:
838
+ """
839
+ A unique character code for each of the 21 different numpy built-in types.
840
+ See https://numpy.org/devdocs/reference/generated/numpy.dtype.html#numpy.dtype for details.
841
+ """
842
+ return self._tensorInfo.dtype
843
+
844
+ @property
845
+ def name(self) -> Optional[str]:
846
+ """The tensor name or None if the tensor is unnamed."""
847
+ return self._name
848
+
849
+ @property
850
+ def shape(self) -> tuple:
851
+ """The tensor shape"""
852
+ return self._tensorInfo.shape
853
+
854
+ @property
855
+ def required(self) -> bool:
856
+ """Whether this tensor is required."""
857
+ return True
858
+
859
+ def to_dict(self) -> Dict[str, Any]:
860
+ if self.name is None:
861
+ return {"type": "tensor", "tensor-spec": self._tensorInfo.to_dict()}
862
+ else:
863
+ return {
864
+ "name": self.name,
865
+ "type": "tensor",
866
+ "tensor-spec": self._tensorInfo.to_dict(),
867
+ }
868
+
869
+ @classmethod
870
+ def from_json_dict(cls, **kwargs):
871
+ """
872
+ Deserialize from a json loaded dictionary.
873
+ The dictionary is expected to contain `type` and `tensor-spec` keys.
874
+ """
875
+ if not {"tensor-spec", "type"} <= set(kwargs.keys()):
876
+ raise MlflowException(
877
+ "Missing keys in TensorSpec JSON. Expected to find keys `tensor-spec` and `type`"
878
+ )
879
+ if kwargs["type"] != "tensor":
880
+ raise MlflowException(
881
+ "Type mismatch, TensorSpec expects `tensor` as the type"
882
+ )
883
+ tensor_info = TensorInfo.from_json_dict(**kwargs["tensor-spec"])
884
+ return cls(
885
+ tensor_info.dtype,
886
+ tensor_info.shape,
887
+ kwargs["name"] if "name" in kwargs else None,
888
+ )
889
+
890
+ def __eq__(self, other) -> bool:
891
+ if isinstance(other, TensorSpec):
892
+ names_eq = (
893
+ self.name is None and other.name is None
894
+ ) or self.name == other.name
895
+ return names_eq and self.type == other.type and self.shape == other.shape
896
+ return False
897
+
898
+ def __repr__(self) -> str:
899
+ if self.name is None:
900
+ return repr(self._tensorInfo)
901
+ else:
902
+ return f"{self.name!r}: {self._tensorInfo!r}"
903
+
904
+
905
+ class Schema:
906
+ """
907
+ Specification of a dataset.
908
+
909
+ Schema is represented as a list of :py:class:`ColSpec` or :py:class:`TensorSpec`. A combination
910
+ of `ColSpec` and `TensorSpec` is not allowed.
911
+
912
+ The dataset represented by a schema can be named, with unique non empty names for every input.
913
+ In the case of :py:class:`ColSpec`, the dataset columns can be unnamed with implicit integer
914
+ index defined by their list indices.
915
+ Combination of named and unnamed data inputs are not allowed.
916
+ """
917
+
918
+ def __init__(self, inputs: List[Union[ColSpec, TensorSpec]]):
919
+ if not isinstance(inputs, list):
920
+ raise MlflowException.invalid_parameter_value(
921
+ f"Inputs of Schema must be a list, got type {type(inputs).__name__}"
922
+ )
923
+ if not inputs:
924
+ raise MlflowException.invalid_parameter_value(
925
+ "Creating Schema with empty inputs is not allowed."
926
+ )
927
+
928
+ if not (
929
+ all(x.name is None for x in inputs)
930
+ or all(x.name is not None for x in inputs)
931
+ ):
932
+ raise MlflowException(
933
+ "Creating Schema with a combination of named and unnamed inputs "
934
+ f"is not allowed. Got input names {[x.name for x in inputs]}"
935
+ )
936
+ if not (
937
+ all(isinstance(x, TensorSpec) for x in inputs)
938
+ or all(isinstance(x, ColSpec) for x in inputs)
939
+ ):
940
+ raise MlflowException(
941
+ "Creating Schema with a combination of {0} and {1} is not supported. "
942
+ "Please choose one of {0} or {1}".format(
943
+ ColSpec.__class__, TensorSpec.__class__
944
+ )
945
+ )
946
+ if (
947
+ all(isinstance(x, TensorSpec) for x in inputs)
948
+ and len(inputs) > 1
949
+ and any(x.name is None for x in inputs)
950
+ ):
951
+ raise MlflowException(
952
+ "Creating Schema with multiple unnamed TensorSpecs is not supported. "
953
+ "Please provide names for each TensorSpec."
954
+ )
955
+ if all(x.name is None for x in inputs) and any(
956
+ x.required is False for x in inputs
957
+ ):
958
+ raise MlflowException(
959
+ "Creating Schema with unnamed optional inputs is not supported. "
960
+ "Please name all inputs or make all inputs required."
961
+ )
962
+ self._inputs = inputs
963
+
964
+ def __len__(self):
965
+ return len(self._inputs)
966
+
967
+ def __iter__(self):
968
+ return iter(self._inputs)
969
+
970
+ @property
971
+ def inputs(self) -> List[Union[ColSpec, TensorSpec]]:
972
+ """Representation of a dataset that defines this schema."""
973
+ return self._inputs
974
+
975
+ def is_tensor_spec(self) -> bool:
976
+ """Return true iff this schema is specified using TensorSpec"""
977
+ return self.inputs and isinstance(self.inputs[0], TensorSpec)
978
+
979
+ def input_names(self) -> List[Union[str, int]]:
980
+ """Get list of data names or range of indices if the schema has no names."""
981
+ return [x.name or i for i, x in enumerate(self.inputs)]
982
+
983
+ def required_input_names(self) -> List[Union[str, int]]:
984
+ """Get list of required data names or range of indices if schema has no names."""
985
+ return [x.name or i for i, x in enumerate(self.inputs) if x.required]
986
+
987
+ def optional_input_names(self) -> List[Union[str, int]]:
988
+ """Get list of optional data names or range of indices if schema has no names."""
989
+ return [x.name or i for i, x in enumerate(self.inputs) if not x.required]
990
+
991
+ def has_input_names(self) -> bool:
992
+ """Return true iff this schema declares names, false otherwise."""
993
+ return self.inputs and self.inputs[0].name is not None
994
+
995
+ def input_types(self) -> List[Union[DataType, np.dtype, Array, Object]]:
996
+ """Get types for each column in the schema."""
997
+ return [x.type for x in self.inputs]
998
+
999
+ def input_types_dict(self) -> Dict[str, Union[DataType, np.dtype, Array, Object]]:
1000
+ """Maps column names to types, iff this schema declares names."""
1001
+ if not self.has_input_names():
1002
+ raise MlflowException(
1003
+ "Cannot get input types as a dict for schema without names."
1004
+ )
1005
+ return {x.name: x.type for x in self.inputs}
1006
+
1007
+ def input_dict(self) -> Dict[str, Union[ColSpec, TensorSpec]]:
1008
+ """Maps column names to inputs, iff this schema declares names."""
1009
+ if not self.has_input_names():
1010
+ raise MlflowException("Cannot get input dict for schema without names.")
1011
+ return {x.name: x for x in self.inputs}
1012
+
1013
+ def numpy_types(self) -> List[np.dtype]:
1014
+ """Convenience shortcut to get the datatypes as numpy types."""
1015
+ if self.is_tensor_spec():
1016
+ return [x.type for x in self.inputs]
1017
+ if all(isinstance(x.type, DataType) for x in self.inputs):
1018
+ return [x.type.to_numpy() for x in self.inputs]
1019
+ raise MlflowException(
1020
+ "Failed to get numpy types as some of the inputs types are not DataType."
1021
+ )
1022
+
1023
+ def pandas_types(self) -> List[np.dtype]:
1024
+ """Convenience shortcut to get the datatypes as pandas types. Unsupported by TensorSpec."""
1025
+ if self.is_tensor_spec():
1026
+ raise MlflowException(
1027
+ "TensorSpec only supports numpy types, use numpy_types() instead"
1028
+ )
1029
+ if all(isinstance(x.type, DataType) for x in self.inputs):
1030
+ return [x.type.to_pandas() for x in self.inputs]
1031
+ raise MlflowException(
1032
+ "Failed to get pandas types as some of the inputs types are not DataType."
1033
+ )
1034
+
1035
+ def as_spark_schema(self):
1036
+ """Convert to Spark schema. If this schema is a single unnamed column, it is converted
1037
+ directly the corresponding spark data type, otherwise it's returned as a struct (missing
1038
+ column names are filled with an integer sequence).
1039
+ Unsupported by TensorSpec.
1040
+ """
1041
+ if self.is_tensor_spec():
1042
+ raise MlflowException("TensorSpec cannot be converted to spark dataframe")
1043
+ if len(self.inputs) == 1 and self.inputs[0].name is None:
1044
+ return self.inputs[0].type.to_spark()
1045
+ from pyspark.sql.types import StructField, StructType
1046
+
1047
+ return StructType(
1048
+ [
1049
+ StructField(
1050
+ name=col.name or str(i),
1051
+ dataType=col.type.to_spark(),
1052
+ nullable=not col.required,
1053
+ )
1054
+ for i, col in enumerate(self.inputs)
1055
+ ]
1056
+ )
1057
+
1058
+ def to_json(self) -> str:
1059
+ """Serialize into json string."""
1060
+ return json.dumps([x.to_dict() for x in self.inputs])
1061
+
1062
+ def to_dict(self) -> List[Dict[str, Any]]:
1063
+ """Serialize into a jsonable dictionary."""
1064
+ return [x.to_dict() for x in self.inputs]
1065
+
1066
+ @classmethod
1067
+ def from_json(cls, json_str: str):
1068
+ """Deserialize from a json string."""
1069
+
1070
+ def read_input(x: dict):
1071
+ return (
1072
+ TensorSpec.from_json_dict(**x)
1073
+ if x["type"] == "tensor"
1074
+ else ColSpec.from_json_dict(**x)
1075
+ )
1076
+
1077
+ return cls([read_input(x) for x in json.loads(json_str)])
1078
+
1079
+ def __eq__(self, other) -> bool:
1080
+ if isinstance(other, Schema):
1081
+ return self.inputs == other.inputs
1082
+ else:
1083
+ return False
1084
+
1085
+ def __repr__(self) -> str:
1086
+ return repr(self.inputs)
1087
+
1088
+
1089
+ class ParamSpec:
1090
+ """
1091
+ Specification used to represent parameters for the model.
1092
+ """
1093
+
1094
+ def __init__(
1095
+ self,
1096
+ name: str,
1097
+ dtype: Union[DataType, str],
1098
+ default: Union[DataType, List[DataType], None],
1099
+ shape: Optional[Tuple[int, ...]] = None,
1100
+ ):
1101
+ self._name = str(name)
1102
+ self._shape = tuple(shape) if shape is not None else None
1103
+
1104
+ try:
1105
+ self._dtype = DataType[dtype] if isinstance(dtype, str) else dtype
1106
+ except KeyError:
1107
+ supported_types = [t.name for t in DataType if t.name != "binary"]
1108
+ raise MlflowException.invalid_parameter_value(
1109
+ f"Unsupported type '{dtype}', expected instance of DataType or "
1110
+ f"one of {supported_types}",
1111
+ ) from None
1112
+ if not isinstance(self.dtype, DataType):
1113
+ raise TypeError(
1114
+ "Expected mlflow.models.signature.Datatype or str for the 'dtype' "
1115
+ f"argument, but got {self.dtype.__class__}"
1116
+ ) from None
1117
+ if self.dtype == DataType.binary:
1118
+ raise MlflowException.invalid_parameter_value(
1119
+ f"Binary type is not supported for parameters, ParamSpec '{self.name}'"
1120
+ "has dtype 'binary'",
1121
+ ) from None
1122
+
1123
+ # This line makes sure repr(self) works fine
1124
+ self._default = default
1125
+ self._default = self.validate_type_and_shape(
1126
+ repr(self), default, self.dtype, self.shape
1127
+ )
1128
+
1129
+ @classmethod
1130
+ def validate_param_spec(
1131
+ cls, value: Union[DataType, List[DataType], None], param_spec: "ParamSpec"
1132
+ ):
1133
+ return cls.validate_type_and_shape(
1134
+ repr(param_spec), value, param_spec.dtype, param_spec.shape
1135
+ )
1136
+
1137
+ @classmethod
1138
+ def enforce_param_datatype(cls, name, value, dtype: DataType):
1139
+ """
1140
+ Enforce the value matches the data type.
1141
+
1142
+ The following type conversions are allowed:
1143
+
1144
+ 1. int -> long, float, double
1145
+ 2. long -> float, double
1146
+ 3. float -> double
1147
+ 4. any -> datetime (try conversion)
1148
+
1149
+ Any other type mismatch will raise error.
1150
+
1151
+ Args:
1152
+ name: parameter name
1153
+ value: parameter value
1154
+ dtype: expected data type
1155
+ """
1156
+ if value is None:
1157
+ return
1158
+
1159
+ if dtype == DataType.datetime:
1160
+ try:
1161
+ datetime_value = np.datetime64(value).item()
1162
+ if isinstance(datetime_value, int):
1163
+ raise MlflowException.invalid_parameter_value(
1164
+ f"Invalid value for param {name}, it should "
1165
+ f"be convertible to datetime.date/datetime, got {value}"
1166
+ )
1167
+ return datetime_value
1168
+ except ValueError as e:
1169
+ raise MlflowException.invalid_parameter_value(
1170
+ f"Failed to convert value {value} from type {type(value).__name__} "
1171
+ f"to {dtype} for param {name}"
1172
+ ) from e
1173
+
1174
+ # Note that np.isscalar(datetime.date(...)) is False
1175
+ if not np.isscalar(value):
1176
+ raise MlflowException.invalid_parameter_value(
1177
+ f"Value should be a scalar for param {name}, got {value}"
1178
+ )
1179
+
1180
+ # Always convert to python native type for params
1181
+ if getattr(DataType, f"is_{dtype.name}")(value):
1182
+ return DataType[dtype.name].to_python()(value)
1183
+
1184
+ if (
1185
+ (
1186
+ DataType.is_integer(value)
1187
+ and dtype in (DataType.long, DataType.float, DataType.double)
1188
+ )
1189
+ or (DataType.is_long(value) and dtype in (DataType.float, DataType.double))
1190
+ or (DataType.is_float(value) and dtype == DataType.double)
1191
+ ):
1192
+ try:
1193
+ return DataType[dtype.name].to_python()(value)
1194
+ except ValueError as e:
1195
+ raise MlflowException.invalid_parameter_value(
1196
+ f"Failed to convert value {value} from type {type(value).__name__} "
1197
+ f"to {dtype} for param {name}"
1198
+ ) from e
1199
+
1200
+ raise MlflowException.invalid_parameter_value(
1201
+ f"Incompatible types for param {name}. Can not safely convert {type(value).__name__} "
1202
+ f"to {dtype}.",
1203
+ )
1204
+
1205
+ @classmethod
1206
+ def validate_type_and_shape(
1207
+ cls,
1208
+ spec: str,
1209
+ value: Union[DataType, List[DataType], None],
1210
+ value_type: DataType,
1211
+ shape: Optional[Tuple[int, ...]],
1212
+ ):
1213
+ """
1214
+ Validate that the value has the expected type and shape.
1215
+ """
1216
+
1217
+ def _is_1d_array(value):
1218
+ return isinstance(value, (list, np.ndarray)) and np.array(value).ndim == 1
1219
+
1220
+ if shape is None:
1221
+ return cls.enforce_param_datatype(
1222
+ f"{spec} with shape None", value, value_type
1223
+ )
1224
+ elif shape == (-1,):
1225
+ if not _is_1d_array(value):
1226
+ raise MlflowException.invalid_parameter_value(
1227
+ f"Value must be a 1D array with shape (-1,) for param {spec}, "
1228
+ f"received {type(value).__name__} with ndim {np.array(value).ndim}",
1229
+ )
1230
+ return [
1231
+ cls.enforce_param_datatype(f"{spec} internal values", v, value_type)
1232
+ for v in value
1233
+ ]
1234
+ else:
1235
+ raise MlflowException.invalid_parameter_value(
1236
+ "Shape must be None for scalar value or (-1,) for 1D array value "
1237
+ f"for ParamSpec {spec}), received {shape}",
1238
+ )
1239
+
1240
+ @property
1241
+ def name(self) -> str:
1242
+ """The name of the parameter."""
1243
+ return self._name
1244
+
1245
+ @property
1246
+ def dtype(self) -> DataType:
1247
+ """The parameter data type."""
1248
+ return self._dtype
1249
+
1250
+ @property
1251
+ def default(self) -> Union[DataType, List[DataType], None]:
1252
+ """Default value of the parameter."""
1253
+ return self._default
1254
+
1255
+ @property
1256
+ def shape(self) -> Optional[tuple]:
1257
+ """
1258
+ The parameter shape.
1259
+ If shape is None, the parameter is a scalar.
1260
+ """
1261
+ return self._shape
1262
+
1263
+ class ParamSpecTypedDict(TypedDict):
1264
+ name: str
1265
+ type: str
1266
+ default: Union[DataType, List[DataType], None]
1267
+ shape: Optional[Tuple[int, ...]]
1268
+
1269
+ def to_dict(self) -> ParamSpecTypedDict:
1270
+ if self.shape is None:
1271
+ default_value = (
1272
+ self.default.isoformat()
1273
+ if self.dtype.name == "datetime"
1274
+ else self.default
1275
+ )
1276
+ elif self.shape == (-1,):
1277
+ default_value = (
1278
+ [v.isoformat() for v in self.default]
1279
+ if self.dtype.name == "datetime"
1280
+ else self.default
1281
+ )
1282
+ return {
1283
+ "name": self.name,
1284
+ "type": self.dtype.name,
1285
+ "default": default_value,
1286
+ "shape": self.shape,
1287
+ }
1288
+
1289
+ def __eq__(self, other) -> bool:
1290
+ if isinstance(other, ParamSpec):
1291
+ return (
1292
+ self.name == other.name
1293
+ and self.dtype == other.dtype
1294
+ and self.default == other.default
1295
+ and self.shape == other.shape
1296
+ )
1297
+ return False
1298
+
1299
+ def __repr__(self) -> str:
1300
+ shape = f" (shape: {self.shape})" if self.shape is not None else ""
1301
+ return f"{self.name!r}: {self.dtype!r} (default: {self.default}){shape}"
1302
+
1303
+ @classmethod
1304
+ def from_json_dict(cls, **kwargs):
1305
+ """
1306
+ Deserialize from a json loaded dictionary.
1307
+ The dictionary is expected to contain `name`, `type` and `default` keys.
1308
+ """
1309
+ # For backward compatibility, we accept both `type` and `dtype` keys
1310
+ required_keys1 = {"name", "dtype", "default"}
1311
+ required_keys2 = {"name", "type", "default"}
1312
+
1313
+ if not (required_keys1.issubset(kwargs) or required_keys2.issubset(kwargs)):
1314
+ raise MlflowException.invalid_parameter_value(
1315
+ "Missing keys in ParamSpec JSON. Expected to find "
1316
+ "keys `name`, `type`(or `dtype`) and `default`. "
1317
+ f"Received keys: {kwargs.keys()}"
1318
+ )
1319
+ dtype = kwargs.get("type") or kwargs.get("dtype")
1320
+ return cls(
1321
+ name=str(kwargs["name"]),
1322
+ dtype=DataType[dtype],
1323
+ default=kwargs["default"],
1324
+ shape=kwargs.get("shape"),
1325
+ )
1326
+
1327
+
1328
+ class ParamSchema:
1329
+ """
1330
+ Specification of parameters applicable to the model.
1331
+ ParamSchema is represented as a list of :py:class:`ParamSpec`.
1332
+ """
1333
+
1334
+ def __init__(self, params: List[ParamSpec]):
1335
+ if not all(isinstance(x, ParamSpec) for x in params):
1336
+ raise MlflowException.invalid_parameter_value(
1337
+ f"ParamSchema inputs only accept {ParamSchema.__class__}"
1338
+ )
1339
+ if duplicates := self._find_duplicates(params):
1340
+ raise MlflowException.invalid_parameter_value(
1341
+ f"Duplicated parameters found in schema: {duplicates}"
1342
+ )
1343
+ self._params = params
1344
+
1345
+ @staticmethod
1346
+ def _find_duplicates(params: List[ParamSpec]) -> List[str]:
1347
+ param_names = [param_spec.name for param_spec in params]
1348
+ uniq_param = set()
1349
+ duplicates = []
1350
+ for name in param_names:
1351
+ if name in uniq_param:
1352
+ duplicates.append(name)
1353
+ else:
1354
+ uniq_param.add(name)
1355
+ return duplicates
1356
+
1357
+ def __len__(self):
1358
+ return len(self._params)
1359
+
1360
+ def __iter__(self):
1361
+ return iter(self._params)
1362
+
1363
+ @property
1364
+ def params(self) -> List[ParamSpec]:
1365
+ """Representation of ParamSchema as a list of ParamSpec."""
1366
+ return self._params
1367
+
1368
+ def to_json(self) -> str:
1369
+ """Serialize into json string."""
1370
+ return json.dumps(self.to_dict())
1371
+
1372
+ @classmethod
1373
+ def from_json(cls, json_str: str):
1374
+ """Deserialize from a json string."""
1375
+ return cls([ParamSpec.from_json_dict(**x) for x in json.loads(json_str)])
1376
+
1377
+ def to_dict(self) -> List[Dict[str, Any]]:
1378
+ """Serialize into a jsonable dictionary."""
1379
+ return [x.to_dict() for x in self.params]
1380
+
1381
+ def __eq__(self, other) -> bool:
1382
+ if isinstance(other, ParamSchema):
1383
+ return self.params == other.params
1384
+ return False
1385
+
1386
+ def __repr__(self) -> str:
1387
+ return repr(self.params)
1388
+
1389
+
1390
+ def _map_field_type(field):
1391
+ field_type_mapping = {
1392
+ bool: "boolean",
1393
+ int: "long", # int is mapped to long to support 64-bit integers
1394
+ builtins.float: "float",
1395
+ str: "string",
1396
+ bytes: "binary",
1397
+ dt.date: "datetime",
1398
+ }
1399
+ return field_type_mapping.get(field)
1400
+
1401
+
1402
+ def _get_dataclass_annotations(cls) -> Dict[str, Any]:
1403
+ """
1404
+ Given a dataclass or an instance of one, collect annotations from it and all its parent
1405
+ dataclasses.
1406
+ """
1407
+ if not is_dataclass(cls):
1408
+ raise TypeError(f"{cls.__name__} is not a dataclass.")
1409
+
1410
+ annotations = {}
1411
+ effective_class = cls if isinstance(cls, type) else type(cls)
1412
+
1413
+ # Reverse MRO so subclass overrides are captured last
1414
+ for base in reversed(effective_class.__mro__):
1415
+ # Only capture supers that are dataclasses
1416
+ if is_dataclass(base) and hasattr(base, "__annotations__"):
1417
+ annotations.update(base.__annotations__)
1418
+ return annotations
1419
+
1420
+
1421
+ def convert_dataclass_to_schema(dataclass):
1422
+ """
1423
+ Converts a given dataclass into a Schema object. The dataclass must include type hints
1424
+ for all its fields. Fields can be of basic types, other dataclasses, or Lists/Optional of
1425
+ these types. Union types are not supported. Only the top-level fields are directly converted
1426
+ to ColSpecs, while nested fields are converted into nested Object types.
1427
+ """
1428
+
1429
+ inputs = []
1430
+
1431
+ for field_name, field_type in _get_dataclass_annotations(dataclass).items():
1432
+ # Determine the type and handle Optional and List correctly
1433
+ is_optional = False
1434
+ effective_type = field_type
1435
+
1436
+ if get_origin(field_type) == Union:
1437
+ if type(None) in get_args(field_type) and len(get_args(field_type)) == 2:
1438
+ # This is an Optional type; determine the effective type excluding None
1439
+ is_optional = True
1440
+ effective_type = next(
1441
+ t for t in get_args(field_type) if t is not type(None)
1442
+ )
1443
+ else:
1444
+ raise MlflowException(
1445
+ "Only Optional[...] is supported as a Union type in dataclass fields"
1446
+ )
1447
+
1448
+ if get_origin(effective_type) == list:
1449
+ # It's a list, check the type within the list
1450
+ list_type = get_args(effective_type)[0]
1451
+ if is_dataclass(list_type):
1452
+ dtype = _convert_dataclass_to_nested_object(
1453
+ list_type
1454
+ ) # Convert to nested Object
1455
+ inputs.append(
1456
+ ColSpec(
1457
+ type=Array(dtype=dtype),
1458
+ name=field_name,
1459
+ required=not is_optional,
1460
+ )
1461
+ )
1462
+ else:
1463
+ if dtype := _map_field_type(list_type):
1464
+ inputs.append(
1465
+ ColSpec(
1466
+ type=Array(dtype=dtype),
1467
+ name=field_name,
1468
+ required=not is_optional,
1469
+ )
1470
+ )
1471
+ else:
1472
+ raise MlflowException(
1473
+ f"List field type {list_type} is not supported in dataclass"
1474
+ f" {dataclass.__name__}"
1475
+ )
1476
+ elif is_dataclass(effective_type):
1477
+ # It's a nested dataclass
1478
+ dtype = _convert_dataclass_to_nested_object(
1479
+ effective_type
1480
+ ) # Convert to nested Object
1481
+ inputs.append(
1482
+ ColSpec(
1483
+ type=dtype,
1484
+ name=field_name,
1485
+ required=not is_optional,
1486
+ )
1487
+ )
1488
+ # confirm the effective type is a basic type
1489
+ elif dtype := _map_field_type(effective_type):
1490
+ # It's a basic type
1491
+ inputs.append(
1492
+ ColSpec(
1493
+ type=dtype,
1494
+ name=field_name,
1495
+ required=not is_optional,
1496
+ )
1497
+ )
1498
+ else:
1499
+ raise MlflowException(
1500
+ f"Unsupported field type {effective_type} in dataclass {dataclass.__name__}"
1501
+ )
1502
+
1503
+ return Schema(inputs=inputs)
1504
+
1505
+
1506
+ def _convert_dataclass_to_nested_object(dataclass):
1507
+ """
1508
+ Convert a nested dataclass to an Object type used within a ColSpec.
1509
+ """
1510
+ properties = []
1511
+ for field_name, field_type in dataclass.__annotations__.items():
1512
+ properties.append(_convert_field_to_property(field_name, field_type))
1513
+ return Object(properties=properties)
1514
+
1515
+
1516
+ def _convert_field_to_property(field_name, field_type):
1517
+ """
1518
+ Helper function to convert a single field to a Property object suitable for inclusion in an
1519
+ Object.
1520
+ """
1521
+
1522
+ is_optional = False
1523
+ effective_type = field_type
1524
+
1525
+ if get_origin(field_type) == Union and type(None) in get_args(field_type):
1526
+ is_optional = True
1527
+ effective_type = next(t for t in get_args(field_type) if t is not type(None))
1528
+
1529
+ if get_origin(effective_type) == list:
1530
+ list_type = get_args(effective_type)[0]
1531
+ return Property(
1532
+ name=field_name,
1533
+ dtype=Array(dtype=_map_field_type(list_type)),
1534
+ required=not is_optional,
1535
+ )
1536
+ elif is_dataclass(effective_type):
1537
+ return Property(
1538
+ name=field_name,
1539
+ dtype=_convert_dataclass_to_nested_object(effective_type),
1540
+ required=not is_optional,
1541
+ )
1542
+ else:
1543
+ return Property(
1544
+ name=field_name,
1545
+ dtype=_map_field_type(effective_type),
1546
+ required=not is_optional,
1547
+ )