arize 8.0.0a13__py3-none-any.whl → 8.0.0a15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
arize/models/client.py CHANGED
@@ -31,13 +31,17 @@ from arize.exceptions.parameters import (
31
31
  from arize.exceptions.spaces import MissingSpaceIDError
32
32
  from arize.logging import get_truncation_warning_message
33
33
  from arize.models.bounded_executor import BoundedExecutor
34
+ from arize.models.casting import cast_dictionary, cast_typed_columns
34
35
  from arize.models.stream_validation import (
35
36
  validate_and_convert_prediction_id,
36
37
  validate_label,
37
38
  )
38
39
  from arize.types import (
40
+ CATEGORICAL_MODEL_TYPES,
41
+ NUMERIC_MODEL_TYPES,
39
42
  ActualLabelTypes,
40
43
  BaseSchema,
44
+ CorpusSchema,
41
45
  Embedding,
42
46
  Environments,
43
47
  LLMRunMetadata,
@@ -51,7 +55,6 @@ from arize.types import (
51
55
  convert_element,
52
56
  is_list_of,
53
57
  )
54
- from arize.utils.casting import cast_dictionary, cast_typed_columns
55
58
 
56
59
  if TYPE_CHECKING:
57
60
  import concurrent.futures as cf
@@ -63,6 +66,11 @@ if TYPE_CHECKING:
63
66
 
64
67
  from arize._generated.protocol.rec import public_pb2 as pb2
65
68
  from arize.config import SDKConfiguration
69
+ from arize.types import (
70
+ EmbeddingColumnNames,
71
+ Schema,
72
+ )
73
+
66
74
 
67
75
  logger = logging.getLogger(__name__)
68
76
 
@@ -122,7 +130,7 @@ class MLModelsClient:
122
130
  ) -> cf.Future:
123
131
  require(_STREAM_EXTRA, _STREAM_DEPS)
124
132
  from arize._generated.protocol.rec import public_pb2 as pb2
125
- from arize.utils.proto import (
133
+ from arize.models.proto import (
126
134
  get_pb_dictionary,
127
135
  get_pb_label,
128
136
  get_pb_timestamp,
@@ -469,7 +477,6 @@ class MLModelsClient:
469
477
  from arize.models.batch_validation.validator import Validator
470
478
  from arize.utils.arrow import post_arrow_table
471
479
  from arize.utils.dataframe import remove_extraneous_columns
472
- from arize.utils.proto import get_pb_schema, get_pb_schema_corpus
473
480
 
474
481
  # This method requires a space_id and project_name
475
482
  if not space_id:
@@ -620,12 +627,12 @@ class MLModelsClient:
620
627
  )
621
628
 
622
629
  if environment == Environments.CORPUS:
623
- proto_schema = get_pb_schema_corpus(
630
+ proto_schema = _get_pb_schema_corpus(
624
631
  schema=schema,
625
632
  model_id=model_name,
626
633
  )
627
634
  else:
628
- proto_schema = get_pb_schema(
635
+ proto_schema = _get_pb_schema(
629
636
  schema=schema,
630
637
  model_id=model_name,
631
638
  model_version=model_version,
@@ -803,3 +810,321 @@ def _is_timestamp_in_range(now: int, ts: int):
803
810
  max_time = now + (MAX_FUTURE_YEARS_FROM_CURRENT_TIME * 365 * 24 * 60 * 60)
804
811
  min_time = now - (MAX_PAST_YEARS_FROM_CURRENT_TIME * 365 * 24 * 60 * 60)
805
812
  return min_time <= ts <= max_time
813
+
814
+
815
+ def _get_pb_schema(
816
+ schema: Schema,
817
+ model_id: str,
818
+ model_version: str | None,
819
+ model_type: ModelTypes,
820
+ environment: Environments,
821
+ batch_id: str,
822
+ ):
823
+ s = pb2.Schema()
824
+ s.constants.model_id = model_id
825
+
826
+ if model_version is not None:
827
+ s.constants.model_version = model_version
828
+
829
+ if environment == Environments.PRODUCTION:
830
+ s.constants.environment = pb2.Schema.Environment.PRODUCTION
831
+ elif environment == Environments.VALIDATION:
832
+ s.constants.environment = pb2.Schema.Environment.VALIDATION
833
+ elif environment == Environments.TRAINING:
834
+ s.constants.environment = pb2.Schema.Environment.TRAINING
835
+ else:
836
+ raise ValueError(f"unexpected environment: {environment}")
837
+
838
+ # Map user-friendly external model types -> internal model types when sending to Arize
839
+ if model_type in NUMERIC_MODEL_TYPES:
840
+ s.constants.model_type = pb2.Schema.ModelType.NUMERIC
841
+ elif model_type in CATEGORICAL_MODEL_TYPES:
842
+ s.constants.model_type = pb2.Schema.ModelType.SCORE_CATEGORICAL
843
+ elif model_type == ModelTypes.RANKING:
844
+ s.constants.model_type = pb2.Schema.ModelType.RANKING
845
+ elif model_type == ModelTypes.OBJECT_DETECTION:
846
+ s.constants.model_type = pb2.Schema.ModelType.OBJECT_DETECTION
847
+ elif model_type == ModelTypes.GENERATIVE_LLM:
848
+ s.constants.model_type = pb2.Schema.ModelType.GENERATIVE_LLM
849
+ elif model_type == ModelTypes.MULTI_CLASS:
850
+ s.constants.model_type = pb2.Schema.ModelType.MULTI_CLASS
851
+
852
+ if batch_id is not None:
853
+ s.constants.batch_id = batch_id
854
+
855
+ if schema.prediction_id_column_name is not None:
856
+ s.arrow_schema.prediction_id_column_name = (
857
+ schema.prediction_id_column_name
858
+ )
859
+
860
+ if schema.timestamp_column_name is not None:
861
+ s.arrow_schema.timestamp_column_name = schema.timestamp_column_name
862
+
863
+ if schema.prediction_label_column_name is not None:
864
+ s.arrow_schema.prediction_label_column_name = (
865
+ schema.prediction_label_column_name
866
+ )
867
+
868
+ if model_type == ModelTypes.OBJECT_DETECTION:
869
+ if schema.object_detection_prediction_column_names is not None:
870
+ s.arrow_schema.prediction_object_detection_label_column_names.bboxes_coordinates_column_name = (
871
+ schema.object_detection_prediction_column_names.bounding_boxes_coordinates_column_name # noqa: E501
872
+ )
873
+ s.arrow_schema.prediction_object_detection_label_column_names.bboxes_categories_column_name = (
874
+ schema.object_detection_prediction_column_names.categories_column_name # noqa: E501
875
+ )
876
+ if (
877
+ schema.object_detection_prediction_column_names.scores_column_name
878
+ is not None
879
+ ):
880
+ s.arrow_schema.prediction_object_detection_label_column_names.bboxes_scores_column_name = (
881
+ schema.object_detection_prediction_column_names.scores_column_name # noqa: E501
882
+ )
883
+
884
+ if schema.semantic_segmentation_prediction_column_names is not None:
885
+ s.arrow_schema.prediction_semantic_segmentation_label_column_names.polygons_coordinates_column_name = ( # noqa: E501
886
+ schema.semantic_segmentation_prediction_column_names.polygon_coordinates_column_name
887
+ )
888
+ s.arrow_schema.prediction_semantic_segmentation_label_column_names.polygons_categories_column_name = ( # noqa: E501
889
+ schema.semantic_segmentation_prediction_column_names.categories_column_name
890
+ )
891
+
892
+ if schema.instance_segmentation_prediction_column_names is not None:
893
+ s.arrow_schema.prediction_instance_segmentation_label_column_names.polygons_coordinates_column_name = ( # noqa: E501
894
+ schema.instance_segmentation_prediction_column_names.polygon_coordinates_column_name
895
+ )
896
+ s.arrow_schema.prediction_instance_segmentation_label_column_names.polygons_categories_column_name = ( # noqa: E501
897
+ schema.instance_segmentation_prediction_column_names.categories_column_name
898
+ )
899
+ if (
900
+ schema.instance_segmentation_prediction_column_names.scores_column_name
901
+ is not None
902
+ ):
903
+ s.arrow_schema.prediction_instance_segmentation_label_column_names.polygons_scores_column_name = ( # noqa: E501
904
+ schema.instance_segmentation_prediction_column_names.scores_column_name
905
+ )
906
+ if (
907
+ schema.instance_segmentation_prediction_column_names.bounding_boxes_coordinates_column_name
908
+ is not None
909
+ ):
910
+ s.arrow_schema.prediction_instance_segmentation_label_column_names.bboxes_coordinates_column_name = ( # noqa: E501
911
+ schema.instance_segmentation_prediction_column_names.bounding_boxes_coordinates_column_name
912
+ )
913
+
914
+ if schema.prediction_score_column_name is not None:
915
+ if model_type in NUMERIC_MODEL_TYPES:
916
+ # allow numeric prediction to be sent in as either prediction_label (legacy) or
917
+ # prediction_score.
918
+ s.arrow_schema.prediction_label_column_name = (
919
+ schema.prediction_score_column_name
920
+ )
921
+ else:
922
+ s.arrow_schema.prediction_score_column_name = (
923
+ schema.prediction_score_column_name
924
+ )
925
+
926
+ if schema.feature_column_names is not None:
927
+ s.arrow_schema.feature_column_names.extend(schema.feature_column_names)
928
+
929
+ if schema.embedding_feature_column_names is not None:
930
+ for (
931
+ emb_name,
932
+ emb_col_names,
933
+ ) in schema.embedding_feature_column_names.items():
934
+ # emb_name is how it will show in the UI
935
+ s.arrow_schema.embedding_feature_column_names_map[
936
+ emb_name
937
+ ].vector_column_name = emb_col_names.vector_column_name
938
+ if emb_col_names.data_column_name:
939
+ s.arrow_schema.embedding_feature_column_names_map[
940
+ emb_name
941
+ ].data_column_name = emb_col_names.data_column_name
942
+ if emb_col_names.link_to_data_column_name:
943
+ s.arrow_schema.embedding_feature_column_names_map[
944
+ emb_name
945
+ ].link_to_data_column_name = (
946
+ emb_col_names.link_to_data_column_name
947
+ )
948
+
949
+ if schema.prompt_column_names is not None:
950
+ if isinstance(schema.prompt_column_names, str):
951
+ s.arrow_schema.embedding_feature_column_names_map[
952
+ "prompt"
953
+ ].data_column_name = schema.prompt_column_names
954
+ elif isinstance(schema.prompt_column_names, EmbeddingColumnNames):
955
+ col_names = schema.prompt_column_names
956
+ s.arrow_schema.embedding_feature_column_names_map[
957
+ "prompt"
958
+ ].vector_column_name = col_names.vector_column_name
959
+ if col_names.data_column_name:
960
+ s.arrow_schema.embedding_feature_column_names_map[
961
+ "prompt"
962
+ ].data_column_name = col_names.data_column_name
963
+ if schema.response_column_names is not None:
964
+ if isinstance(schema.response_column_names, str):
965
+ s.arrow_schema.embedding_feature_column_names_map[
966
+ "response"
967
+ ].data_column_name = schema.response_column_names
968
+ elif isinstance(schema.response_column_names, EmbeddingColumnNames):
969
+ col_names = schema.response_column_names
970
+ s.arrow_schema.embedding_feature_column_names_map[
971
+ "response"
972
+ ].vector_column_name = col_names.vector_column_name
973
+ if col_names.data_column_name:
974
+ s.arrow_schema.embedding_feature_column_names_map[
975
+ "response"
976
+ ].data_column_name = col_names.data_column_name
977
+
978
+ if schema.tag_column_names is not None:
979
+ s.arrow_schema.tag_column_names.extend(schema.tag_column_names)
980
+
981
+ if (
982
+ model_type == ModelTypes.RANKING
983
+ and schema.relevance_labels_column_name is not None
984
+ ):
985
+ s.arrow_schema.actual_label_column_name = (
986
+ schema.relevance_labels_column_name
987
+ )
988
+ elif (
989
+ model_type == ModelTypes.RANKING
990
+ and schema.attributions_column_name is not None
991
+ ):
992
+ s.arrow_schema.actual_label_column_name = (
993
+ schema.attributions_column_name
994
+ )
995
+ elif schema.actual_label_column_name is not None:
996
+ s.arrow_schema.actual_label_column_name = (
997
+ schema.actual_label_column_name
998
+ )
999
+
1000
+ if (
1001
+ model_type == ModelTypes.RANKING
1002
+ and schema.relevance_score_column_name is not None
1003
+ ):
1004
+ s.arrow_schema.actual_score_column_name = (
1005
+ schema.relevance_score_column_name
1006
+ )
1007
+ elif schema.actual_score_column_name is not None:
1008
+ if model_type in NUMERIC_MODEL_TYPES:
1009
+ # allow numeric prediction to be sent in as either prediction_label (legacy) or
1010
+ # prediction_score.
1011
+ s.arrow_schema.actual_label_column_name = (
1012
+ schema.actual_score_column_name
1013
+ )
1014
+ else:
1015
+ s.arrow_schema.actual_score_column_name = (
1016
+ schema.actual_score_column_name
1017
+ )
1018
+
1019
+ if schema.shap_values_column_names is not None:
1020
+ s.arrow_schema.shap_values_column_names.update(
1021
+ schema.shap_values_column_names
1022
+ )
1023
+
1024
+ if schema.prediction_group_id_column_name is not None:
1025
+ s.arrow_schema.prediction_group_id_column_name = (
1026
+ schema.prediction_group_id_column_name
1027
+ )
1028
+
1029
+ if schema.rank_column_name is not None:
1030
+ s.arrow_schema.rank_column_name = schema.rank_column_name
1031
+
1032
+ if model_type == ModelTypes.OBJECT_DETECTION:
1033
+ if schema.object_detection_actual_column_names is not None:
1034
+ s.arrow_schema.actual_object_detection_label_column_names.bboxes_coordinates_column_name = ( # noqa: E501
1035
+ schema.object_detection_actual_column_names.bounding_boxes_coordinates_column_name
1036
+ )
1037
+ s.arrow_schema.actual_object_detection_label_column_names.bboxes_categories_column_name = ( # noqa: E501
1038
+ schema.object_detection_actual_column_names.categories_column_name
1039
+ )
1040
+ if (
1041
+ schema.object_detection_actual_column_names.scores_column_name
1042
+ is not None
1043
+ ):
1044
+ s.arrow_schema.actual_object_detection_label_column_names.bboxes_scores_column_name = ( # noqa: E501
1045
+ schema.object_detection_actual_column_names.scores_column_name
1046
+ )
1047
+
1048
+ if schema.semantic_segmentation_actual_column_names is not None:
1049
+ s.arrow_schema.actual_semantic_segmentation_label_column_names.polygons_coordinates_column_name = ( # noqa: E501
1050
+ schema.semantic_segmentation_actual_column_names.polygon_coordinates_column_name
1051
+ )
1052
+ s.arrow_schema.actual_semantic_segmentation_label_column_names.polygons_categories_column_name = ( # noqa: E501
1053
+ schema.semantic_segmentation_actual_column_names.categories_column_name
1054
+ )
1055
+
1056
+ if schema.instance_segmentation_actual_column_names is not None:
1057
+ s.arrow_schema.actual_instance_segmentation_label_column_names.polygons_coordinates_column_name = ( # noqa: E501
1058
+ schema.instance_segmentation_actual_column_names.polygon_coordinates_column_name
1059
+ )
1060
+ s.arrow_schema.actual_instance_segmentation_label_column_names.polygons_categories_column_name = ( # noqa: E501
1061
+ schema.instance_segmentation_actual_column_names.categories_column_name
1062
+ )
1063
+ if (
1064
+ schema.instance_segmentation_actual_column_names.bounding_boxes_coordinates_column_name
1065
+ is not None
1066
+ ):
1067
+ s.arrow_schema.actual_instance_segmentation_label_column_names.bboxes_coordinates_column_name = ( # noqa: E501
1068
+ schema.instance_segmentation_actual_column_names.bounding_boxes_coordinates_column_name
1069
+ )
1070
+
1071
+ if model_type == ModelTypes.GENERATIVE_LLM:
1072
+ if schema.prompt_template_column_names is not None:
1073
+ s.arrow_schema.prompt_template_column_names.template_column_name = (
1074
+ schema.prompt_template_column_names.template_column_name
1075
+ )
1076
+ s.arrow_schema.prompt_template_column_names.template_version_column_name = ( # noqa: E501
1077
+ schema.prompt_template_column_names.template_version_column_name
1078
+ )
1079
+ if schema.llm_config_column_names is not None:
1080
+ s.arrow_schema.llm_config_column_names.model_column_name = (
1081
+ schema.llm_config_column_names.model_column_name
1082
+ )
1083
+ s.arrow_schema.llm_config_column_names.params_map_column_name = (
1084
+ schema.llm_config_column_names.params_column_name
1085
+ )
1086
+ if schema.retrieved_document_ids_column_name is not None:
1087
+ s.arrow_schema.retrieved_document_ids_column_name = (
1088
+ schema.retrieved_document_ids_column_name
1089
+ )
1090
+ if model_type == ModelTypes.MULTI_CLASS:
1091
+ if schema.prediction_score_column_name is not None:
1092
+ s.arrow_schema.prediction_score_column_name = (
1093
+ schema.prediction_score_column_name
1094
+ )
1095
+ if schema.multi_class_threshold_scores_column_name is not None:
1096
+ s.arrow_schema.multi_class_threshold_scores_column_name = (
1097
+ schema.multi_class_threshold_scores_column_name
1098
+ )
1099
+ if schema.actual_score_column_name is not None:
1100
+ s.arrow_schema.actual_score_column_name = (
1101
+ schema.actual_score_column_name
1102
+ )
1103
+ return s
1104
+
1105
+
1106
+ def _get_pb_schema_corpus(
1107
+ schema: CorpusSchema,
1108
+ model_id: str,
1109
+ ) -> pb2.Schema:
1110
+ s = pb2.Schema()
1111
+ s.constants.model_id = model_id
1112
+ s.constants.environment = pb2.Schema.Environment.CORPUS
1113
+ s.constants.model_type = pb2.Schema.ModelType.GENERATIVE_LLM
1114
+ if schema.document_id_column_name is not None:
1115
+ s.arrow_schema.document_column_names.id_column_name = (
1116
+ schema.document_id_column_name
1117
+ )
1118
+ if schema.document_version_column_name is not None:
1119
+ s.arrow_schema.document_column_names.version_column_name = (
1120
+ schema.document_version_column_name
1121
+ )
1122
+ if schema.document_text_embedding_column_names is not None:
1123
+ s.arrow_schema.document_column_names.text_column_name.vector_column_name = schema.document_text_embedding_column_names.vector_column_name # noqa: E501
1124
+ s.arrow_schema.document_column_names.text_column_name.data_column_name = schema.document_text_embedding_column_names.data_column_name # noqa: E501
1125
+ if (
1126
+ schema.document_text_embedding_column_names.link_to_data_column_name
1127
+ is not None
1128
+ ):
1129
+ s.arrow_schema.document_column_names.text_column_name.link_to_data_column_name = schema.document_text_embedding_column_names.link_to_data_column_name # noqa: E501
1130
+ return s