arize 8.0.0a10__tar.gz → 8.0.0a12__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (128) hide show
  1. {arize-8.0.0a10 → arize-8.0.0a12}/PKG-INFO +75 -1
  2. {arize-8.0.0a10 → arize-8.0.0a12}/README.md +67 -0
  3. {arize-8.0.0a10 → arize-8.0.0a12}/pyproject.toml +8 -1
  4. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/client.py +3 -1
  5. arize-8.0.0a12/src/arize/embeddings/__init__.py +4 -0
  6. arize-8.0.0a12/src/arize/embeddings/auto_generator.py +108 -0
  7. arize-8.0.0a12/src/arize/embeddings/base_generators.py +255 -0
  8. arize-8.0.0a12/src/arize/embeddings/constants.py +34 -0
  9. arize-8.0.0a12/src/arize/embeddings/cv_generators.py +28 -0
  10. arize-8.0.0a12/src/arize/embeddings/errors.py +41 -0
  11. arize-8.0.0a12/src/arize/embeddings/nlp_generators.py +111 -0
  12. arize-8.0.0a12/src/arize/embeddings/tabular_generators.py +161 -0
  13. arize-8.0.0a12/src/arize/embeddings/usecases.py +26 -0
  14. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/logging.py +0 -5
  15. arize-8.0.0a12/src/arize/utils/online_tasks/__init__.py +5 -0
  16. arize-8.0.0a12/src/arize/utils/online_tasks/dataframe_preprocessor.py +235 -0
  17. arize-8.0.0a12/src/arize/version.py +1 -0
  18. arize-8.0.0a10/src/arize/version.py +0 -1
  19. {arize-8.0.0a10 → arize-8.0.0a12}/.gitignore +0 -0
  20. {arize-8.0.0a10 → arize-8.0.0a12}/LICENSE.md +0 -0
  21. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/__init__.py +0 -0
  22. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_exporter/__init__.py +0 -0
  23. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_exporter/client.py +0 -0
  24. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_exporter/parsers/__init__.py +0 -0
  25. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_exporter/parsers/tracing_data_parser.py +0 -0
  26. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_exporter/validation.py +0 -0
  27. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_flight/__init__.py +0 -0
  28. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_flight/client.py +0 -0
  29. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_flight/types.py +0 -0
  30. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_generated/__init__.py +0 -0
  31. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_generated/api_client/__init__.py +0 -0
  32. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_generated/api_client/api/__init__.py +0 -0
  33. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_generated/api_client/api/datasets_api.py +0 -0
  34. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_generated/api_client/api/experiments_api.py +0 -0
  35. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_generated/api_client/api_client.py +0 -0
  36. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_generated/api_client/api_response.py +0 -0
  37. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_generated/api_client/configuration.py +0 -0
  38. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_generated/api_client/exceptions.py +0 -0
  39. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_generated/api_client/models/__init__.py +0 -0
  40. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_generated/api_client/models/dataset.py +0 -0
  41. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_generated/api_client/models/dataset_version.py +0 -0
  42. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_generated/api_client/models/datasets_create201_response.py +0 -0
  43. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_generated/api_client/models/datasets_create_request.py +0 -0
  44. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_generated/api_client/models/datasets_list200_response.py +0 -0
  45. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_generated/api_client/models/datasets_list_examples200_response.py +0 -0
  46. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_generated/api_client/models/error.py +0 -0
  47. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_generated/api_client/models/experiment.py +0 -0
  48. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_generated/api_client/models/experiments_list200_response.py +0 -0
  49. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_generated/api_client/rest.py +0 -0
  50. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_generated/api_client/test/__init__.py +0 -0
  51. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_generated/api_client/test/test_dataset.py +0 -0
  52. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_generated/api_client/test/test_dataset_version.py +0 -0
  53. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_generated/api_client/test/test_datasets_api.py +0 -0
  54. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_generated/api_client/test/test_datasets_create201_response.py +0 -0
  55. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_generated/api_client/test/test_datasets_create_request.py +0 -0
  56. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_generated/api_client/test/test_datasets_list200_response.py +0 -0
  57. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_generated/api_client/test/test_datasets_list_examples200_response.py +0 -0
  58. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_generated/api_client/test/test_error.py +0 -0
  59. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_generated/api_client/test/test_experiment.py +0 -0
  60. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_generated/api_client/test/test_experiments_api.py +0 -0
  61. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_generated/api_client/test/test_experiments_list200_response.py +0 -0
  62. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_generated/api_client_README.md +0 -0
  63. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_generated/protocol/__init__.py +0 -0
  64. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_generated/protocol/flight/__init__.py +0 -0
  65. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_generated/protocol/flight/export_pb2.py +0 -0
  66. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_generated/protocol/flight/ingest_pb2.py +0 -0
  67. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_generated/protocol/rec/__init__.py +0 -0
  68. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_generated/protocol/rec/public_pb2.py +0 -0
  69. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/_lazy.py +0 -0
  70. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/config.py +0 -0
  71. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/constants/__init__.py +0 -0
  72. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/constants/config.py +0 -0
  73. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/constants/ml.py +0 -0
  74. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/constants/model_mapping.json +0 -0
  75. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/constants/spans.py +0 -0
  76. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/datasets/__init__.py +0 -0
  77. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/datasets/client.py +0 -0
  78. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/exceptions/__init__.py +0 -0
  79. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/exceptions/auth.py +0 -0
  80. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/exceptions/base.py +0 -0
  81. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/exceptions/models.py +0 -0
  82. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/exceptions/parameters.py +0 -0
  83. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/exceptions/spaces.py +0 -0
  84. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/exceptions/types.py +0 -0
  85. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/exceptions/values.py +0 -0
  86. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/experiments/__init__.py +0 -0
  87. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/experiments/client.py +0 -0
  88. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/models/__init__.py +0 -0
  89. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/models/batch_validation/__init__.py +0 -0
  90. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/models/batch_validation/errors.py +0 -0
  91. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/models/batch_validation/validator.py +0 -0
  92. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/models/bounded_executor.py +0 -0
  93. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/models/client.py +0 -0
  94. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/models/stream_validation.py +0 -0
  95. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/models/surrogate_explainer/__init__.py +0 -0
  96. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/models/surrogate_explainer/mimic.py +0 -0
  97. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/spans/__init__.py +0 -0
  98. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/spans/client.py +0 -0
  99. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/spans/columns.py +0 -0
  100. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/spans/conversion.py +0 -0
  101. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/spans/validation/__init__.py +0 -0
  102. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/spans/validation/annotations/__init__.py +0 -0
  103. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/spans/validation/annotations/annotations_validation.py +0 -0
  104. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/spans/validation/annotations/dataframe_form_validation.py +0 -0
  105. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/spans/validation/annotations/value_validation.py +0 -0
  106. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/spans/validation/common/__init__.py +0 -0
  107. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/spans/validation/common/argument_validation.py +0 -0
  108. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/spans/validation/common/dataframe_form_validation.py +0 -0
  109. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/spans/validation/common/errors.py +0 -0
  110. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/spans/validation/common/value_validation.py +0 -0
  111. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/spans/validation/evals/__init__.py +0 -0
  112. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/spans/validation/evals/dataframe_form_validation.py +0 -0
  113. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/spans/validation/evals/evals_validation.py +0 -0
  114. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/spans/validation/evals/value_validation.py +0 -0
  115. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/spans/validation/metadata/__init__.py +0 -0
  116. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/spans/validation/metadata/argument_validation.py +0 -0
  117. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/spans/validation/metadata/dataframe_form_validation.py +0 -0
  118. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/spans/validation/metadata/value_validation.py +0 -0
  119. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/spans/validation/spans/__init__.py +0 -0
  120. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/spans/validation/spans/dataframe_form_validation.py +0 -0
  121. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/spans/validation/spans/spans_validation.py +0 -0
  122. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/spans/validation/spans/value_validation.py +0 -0
  123. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/types.py +0 -0
  124. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/utils/__init__.py +0 -0
  125. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/utils/arrow.py +0 -0
  126. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/utils/casting.py +0 -0
  127. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/utils/dataframe.py +0 -0
  128. {arize-8.0.0a10 → arize-8.0.0a12}/src/arize/utils/proto.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: arize
