arize 8.0.0a8__tar.gz → 8.0.0a10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (117) hide show
  1. {arize-8.0.0a8 → arize-8.0.0a10}/PKG-INFO +5 -1
  2. {arize-8.0.0a8 → arize-8.0.0a10}/pyproject.toml +9 -4
  3. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/client.py +1 -0
  4. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/models/client.py +11 -13
  5. arize-8.0.0a10/src/arize/models/surrogate_explainer/mimic.py +164 -0
  6. arize-8.0.0a10/src/arize/utils/__init__.py +0 -0
  7. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/utils/arrow.py +16 -27
  8. arize-8.0.0a10/src/arize/version.py +1 -0
  9. arize-8.0.0a8/src/arize/version.py +0 -1
  10. {arize-8.0.0a8 → arize-8.0.0a10}/.gitignore +0 -0
  11. {arize-8.0.0a8 → arize-8.0.0a10}/LICENSE.md +0 -0
  12. {arize-8.0.0a8 → arize-8.0.0a10}/README.md +0 -0
  13. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/__init__.py +0 -0
  14. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_exporter/__init__.py +0 -0
  15. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_exporter/client.py +0 -0
  16. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_exporter/parsers/__init__.py +0 -0
  17. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_exporter/parsers/tracing_data_parser.py +0 -0
  18. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_exporter/validation.py +0 -0
  19. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_flight/__init__.py +0 -0
  20. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_flight/client.py +0 -0
  21. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_flight/types.py +0 -0
  22. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_generated/__init__.py +0 -0
  23. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_generated/api_client/__init__.py +0 -0
  24. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_generated/api_client/api/__init__.py +0 -0
  25. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_generated/api_client/api/datasets_api.py +0 -0
  26. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_generated/api_client/api/experiments_api.py +0 -0
  27. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_generated/api_client/api_client.py +0 -0
  28. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_generated/api_client/api_response.py +0 -0
  29. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_generated/api_client/configuration.py +0 -0
  30. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_generated/api_client/exceptions.py +0 -0
  31. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_generated/api_client/models/__init__.py +0 -0
  32. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_generated/api_client/models/dataset.py +0 -0
  33. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_generated/api_client/models/dataset_version.py +0 -0
  34. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_generated/api_client/models/datasets_create201_response.py +0 -0
  35. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_generated/api_client/models/datasets_create_request.py +0 -0
  36. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_generated/api_client/models/datasets_list200_response.py +0 -0
  37. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_generated/api_client/models/datasets_list_examples200_response.py +0 -0
  38. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_generated/api_client/models/error.py +0 -0
  39. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_generated/api_client/models/experiment.py +0 -0
  40. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_generated/api_client/models/experiments_list200_response.py +0 -0
  41. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_generated/api_client/rest.py +0 -0
  42. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_generated/api_client/test/__init__.py +0 -0
  43. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_generated/api_client/test/test_dataset.py +0 -0
  44. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_generated/api_client/test/test_dataset_version.py +0 -0
  45. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_generated/api_client/test/test_datasets_api.py +0 -0
  46. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_generated/api_client/test/test_datasets_create201_response.py +0 -0
  47. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_generated/api_client/test/test_datasets_create_request.py +0 -0
  48. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_generated/api_client/test/test_datasets_list200_response.py +0 -0
  49. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_generated/api_client/test/test_datasets_list_examples200_response.py +0 -0
  50. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_generated/api_client/test/test_error.py +0 -0
  51. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_generated/api_client/test/test_experiment.py +0 -0
  52. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_generated/api_client/test/test_experiments_api.py +0 -0
  53. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_generated/api_client/test/test_experiments_list200_response.py +0 -0
  54. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_generated/api_client_README.md +0 -0
  55. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_generated/protocol/__init__.py +0 -0
  56. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_generated/protocol/flight/__init__.py +0 -0
  57. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_generated/protocol/flight/export_pb2.py +0 -0
  58. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_generated/protocol/flight/ingest_pb2.py +0 -0
  59. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_generated/protocol/rec/__init__.py +0 -0
  60. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_generated/protocol/rec/public_pb2.py +0 -0
  61. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/_lazy.py +0 -0
  62. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/config.py +0 -0
  63. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/constants/__init__.py +0 -0
  64. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/constants/config.py +0 -0
  65. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/constants/ml.py +0 -0
  66. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/constants/model_mapping.json +0 -0
  67. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/constants/spans.py +0 -0
  68. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/datasets/__init__.py +0 -0
  69. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/datasets/client.py +0 -0
  70. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/exceptions/__init__.py +0 -0
  71. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/exceptions/auth.py +0 -0
  72. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/exceptions/base.py +0 -0
  73. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/exceptions/models.py +0 -0
  74. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/exceptions/parameters.py +0 -0
  75. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/exceptions/spaces.py +0 -0
  76. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/exceptions/types.py +0 -0
  77. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/exceptions/values.py +0 -0
  78. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/experiments/__init__.py +0 -0
  79. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/experiments/client.py +0 -0
  80. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/logging.py +0 -0
  81. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/models/__init__.py +0 -0
  82. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/models/batch_validation/__init__.py +0 -0
  83. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/models/batch_validation/errors.py +0 -0
  84. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/models/batch_validation/validator.py +0 -0
  85. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/models/bounded_executor.py +0 -0
  86. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/models/stream_validation.py +0 -0
  87. {arize-8.0.0a8/src/arize/spans → arize-8.0.0a10/src/arize/models/surrogate_explainer}/__init__.py +0 -0
  88. {arize-8.0.0a8/src/arize/spans/validation → arize-8.0.0a10/src/arize/spans}/__init__.py +0 -0
  89. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/spans/client.py +0 -0
  90. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/spans/columns.py +0 -0
  91. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/spans/conversion.py +0 -0
  92. {arize-8.0.0a8/src/arize/spans/validation/annotations → arize-8.0.0a10/src/arize/spans/validation}/__init__.py +0 -0
  93. {arize-8.0.0a8/src/arize/spans/validation/common → arize-8.0.0a10/src/arize/spans/validation/annotations}/__init__.py +0 -0
  94. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/spans/validation/annotations/annotations_validation.py +0 -0
  95. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/spans/validation/annotations/dataframe_form_validation.py +0 -0
  96. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/spans/validation/annotations/value_validation.py +0 -0
  97. {arize-8.0.0a8/src/arize/spans/validation/evals → arize-8.0.0a10/src/arize/spans/validation/common}/__init__.py +0 -0
  98. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/spans/validation/common/argument_validation.py +0 -0
  99. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/spans/validation/common/dataframe_form_validation.py +0 -0
  100. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/spans/validation/common/errors.py +0 -0
  101. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/spans/validation/common/value_validation.py +0 -0
  102. {arize-8.0.0a8/src/arize/spans/validation/spans → arize-8.0.0a10/src/arize/spans/validation/evals}/__init__.py +0 -0
  103. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/spans/validation/evals/dataframe_form_validation.py +0 -0
  104. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/spans/validation/evals/evals_validation.py +0 -0
  105. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/spans/validation/evals/value_validation.py +0 -0
  106. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/spans/validation/metadata/__init__.py +0 -0
  107. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/spans/validation/metadata/argument_validation.py +0 -0
  108. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/spans/validation/metadata/dataframe_form_validation.py +0 -0
  109. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/spans/validation/metadata/value_validation.py +0 -0
  110. {arize-8.0.0a8/src/arize/utils → arize-8.0.0a10/src/arize/spans/validation/spans}/__init__.py +0 -0
  111. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/spans/validation/spans/dataframe_form_validation.py +0 -0
  112. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/spans/validation/spans/spans_validation.py +0 -0
  113. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/spans/validation/spans/value_validation.py +0 -0
  114. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/types.py +0 -0
  115. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/utils/casting.py +0 -0
  116. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/utils/dataframe.py +0 -0
  117. {arize-8.0.0a8 → arize-8.0.0a10}/src/arize/utils/proto.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: arize
