arize 8.0.0a13__tar.gz → 8.0.0a15__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (127) hide show
  1. {arize-8.0.0a13 → arize-8.0.0a15}/PKG-INFO +11 -3
  2. {arize-8.0.0a13 → arize-8.0.0a15}/README.md +10 -2
  3. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/_exporter/client.py +18 -3
  4. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/_flight/client.py +6 -2
  5. arize-8.0.0a15/src/arize/datasets/client.py +142 -0
  6. {arize-8.0.0a13/src/arize/utils → arize-8.0.0a15/src/arize/models}/casting.py +12 -12
  7. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/models/client.py +330 -5
  8. {arize-8.0.0a13/src/arize/utils → arize-8.0.0a15/src/arize/models}/proto.py +1 -369
  9. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/spans/client.py +30 -6
  10. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/utils/arrow.py +4 -4
  11. arize-8.0.0a15/src/arize/version.py +1 -0
  12. arize-8.0.0a13/src/arize/datasets/client.py +0 -137
  13. arize-8.0.0a13/src/arize/version.py +0 -1
  14. {arize-8.0.0a13 → arize-8.0.0a15}/.gitignore +0 -0
  15. {arize-8.0.0a13 → arize-8.0.0a15}/LICENSE.md +0 -0
  16. {arize-8.0.0a13 → arize-8.0.0a15}/pyproject.toml +0 -0
  17. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/__init__.py +0 -0
  18. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/_exporter/__init__.py +0 -0
  19. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/_exporter/parsers/__init__.py +0 -0
  20. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/_exporter/parsers/tracing_data_parser.py +0 -0
  21. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/_exporter/validation.py +0 -0
  22. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/_flight/__init__.py +0 -0
  23. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/_flight/types.py +0 -0
  24. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/_generated/__init__.py +0 -0
  25. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/_generated/api_client/__init__.py +0 -0
  26. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/_generated/api_client/api/__init__.py +0 -0
  27. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/_generated/api_client/api/datasets_api.py +0 -0
  28. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/_generated/api_client/api/experiments_api.py +0 -0
  29. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/_generated/api_client/api_client.py +0 -0
  30. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/_generated/api_client/api_response.py +0 -0
  31. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/_generated/api_client/configuration.py +0 -0
  32. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/_generated/api_client/exceptions.py +0 -0
  33. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/_generated/api_client/models/__init__.py +0 -0
  34. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/_generated/api_client/models/dataset.py +0 -0
  35. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/_generated/api_client/models/dataset_version.py +0 -0
  36. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/_generated/api_client/models/datasets_create_request.py +0 -0
  37. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/_generated/api_client/models/datasets_list200_response.py +0 -0
  38. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/_generated/api_client/models/datasets_list_examples200_response.py +0 -0
  39. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/_generated/api_client/models/error.py +0 -0
  40. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/_generated/api_client/models/experiment.py +0 -0
  41. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/_generated/api_client/models/experiments_list200_response.py +0 -0
  42. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/_generated/api_client/rest.py +0 -0
  43. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/_generated/api_client/test/__init__.py +0 -0
  44. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/_generated/api_client/test/test_dataset.py +0 -0
  45. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/_generated/api_client/test/test_dataset_version.py +0 -0
  46. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/_generated/api_client/test/test_datasets_api.py +0 -0
  47. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/_generated/api_client/test/test_datasets_create_request.py +0 -0
  48. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/_generated/api_client/test/test_datasets_list200_response.py +0 -0
  49. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/_generated/api_client/test/test_datasets_list_examples200_response.py +0 -0
  50. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/_generated/api_client/test/test_error.py +0 -0
  51. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/_generated/api_client/test/test_experiment.py +0 -0
  52. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/_generated/api_client/test/test_experiments_api.py +0 -0
  53. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/_generated/api_client/test/test_experiments_list200_response.py +0 -0
  54. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/_generated/api_client_README.md +0 -0
  55. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/_generated/protocol/__init__.py +0 -0
  56. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/_generated/protocol/flight/__init__.py +0 -0
  57. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/_generated/protocol/flight/export_pb2.py +0 -0
  58. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/_generated/protocol/flight/ingest_pb2.py +0 -0
  59. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/_generated/protocol/rec/__init__.py +0 -0
  60. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/_generated/protocol/rec/public_pb2.py +0 -0
  61. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/_lazy.py +0 -0
  62. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/client.py +0 -0
  63. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/config.py +0 -0
  64. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/constants/__init__.py +0 -0
  65. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/constants/config.py +0 -0
  66. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/constants/ml.py +0 -0
  67. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/constants/model_mapping.json +0 -0
  68. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/constants/spans.py +0 -0
  69. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/datasets/__init__.py +0 -0
  70. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/embeddings/__init__.py +0 -0
  71. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/embeddings/auto_generator.py +0 -0
  72. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/embeddings/base_generators.py +0 -0
  73. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/embeddings/constants.py +0 -0
  74. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/embeddings/cv_generators.py +0 -0
  75. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/embeddings/errors.py +0 -0
  76. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/embeddings/nlp_generators.py +0 -0
  77. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/embeddings/tabular_generators.py +0 -0
  78. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/embeddings/usecases.py +0 -0
  79. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/exceptions/__init__.py +0 -0
  80. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/exceptions/auth.py +0 -0
  81. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/exceptions/base.py +0 -0
  82. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/exceptions/models.py +0 -0
  83. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/exceptions/parameters.py +0 -0
  84. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/exceptions/spaces.py +0 -0
  85. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/exceptions/types.py +0 -0
  86. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/exceptions/values.py +0 -0
  87. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/experiments/__init__.py +0 -0
  88. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/experiments/client.py +0 -0
  89. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/logging.py +0 -0
  90. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/models/__init__.py +0 -0
  91. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/models/batch_validation/__init__.py +0 -0
  92. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/models/batch_validation/errors.py +0 -0
  93. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/models/batch_validation/validator.py +0 -0
  94. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/models/bounded_executor.py +0 -0
  95. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/models/stream_validation.py +0 -0
  96. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/models/surrogate_explainer/__init__.py +0 -0
  97. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/models/surrogate_explainer/mimic.py +0 -0
  98. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/spans/__init__.py +0 -0
  99. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/spans/columns.py +0 -0
  100. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/spans/conversion.py +0 -0
  101. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/spans/validation/__init__.py +0 -0
  102. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/spans/validation/annotations/__init__.py +0 -0
  103. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/spans/validation/annotations/annotations_validation.py +0 -0
  104. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/spans/validation/annotations/dataframe_form_validation.py +0 -0
  105. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/spans/validation/annotations/value_validation.py +0 -0
  106. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/spans/validation/common/__init__.py +0 -0
  107. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/spans/validation/common/argument_validation.py +0 -0
  108. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/spans/validation/common/dataframe_form_validation.py +0 -0
  109. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/spans/validation/common/errors.py +0 -0
  110. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/spans/validation/common/value_validation.py +0 -0
  111. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/spans/validation/evals/__init__.py +0 -0
  112. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/spans/validation/evals/dataframe_form_validation.py +0 -0
  113. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/spans/validation/evals/evals_validation.py +0 -0
  114. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/spans/validation/evals/value_validation.py +0 -0
  115. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/spans/validation/metadata/__init__.py +0 -0
  116. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/spans/validation/metadata/argument_validation.py +0 -0
  117. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/spans/validation/metadata/dataframe_form_validation.py +0 -0
  118. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/spans/validation/metadata/value_validation.py +0 -0
  119. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/spans/validation/spans/__init__.py +0 -0
  120. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/spans/validation/spans/dataframe_form_validation.py +0 -0
  121. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/spans/validation/spans/spans_validation.py +0 -0
  122. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/spans/validation/spans/value_validation.py +0 -0
  123. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/types.py +0 -0
  124. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/utils/__init__.py +0 -0
  125. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/utils/dataframe.py +0 -0
  126. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/utils/online_tasks/__init__.py +0 -0
  127. {arize-8.0.0a13 → arize-8.0.0a15}/src/arize/utils/online_tasks/dataframe_preprocessor.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: arize