3
- Version: 8.0.0a10
3
+ Version: 8.0.0a12
4
4
  Summary: A helper library to interact with Arize AI APIs
5
5
  Project-URL: Homepage, https://arize.com
6
6
  Project-URL: Documentation, https://docs.arize.com/arize
@@ -27,6 +27,13 @@ Classifier: Topic :: System :: Monitoring
27
27
  Requires-Python: >=3.10
28
28
  Requires-Dist: lazy-imports
29
29
  Requires-Dist: numpy>=2.0.0
30
+ Provides-Extra: auto-embeddings
31
+ Requires-Dist: datasets!=2.14.*,<3,>=2.8; extra == 'auto-embeddings'
32
+ Requires-Dist: pandas<3,>=1.0.0; extra == 'auto-embeddings'
33
+ Requires-Dist: pillow<11,>=8.4.0; extra == 'auto-embeddings'
34
+ Requires-Dist: tokenizers<1,>=0.13; extra == 'auto-embeddings'
35
+ Requires-Dist: torch<3,>=1.13; extra == 'auto-embeddings'
36
+ Requires-Dist: transformers<5,>=4.25; extra == 'auto-embeddings'
30
37
  Provides-Extra: dev
31
38
  Requires-Dist: pytest==8.4.2; extra == 'dev'
32
39
  Requires-Dist: ruff==0.13.2; extra == 'dev'