3
- Version: 8.0.0a8
3
+ Version: 8.0.0a10
4
4
  Summary: A helper library to interact with Arize AI APIs
5
5
  Project-URL: Homepage, https://arize.com
6
6
  Project-URL: Documentation, https://docs.arize.com/arize
@@ -30,10 +30,13 @@ Requires-Dist: numpy>=2.0.0
30
30
  Provides-Extra: dev
31
31
  Requires-Dist: pytest==8.4.2; extra == 'dev'
32
32
  Requires-Dist: ruff==0.13.2; extra == 'dev'
33
+ Provides-Extra: mimic-explainer
34
+ Requires-Dist: interpret-community[mimic]<1,>=0.22.0; extra == 'mimic-explainer'
33
35
  Provides-Extra: ml-batch
34
36
  Requires-Dist: pandas<3,>=1.0.0; extra == 'ml-batch'
35
37
  Requires-Dist: protobuf<6,>=4.21.0; extra == 'ml-batch'
36
38
  Requires-Dist: pyarrow>=0.15.0; extra == 'ml-batch'
39
+ Requires-Dist: requests<3,>=2.0.0; extra == 'ml-batch'
37
40
  Requires-Dist: tqdm; extra == 'ml-batch'
38
41
  Provides-Extra: ml-stream