3
- Version: 8.0.0a13
3
+ Version: 8.0.0a15
4
4
  Summary: A helper library to interact with Arize AI APIs
5
5
  Project-URL: Homepage, https://arize.com
6
6
  Project-URL: Documentation, https://docs.arize.com/arize
@@ -424,7 +424,7 @@ examples = [
424
424
  ]
425
425
  ```
426
426
 
427
- If the number of examples (rows in dataframe, items in list) is too large, the client SDK will try to send the data via Arrow Flight via gRPC for better performance. If you want to force the data transfer to HTTP you can use the `force_http` flag. The response is always a `Dataset` object.
427
+ If the number of examples (rows in dataframe, items in list) is too large, the client SDK will try to send the data via Arrow Flight via gRPC for better performance. If you want to force the data transfer to HTTP you can use the `force_http` flag. The response is a `Dataset` object.
428
428
 
429
429
  ```python
430
430
  created_dataset = client.datasets.create(
@@ -434,11 +434,19 @@ created_dataset = client.datasets.create(
434
434
  )
435
435
  ```
436
436
 
437
+ The `Dataset` object also counts with convenience method similar to `List***` objects:
438
+
439
+ ```python
440
+ # Get the response as a dictionary
441
+ dataset_dict = create_dataset.to_dict()
442
+ # Get the response in JSON format
443
+ dataset_dict = create_dataset.to_json()
444
+ ```
437
445
 
438
446
 
439
447
  ### Get Dataset by ID
440
448
 
441
- To get a dataset by its ID use `client.datasets.get()`, you can optionally also pass the version ID of a particular version of interest of the dataset.
449
+ To get a dataset by its ID use `client.datasets.get()`, you can optionally also pass the version ID of a particular version of interest of the dataset. The returned type is `Dataset`.
442
450
 
443
451
  ```python
444
452
  dataset = client.datasets.get(
@@ -362,7 +362,7 @@ examples = [
362
362
  ]
363
363
  ```
364
364
 
365
- If the number of examples (rows in dataframe, items in list) is too large, the client SDK will try to send the data via Arrow Flight via gRPC for better performance. If you want to force the data transfer to HTTP you can use the `force_http` flag. The response is always a `Dataset` object.
365
+ If the number of examples (rows in dataframe, items in list) is too large, the client SDK will try to send the data via Arrow Flight via gRPC for better performance. If you want to force the data transfer to HTTP you can use the `force_http` flag. The response is a `Dataset` object.
366
366
 
367
367
  ```python
368
368
  created_dataset = client.datasets.create(
@@ -372,11 +372,19 @@ created_dataset = client.datasets.create(
372
372
  )
373
373
  ```
374
374
 
375
+ The `Dataset` object also counts with convenience method similar to `List***` objects:
376
+
377
+ ```python
378
+ # Get the response as a dictionary
379
+ dataset_dict = create_dataset.to_dict()
380
+ # Get the response in JSON format
381
+ dataset_dict = create_dataset.to_json()
382
+ ```
375
383
 
376
384
 
377
385
  ### Get Dataset by ID
378
386
 
379
- To get a dataset by its ID use `client.datasets.get()`, you can optionally also pass the version ID of a particular version of interest of the dataset.
387
+ To get a dataset by its ID use `client.datasets.get()`, you can optionally also pass the version ID of a particular version of interest of the dataset. The returned type is `Dataset`.
380
388
 
381
389
  ```python
382
390
  dataset = client.datasets.get(
@@ -20,7 +20,6 @@ from arize._generated.protocol.flight import export_pb2
20
20
  from arize.logging import CtxAdapter
21
21
  from arize.types import Environments, SimilaritySearchParams
22
22
  from arize.utils.dataframe import reset_dataframe_index
23
- from arize.utils.proto import get_pb_flight_doput_request
24
23
 
25
24
  logger = logging.getLogger(__name__)
26
25
 
@@ -131,7 +130,7 @@ class ArizeExportClient:
131
130
  reset_dataframe_index(df)
132
131
  return df
133
132
 
134
- def export_model_to_parquet(
133
+ def export_to_parquet(
135
134
  self,
136
135
  path: str,
137
136
  space_id: str,
@@ -285,7 +284,7 @@ class ArizeExportClient:
285
284
  end_time=Timestamp(seconds=int(end_time.timestamp())),
286
285
  filter_expression=where,
287
286
  similarity_search_params=(
288
- get_pb_flight_doput_request(similarity_search_params)
287
+ _get_pb_similarity_search_params(similarity_search_params)
289
288
  if similarity_search_params
290
289
  else None
291
290
  ),
@@ -326,3 +325,19 @@ class ArizeExportClient:
326
325
  colour="#008000",
327
326
  unit=" row",
328
327
  )
328
+
329
+
330
+ def _get_pb_similarity_search_params(
331
+ similarity_params: SimilaritySearchParams,
332
+ ) -> export_pb2.SimilaritySearchParams:
333
+ proto_params = export_pb2.SimilaritySearchParams()
334
+ proto_params.search_column_name = similarity_params.search_column_name
335
+ proto_params.threshold = similarity_params.threshold
336
+ for ref in similarity_params.references:
337
+ new_ref = proto_params.references.add()
338
+ new_ref.prediction_id = ref.prediction_id
339
+ new_ref.reference_column_name = ref.reference_column_name
340
+ if ref.prediction_timestamp:
341
+ new_ref.prediction_timestamp.FromDatetime(ref.prediction_timestamp)
342
+
343
+ return proto_params
@@ -179,7 +179,9 @@ class ArizeFlightClient:
179
179
  return res
180
180
  except Exception as e:
181
181
  logger.exception(f"Error logging arrow table to Arize: {e}")
182
- raise RuntimeError(f"Error logging arrow table to Arize: {e}") from e
182
+ raise RuntimeError(
183
+ f"Error logging arrow table to Arize: {e}"
184
+ ) from e
183
185
 
184
186
  # ---------- dataset methods ----------
185
187
 
@@ -221,7 +223,9 @@ class ArizeFlightClient:
221
223
  return res
222
224
  except Exception as e:
223
225
  logger.exception(f"Error logging arrow table to Arize: {e}")
224
- raise RuntimeError(f"Error logging arrow table to Arize: {e}") from e
226
+ raise RuntimeError(
227
+ f"Error logging arrow table to Arize: {e}"
228
+ ) from e
225
229
 
226
230
 
227
231
  def append_to_pyarrow_metadata(
@@ -0,0 +1,142 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from typing import Any, Dict, List
5
+
6
+ import pandas as pd
7
+ import pyarrow as pa
8
+
9
+ from arize._flight.client import ArizeFlightClient
10
+ from arize.config import SDKConfiguration
11
+ from arize.exceptions.base import INVALID_ARROW_CONVERSION_MSG
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ REST_LIMIT_DATASET_EXAMPLES = 3
16
+
17
+
18
+ class DatasetsClient:
19
+ def __init__(self, sdk_config: SDKConfiguration):
20
+ self._sdk_config = sdk_config
21
+
22
+ # Import at runtime so it’s still lazy and extras-gated by the parent
23
+ from arize._generated import api_client as gen
24
+
25
+ # Use the shared generated client from the config
26
+ self._api = gen.DatasetsApi(self._sdk_config.get_generated_client())
27
+
28
+ # Forward methods to preserve exact runtime signatures/docs
29
+ self.list = self._api.datasets_list
30
+ self.get = self._api.datasets_get
31
+ self.delete = self._api.datasets_delete
32
+ self.list_examples = self._api.datasets_list_examples
33
+
34
+ # Custom methods
35
+ self.create = self._create_dataset
36
+
37
+ def _create_dataset(
38
+ self,
39
+ name: str,
40
+ space_id: str,
41
+ examples: List[Dict[str, Any]] | pd.DataFrame,
42
+ force_http: bool = False,
43
+ ):
44
+ if not isinstance(examples, (list, pd.DataFrame)):
45
+ raise TypeError(
46
+ "Examples must be a list of dicts or a pandas DataFrame"
47
+ )
48
+ if len(examples) <= REST_LIMIT_DATASET_EXAMPLES or force_http:
49
+ from arize._generated import api_client as gen
50
+
51
+ data = (
52
+ examples.to_dict(orient="records")
53
+ if isinstance(examples, pd.DataFrame)
54
+ else examples
55
+ )
56
+
57
+ body = gen.DatasetsCreateRequest(
58
+ name=name,
59
+ spaceId=space_id,
60
+ examples=data,
61
+ )
62
+ return self._api.datasets_create(datasets_create_request=body)
63
+
64
+ # If we have too many examples, try to convert to a dataframe
65
+ # and log via gRPC + flight
66
+ logger.info(
67
+ f"Uploading {len(examples)} examples via REST may be slow. "
68
+ "Trying to convert to DataFrame for more efficient upload via "
69
+ "gRPC + Flight."
70
+ )
71
+ data = (
72
+ pd.DataFrame(examples) if isinstance(examples, list) else examples
73
+ )
74
+ return self._create_dataset_via_flight(
75
+ name=name,
76
+ space_id=space_id,
77
+ examples=data,
78
+ )
79
+
80
+ def _create_dataset_via_flight(
81
+ self,
82
+ name: str,
83
+ space_id: str,
84
+ examples: pd.DataFrame,
85
+ ):
86
+ # Convert datetime columns to int64 (ms since epoch)
87
+ # TODO(Kiko): Missing validation block
88
+ # data = _convert_datetime_columns_to_int(data)
89
+ # df = self._set_default_columns_for_dataset(data)
90
+ # if convert_dict_to_json:
91
+ # df = _convert_default_columns_to_json_str(df)
92
+ # df = _convert_boolean_columns_to_str(df)
93
+ # validation_errors = Validator.validate(df)
94
+ # validation_errors.extend(
95
+ # Validator.validate_max_chunk_size(max_chunk_size)
96
+ # )
97
+ # if validation_errors:
98
+ # raise RuntimeError(
99
+ # [e.error_message() for e in validation_errors]
100
+ # )
101
+
102
+ # Convert to Arrow table
103
+ try:
104
+ logger.debug("Converting data to Arrow format")
105
+ pa_table = pa.Table.from_pandas(examples)
106
+ except pa.ArrowInvalid as e:
107
+ logger.error(f"{INVALID_ARROW_CONVERSION_MSG}: {str(e)}")
108
+ raise pa.ArrowInvalid(
109
+ f"Error converting to Arrow format: {str(e)}"
110
+ ) from e
111
+ except Exception as e:
112
+ logger.error(f"Unexpected error creating Arrow table: {str(e)}")
113
+ raise
114
+
115
+ response = None
116
+ with ArizeFlightClient(
117
+ api_key=self._sdk_config.api_key,
118
+ host=self._sdk_config.flight_server_host,
119
+ port=self._sdk_config.flight_server_port,
120
+ scheme=self._sdk_config.flight_scheme,
121
+ request_verify=self._sdk_config.request_verify,
122
+ ) as flight_client:
123
+ try:
124
+ response = flight_client.create_dataset(
125
+ space_id=space_id,
126
+ dataset_name=name,
127
+ pa_table=pa_table,
128
+ )
129
+ except Exception as e:
130
+ msg = f"Error during update request: {str(e)}"
131
+ logger.error(msg)
132
+ raise RuntimeError(msg) from e
133
+ if response is None:
134
+ # This should not happen with proper Flight client implementation,
135
+ # but we handle it defensively
136
+ msg = "No response received from flight server during update"
137
+ logger.error(msg)
138
+ raise RuntimeError(msg)
139
+ # The response from flightserver is the dataset ID. To return the dataset
140
+ # object we make a GET query
141
+ dataset = self.get(dataset_id=response)
142
+ return dataset
@@ -131,28 +131,28 @@ def cast_typed_columns(
131
131
  f = getattr(schema, field_name)
132
132
  if f:
133
133
  try:
134
- validate_typed_columns(field_name, f)
134
+ _validate_typed_columns(field_name, f)
135
135
  except InvalidTypedColumnsError:
136
136
  raise
137
- dataframe = cast_columns(dataframe, f)
137
+ dataframe = _cast_columns(dataframe, f)
138
138
 
139
139
  # Now that the dataframe values have been cast to the specified types:
140
140
  # for downstream validation to work as expected,
141
141
  # feature & tag schema field types should be List[string] of column names.
142
142
  # Since Schema is a frozen class, we must construct a new instance.
143
- return dataframe, convert_schema_field_types(schema)
143
+ return dataframe, _convert_schema_field_types(schema)
144
144
 
145
145
 
146
146
  def cast_dictionary(d: dict) -> dict:
147
147
  cast_dict = {}
148
148
  for k, v in d.items():
149
149
  if isinstance(v, TypedValue):
150
- v = cast_value(v)
150
+ v = _cast_value(v)
151
151
  cast_dict[k] = v
152
152
  return cast_dict
153
153
 
154
154
 
155
- def cast_value(
155
+ def _cast_value(
156
156
  typed_value: TypedValue,
157
157
  ) -> Union[str, int, float, List[str], None]:
158
158
  """
@@ -224,7 +224,7 @@ def _cast_to_str(typed_value: TypedValue) -> Union[str, None]:
224
224
  raise CastingError(str(e), typed_value) from e
225
225
 
226
226
 
227
- def validate_typed_columns(
227
+ def _validate_typed_columns(
228
228
  field_name: str, typed_columns: TypedColumns
229
229
  ) -> None:
230
230
  """
@@ -253,7 +253,7 @@ def validate_typed_columns(
253
253
  )
254
254
 
255
255
 
256
- def cast_columns(
256
+ def _cast_columns(
257
257
  dataframe: pd.DataFrame, columns: TypedColumns
258
258
  ) -> pd.DataFrame:
259
259
  """
@@ -288,7 +288,7 @@ def cast_columns(
288
288
  # uses pd.NA for missing values (when storage arg is not configured)
289
289
  # In the future, try out pd.convert_dtypes (new in pandas 2.0):
290
290
  # https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.convert_dtypes.html
291
- dataframe = cast_df(dataframe, columns.to_str, "string")
291
+ dataframe = _cast_df(dataframe, columns.to_str, "string")
292
292
  except Exception as e:
293
293
  raise ColumnCastingError(
294
294
  error_msg=str(e),
@@ -300,7 +300,7 @@ def cast_columns(
300
300
  # see https://pandas.pydata.org/docs/reference/api/pandas.Int64Dtype.html
301
301
  # uses pd.NA for missing values
302
302
  try:
303
- dataframe = cast_df(dataframe, columns.to_int, "Int64")
303
+ dataframe = _cast_df(dataframe, columns.to_int, "Int64")
304
304
  except Exception as e:
305
305
  raise ColumnCastingError(
306
306
  error_msg=str(e),
@@ -312,7 +312,7 @@ def cast_columns(
312
312
  # see https://pandas.pydata.org/docs/reference/api/pandas.Float64Dtype.html
313
313
  # uses pd.NA for missing values
314
314
  try:
315
- dataframe = cast_df(dataframe, columns.to_float, "Float64")
315
+ dataframe = _cast_df(dataframe, columns.to_float, "Float64")
316
316
  except Exception as e:
317
317
  raise ColumnCastingError(
318
318
  error_msg=str(e),
@@ -323,7 +323,7 @@ def cast_columns(
323
323
  return dataframe
324
324
 
325
325
 
326
- def cast_df(
326
+ def _cast_df(
327
327
  df: pd.DataFrame, cols: List[str], target_type_str: str
328
328
  ) -> pd.DataFrame:
329
329
  """
@@ -354,7 +354,7 @@ def cast_df(
354
354
  return df.astype({col: target_type_str for col in cols})
355
355
 
356
356
 
357
- def convert_schema_field_types(
357
+ def _convert_schema_field_types(
358
358
  schema: Schema,
359
359
  ) -> Schema:
360
360
  """