@@ -84,6 +91,10 @@ Description-Content-Type: text/markdown
84
91
  - [Stream log ML Data for a Classification use-case](#stream-log-ml-data-for-a-classification-use-case)
85
92
  - [Log a batch of ML Data for a Object Detection use-case](#log-a-batch-of-ml-data-for-a-object-detection-use-case)
86
93
  - [Exporting ML Data](#exporting-ml-data)
94
+ - [Generate embeddings for your data](#generate-embeddings-for-your-data)
95
+ - [Configure Logging](#configure-logging)
96
+ - [In Code](#in-code)
97
+ - [Via Environment Variables](#via-environment-variables)
87
98
  - [Community](#community)
88
99
 
89
100
  # Overview
@@ -326,6 +337,69 @@ df = client.models.export_to_df(
326
337
  )
327
338
  ```
328
339
 
340
+ ## Generate embeddings for your data
341
+
342
+ ```python
343
+ import pandas as pd
344
+ from arize.embeddings import EmbeddingGenerator, UseCases
345
+
346
+ # You can check available models
347
+ print(EmbeddingGenerator.list_pretrained_models())
348
+
349
+ # Example dataframe
350
+ df = pd.DataFrame(
351
+ {
352
+ "text": [
353
+ "Hello world.",
354
+ "Artificial Intelligence is the future.",
355
+ "Spain won the FIFA World Cup on 2010.",
356
+ ],
357
+ }
358
+ )
359
+ # Instantiate the generator for your usecase, selecting the base model
360
+ generator = EmbeddingGenerator.from_use_case(
361
+ use_case=UseCases.NLP.SEQUENCE_CLASSIFICATION,
362
+ model_name="distilbert-base-uncased",
363
+ tokenizer_max_length=512,
364
+ batch_size=100,
365
+ )
366
+
367
+ # Generate embeddings
368
+ df["text_vector"] = generator.generate_embeddings(text_col=df["text"])
369
+ ```
370
+
371
+ # Configure Logging
372
+
373
+ ## In Code
374
+
375
+ You can use `configure_logging` to set up the logging behavior of the Arize package to your needs.
376
+
377
+ ```python
378
+ from arize.logging import configure_logging
379
+
380
+ configure_logging(
381
+ level=..., # Defaults to logging.INFO
382
+ structured=..., # if True, emit JSON logs. Defaults to False
383
+ )
384
+ ```
385
+
386
+ ## Via Environment Variables
387
+
388
+ Configure the same options as the section above, via:
389
+
390
+ ```python
391
+ import os
392
+
393
+ # You can disable logging altogether
394
+ os.environ["ARIZE_LOG_ENABLE"] = "true"
395
+ # Set up the logging level
396
+ os.environ["ARIZE_LOG_LEVEL"] = "debug"
397
+ # Whether or not you want structured JSON logs
398
+ os.environ["ARIZE_LOG_STRUCTURED"] = "false"
399
+ ```
400
+
401
+ The default behavior of Arize's logs is: enabled, `INFO` level, and not structured.
402
+
329
403
  # Community
330
404
 
331
405
  Join our community to connect with thousands of AI builders.
@@ -31,6 +31,10 @@
31
31
  - [Stream log ML Data for a Classification use-case](#stream-log-ml-data-for-a-classification-use-case)
32
32
  - [Log a batch of ML Data for a Object Detection use-case](#log-a-batch-of-ml-data-for-a-object-detection-use-case)
33
33
  - [Exporting ML Data](#exporting-ml-data)
34
+ - [Generate embeddings for your data](#generate-embeddings-for-your-data)
35
+ - [Configure Logging](#configure-logging)
36
+ - [In Code](#in-code)
37
+ - [Via Environment Variables](#via-environment-variables)
34
38
  - [Community](#community)
35
39
 
36
40
  # Overview
@@ -273,6 +277,69 @@ df = client.models.export_to_df(
273
277
  )
274
278
  ```
275
279
 
280
+ ## Generate embeddings for your data
281
+
282
+ ```python
283
+ import pandas as pd
284
+ from arize.embeddings import EmbeddingGenerator, UseCases
285
+
286
+ # You can check available models
287
+ print(EmbeddingGenerator.list_pretrained_models())
288
+
289
+ # Example dataframe
290
+ df = pd.DataFrame(
291
+ {
292
+ "text": [
293
+ "Hello world.",
294
+ "Artificial Intelligence is the future.",
295
+ "Spain won the FIFA World Cup on 2010.",
296
+ ],
297
+ }
298
+ )
299
+ # Instantiate the generator for your usecase, selecting the base model
300
+ generator = EmbeddingGenerator.from_use_case(
301
+ use_case=UseCases.NLP.SEQUENCE_CLASSIFICATION,
302
+ model_name="distilbert-base-uncased",
303
+ tokenizer_max_length=512,
304
+ batch_size=100,
305
+ )
306
+
307
+ # Generate embeddings
308
+ df["text_vector"] = generator.generate_embeddings(text_col=df["text"])
309
+ ```
310
+
311
+ # Configure Logging
312
+
313
+ ## In Code
314
+
315
+ You can use `configure_logging` to set up the logging behavior of the Arize package to your needs.
316
+
317
+ ```python
318
+ from arize.logging import configure_logging
319
+
320
+ configure_logging(
321
+ level=..., # Defaults to logging.INFO
322
+ structured=..., # if True, emit JSON logs. Defaults to False
323
+ )
324
+ ```
325
+
326
+ ## Via Environment Variables
327
+
328
+ Configure the same options as the section above, via:
329
+
330
+ ```python
331
+ import os
332
+
333
+ # You can disable logging altogether
334
+ os.environ["ARIZE_LOG_ENABLE"] = "true"
335
+ # Set up the logging level
336
+ os.environ["ARIZE_LOG_LEVEL"] = "debug"
337
+ # Whether or not you want structured JSON logs
338
+ os.environ["ARIZE_LOG_STRUCTURED"] = "false"
339
+ ```
340
+
341
+ The default behavior of Arize's logs is: enabled, `INFO` level, and not structured.
342
+
276
343
  # Community
277
344
 
278
345
  Join our community to connect with thousands of AI builders.
@@ -39,7 +39,6 @@ dependencies = [
39
39
  # "requests_futures==1.0.0",
40
40
  # "googleapis_common_protos>=1.51.0,<2",
41
41
  # "protobuf>=4.21.0,<6",
42
- # "pandas>=0.25.3,<3",
43
42
  # "pyarrow>=0.15.0",
44
43
  # "tqdm>=4.60.0,<5",
45
44
  # "pydantic>=2.0.0,<3",
@@ -77,6 +76,14 @@ ml-batch = [
77
76
  mimic-explainer = [
78
77
  "interpret-community[mimic]>=0.22.0,<1",
79
78
  ]
79
+ auto-embeddings = [
80
+ "Pillow>=8.4.0, <11",
81
+ "datasets>=2.8, <3, !=2.14.*",
82
+ "pandas>=1.0.0,<3",
83
+ "tokenizers>=0.13, <1",
84
+ "torch>=1.13, <3",
85
+ "transformers>=4.25, <5",
86
+ ]
80
87
 
81
88
  [project.urls]
82
89
  Homepage = "https://arize.com"
@@ -12,11 +12,13 @@ if TYPE_CHECKING:
12
12
  from arize.spans.client import SpansClient
13
13
 
14
14
 
15
+ # TODO(Kiko): experimental/datasets must be adapted into the datasets subclient
16
+ # TODO(Kiko): experimental/prompt hub is missing
17
+ # TODO(Kiko): exporter/utils/schema_parser is missing
15
18
  # TODO(Kiko): Go through main APIs and add CtxAdapter where missing
16
19
  # TODO(Kiko): Search and handle other TODOs
17
20
  # TODO(Kiko): Go over **every file** and do not import anything at runtime, use `if TYPE_CHECKING`
18
21
  # with `from __future__ import annotations` (must include for Python < 3.11)
19
- # TODO(Kiko): MIMIC Explainer not done
20
22
  # TODO(Kiko): Go over docstrings
21
23
  class ArizeClient(LazySubclientsMixin):
22
24
  """
@@ -0,0 +1,4 @@
1
+ from arize.embeddings.auto_generator import EmbeddingGenerator
2
+ from arize.embeddings.usecases import UseCases
3
+
4
+ __all__ = ["EmbeddingGenerator", "UseCases"]
@@ -0,0 +1,108 @@
1
+ from typing import Any
2
+
3
+ import pandas as pd
4
+
5
+ from arize.embeddings import constants
6
+ from arize.embeddings.base_generators import BaseEmbeddingGenerator
7
+ from arize.embeddings.constants import (
8
+ CV_PRETRAINED_MODELS,
9
+ DEFAULT_CV_IMAGE_CLASSIFICATION_MODEL,
10
+ DEFAULT_CV_OBJECT_DETECTION_MODEL,
11
+ DEFAULT_NLP_SEQUENCE_CLASSIFICATION_MODEL,
12
+ DEFAULT_NLP_SUMMARIZATION_MODEL,
13
+ DEFAULT_TABULAR_MODEL,
14
+ NLP_PRETRAINED_MODELS,
15
+ )
16
+ from arize.embeddings.cv_generators import (
17
+ EmbeddingGeneratorForCVImageClassification,
18
+ EmbeddingGeneratorForCVObjectDetection,
19
+ )
20
+ from arize.embeddings.nlp_generators import (
21
+ EmbeddingGeneratorForNLPSequenceClassification,
22
+ EmbeddingGeneratorForNLPSummarization,
23
+ )
24
+ from arize.embeddings.tabular_generators import (
25
+ EmbeddingGeneratorForTabularFeatures,
26
+ )
27
+ from arize.embeddings.usecases import UseCases
28
+
29
+ UseCaseLike = str | UseCases.NLP | UseCases.CV | UseCases.STRUCTURED
30
+
31
+
32
+ class EmbeddingGenerator:
33
+ def __init__(self, **kwargs: str):
34
+ raise OSError(
35
+ f"{self.__class__.__name__} is designed to be instantiated using the "
36
+ f"`{self.__class__.__name__}.from_use_case(use_case, **kwargs)` method."
37
+ )
38
+
39
+ @staticmethod
40
+ def from_use_case(
41
+ use_case: UseCaseLike, **kwargs: Any
42
+ ) -> BaseEmbeddingGenerator:
43
+ if use_case == UseCases.NLP.SEQUENCE_CLASSIFICATION:
44
+ return EmbeddingGeneratorForNLPSequenceClassification(**kwargs)
45
+ elif use_case == UseCases.NLP.SUMMARIZATION:
46
+ return EmbeddingGeneratorForNLPSummarization(**kwargs)
47
+ elif use_case == UseCases.CV.IMAGE_CLASSIFICATION:
48
+ return EmbeddingGeneratorForCVImageClassification(**kwargs)
49
+ elif use_case == UseCases.CV.OBJECT_DETECTION:
50
+ return EmbeddingGeneratorForCVObjectDetection(**kwargs)
51
+ elif use_case == UseCases.STRUCTURED.TABULAR_EMBEDDINGS:
52
+ return EmbeddingGeneratorForTabularFeatures(**kwargs)
53
+ else:
54
+ raise ValueError(f"Invalid use case {use_case}")
55
+
56
+ @classmethod
57
+ def list_default_models(cls) -> pd.DataFrame:
58
+ df = pd.DataFrame(
59
+ {
60
+ "Area": ["NLP", "NLP", "CV", "CV", "STRUCTURED"],
61
+ "Usecase": [
62
+ UseCases.NLP.SEQUENCE_CLASSIFICATION.name,
63
+ UseCases.NLP.SUMMARIZATION.name,
64
+ UseCases.CV.IMAGE_CLASSIFICATION.name,
65
+ UseCases.CV.OBJECT_DETECTION.name,
66
+ UseCases.STRUCTURED.TABULAR_EMBEDDINGS.name,
67
+ ],
68
+ "Model Name": [
69
+ DEFAULT_NLP_SEQUENCE_CLASSIFICATION_MODEL,
70
+ DEFAULT_NLP_SUMMARIZATION_MODEL,
71
+ DEFAULT_CV_IMAGE_CLASSIFICATION_MODEL,
72
+ DEFAULT_CV_OBJECT_DETECTION_MODEL,
73
+ DEFAULT_TABULAR_MODEL,
74
+ ],
75
+ }
76
+ )
77
+ df.sort_values(
78
+ by=[col for col in df.columns], ascending=True, inplace=True
79
+ )
80
+ return df.reset_index(drop=True)
81
+
82
+ @classmethod
83
+ def list_pretrained_models(cls) -> pd.DataFrame:
84
+ data = {
85
+ "Task": ["NLP" for _ in NLP_PRETRAINED_MODELS]
86
+ + ["CV" for _ in CV_PRETRAINED_MODELS],
87
+ "Architecture": [
88
+ cls.__parse_model_arch(model)
89
+ for model in NLP_PRETRAINED_MODELS + CV_PRETRAINED_MODELS
90
+ ],
91
+ "Model Name": NLP_PRETRAINED_MODELS + CV_PRETRAINED_MODELS,
92
+ }
93
+ df = pd.DataFrame(data)
94
+ df.sort_values(
95
+ by=[col for col in df.columns], ascending=True, inplace=True
96
+ )
97
+ return df.reset_index(drop=True)
98
+
99
+ @staticmethod
100
+ def __parse_model_arch(model_name: str) -> str:
101
+ if constants.GPT.lower() in model_name.lower():
102
+ return constants.GPT
103
+ elif constants.BERT.lower() in model_name.lower():
104
+ return constants.BERT
105
+ elif constants.VIT.lower() in model_name.lower():
106
+ return constants.VIT
107
+ else:
108
+ raise ValueError("Invalid model_name, unknown architecture.")
@@ -0,0 +1,255 @@
1
+ import os
2
+ from abc import ABC, abstractmethod
3
+ from enum import Enum
4
+ from functools import partial
5
+ from typing import Dict, List, Union, cast
6
+
7
+ import pandas as pd
8
+
9
+ import arize.embeddings.errors as err
10
+ from arize.embeddings.constants import IMPORT_ERROR_MESSAGE
11
+
12
+ try:
13
+ import torch
14
+ from datasets import Dataset
15
+ from PIL import Image
16
+ from transformers import ( # type: ignore
17
+ AutoImageProcessor,
18
+ AutoModel,
19
+ AutoTokenizer,
20
+ BatchEncoding,
21
+ )
22
+ from transformers.utils import logging as transformer_logging
23
+ except ImportError as e:
24
+ raise ImportError(IMPORT_ERROR_MESSAGE) from e
25
+
26
+ import logging
27
+
28
+ logger = logging.getLogger(__name__)
29
+ transformer_logging.set_verbosity(50)
30
+ transformer_logging.enable_progress_bar()
31
+
32
+
33
+ class BaseEmbeddingGenerator(ABC):
34
+ def __init__(
35
+ self, use_case: Enum, model_name: str, batch_size: int = 100, **kwargs
36
+ ):
37
+ self.__use_case = self._parse_use_case(use_case=use_case)
38
+ self.__model_name = model_name
39
+ self.__device = self.select_device()
40
+ self.__batch_size = batch_size
41
+ logger.info(f"Downloading pre-trained model '{self.model_name}'")
42
+ try:
43
+ self.__model = AutoModel.from_pretrained(
44
+ self.model_name, **kwargs
45
+ ).to(self.device)
46
+ except OSError as e:
47
+ raise err.HuggingFaceRepositoryNotFound(model_name) from e
48
+ except Exception as e:
49
+ raise e
50
+
51
+ @abstractmethod
52
+ def generate_embeddings(self, **kwargs) -> pd.Series: ...
53
+
54
+ def select_device(self) -> torch.device:
55
+ if torch.cuda.is_available():
56
+ return torch.device("cuda")
57
+ elif torch.backends.mps.is_available():
58
+ return torch.device("mps")
59
+ else:
60
+ logger.warning(
61
+ "No available GPU has been detected. The use of GPU acceleration is "
62
+ "strongly recommended. You can check for GPU availability by running "
63
+ "`torch.cuda.is_available()` or `torch.backends.mps.is_available()`."
64
+ )
65
+ return torch.device("cpu")
66
+
67
+ @property
68
+ def use_case(self) -> str:
69
+ return self.__use_case
70
+
71
+ @property
72
+ def model_name(self) -> str:
73
+ return self.__model_name
74
+
75
+ @property
76
+ def model(self):
77
+ return self.__model
78
+
79
+ @property
80
+ def device(self) -> torch.device:
81
+ return self.__device
82
+
83
+ @property
84
+ def batch_size(self) -> int:
85
+ return self.__batch_size
86
+
87
+ @batch_size.setter
88
+ def batch_size(self, new_batch_size: int) -> None:
89
+ err_message = "New batch size should be an integer greater than 0."
90
+ if not isinstance(new_batch_size, int):
91
+ raise TypeError(err_message)
92
+ elif new_batch_size <= 0:
93
+ raise ValueError(err_message)
94
+ else:
95
+ self.__batch_size = new_batch_size
96
+ logger.info(f"Batch size has been set to {new_batch_size}.")
97
+
98
+ @staticmethod
99
+ def _parse_use_case(use_case: Enum) -> str:
100
+ uc_area = use_case.__class__.__name__.split("UseCases")[0]
101
+ uc_task = use_case.name
102
+ return f"{uc_area}.{uc_task}"
103
+
104
+ def _get_embedding_vector(
105
+ self, batch: Dict[str, torch.Tensor], method
106
+ ) -> Dict[str, torch.Tensor]:
107
+ with torch.no_grad():
108
+ outputs = self.model(**batch)
109
+ # (batch_size, seq_length/or/num_tokens, hidden_size)
110
+ if method == "cls_token": # Select CLS token vector
111
+ embeddings = outputs.last_hidden_state[:, 0, :]
112
+ elif method == "avg_token": # Select avg token vector
113
+ embeddings = torch.mean(outputs.last_hidden_state, 1)
114
+ else:
115
+ raise ValueError(f"Invalid method = {method}")
116
+ return {"embedding_vector": embeddings.cpu().numpy().astype(float)}
117
+
118
+ @staticmethod
119
+ def check_invalid_index(field: Union[pd.Series, pd.DataFrame]) -> None:
120
+ if (field.index != field.reset_index(drop=True).index).any():
121
+ if isinstance(field, pd.DataFrame):
122
+ raise err.InvalidIndexError("DataFrame")
123
+ else:
124
+ raise err.InvalidIndexError(str(field.name))
125
+
126
+ @abstractmethod
127
+ def __repr__(self) -> str:
128
+ pass
129
+
130
+
131
+ class NLPEmbeddingGenerator(BaseEmbeddingGenerator):
132
+ def __repr__(self) -> str:
133
+ return (
134
+ f"{self.__class__.__name__}(\n"
135
+ f" use_case={self.use_case},\n"
136
+ f" model_name='{self.model_name}',\n"
137
+ f" tokenizer_max_length={self.tokenizer_max_length},\n"
138
+ f" tokenizer={self.tokenizer.__class__},\n"
139
+ f" model={self.model.__class__},\n"
140
+ f" batch_size={self.batch_size},\n"
141
+ f")"
142
+ )
143
+
144
+ def __init__(
145
+ self,
146
+ use_case: Enum,
147
+ model_name: str,
148
+ tokenizer_max_length: int = 512,
149
+ **kwargs,
150
+ ):
151
+ super().__init__(use_case=use_case, model_name=model_name, **kwargs)
152
+ self.__tokenizer_max_length = tokenizer_max_length
153
+ # We don't check for the tokenizer's existence since it is coupled with the corresponding model
154
+ # We check the model's existence in `BaseEmbeddingGenerator`
155
+ logger.info(f"Downloading tokenizer for '{self.model_name}'")
156
+ self.__tokenizer = AutoTokenizer.from_pretrained(
157
+ self.model_name, model_max_length=self.tokenizer_max_length
158
+ )
159
+
160
+ @property
161
+ def tokenizer(self):
162
+ return self.__tokenizer
163
+
164
+ @property
165
+ def tokenizer_max_length(self) -> int:
166
+ return self.__tokenizer_max_length
167
+
168
+ def tokenize(
169
+ self, batch: Dict[str, List[str]], text_feat_name: str
170
+ ) -> BatchEncoding:
171
+ return self.tokenizer(
172
+ batch[text_feat_name],
173
+ padding=True,
174
+ truncation=True,
175
+ max_length=self.tokenizer_max_length,
176
+ return_tensors="pt",
177
+ ).to(self.device)
178
+
179
+
180
+ class CVEmbeddingGenerator(BaseEmbeddingGenerator):
181
+ def __repr__(self) -> str:
182
+ return (
183
+ f"{self.__class__.__name__}(\n"
184
+ f" use_case={self.use_case},\n"
185
+ f" model_name='{self.model_name}',\n"
186
+ f" image_processor={self.image_processor.__class__},\n"
187
+ f" model={self.model.__class__},\n"
188
+ f" batch_size={self.batch_size},\n"
189
+ f")"
190
+ )
191
+
192
+ def __init__(self, use_case: Enum, model_name: str, **kwargs):
193
+ super().__init__(use_case=use_case, model_name=model_name, **kwargs)
194
+ logger.info("Downloading image processor")
195
+ # We don't check for the image processor's existence since it is coupled with the corresponding model
196
+ # We check the model's existence in `BaseEmbeddingGenerator`
197
+ self.__image_processor = AutoImageProcessor.from_pretrained(
198
+ self.model_name
199
+ )
200
+
201
+ @property
202
+ def image_processor(self):
203
+ return self.__image_processor
204
+
205
+ @staticmethod
206
+ def open_image(image_path: str) -> Image.Image:
207
+ if not os.path.exists(image_path):
208
+ raise ValueError(f"Cannot find image {image_path}")
209
+ return Image.open(image_path).convert("RGB")
210
+
211
+ def preprocess_image(
212
+ self, batch: Dict[str, List[str]], local_image_feat_name: str
213
+ ):
214
+ return self.image_processor(
215
+ [
216
+ self.open_image(image_path)
217
+ for image_path in batch[local_image_feat_name]
218
+ ],
219
+ return_tensors="pt",
220
+ ).to(self.device)
221
+
222
+ def generate_embeddings(self, local_image_path_col: pd.Series) -> pd.Series:
223
+ """
224
+ Obtain embedding vectors from your image data using pre-trained image models.
225
+
226
+ :param local_image_path_col: a pandas Series containing the local path to the images to
227
+ be used to generate the embedding vectors.
228
+ :return: a pandas Series containing the embedding vectors.
229
+ """
230
+ if not isinstance(local_image_path_col, pd.Series):
231
+ raise TypeError(
232
+ "local_image_path_col_name must be pandas Series object"
233
+ )
234
+ self.check_invalid_index(field=local_image_path_col)
235
+
236
+ # Validate that there are no null image paths
237
+ if local_image_path_col.isnull().any():
238
+ raise ValueError(
239
+ "There can't be any null values in the local_image_path_col series"
240
+ )
241
+
242
+ ds = Dataset.from_dict({"local_path": local_image_path_col})
243
+ ds.set_transform(
244
+ partial(
245
+ self.preprocess_image,
246
+ local_image_feat_name="local_path",
247
+ )
248
+ )
249
+ logger.info("Generating embedding vectors")
250
+ ds = ds.map(
251
+ lambda batch: self._get_embedding_vector(batch, "avg_token"),
252
+ batched=True,
253
+ batch_size=self.batch_size,
254
+ )
255
+ return cast(pd.DataFrame, ds.to_pandas())["embedding_vector"]
@@ -0,0 +1,34 @@
1
+ DEFAULT_NLP_SEQUENCE_CLASSIFICATION_MODEL = "distilbert-base-uncased"
2
+ DEFAULT_NLP_SUMMARIZATION_MODEL = "distilbert-base-uncased"
3
+ DEFAULT_TABULAR_MODEL = "distilbert-base-uncased"
4
+ DEFAULT_CV_IMAGE_CLASSIFICATION_MODEL = "google/vit-base-patch32-224-in21k"
5
+ DEFAULT_CV_OBJECT_DETECTION_MODEL = "facebook/detr-resnet-101"
6
+ NLP_PRETRAINED_MODELS = [
7
+ "bert-base-cased",
8
+ "bert-base-uncased",
9
+ "bert-large-cased",
10
+ "bert-large-uncased",
11
+ "distilbert-base-cased",
12
+ "distilbert-base-uncased",
13
+ "xlm-roberta-base",
14
+ "xlm-roberta-large",
15
+ ]
16
+
17
+ CV_PRETRAINED_MODELS = [
18
+ "google/vit-base-patch16-224-in21k",
19
+ "google/vit-base-patch16-384",
20
+ "google/vit-base-patch32-224-in21k",
21
+ "google/vit-base-patch32-384",
22
+ "google/vit-large-patch16-224-in21k",
23
+ "google/vit-large-patch16-384",
24
+ "google/vit-large-patch32-224-in21k",
25
+ "google/vit-large-patch32-384",
26
+ ]
27
+ IMPORT_ERROR_MESSAGE = (
28
+ "To enable embedding generation, the arize module must be installed with "
29
+ "extra dependencies. Run: pip install 'arize[auto-embeddings]'."
30
+ )
31
+
32
+ GPT = "GPT"
33
+ BERT = "BERT"
34
+ VIT = "ViT"
@@ -0,0 +1,28 @@
1
+ from arize.embeddings.base_generators import CVEmbeddingGenerator
2
+ from arize.embeddings.constants import (
3
+ DEFAULT_CV_IMAGE_CLASSIFICATION_MODEL,
4
+ DEFAULT_CV_OBJECT_DETECTION_MODEL,
5
+ )
6
+ from arize.embeddings.usecases import UseCases
7
+
8
+
9
+ class EmbeddingGeneratorForCVImageClassification(CVEmbeddingGenerator):
10
+ def __init__(
11
+ self, model_name: str = DEFAULT_CV_IMAGE_CLASSIFICATION_MODEL, **kwargs
12
+ ):
13
+ super().__init__(
14
+ use_case=UseCases.CV.IMAGE_CLASSIFICATION,
15
+ model_name=model_name,
16
+ **kwargs,
17
+ )
18
+
19
+
20
+ class EmbeddingGeneratorForCVObjectDetection(CVEmbeddingGenerator):
21
+ def __init__(
22
+ self, model_name: str = DEFAULT_CV_OBJECT_DETECTION_MODEL, **kwargs
23
+ ):
24
+ super().__init__(
25
+ use_case=UseCases.CV.OBJECT_DETECTION,
26
+ model_name=model_name,
27
+ **kwargs,
28
+ )
@@ -0,0 +1,41 @@
1
+ class InvalidIndexError(Exception):
2
+ def __repr__(self) -> str:
3
+ return "Invalid_Index_Error"
4
+
5
+ def __str__(self) -> str:
6
+ return self.error_message()
7
+
8
+ def __init__(self, field_name: str) -> None:
9
+ self.field_name = field_name
10
+
11
+ def error_message(self) -> str:
12
+ if self.field_name == "DataFrame":
13
+ return (
14
+ f"The index of the {self.field_name} is invalid; "
15
+ f"reset the index by using df.reset_index(drop=True, inplace=True)"
16
+ )
17
+ else:
18
+ return (
19
+ f"The index of the Series given by the column '{self.field_name}' is invalid; "
20
+ f"reset the index by using df.reset_index(drop=True, inplace=True)"
21
+ )
22
+
23
+
24
+ class HuggingFaceRepositoryNotFound(Exception):
25
+ def __repr__(self) -> str:
26
+ return "HuggingFace_Repository_Not_Found_Error"
27
+
28
+ def __str__(self) -> str:
29
+ return self.error_message()
30
+
31
+ def __init__(self, model_name: str) -> None:
32
+ self.model_name = model_name
33
+
34
+ def error_message(self) -> str:
35
+ return (
36
+ f"The given model name '{self.model_name}' is not a valid model identifier listed on "
37
+ "'https://huggingface.co/models'. "
38
+ "If this is a private repository, log in with `huggingface-cli login` or importing "
39
+ "`login` from `huggingface_hub` if you are using a notebook. "
40
+ "Learn more in https://huggingface.co/docs/huggingface_hub/quick-start#login"
41
+ )