39
42
  Requires-Dist: protobuf<6,>=4.21.0; extra == 'ml-stream'
@@ -44,6 +47,7 @@ Requires-Dist: opentelemetry-semantic-conventions<1,>=0.43b0; extra == 'spans'
44
47
  Requires-Dist: pandas<3,>=1.0.0; extra == 'spans'
45
48
  Requires-Dist: protobuf<6,>=4.21.0; extra == 'spans'
46
49
  Requires-Dist: pyarrow>=0.15.0; extra == 'spans'
50
+ Requires-Dist: requests<3,>=2.0.0; extra == 'spans'
47
51
  Requires-Dist: tqdm; extra == 'spans'
48
52
  Description-Content-Type: text/markdown
49
53
 
@@ -34,8 +34,8 @@ classifiers = [
34
34
  "Topic :: System :: Monitoring",
35
35
  ]
36
36
  dependencies = [
37
+ "numpy>=2.0.0", # For vector embeddings
37
38
  "lazy-imports",
38
- "numpy>=2.0.0",
39
39
  # "requests_futures==1.0.0",
40
40
  # "googleapis_common_protos>=1.51.0,<2",
41
41
  # "protobuf>=4.21.0,<6",
@@ -57,7 +57,8 @@ spans = [
57
57
  "pandas>=1.0.0,<3",
58
58
  "protobuf>=4.21.0,<6",
59
59
  "pyarrow>=0.15.0",
60
- "tqdm",
60
+ "requests>=2.0.0, <3", # For posting pyarrow files
61
+ "tqdm", # For export progress bars
61
62
  ]
62
63
  ml-stream = [
63
64
  "requests_futures>=1.0.0, <2",
@@ -65,13 +66,17 @@ ml-stream = [
65
66
  ]
66
67
  ml-batch = [
67
68
  "pandas>=1.0.0,<3",
68
- "pyarrow>=0.15.0",
69
69
  "protobuf>=4.21.0,<6",
70
- "tqdm",
70
+ "pyarrow>=0.15.0",
71
+ "requests>=2.0.0, <3", # For posting pyarrow files
72
+ "tqdm", # For export progress bars
71
73
  ]
72
74
  # datasets-experiments = [
73
75
  # "pydantic",
74
76
  # ]
77
+ mimic-explainer = [
78
+ "interpret-community[mimic]>=0.22.0,<1",
79
+ ]
75
80
 
76
81
  [project.urls]
77
82
  Homepage = "https://arize.com"
@@ -65,6 +65,7 @@ class ArizeClient(LazySubclientsMixin):
65
65
  "opentelemetry",
66
66
  "pandas",
67
67
  "pyarrow",
68
+ "requests",
68
69
  "tqdm",
69
70
  ),
70
71
  ),
@@ -52,7 +52,6 @@ from arize.types import (
52
52
  is_list_of,
53
53
  )
54
54
  from arize.utils.casting import cast_dictionary, cast_typed_columns
55
- from arize.utils.dataframe import remove_extraneous_columns
56
55
 
57
56
  if TYPE_CHECKING:
58
57
  import concurrent.futures as cf
@@ -75,11 +74,17 @@ _STREAM_EXTRA = "ml-stream"
75
74
 
76
75
  _BATCH_DEPS = (
77
76
  "pandas",
78
- "pyarrow",
79
77
  "google.protobuf",
78
+ "pyarrow",
79
+ "requests",
80
80
  "tqdm",
81
81
  )
82
82
  _BATCH_EXTRA = "ml-batch"
83
+ _MIMIC_DEPS = (
84
+ "interpret_community.mimic",
85
+ "sklearn.preprocessing",
86
+ )
87
+ _MIMIC_EXTRA = "mimic-explainer"
83
88
 
84
89
 
85
90
  class MLModelsClient:
@@ -116,7 +121,6 @@ class MLModelsClient:
116
121
  timeout: float | None = None,
117
122
  ) -> cf.Future:
118
123
  require(_STREAM_EXTRA, _STREAM_DEPS)
119
-
120
124
  from arize._generated.protocol.rec import public_pb2 as pb2
121
125
  from arize.utils.proto import (
122
126
  get_pb_dictionary,
@@ -464,6 +468,7 @@ class MLModelsClient:
464
468
 
465
469
  from arize.models.batch_validation.validator import Validator
466
470
  from arize.utils.arrow import post_arrow_table
471
+ from arize.utils.dataframe import remove_extraneous_columns
467
472
  from arize.utils.proto import get_pb_schema, get_pb_schema_corpus
468
473
 
469
474
  # This method requires a space_id and project_name
@@ -544,17 +549,10 @@ class MLModelsClient:
544
549
  dataframe = dataframe.astype(cat_str_map)
545
550
 
546
551
  if surrogate_explainability:
547
- logger.debug("Running surrogate_explainability.")
552
+ require(_MIMIC_EXTRA, _MIMIC_DEPS)
553
+ from arize.models.surrogate_explainer.mimic import Mimic
548
554
 
549
- try:
550
- # WARNING: MIMIC EXPLAINER IS NOT DONE
551
- from arize.pandas.surrogate_explainer.mimic import Mimic
552
- except ImportError:
553
- raise ImportError(
554
- "To enable surrogate explainability, "
555
- "the arize module must be installed with the MimicExplainer option: pip "
556
- "install 'arize[MimicExplainer]'."
557
- ) from None
555
+ logger.debug("Running surrogate_explainability.")
558
556
  if schema.shap_values_column_names:
559
557
  logger.info(
560
558
  "surrogate_explainability=True has no effect "
@@ -0,0 +1,164 @@
1
+ from __future__ import annotations
2
+
3
+ import random
4
+ import string
5
+ from dataclasses import replace
6
+ from typing import TYPE_CHECKING, Callable, Tuple
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ from interpret_community.mimic.mimic_explainer import (
11
+ LGBMExplainableModel,
12
+ MimicExplainer,
13
+ )
14
+ from sklearn.preprocessing import LabelEncoder
15
+
16
+ from arize.types import (
17
+ CATEGORICAL_MODEL_TYPES,
18
+ NUMERIC_MODEL_TYPES,
19
+ ModelTypes,
20
+ )
21
+
22
+ if TYPE_CHECKING:
23
+ from arize.types import Schema
24
+
25
+
26
+ class Mimic:
27
+ _testing = False
28
+
29
+ def __init__(self, X: pd.DataFrame, model_func: Callable):
30
+ self.explainer = MimicExplainer(
31
+ model_func,
32
+ X,
33
+ LGBMExplainableModel,
34
+ augment_data=False,
35
+ is_function=True,
36
+ )
37
+
38
+ def explain(self, X: pd.DataFrame) -> pd.DataFrame:
39
+ return pd.DataFrame(
40
+ self.explainer.explain_local(X).local_importance_values,
41
+ columns=X.columns,
42
+ index=X.index,
43
+ )
44
+
45
+ @staticmethod
46
+ def augment(
47
+ df: pd.DataFrame, schema: Schema, model_type: ModelTypes
48
+ ) -> Tuple[pd.DataFrame, Schema]:
49
+ features = schema.feature_column_names
50
+ X = df[features]
51
+
52
+ if X.shape[1] == 0:
53
+ return df, schema
54
+
55
+ if model_type in CATEGORICAL_MODEL_TYPES:
56
+ if not schema.prediction_score_column_name:
57
+ raise ValueError(
58
+ "To calculate surrogate explainability, "
59
+ f"prediction_score_column_name must be specified in schema for {model_type}."
60
+ )
61
+
62
+ y_col_name = schema.prediction_score_column_name
63
+ y = df[y_col_name].to_numpy()
64
+
65
+ _min, _max = np.min(y), np.max(y)
66
+ if not 0 <= _min <= 1 or not 0 <= _max <= 1:
67
+ raise ValueError(
68
+ f"To calculate surrogate explainability for {model_type}, "
69
+ f"prediction scores must be between 0 and 1, but current "
70
+ f"prediction scores range from {_min} to {_max}."
71
+ )
72
+
73
+ # model func requires 1 positional argument
74
+ def model_func(_): # type: ignore
75
+ return np.column_stack((1 - y, y))
76
+
77
+ elif model_type in NUMERIC_MODEL_TYPES:
78
+ y_col_name = schema.prediction_label_column_name
79
+ if schema.prediction_score_column_name is not None:
80
+ y_col_name = schema.prediction_score_column_name
81
+ y = df[y_col_name].to_numpy()
82
+
83
+ _finite_count = np.isfinite(y).sum()
84
+ if len(y) - _finite_count:
85
+ raise ValueError(
86
+ f"To calculate surrogate explainability for {model_type}, "
87
+ f"predictions must not contain NaN or infinite values, but "
88
+ f"{len(y) - _finite_count} NaN or infinite value(s) are found in {y_col_name}."
89
+ )
90
+
91
+ # model func requires 1 positional argument
92
+ def model_func(_): # type: ignore
93
+ return y
94
+
95
+ else:
96
+ raise ValueError(
97
+ "Surrogate explainability is not supported for the specified "
98
+ f"model type {model_type}."
99
+ )
100
+
101
+ # Column name mapping between features and feature importance values.
102
+ # This is used to augment the schema.
103
+ col_map = {
104
+ ft: f"{''.join(random.choices(string.ascii_letters, k=8))}"
105
+ for ft in features
106
+ }
107
+ aug_schema = replace(schema, shap_values_column_names=col_map)
108
+
109
+ # Limit the total number of "cells" to 20M, unless it results in too few or
110
+ # too many rows. This is done to keep the runtime low. Records not sampled
111
+ # have feature importance values set to 0.
112
+ samp_size = min(
113
+ len(X), min(100_000, max(1_000, 20_000_000 // X.shape[1]))
114
+ )
115
+
116
+ if samp_size < len(X):
117
+ _mask = np.zeros(len(X), dtype=int)
118
+ _mask[:samp_size] = 1
119
+ np.random.shuffle(_mask)
120
+ _mask = _mask.astype(bool)
121
+ X = X[_mask]
122
+ y = y[_mask]
123
+
124
+ # Replace all pd.NA values with np.nan values
125
+ for col in X.columns:
126
+ if X[col].isna().any():
127
+ X[col] = X[col].astype(object).where(~X[col].isna(), np.nan)
128
+
129
+ # Apply integer encoding to non-numeric columns.
130
+ # Currently training and explaining detasets are the same, but
131
+ # this can be changed in the future. The student model can be
132
+ # fitted on a much larger dataset since it takes a lot less time.
133
+ X = pd.concat(
134
+ [
135
+ X.select_dtypes(exclude=[object, "string"]),
136
+ pd.DataFrame(
137
+ {
138
+ name: LabelEncoder().fit_transform(data)
139
+ for name, data in X.select_dtypes(
140
+ include=[object, "string"]
141
+ ).items()
142
+ },
143
+ index=X.index,
144
+ ),
145
+ ],
146
+ axis=1,
147
+ )
148
+
149
+ aug_df = pd.concat(
150
+ [
151
+ df,
152
+ Mimic(X, model_func).explain(X).rename(col_map, axis=1),
153
+ ],
154
+ axis=1,
155
+ )
156
+
157
+ # Fill null with zero so they're not counted as missing records by server
158
+ if not Mimic._testing:
159
+ aug_df.fillna({c: 0 for c in col_map.values()}, inplace=True)
160
+
161
+ return (
162
+ aug_df,
163
+ aug_schema,
164
+ )
File without changes
@@ -8,11 +8,12 @@ import tempfile
8
8
  from typing import TYPE_CHECKING, Any, Dict
9
9
 
10
10
  import pyarrow as pa
11
- import requests
12
11
 
13
12
  from arize.logging import get_arize_project_url, log_a_list
14
13
 
15
14
  if TYPE_CHECKING:
15
+ import requests
16
+
16
17
  from arize._generated.protocol.rec import public_pb2 as pb2
17
18
 
18
19
  logger = logging.getLogger(__name__)
@@ -27,6 +28,9 @@ def post_arrow_table(
27
28
  verify: bool,
28
29
  tmp_dir: str = "",
29
30
  ) -> requests.Response:
31
+ # We import here to avoid depending onn requests for all arrow utils
32
+ import requests
33
+
30
34
  logger.debug("Preparing to log Arrow table via file upload")
31
35
  logger.debug(
32
36
  "Preparing to log Arrow table via file upload",
@@ -71,15 +75,17 @@ def post_arrow_table(
71
75
  "Uploading file to Arize",
72
76
  extra={"path": outfile, "size_bytes": _filesize(outfile)},
73
77
  )
74
- resp = _post_file(
75
- files_url=files_url,
76
- path=outfile,
77
- headers=headers,
78
- timeout=timeout,
79
- verify=verify,
80
- )
81
- _maybe_log_project_url(resp)
82
- return resp
78
+ # Post file
79
+ with open(outfile, "rb") as f:
80
+ resp = requests.post(
81
+ files_url,
82
+ timeout=timeout,
83
+ data=f,
84
+ headers=headers,
85
+ verify=verify,
86
+ )
87
+ _maybe_log_project_url(resp)
88
+ return resp
83
89
  finally:
84
90
  if tdir is not None:
85
91
  try:
@@ -98,23 +104,6 @@ def post_arrow_table(
98
104
  )
99
105
 
100
106
 
101
- def _post_file(
102
- files_url: str,
103
- path: str,
104
- headers: Dict[str, str],
105
- timeout: float | None,
106
- verify: bool,
107
- ) -> requests.Response:
108
- with open(path, "rb") as f:
109
- return requests.post(
110
- files_url,
111
- timeout=timeout,
112
- data=f,
113
- headers=headers,
114
- verify=verify,
115
- )
116
-
117
-
118
107
  def append_to_pyarrow_metadata(
119
108
  pa_schema: pa.Schema, new_metadata: Dict[str, Any]
120
109
  ):
@@ -0,0 +1 @@
1
+ __version__ = "8.0.0a10"
@@ -1 +0,0 @@
1
- __version__ = "8.0.0a8"
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes