genesis-flow 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (645) hide show
  1. genesis_flow-1.0.0.dist-info/METADATA +822 -0
  2. genesis_flow-1.0.0.dist-info/RECORD +645 -0
  3. genesis_flow-1.0.0.dist-info/WHEEL +5 -0
  4. genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
  5. genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
  6. genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
  7. mlflow/__init__.py +367 -0
  8. mlflow/__main__.py +3 -0
  9. mlflow/ag2/__init__.py +56 -0
  10. mlflow/ag2/ag2_logger.py +294 -0
  11. mlflow/anthropic/__init__.py +40 -0
  12. mlflow/anthropic/autolog.py +129 -0
  13. mlflow/anthropic/chat.py +144 -0
  14. mlflow/artifacts/__init__.py +268 -0
  15. mlflow/autogen/__init__.py +144 -0
  16. mlflow/autogen/chat.py +142 -0
  17. mlflow/azure/__init__.py +26 -0
  18. mlflow/azure/auth_handler.py +257 -0
  19. mlflow/azure/client.py +319 -0
  20. mlflow/azure/config.py +120 -0
  21. mlflow/azure/connection_factory.py +340 -0
  22. mlflow/azure/exceptions.py +27 -0
  23. mlflow/azure/stores.py +327 -0
  24. mlflow/azure/utils.py +183 -0
  25. mlflow/bedrock/__init__.py +45 -0
  26. mlflow/bedrock/_autolog.py +202 -0
  27. mlflow/bedrock/chat.py +122 -0
  28. mlflow/bedrock/stream.py +160 -0
  29. mlflow/bedrock/utils.py +43 -0
  30. mlflow/cli.py +707 -0
  31. mlflow/client.py +12 -0
  32. mlflow/config/__init__.py +56 -0
  33. mlflow/crewai/__init__.py +79 -0
  34. mlflow/crewai/autolog.py +253 -0
  35. mlflow/crewai/chat.py +29 -0
  36. mlflow/data/__init__.py +75 -0
  37. mlflow/data/artifact_dataset_sources.py +170 -0
  38. mlflow/data/code_dataset_source.py +40 -0
  39. mlflow/data/dataset.py +123 -0
  40. mlflow/data/dataset_registry.py +168 -0
  41. mlflow/data/dataset_source.py +110 -0
  42. mlflow/data/dataset_source_registry.py +219 -0
  43. mlflow/data/delta_dataset_source.py +167 -0
  44. mlflow/data/digest_utils.py +108 -0
  45. mlflow/data/evaluation_dataset.py +562 -0
  46. mlflow/data/filesystem_dataset_source.py +81 -0
  47. mlflow/data/http_dataset_source.py +145 -0
  48. mlflow/data/huggingface_dataset.py +258 -0
  49. mlflow/data/huggingface_dataset_source.py +118 -0
  50. mlflow/data/meta_dataset.py +104 -0
  51. mlflow/data/numpy_dataset.py +223 -0
  52. mlflow/data/pandas_dataset.py +231 -0
  53. mlflow/data/polars_dataset.py +352 -0
  54. mlflow/data/pyfunc_dataset_mixin.py +31 -0
  55. mlflow/data/schema.py +76 -0
  56. mlflow/data/sources.py +1 -0
  57. mlflow/data/spark_dataset.py +406 -0
  58. mlflow/data/spark_dataset_source.py +74 -0
  59. mlflow/data/spark_delta_utils.py +118 -0
  60. mlflow/data/tensorflow_dataset.py +350 -0
  61. mlflow/data/uc_volume_dataset_source.py +81 -0
  62. mlflow/db.py +27 -0
  63. mlflow/dspy/__init__.py +17 -0
  64. mlflow/dspy/autolog.py +197 -0
  65. mlflow/dspy/callback.py +398 -0
  66. mlflow/dspy/constant.py +1 -0
  67. mlflow/dspy/load.py +93 -0
  68. mlflow/dspy/save.py +393 -0
  69. mlflow/dspy/util.py +109 -0
  70. mlflow/dspy/wrapper.py +226 -0
  71. mlflow/entities/__init__.py +104 -0
  72. mlflow/entities/_mlflow_object.py +52 -0
  73. mlflow/entities/assessment.py +545 -0
  74. mlflow/entities/assessment_error.py +80 -0
  75. mlflow/entities/assessment_source.py +141 -0
  76. mlflow/entities/dataset.py +92 -0
  77. mlflow/entities/dataset_input.py +51 -0
  78. mlflow/entities/dataset_summary.py +62 -0
  79. mlflow/entities/document.py +48 -0
  80. mlflow/entities/experiment.py +109 -0
  81. mlflow/entities/experiment_tag.py +35 -0
  82. mlflow/entities/file_info.py +45 -0
  83. mlflow/entities/input_tag.py +35 -0
  84. mlflow/entities/lifecycle_stage.py +35 -0
  85. mlflow/entities/logged_model.py +228 -0
  86. mlflow/entities/logged_model_input.py +26 -0
  87. mlflow/entities/logged_model_output.py +32 -0
  88. mlflow/entities/logged_model_parameter.py +46 -0
  89. mlflow/entities/logged_model_status.py +74 -0
  90. mlflow/entities/logged_model_tag.py +33 -0
  91. mlflow/entities/metric.py +200 -0
  92. mlflow/entities/model_registry/__init__.py +29 -0
  93. mlflow/entities/model_registry/_model_registry_entity.py +13 -0
  94. mlflow/entities/model_registry/model_version.py +243 -0
  95. mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
  96. mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
  97. mlflow/entities/model_registry/model_version_search.py +25 -0
  98. mlflow/entities/model_registry/model_version_stages.py +25 -0
  99. mlflow/entities/model_registry/model_version_status.py +35 -0
  100. mlflow/entities/model_registry/model_version_tag.py +35 -0
  101. mlflow/entities/model_registry/prompt.py +73 -0
  102. mlflow/entities/model_registry/prompt_version.py +244 -0
  103. mlflow/entities/model_registry/registered_model.py +175 -0
  104. mlflow/entities/model_registry/registered_model_alias.py +35 -0
  105. mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
  106. mlflow/entities/model_registry/registered_model_search.py +25 -0
  107. mlflow/entities/model_registry/registered_model_tag.py +35 -0
  108. mlflow/entities/multipart_upload.py +74 -0
  109. mlflow/entities/param.py +49 -0
  110. mlflow/entities/run.py +97 -0
  111. mlflow/entities/run_data.py +84 -0
  112. mlflow/entities/run_info.py +188 -0
  113. mlflow/entities/run_inputs.py +59 -0
  114. mlflow/entities/run_outputs.py +43 -0
  115. mlflow/entities/run_status.py +41 -0
  116. mlflow/entities/run_tag.py +36 -0
  117. mlflow/entities/source_type.py +31 -0
  118. mlflow/entities/span.py +774 -0
  119. mlflow/entities/span_event.py +96 -0
  120. mlflow/entities/span_status.py +102 -0
  121. mlflow/entities/trace.py +317 -0
  122. mlflow/entities/trace_data.py +71 -0
  123. mlflow/entities/trace_info.py +220 -0
  124. mlflow/entities/trace_info_v2.py +162 -0
  125. mlflow/entities/trace_location.py +173 -0
  126. mlflow/entities/trace_state.py +39 -0
  127. mlflow/entities/trace_status.py +68 -0
  128. mlflow/entities/view_type.py +51 -0
  129. mlflow/environment_variables.py +866 -0
  130. mlflow/evaluation/__init__.py +16 -0
  131. mlflow/evaluation/assessment.py +369 -0
  132. mlflow/evaluation/evaluation.py +411 -0
  133. mlflow/evaluation/evaluation_tag.py +61 -0
  134. mlflow/evaluation/fluent.py +48 -0
  135. mlflow/evaluation/utils.py +201 -0
  136. mlflow/exceptions.py +213 -0
  137. mlflow/experiments.py +140 -0
  138. mlflow/gemini/__init__.py +81 -0
  139. mlflow/gemini/autolog.py +186 -0
  140. mlflow/gemini/chat.py +261 -0
  141. mlflow/genai/__init__.py +71 -0
  142. mlflow/genai/datasets/__init__.py +67 -0
  143. mlflow/genai/datasets/evaluation_dataset.py +131 -0
  144. mlflow/genai/evaluation/__init__.py +3 -0
  145. mlflow/genai/evaluation/base.py +411 -0
  146. mlflow/genai/evaluation/constant.py +23 -0
  147. mlflow/genai/evaluation/utils.py +244 -0
  148. mlflow/genai/judges/__init__.py +21 -0
  149. mlflow/genai/judges/databricks.py +404 -0
  150. mlflow/genai/label_schemas/__init__.py +153 -0
  151. mlflow/genai/label_schemas/label_schemas.py +209 -0
  152. mlflow/genai/labeling/__init__.py +159 -0
  153. mlflow/genai/labeling/labeling.py +250 -0
  154. mlflow/genai/optimize/__init__.py +13 -0
  155. mlflow/genai/optimize/base.py +198 -0
  156. mlflow/genai/optimize/optimizers/__init__.py +4 -0
  157. mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
  158. mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
  159. mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
  160. mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
  161. mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
  162. mlflow/genai/optimize/types.py +75 -0
  163. mlflow/genai/optimize/util.py +30 -0
  164. mlflow/genai/prompts/__init__.py +206 -0
  165. mlflow/genai/scheduled_scorers.py +431 -0
  166. mlflow/genai/scorers/__init__.py +26 -0
  167. mlflow/genai/scorers/base.py +492 -0
  168. mlflow/genai/scorers/builtin_scorers.py +765 -0
  169. mlflow/genai/scorers/scorer_utils.py +138 -0
  170. mlflow/genai/scorers/validation.py +165 -0
  171. mlflow/genai/utils/data_validation.py +146 -0
  172. mlflow/genai/utils/enum_utils.py +23 -0
  173. mlflow/genai/utils/trace_utils.py +211 -0
  174. mlflow/groq/__init__.py +42 -0
  175. mlflow/groq/_groq_autolog.py +74 -0
  176. mlflow/johnsnowlabs/__init__.py +888 -0
  177. mlflow/langchain/__init__.py +24 -0
  178. mlflow/langchain/api_request_parallel_processor.py +330 -0
  179. mlflow/langchain/autolog.py +147 -0
  180. mlflow/langchain/chat_agent_langgraph.py +340 -0
  181. mlflow/langchain/constant.py +1 -0
  182. mlflow/langchain/constants.py +1 -0
  183. mlflow/langchain/databricks_dependencies.py +444 -0
  184. mlflow/langchain/langchain_tracer.py +597 -0
  185. mlflow/langchain/model.py +919 -0
  186. mlflow/langchain/output_parsers.py +142 -0
  187. mlflow/langchain/retriever_chain.py +153 -0
  188. mlflow/langchain/runnables.py +527 -0
  189. mlflow/langchain/utils/chat.py +402 -0
  190. mlflow/langchain/utils/logging.py +671 -0
  191. mlflow/langchain/utils/serialization.py +36 -0
  192. mlflow/legacy_databricks_cli/__init__.py +0 -0
  193. mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
  194. mlflow/legacy_databricks_cli/configure/provider.py +482 -0
  195. mlflow/litellm/__init__.py +175 -0
  196. mlflow/llama_index/__init__.py +22 -0
  197. mlflow/llama_index/autolog.py +55 -0
  198. mlflow/llama_index/chat.py +43 -0
  199. mlflow/llama_index/constant.py +1 -0
  200. mlflow/llama_index/model.py +577 -0
  201. mlflow/llama_index/pyfunc_wrapper.py +332 -0
  202. mlflow/llama_index/serialize_objects.py +188 -0
  203. mlflow/llama_index/tracer.py +561 -0
  204. mlflow/metrics/__init__.py +479 -0
  205. mlflow/metrics/base.py +39 -0
  206. mlflow/metrics/genai/__init__.py +25 -0
  207. mlflow/metrics/genai/base.py +101 -0
  208. mlflow/metrics/genai/genai_metric.py +771 -0
  209. mlflow/metrics/genai/metric_definitions.py +450 -0
  210. mlflow/metrics/genai/model_utils.py +371 -0
  211. mlflow/metrics/genai/prompt_template.py +68 -0
  212. mlflow/metrics/genai/prompts/__init__.py +0 -0
  213. mlflow/metrics/genai/prompts/v1.py +422 -0
  214. mlflow/metrics/genai/utils.py +6 -0
  215. mlflow/metrics/metric_definitions.py +619 -0
  216. mlflow/mismatch.py +34 -0
  217. mlflow/mistral/__init__.py +34 -0
  218. mlflow/mistral/autolog.py +71 -0
  219. mlflow/mistral/chat.py +135 -0
  220. mlflow/ml_package_versions.py +452 -0
  221. mlflow/models/__init__.py +97 -0
  222. mlflow/models/auth_policy.py +83 -0
  223. mlflow/models/cli.py +354 -0
  224. mlflow/models/container/__init__.py +294 -0
  225. mlflow/models/container/scoring_server/__init__.py +0 -0
  226. mlflow/models/container/scoring_server/nginx.conf +39 -0
  227. mlflow/models/dependencies_schemas.py +287 -0
  228. mlflow/models/display_utils.py +158 -0
  229. mlflow/models/docker_utils.py +211 -0
  230. mlflow/models/evaluation/__init__.py +23 -0
  231. mlflow/models/evaluation/_shap_patch.py +64 -0
  232. mlflow/models/evaluation/artifacts.py +194 -0
  233. mlflow/models/evaluation/base.py +1811 -0
  234. mlflow/models/evaluation/calibration_curve.py +109 -0
  235. mlflow/models/evaluation/default_evaluator.py +996 -0
  236. mlflow/models/evaluation/deprecated.py +23 -0
  237. mlflow/models/evaluation/evaluator_registry.py +80 -0
  238. mlflow/models/evaluation/evaluators/classifier.py +704 -0
  239. mlflow/models/evaluation/evaluators/default.py +233 -0
  240. mlflow/models/evaluation/evaluators/regressor.py +96 -0
  241. mlflow/models/evaluation/evaluators/shap.py +296 -0
  242. mlflow/models/evaluation/lift_curve.py +178 -0
  243. mlflow/models/evaluation/utils/metric.py +123 -0
  244. mlflow/models/evaluation/utils/trace.py +179 -0
  245. mlflow/models/evaluation/validation.py +434 -0
  246. mlflow/models/flavor_backend.py +93 -0
  247. mlflow/models/flavor_backend_registry.py +53 -0
  248. mlflow/models/model.py +1639 -0
  249. mlflow/models/model_config.py +150 -0
  250. mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
  251. mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
  252. mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
  253. mlflow/models/python_api.py +369 -0
  254. mlflow/models/rag_signatures.py +128 -0
  255. mlflow/models/resources.py +321 -0
  256. mlflow/models/signature.py +662 -0
  257. mlflow/models/utils.py +2054 -0
  258. mlflow/models/wheeled_model.py +280 -0
  259. mlflow/openai/__init__.py +57 -0
  260. mlflow/openai/_agent_tracer.py +364 -0
  261. mlflow/openai/api_request_parallel_processor.py +131 -0
  262. mlflow/openai/autolog.py +509 -0
  263. mlflow/openai/constant.py +1 -0
  264. mlflow/openai/model.py +824 -0
  265. mlflow/openai/utils/chat_schema.py +367 -0
  266. mlflow/optuna/__init__.py +3 -0
  267. mlflow/optuna/storage.py +646 -0
  268. mlflow/plugins/__init__.py +72 -0
  269. mlflow/plugins/base.py +358 -0
  270. mlflow/plugins/builtin/__init__.py +24 -0
  271. mlflow/plugins/builtin/pytorch_plugin.py +150 -0
  272. mlflow/plugins/builtin/sklearn_plugin.py +158 -0
  273. mlflow/plugins/builtin/transformers_plugin.py +187 -0
  274. mlflow/plugins/cli.py +321 -0
  275. mlflow/plugins/discovery.py +340 -0
  276. mlflow/plugins/manager.py +465 -0
  277. mlflow/plugins/registry.py +316 -0
  278. mlflow/plugins/templates/framework_plugin_template.py +329 -0
  279. mlflow/prompt/constants.py +20 -0
  280. mlflow/prompt/promptlab_model.py +197 -0
  281. mlflow/prompt/registry_utils.py +248 -0
  282. mlflow/promptflow/__init__.py +495 -0
  283. mlflow/protos/__init__.py +0 -0
  284. mlflow/protos/assessments_pb2.py +174 -0
  285. mlflow/protos/databricks_artifacts_pb2.py +489 -0
  286. mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
  287. mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
  288. mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
  289. mlflow/protos/databricks_pb2.py +267 -0
  290. mlflow/protos/databricks_trace_server_pb2.py +374 -0
  291. mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
  292. mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
  293. mlflow/protos/facet_feature_statistics_pb2.py +296 -0
  294. mlflow/protos/internal_pb2.py +77 -0
  295. mlflow/protos/mlflow_artifacts_pb2.py +336 -0
  296. mlflow/protos/model_registry_pb2.py +1073 -0
  297. mlflow/protos/scalapb/__init__.py +0 -0
  298. mlflow/protos/scalapb/scalapb_pb2.py +104 -0
  299. mlflow/protos/service_pb2.py +2600 -0
  300. mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
  301. mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
  302. mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
  303. mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
  304. mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
  305. mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
  306. mlflow/py.typed +0 -0
  307. mlflow/pydantic_ai/__init__.py +57 -0
  308. mlflow/pydantic_ai/autolog.py +173 -0
  309. mlflow/pyfunc/__init__.py +3844 -0
  310. mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
  311. mlflow/pyfunc/backend.py +523 -0
  312. mlflow/pyfunc/context.py +78 -0
  313. mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
  314. mlflow/pyfunc/loaders/__init__.py +7 -0
  315. mlflow/pyfunc/loaders/chat_agent.py +117 -0
  316. mlflow/pyfunc/loaders/chat_model.py +125 -0
  317. mlflow/pyfunc/loaders/code_model.py +31 -0
  318. mlflow/pyfunc/loaders/responses_agent.py +112 -0
  319. mlflow/pyfunc/mlserver.py +46 -0
  320. mlflow/pyfunc/model.py +1473 -0
  321. mlflow/pyfunc/scoring_server/__init__.py +604 -0
  322. mlflow/pyfunc/scoring_server/app.py +7 -0
  323. mlflow/pyfunc/scoring_server/client.py +146 -0
  324. mlflow/pyfunc/spark_model_cache.py +48 -0
  325. mlflow/pyfunc/stdin_server.py +44 -0
  326. mlflow/pyfunc/utils/__init__.py +3 -0
  327. mlflow/pyfunc/utils/data_validation.py +224 -0
  328. mlflow/pyfunc/utils/environment.py +22 -0
  329. mlflow/pyfunc/utils/input_converter.py +47 -0
  330. mlflow/pyfunc/utils/serving_data_parser.py +11 -0
  331. mlflow/pytorch/__init__.py +1171 -0
  332. mlflow/pytorch/_lightning_autolog.py +580 -0
  333. mlflow/pytorch/_pytorch_autolog.py +50 -0
  334. mlflow/pytorch/pickle_module.py +35 -0
  335. mlflow/rfunc/__init__.py +42 -0
  336. mlflow/rfunc/backend.py +134 -0
  337. mlflow/runs.py +89 -0
  338. mlflow/server/__init__.py +302 -0
  339. mlflow/server/auth/__init__.py +1224 -0
  340. mlflow/server/auth/__main__.py +4 -0
  341. mlflow/server/auth/basic_auth.ini +6 -0
  342. mlflow/server/auth/cli.py +11 -0
  343. mlflow/server/auth/client.py +537 -0
  344. mlflow/server/auth/config.py +34 -0
  345. mlflow/server/auth/db/__init__.py +0 -0
  346. mlflow/server/auth/db/cli.py +18 -0
  347. mlflow/server/auth/db/migrations/__init__.py +0 -0
  348. mlflow/server/auth/db/migrations/alembic.ini +110 -0
  349. mlflow/server/auth/db/migrations/env.py +76 -0
  350. mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
  351. mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
  352. mlflow/server/auth/db/models.py +67 -0
  353. mlflow/server/auth/db/utils.py +37 -0
  354. mlflow/server/auth/entities.py +165 -0
  355. mlflow/server/auth/logo.py +14 -0
  356. mlflow/server/auth/permissions.py +65 -0
  357. mlflow/server/auth/routes.py +18 -0
  358. mlflow/server/auth/sqlalchemy_store.py +263 -0
  359. mlflow/server/graphql/__init__.py +0 -0
  360. mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
  361. mlflow/server/graphql/graphql_custom_scalars.py +24 -0
  362. mlflow/server/graphql/graphql_errors.py +15 -0
  363. mlflow/server/graphql/graphql_no_batching.py +89 -0
  364. mlflow/server/graphql/graphql_schema_extensions.py +74 -0
  365. mlflow/server/handlers.py +3217 -0
  366. mlflow/server/prometheus_exporter.py +17 -0
  367. mlflow/server/validation.py +30 -0
  368. mlflow/shap/__init__.py +691 -0
  369. mlflow/sklearn/__init__.py +1994 -0
  370. mlflow/sklearn/utils.py +1041 -0
  371. mlflow/smolagents/__init__.py +66 -0
  372. mlflow/smolagents/autolog.py +139 -0
  373. mlflow/smolagents/chat.py +29 -0
  374. mlflow/store/__init__.py +10 -0
  375. mlflow/store/_unity_catalog/__init__.py +1 -0
  376. mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
  377. mlflow/store/_unity_catalog/lineage/constants.py +2 -0
  378. mlflow/store/_unity_catalog/registry/__init__.py +6 -0
  379. mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
  380. mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
  381. mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
  382. mlflow/store/_unity_catalog/registry/utils.py +121 -0
  383. mlflow/store/artifact/__init__.py +0 -0
  384. mlflow/store/artifact/artifact_repo.py +472 -0
  385. mlflow/store/artifact/artifact_repository_registry.py +154 -0
  386. mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
  387. mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
  388. mlflow/store/artifact/cli.py +141 -0
  389. mlflow/store/artifact/cloud_artifact_repo.py +332 -0
  390. mlflow/store/artifact/databricks_artifact_repo.py +729 -0
  391. mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
  392. mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
  393. mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
  394. mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
  395. mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
  396. mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
  397. mlflow/store/artifact/ftp_artifact_repo.py +132 -0
  398. mlflow/store/artifact/gcs_artifact_repo.py +296 -0
  399. mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
  400. mlflow/store/artifact/http_artifact_repo.py +218 -0
  401. mlflow/store/artifact/local_artifact_repo.py +142 -0
  402. mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
  403. mlflow/store/artifact/models_artifact_repo.py +259 -0
  404. mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
  405. mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
  406. mlflow/store/artifact/r2_artifact_repo.py +70 -0
  407. mlflow/store/artifact/runs_artifact_repo.py +265 -0
  408. mlflow/store/artifact/s3_artifact_repo.py +330 -0
  409. mlflow/store/artifact/sftp_artifact_repo.py +141 -0
  410. mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
  411. mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
  412. mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
  413. mlflow/store/artifact/utils/__init__.py +0 -0
  414. mlflow/store/artifact/utils/models.py +148 -0
  415. mlflow/store/db/__init__.py +0 -0
  416. mlflow/store/db/base_sql_model.py +3 -0
  417. mlflow/store/db/db_types.py +10 -0
  418. mlflow/store/db/utils.py +314 -0
  419. mlflow/store/db_migrations/__init__.py +0 -0
  420. mlflow/store/db_migrations/alembic.ini +74 -0
  421. mlflow/store/db_migrations/env.py +84 -0
  422. mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
  423. mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
  424. mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
  425. mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
  426. mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
  427. mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
  428. mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
  429. mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
  430. mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
  431. mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
  432. mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
  433. mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
  434. mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
  435. mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
  436. mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
  437. mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
  438. mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
  439. mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
  440. mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
  441. mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
  442. mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
  443. mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
  444. mlflow/store/db_migrations/versions/__init__.py +0 -0
  445. mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
  446. mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
  447. mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
  448. mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
  449. mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
  450. mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
  451. mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
  452. mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
  453. mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
  454. mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
  455. mlflow/store/entities/__init__.py +3 -0
  456. mlflow/store/entities/paged_list.py +18 -0
  457. mlflow/store/model_registry/__init__.py +10 -0
  458. mlflow/store/model_registry/abstract_store.py +1081 -0
  459. mlflow/store/model_registry/base_rest_store.py +44 -0
  460. mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
  461. mlflow/store/model_registry/dbmodels/__init__.py +0 -0
  462. mlflow/store/model_registry/dbmodels/models.py +206 -0
  463. mlflow/store/model_registry/file_store.py +1091 -0
  464. mlflow/store/model_registry/rest_store.py +481 -0
  465. mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
  466. mlflow/store/tracking/__init__.py +23 -0
  467. mlflow/store/tracking/abstract_store.py +816 -0
  468. mlflow/store/tracking/dbmodels/__init__.py +0 -0
  469. mlflow/store/tracking/dbmodels/initial_models.py +243 -0
  470. mlflow/store/tracking/dbmodels/models.py +1073 -0
  471. mlflow/store/tracking/file_store.py +2438 -0
  472. mlflow/store/tracking/postgres_managed_identity.py +146 -0
  473. mlflow/store/tracking/rest_store.py +1131 -0
  474. mlflow/store/tracking/sqlalchemy_store.py +2785 -0
  475. mlflow/system_metrics/__init__.py +61 -0
  476. mlflow/system_metrics/metrics/__init__.py +0 -0
  477. mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
  478. mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
  479. mlflow/system_metrics/metrics/disk_monitor.py +21 -0
  480. mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
  481. mlflow/system_metrics/metrics/network_monitor.py +34 -0
  482. mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
  483. mlflow/system_metrics/system_metrics_monitor.py +198 -0
  484. mlflow/tracing/__init__.py +16 -0
  485. mlflow/tracing/assessment.py +356 -0
  486. mlflow/tracing/client.py +531 -0
  487. mlflow/tracing/config.py +125 -0
  488. mlflow/tracing/constant.py +105 -0
  489. mlflow/tracing/destination.py +81 -0
  490. mlflow/tracing/display/__init__.py +40 -0
  491. mlflow/tracing/display/display_handler.py +196 -0
  492. mlflow/tracing/export/async_export_queue.py +186 -0
  493. mlflow/tracing/export/inference_table.py +138 -0
  494. mlflow/tracing/export/mlflow_v3.py +137 -0
  495. mlflow/tracing/export/utils.py +70 -0
  496. mlflow/tracing/fluent.py +1417 -0
  497. mlflow/tracing/processor/base_mlflow.py +199 -0
  498. mlflow/tracing/processor/inference_table.py +175 -0
  499. mlflow/tracing/processor/mlflow_v3.py +47 -0
  500. mlflow/tracing/processor/otel.py +73 -0
  501. mlflow/tracing/provider.py +487 -0
  502. mlflow/tracing/trace_manager.py +200 -0
  503. mlflow/tracing/utils/__init__.py +616 -0
  504. mlflow/tracing/utils/artifact_utils.py +28 -0
  505. mlflow/tracing/utils/copy.py +55 -0
  506. mlflow/tracing/utils/environment.py +55 -0
  507. mlflow/tracing/utils/exception.py +21 -0
  508. mlflow/tracing/utils/once.py +35 -0
  509. mlflow/tracing/utils/otlp.py +63 -0
  510. mlflow/tracing/utils/processor.py +54 -0
  511. mlflow/tracing/utils/search.py +292 -0
  512. mlflow/tracing/utils/timeout.py +250 -0
  513. mlflow/tracing/utils/token.py +19 -0
  514. mlflow/tracing/utils/truncation.py +124 -0
  515. mlflow/tracing/utils/warning.py +76 -0
  516. mlflow/tracking/__init__.py +39 -0
  517. mlflow/tracking/_model_registry/__init__.py +1 -0
  518. mlflow/tracking/_model_registry/client.py +764 -0
  519. mlflow/tracking/_model_registry/fluent.py +853 -0
  520. mlflow/tracking/_model_registry/registry.py +67 -0
  521. mlflow/tracking/_model_registry/utils.py +251 -0
  522. mlflow/tracking/_tracking_service/__init__.py +0 -0
  523. mlflow/tracking/_tracking_service/client.py +883 -0
  524. mlflow/tracking/_tracking_service/registry.py +56 -0
  525. mlflow/tracking/_tracking_service/utils.py +275 -0
  526. mlflow/tracking/artifact_utils.py +179 -0
  527. mlflow/tracking/client.py +5900 -0
  528. mlflow/tracking/context/__init__.py +0 -0
  529. mlflow/tracking/context/abstract_context.py +35 -0
  530. mlflow/tracking/context/databricks_cluster_context.py +15 -0
  531. mlflow/tracking/context/databricks_command_context.py +15 -0
  532. mlflow/tracking/context/databricks_job_context.py +49 -0
  533. mlflow/tracking/context/databricks_notebook_context.py +41 -0
  534. mlflow/tracking/context/databricks_repo_context.py +43 -0
  535. mlflow/tracking/context/default_context.py +51 -0
  536. mlflow/tracking/context/git_context.py +32 -0
  537. mlflow/tracking/context/registry.py +98 -0
  538. mlflow/tracking/context/system_environment_context.py +15 -0
  539. mlflow/tracking/default_experiment/__init__.py +1 -0
  540. mlflow/tracking/default_experiment/abstract_context.py +43 -0
  541. mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
  542. mlflow/tracking/default_experiment/registry.py +75 -0
  543. mlflow/tracking/fluent.py +3595 -0
  544. mlflow/tracking/metric_value_conversion_utils.py +93 -0
  545. mlflow/tracking/multimedia.py +206 -0
  546. mlflow/tracking/registry.py +86 -0
  547. mlflow/tracking/request_auth/__init__.py +0 -0
  548. mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
  549. mlflow/tracking/request_auth/registry.py +60 -0
  550. mlflow/tracking/request_header/__init__.py +0 -0
  551. mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
  552. mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
  553. mlflow/tracking/request_header/default_request_header_provider.py +17 -0
  554. mlflow/tracking/request_header/registry.py +79 -0
  555. mlflow/transformers/__init__.py +2982 -0
  556. mlflow/transformers/flavor_config.py +258 -0
  557. mlflow/transformers/hub_utils.py +83 -0
  558. mlflow/transformers/llm_inference_utils.py +468 -0
  559. mlflow/transformers/model_io.py +301 -0
  560. mlflow/transformers/peft.py +51 -0
  561. mlflow/transformers/signature.py +183 -0
  562. mlflow/transformers/torch_utils.py +55 -0
  563. mlflow/types/__init__.py +21 -0
  564. mlflow/types/agent.py +270 -0
  565. mlflow/types/chat.py +240 -0
  566. mlflow/types/llm.py +935 -0
  567. mlflow/types/responses.py +139 -0
  568. mlflow/types/responses_helpers.py +416 -0
  569. mlflow/types/schema.py +1505 -0
  570. mlflow/types/type_hints.py +647 -0
  571. mlflow/types/utils.py +753 -0
  572. mlflow/utils/__init__.py +283 -0
  573. mlflow/utils/_capture_modules.py +256 -0
  574. mlflow/utils/_capture_transformers_modules.py +75 -0
  575. mlflow/utils/_spark_utils.py +201 -0
  576. mlflow/utils/_unity_catalog_oss_utils.py +97 -0
  577. mlflow/utils/_unity_catalog_utils.py +479 -0
  578. mlflow/utils/annotations.py +218 -0
  579. mlflow/utils/arguments_utils.py +16 -0
  580. mlflow/utils/async_logging/__init__.py +1 -0
  581. mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
  582. mlflow/utils/async_logging/async_logging_queue.py +366 -0
  583. mlflow/utils/async_logging/run_artifact.py +38 -0
  584. mlflow/utils/async_logging/run_batch.py +58 -0
  585. mlflow/utils/async_logging/run_operations.py +49 -0
  586. mlflow/utils/autologging_utils/__init__.py +737 -0
  587. mlflow/utils/autologging_utils/client.py +432 -0
  588. mlflow/utils/autologging_utils/config.py +33 -0
  589. mlflow/utils/autologging_utils/events.py +294 -0
  590. mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
  591. mlflow/utils/autologging_utils/metrics_queue.py +71 -0
  592. mlflow/utils/autologging_utils/safety.py +1104 -0
  593. mlflow/utils/autologging_utils/versioning.py +95 -0
  594. mlflow/utils/checkpoint_utils.py +206 -0
  595. mlflow/utils/class_utils.py +6 -0
  596. mlflow/utils/cli_args.py +257 -0
  597. mlflow/utils/conda.py +354 -0
  598. mlflow/utils/credentials.py +231 -0
  599. mlflow/utils/data_utils.py +17 -0
  600. mlflow/utils/databricks_utils.py +1436 -0
  601. mlflow/utils/docstring_utils.py +477 -0
  602. mlflow/utils/doctor.py +133 -0
  603. mlflow/utils/download_cloud_file_chunk.py +43 -0
  604. mlflow/utils/env_manager.py +16 -0
  605. mlflow/utils/env_pack.py +131 -0
  606. mlflow/utils/environment.py +1009 -0
  607. mlflow/utils/exception_utils.py +14 -0
  608. mlflow/utils/file_utils.py +978 -0
  609. mlflow/utils/git_utils.py +77 -0
  610. mlflow/utils/gorilla.py +797 -0
  611. mlflow/utils/import_hooks/__init__.py +363 -0
  612. mlflow/utils/lazy_load.py +51 -0
  613. mlflow/utils/logging_utils.py +168 -0
  614. mlflow/utils/mime_type_utils.py +58 -0
  615. mlflow/utils/mlflow_tags.py +103 -0
  616. mlflow/utils/model_utils.py +486 -0
  617. mlflow/utils/name_utils.py +346 -0
  618. mlflow/utils/nfs_on_spark.py +62 -0
  619. mlflow/utils/openai_utils.py +164 -0
  620. mlflow/utils/os.py +12 -0
  621. mlflow/utils/oss_registry_utils.py +29 -0
  622. mlflow/utils/plugins.py +17 -0
  623. mlflow/utils/process.py +182 -0
  624. mlflow/utils/promptlab_utils.py +146 -0
  625. mlflow/utils/proto_json_utils.py +743 -0
  626. mlflow/utils/pydantic_utils.py +54 -0
  627. mlflow/utils/request_utils.py +279 -0
  628. mlflow/utils/requirements_utils.py +704 -0
  629. mlflow/utils/rest_utils.py +673 -0
  630. mlflow/utils/search_logged_model_utils.py +127 -0
  631. mlflow/utils/search_utils.py +2111 -0
  632. mlflow/utils/secure_loading.py +221 -0
  633. mlflow/utils/security_validation.py +384 -0
  634. mlflow/utils/server_cli_utils.py +61 -0
  635. mlflow/utils/spark_utils.py +15 -0
  636. mlflow/utils/string_utils.py +138 -0
  637. mlflow/utils/thread_utils.py +63 -0
  638. mlflow/utils/time.py +54 -0
  639. mlflow/utils/timeout.py +42 -0
  640. mlflow/utils/uri.py +572 -0
  641. mlflow/utils/validation.py +662 -0
  642. mlflow/utils/virtualenv.py +458 -0
  643. mlflow/utils/warnings_utils.py +25 -0
  644. mlflow/utils/yaml_utils.py +179 -0
  645. mlflow/version.py +24 -0
@@ -0,0 +1,888 @@
1
+ """
2
+ The ``mlflow.johnsnowlabs`` module provides an API for logging and loading Spark NLP and NLU models.
3
+ This module exports the following flavors:
4
+
5
+ Johnsnowlabs (native) format
6
+ Allows models to be loaded as Spark Transformers for scoring in a Spark session.
7
+ Models with this flavor can be loaded as NluPipelines, with underlying Spark MLlib PipelineModel
8
+ This is the main flavor and is always produced.
9
+ :py:mod:`mlflow.pyfunc`
10
+ Supports deployment outside of Spark by instantiating a SparkContext and reading
11
+ input data as a Spark DataFrame prior to scoring. Also supports deployment in Spark
12
+ as a Spark UDF. Models with this flavor can be loaded as Python functions
13
+ for performing inference. This flavor is always produced.
14
+
15
+ This flavor gives you access to `20.000+ state-of-the-art enterprise NLP models in 200+ languages
16
+ <https://nlp.johnsnowlabs.com/models>`_ for medical, finance, legal and many more domains.
17
+ Features include: LLM's, Text Summarization, Question Answering, Named Entity Recognition, Relation
18
+ Extraction, Sentiment Analysis, Spell Checking, Image Classification, Automatic Speech Recognition
19
+ and much more, powered by the latest Transformer Architectures. The models are provided by
20
+ `John Snow Labs <https://www.johnsnowlabs.com/>`_ and requires a `John Snow Labs
21
+ <https://www.johnsnowlabs.com/>`_ Enterprise NLP License. `You can reach out to us
22
+ <https://www.johnsnowlabs.com/schedule-a-demo/>`_ for a research or industry license.
23
+
24
+ These keys must be present in your license json:
25
+
26
+ 1. ``SECRET``: The secret for the John Snow Labs Enterprise NLP Library
27
+ 2. ``SPARK_NLP_LICENSE``: Your John Snow Labs Enterprise NLP License
28
+ 3. ``AWS_ACCESS_KEY_ID``: Your AWS Secret ID for accessing John Snow Labs Enterprise Models
29
+ 4. ``AWS_SECRET_ACCESS_KEY``: Your AWS Secret key for accessing John Snow Labs Enterprise Models
30
+
31
+ You can set them using the following code:
32
+
33
+ .. code-block:: python
34
+
35
+ import os
36
+ import json
37
+
38
+ # Write your raw license.json string into the 'JOHNSNOWLABS_LICENSE_JSON' env variable
39
+ creds = {
40
+ "AWS_ACCESS_KEY_ID": "...",
41
+ "AWS_SECRET_ACCESS_KEY": "...",
42
+ "SPARK_NLP_LICENSE": "...",
43
+ "SECRET": "...",
44
+ }
45
+ os.environ["JOHNSNOWLABS_LICENSE_JSON"] = json.dumps(creds)
46
+ """
47
+
48
+ import json
49
+ import logging
50
+ import os
51
+ import posixpath
52
+ import shutil
53
+ import sys
54
+ from pathlib import Path
55
+ from typing import Any, Optional
56
+
57
+ import yaml
58
+
59
+ import mlflow
60
+ from mlflow import pyfunc
61
+ from mlflow.environment_variables import MLFLOW_DFS_TMP
62
+ from mlflow.models import Model
63
+ from mlflow.models.model import MLMODEL_FILE_NAME
64
+ from mlflow.models.signature import ModelSignature
65
+ from mlflow.models.utils import ModelInputExample, _save_example
66
+ from mlflow.spark import (
67
+ _HadoopFileSystem,
68
+ _maybe_save_model,
69
+ _mlflowdbfs_path,
70
+ _should_use_mlflowdbfs,
71
+ )
72
+ from mlflow.store.artifact.databricks_artifact_repo import DatabricksArtifactRepository
73
+ from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS
74
+ from mlflow.tracking.artifact_utils import (
75
+ _download_artifact_from_uri,
76
+ _get_root_uri_and_artifact_path,
77
+ )
78
+ from mlflow.utils import databricks_utils
79
+ from mlflow.utils.docstring_utils import LOG_MODEL_PARAM_DOCS, format_docstring
80
+ from mlflow.utils.environment import (
81
+ _CONDA_ENV_FILE_NAME,
82
+ _CONSTRAINTS_FILE_NAME,
83
+ _PYTHON_ENV_FILE_NAME,
84
+ _REQUIREMENTS_FILE_NAME,
85
+ _mlflow_conda_env,
86
+ _process_conda_env,
87
+ _process_pip_requirements,
88
+ _PythonEnv,
89
+ )
90
+ from mlflow.utils.file_utils import (
91
+ TempDir,
92
+ get_total_file_size,
93
+ shutil_copytree_without_file_permissions,
94
+ write_to,
95
+ )
96
+ from mlflow.utils.model_utils import (
97
+ _add_code_from_conf_to_system_path,
98
+ _get_flavor_configuration_from_uri,
99
+ _validate_and_copy_code_paths,
100
+ )
101
+ from mlflow.utils.requirements_utils import _get_pinned_requirement
102
+ from mlflow.utils.uri import (
103
+ append_to_uri_path,
104
+ dbfs_hdfs_uri_to_fuse_path,
105
+ generate_tmp_dfs_path,
106
+ get_databricks_profile_uri_from_artifact_uri,
107
+ is_local_uri,
108
+ is_valid_dbfs_uri,
109
+ )
110
+
111
+ FLAVOR_NAME = "johnsnowlabs"
112
+ _JOHNSNOWLABS_ENV_JSON_LICENSE_KEY = "JOHNSNOWLABS_LICENSE_JSON"
113
+ _JOHNSNOWLABS_ENV_HEALTHCARE_SECRET = "HEALTHCARE_SECRET"
114
+ _JOHNSNOWLABS_ENV_VISUAL_SECRET = "VISUAL_SECRET"
115
+ _JOHNSNOWLABS_MODEL_PATH_SUB = "jsl-model"
116
+ _logger = logging.getLogger(__name__)
117
+
118
+
119
+ def _validate_env_vars():
120
+ if _JOHNSNOWLABS_ENV_JSON_LICENSE_KEY not in os.environ:
121
+ raise Exception(
122
+ f"Please set the {_JOHNSNOWLABS_ENV_JSON_LICENSE_KEY}"
123
+ f" environment variable as the raw license.json string from John Snow Labs"
124
+ )
125
+ _set_env_vars()
126
+
127
+
128
+ def _set_env_vars():
129
+ # if json license is detected, we parse it and set the env vars
130
+ loaded_license = json.loads(os.environ[_JOHNSNOWLABS_ENV_JSON_LICENSE_KEY])
131
+ os.environ.update({k: str(v) for k, v in loaded_license.items() if v is not None})
132
+
133
+
134
+ def get_default_pip_requirements():
135
+ """
136
+ Returns:
137
+ A list of default pip requirements for MLflow Models produced by this flavor.
138
+ Calls to :func:`save_model()` and :func:`log_model()` produce a pip environment
139
+ that, at minimum, contains these requirements.
140
+ """
141
+ from johnsnowlabs import settings
142
+
143
+ if (
144
+ _JOHNSNOWLABS_ENV_HEALTHCARE_SECRET not in os.environ
145
+ and _JOHNSNOWLABS_ENV_VISUAL_SECRET not in os.environ
146
+ ):
147
+ raise Exception(
148
+ f"You need to set either {_JOHNSNOWLABS_ENV_HEALTHCARE_SECRET} "
149
+ f"or {_JOHNSNOWLABS_ENV_VISUAL_SECRET} environment variable. "
150
+ f"Please contact John Snow Labs to get one."
151
+ )
152
+
153
+ _SPARK_NLP_JSL_WHEEL_URI = (
154
+ "https://pypi.johnsnowlabs.com/{secret}/spark-nlp-jsl/spark_nlp_jsl-"
155
+ + f"{settings.raw_version_medical}-py3-none-any.whl"
156
+ )
157
+
158
+ _SPARK_NLP_VISUAL_WHEEL_URI = (
159
+ "https://pypi.johnsnowlabs.com/{secret}/spark-ocr/"
160
+ f"spark_ocr-{settings.raw_version_ocr}-py3-none-any.whl"
161
+ )
162
+
163
+ deps = [
164
+ f"johnsnowlabs_for_databricks=={settings.raw_version_jsl_lib}",
165
+ _get_pinned_requirement("pyspark"),
166
+ ]
167
+
168
+ if _JOHNSNOWLABS_ENV_HEALTHCARE_SECRET in os.environ:
169
+ _SPARK_NLP_JSL_WHEEL_URI = _SPARK_NLP_JSL_WHEEL_URI.format(
170
+ secret=os.environ[_JOHNSNOWLABS_ENV_HEALTHCARE_SECRET]
171
+ )
172
+ deps.append(_SPARK_NLP_JSL_WHEEL_URI)
173
+
174
+ if _JOHNSNOWLABS_ENV_VISUAL_SECRET in os.environ:
175
+ _SPARK_NLP_VISUAL_WHEEL_URI = _SPARK_NLP_VISUAL_WHEEL_URI.format(
176
+ secret=os.environ[_JOHNSNOWLABS_ENV_VISUAL_SECRET]
177
+ )
178
+ deps.append(_SPARK_NLP_VISUAL_WHEEL_URI)
179
+
180
+ return deps
181
+
182
+
183
+ def get_default_conda_env():
184
+ """
185
+ Returns:
186
+ The default Conda environment for MLflow Models produced by calls to
187
+ :func:`save_model()` and :func:`log_model()`.
188
+ """
189
+ return _mlflow_conda_env(additional_pip_deps=get_default_pip_requirements())
190
+
191
+
192
+ @format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name="johnsnowlabs"))
193
+ def log_model(
194
+ spark_model,
195
+ artifact_path: Optional[str] = None,
196
+ conda_env=None,
197
+ code_paths=None,
198
+ dfs_tmpdir=None,
199
+ registered_model_name=None,
200
+ signature: ModelSignature = None,
201
+ input_example: ModelInputExample = None,
202
+ await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS,
203
+ pip_requirements=None,
204
+ extra_pip_requirements=None,
205
+ metadata=None,
206
+ store_license=False,
207
+ name: Optional[str] = None,
208
+ params: Optional[dict[str, Any]] = None,
209
+ tags: Optional[dict[str, Any]] = None,
210
+ model_type: Optional[str] = None,
211
+ step: int = 0,
212
+ model_id: Optional[str] = None,
213
+ ):
214
+ """
215
+ Log a ``Johnsnowlabs NLUPipeline`` created via `nlp.load()
216
+ <https://nlp.johnsnowlabs.com/docs/en/jsl/load_api>`_, as an MLflow artifact for the current
217
+ run. This uses the MLlib persistence format and produces an MLflow Model with the
218
+ ``johnsnowlabs`` flavor.
219
+
220
+ Note: If no run is active, it will instantiate a run to obtain a run_id.
221
+
222
+ Args:
223
+ spark_model: NLUPipeline obtained via `nlp.load()
224
+ <https://nlp.johnsnowlabs.com/docs/en/jsl/load_api>`_
225
+ artifact_path: Deprecated. Use `name` instead.
226
+ conda_env: Either a dictionary representation of a Conda environment or the path to a
227
+ Conda environment yaml file. If provided, this describes the environment
228
+ this model should be run in. At minimum, it should specify the dependencies
229
+ contained in :func:`get_default_conda_env()`. If `None`, the default
230
+ :func:`get_default_conda_env()` environment is added to the model.
231
+ The following is an *example* dictionary representation of a Conda
232
+ environment::
233
+
234
+ {
235
+ 'name': 'mlflow-env',
236
+ 'channels': ['defaults'],
237
+ 'dependencies': [
238
+ 'python=3.8.15',
239
+ 'johnsnowlabs'
240
+ ]
241
+ }
242
+ code_paths: {{ code_paths }}
243
+ dfs_tmpdir: Temporary directory path on Distributed (Hadoop) File System (DFS) or local
244
+ filesystem if running in local mode. The model is written in this
245
+ destination and then copied into the model's artifact directory. This is
246
+ necessary as Spark ML models read from and write to DFS if running on a
247
+ cluster. If this operation completes successfully, all temporary files
248
+ created on the DFS are removed. Defaults to ``/tmp/mlflow``.
249
+ registered_model_name: If given, create a model version under
250
+ ``registered_model_name``, also creating a registered model if one
251
+ with the given name does not exist.
252
+ signature: :py:class:`ModelSignature <mlflow.models.ModelSignature>`
253
+ describes model input and output :py:class:`Schema <mlflow.types.Schema>`.
254
+ The model signature can be :py:func:`inferred <mlflow.models.infer_signature>`
255
+ from datasets with valid model input (e.g. the training dataset with target
256
+ column omitted) and valid model output (e.g. model predictions generated on
257
+ the training dataset), for example:
258
+
259
+ .. code-block:: python
260
+
261
+ from mlflow.models.signature import infer_signature
262
+
263
+ train = df.drop_column("target_label")
264
+ predictions = ... # compute model predictions
265
+ signature = infer_signature(train, predictions)
266
+
267
+ input_example: {{ input_example }}
268
+ await_registration_for: Number of seconds to wait for the model version to finish
269
+ being created and is in ``READY`` status. By default, the function
270
+ waits for five minutes. Specify 0 or None to skip waiting.
271
+ pip_requirements: {{ pip_requirements }}
272
+ extra_pip_requirements: {{ extra_pip_requirements }}
273
+ metadata: {{ metadata }}
274
+ store_license: If True, the license will be stored with the model and used and re-loading
275
+ it.
276
+ name: {{ name }}
277
+ params: {{ params }}
278
+ tags: {{ tags }}
279
+ model_type: {{ model_type }}
280
+ step: {{ step }}
281
+ model_id: {{ model_id }}
282
+
283
+ Returns:
284
+ A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
285
+ metadata of the logged model.
286
+
287
+ .. code-block:: python
288
+ :caption: Example
289
+
290
+ import os
291
+ import json
292
+ import pandas as pd
293
+ import mlflow
294
+ from johnsnowlabs import nlp
295
+
296
+ # Write your raw license.json string into the 'JOHNSNOWLABS_LICENSE_JSON' env variable
297
+ creds = {
298
+ "AWS_ACCESS_KEY_ID": "...",
299
+ "AWS_SECRET_ACCESS_KEY": "...",
300
+ "SPARK_NLP_LICENSE": "...",
301
+ "SECRET": "...",
302
+ }
303
+ os.environ["JOHNSNOWLABS_LICENSE_JSON"] = json.dumps(creds)
304
+
305
+ # Download & Install Jars/Wheels if missing and Start a spark Session
306
+ nlp.start()
307
+
308
+ # For more details on trainable models and parameterization like embedding choice see
309
+ # https://nlp.johnsnowlabs.com/docs/en/jsl/training
310
+ trainable_classifier = nlp.load("train.classifier")
311
+
312
+ # Create a sample training dataset
313
+ data = pd.DataFrame(
314
+ {"text": ["I hate covid ", "I love covid"], "y": ["negative", "positive"]}
315
+ )
316
+
317
+ # Fit and get a trained classifier
318
+ trained_classifier = trainable_classifier.fit(data)
319
+ trained_classifier.predict("He hates covid")
320
+
321
+ # Log it
322
+ mlflow.johnsnowlabs.log_model(trained_classifier, name="my_trained_model")
323
+ """
324
+
325
+ _validate_env_vars()
326
+ run_id = mlflow.tracking.fluent._get_or_start_run().info.run_id
327
+ run_root_artifact_uri = mlflow.get_artifact_uri()
328
+ remote_model_path = None
329
+
330
+ # If the artifact URI is a local filesystem path, defer to Model.log() to persist the model,
331
+ # since Spark may not be able to write directly to the driver's filesystem. For example,
332
+ # writing to `file:/uri` will write to the local filesystem from each executor, which will
333
+ # be incorrect on multi-node clusters.
334
+ # If the artifact URI is not a local filesystem path we attempt to write directly to the
335
+ # artifact repo via Spark. If this fails, we defer to Model.log().
336
+ if is_local_uri(run_root_artifact_uri) or not _maybe_save_model(
337
+ spark_model,
338
+ append_to_uri_path(run_root_artifact_uri, name),
339
+ ):
340
+ return Model.log(
341
+ artifact_path=artifact_path,
342
+ name=name,
343
+ flavor=mlflow.johnsnowlabs,
344
+ spark_model=spark_model,
345
+ conda_env=conda_env,
346
+ code_paths=code_paths,
347
+ dfs_tmpdir=dfs_tmpdir,
348
+ registered_model_name=registered_model_name,
349
+ signature=signature,
350
+ input_example=input_example,
351
+ await_registration_for=await_registration_for,
352
+ pip_requirements=pip_requirements,
353
+ extra_pip_requirements=extra_pip_requirements,
354
+ metadata=metadata,
355
+ params=params,
356
+ tags=tags,
357
+ model_type=model_type,
358
+ step=step,
359
+ model_id=model_id,
360
+ )
361
+ # Otherwise, override the default model log behavior and save model directly to artifact repo
362
+ logged_model = mlflow.initialize_logged_model(
363
+ name=name,
364
+ source_run_id=run.info.run_id if (run := mlflow.active_run()) else None,
365
+ model_type=model_type,
366
+ params=params,
367
+ tags=tags,
368
+ )
369
+ mlflow_model = Model(artifact_path=logged_model.artifact_location, run_id=run_id)
370
+ with TempDir() as tmp:
371
+ tmp_model_metadata_dir = tmp.path()
372
+ _save_model_metadata(
373
+ tmp_model_metadata_dir,
374
+ spark_model,
375
+ mlflow_model,
376
+ conda_env,
377
+ code_paths,
378
+ signature=signature,
379
+ input_example=input_example,
380
+ pip_requirements=pip_requirements,
381
+ extra_pip_requirements=extra_pip_requirements,
382
+ remote_model_path=remote_model_path,
383
+ store_license=store_license,
384
+ )
385
+ mlflow.tracking.fluent.log_artifacts(tmp_model_metadata_dir, name)
386
+ mlflow.tracking.fluent._record_logged_model(mlflow_model)
387
+ if registered_model_name is not None:
388
+ mlflow.register_model(
389
+ f"runs:/{logged_model.model_id}",
390
+ registered_model_name,
391
+ await_registration_for,
392
+ )
393
+ return mlflow_model.get_model_info()
394
+
395
+
396
+ def _save_model_metadata(
397
+ dst_dir,
398
+ spark_model,
399
+ mlflow_model,
400
+ conda_env,
401
+ code_paths,
402
+ signature=None,
403
+ input_example=None,
404
+ pip_requirements=None,
405
+ extra_pip_requirements=None,
406
+ remote_model_path=None,
407
+ store_license=False,
408
+ ):
409
+ """
410
+ Saves model metadata into the passed-in directory.
411
+ If mlflowdbfs is not used, the persisted metadata assumes that a model can be
412
+ loaded from a relative path to the metadata file (currently hard-coded to "jsl-model").
413
+ If mlflowdbfs is used, remote_model_path should be provided, and the model needs to
414
+ be loaded from the remote_model_path.
415
+ """
416
+
417
+ if signature is not None:
418
+ mlflow_model.signature = signature
419
+ if input_example is not None:
420
+ _save_example(mlflow_model, input_example, dst_dir)
421
+
422
+ code_dir_subpath = _validate_and_copy_code_paths(code_paths, dst_dir)
423
+
424
+ # add the johnsnowlabs flavor
425
+ import pyspark
426
+
427
+ mlflow_model.add_flavor(
428
+ FLAVOR_NAME,
429
+ pyspark_version=pyspark.__version__,
430
+ model_data=_JOHNSNOWLABS_MODEL_PATH_SUB,
431
+ code=code_dir_subpath,
432
+ )
433
+
434
+ # add the pyfunc flavor
435
+ pyfunc.add_to_model(
436
+ mlflow_model,
437
+ loader_module="mlflow.johnsnowlabs",
438
+ data=_JOHNSNOWLABS_MODEL_PATH_SUB,
439
+ conda_env=_CONDA_ENV_FILE_NAME,
440
+ python_env=_PYTHON_ENV_FILE_NAME,
441
+ code=code_dir_subpath,
442
+ )
443
+ if size := get_total_file_size(dst_dir):
444
+ mlflow_model.model_size_bytes = size
445
+ mlflow_model.save(str(Path(dst_dir) / MLMODEL_FILE_NAME))
446
+
447
+ if conda_env is None:
448
+ default_reqs = get_default_pip_requirements() if pip_requirements is None else None
449
+ conda_env, pip_requirements, pip_constraints = _process_pip_requirements(
450
+ default_reqs,
451
+ pip_requirements,
452
+ extra_pip_requirements,
453
+ )
454
+ else:
455
+ conda_env, pip_requirements, pip_constraints = _process_conda_env(conda_env)
456
+
457
+ with open(str(Path(dst_dir) / _CONDA_ENV_FILE_NAME), "w") as f:
458
+ yaml.safe_dump(conda_env, stream=f, default_flow_style=False)
459
+
460
+ # Save `constraints.txt` if necessary
461
+ if pip_constraints:
462
+ write_to(str(Path(dst_dir) / _CONSTRAINTS_FILE_NAME), "\n".join(pip_constraints))
463
+ write_to(str(Path(dst_dir) / _REQUIREMENTS_FILE_NAME), "\n".join(pip_requirements))
464
+
465
+ _PythonEnv.current().to_yaml(str(Path(dst_dir) / _PYTHON_ENV_FILE_NAME))
466
+
467
+ _save_jars_and_lic(dst_dir)
468
+
469
+
470
+ def _save_jars_and_lic(dst_dir, store_license=False):
471
+ from johnsnowlabs.auto_install.jsl_home import get_install_suite_from_jsl_home
472
+ from johnsnowlabs.py_models.jsl_secrets import JslSecrets
473
+
474
+ deps_data_path = Path(dst_dir) / _JOHNSNOWLABS_MODEL_PATH_SUB / "jars.jsl"
475
+ deps_data_path.mkdir(parents=True, exist_ok=True)
476
+
477
+ suite = get_install_suite_from_jsl_home(
478
+ False,
479
+ visual=_JOHNSNOWLABS_ENV_VISUAL_SECRET in os.environ,
480
+ )
481
+ if suite.hc.get_java_path():
482
+ shutil.copy2(suite.hc.get_java_path(), deps_data_path / "hc_jar.jar")
483
+ if suite.nlp.get_java_path():
484
+ shutil.copy2(suite.nlp.get_java_path(), deps_data_path / "os_jar.jar")
485
+ if suite.ocr.get_java_path():
486
+ shutil.copy2(suite.ocr.get_java_path(), deps_data_path / "visual_nlp.jar")
487
+
488
+ if store_license:
489
+ # Read the secrets from env vars and write to license.json
490
+ secrets = JslSecrets.build_or_try_find_secrets()
491
+ if secrets.HC_LICENSE:
492
+ deps_data_path.joinpath("license.json").write(secrets.json())
493
+
494
+
495
+ @format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name="johnsnowlabs"))
496
+ def save_model(
497
+ spark_model,
498
+ path,
499
+ mlflow_model=None,
500
+ conda_env=None,
501
+ code_paths=None,
502
+ dfs_tmpdir=None,
503
+ signature: ModelSignature = None,
504
+ input_example: ModelInputExample = None,
505
+ pip_requirements=None,
506
+ extra_pip_requirements=None,
507
+ metadata=None,
508
+ store_license=False,
509
+ ):
510
+ """
511
+ Save a Spark johnsnowlabs Model to a local path.
512
+
513
+ By default, this function saves models using the Spark MLlib persistence mechanism.
514
+
515
+ Args:
516
+ spark_model: Either a pyspark.ml.pipeline.PipelineModel or nlu.NLUPipeline object to be
517
+ saved. `Every johnsnowlabs model <https://nlp.johnsnowlabs.com/models>`_
518
+ is a PipelineModel and loadable as nlu.NLUPipeline.
519
+ path: Local path where the model is to be saved.
520
+ mlflow_model: MLflow model config this flavor is being added to.
521
+ conda_env: Either a dictionary representation of a Conda environment or the path to a
522
+ Conda environment yaml file. If provided, this describes the environment
523
+ this model should be run in. At minimum, it should specify the dependencies
524
+ contained in :func:`get_default_conda_env()`. If `None`, the default
525
+ :func:`get_default_conda_env()` environment is added to the model.
526
+ The following is an *example* dictionary representation of a Conda
527
+ environment::
528
+
529
+ {
530
+ 'name': 'mlflow-env',
531
+ 'channels': ['defaults'],
532
+ 'dependencies': [
533
+ 'python=3.8.15',
534
+ 'johnsnowlabs'
535
+ ]
536
+ }
537
+ code_paths: {{ code_paths }}
538
+ dfs_tmpdir: Temporary directory path on Distributed (Hadoop) File System (DFS) or local
539
+ filesystem if running in local mode. The model is be written in this
540
+ destination and then copied to the requested local path. This is necessary
541
+ as Spark ML models read from and write to DFS if running on a cluster. All
542
+ temporary files created on the DFS are removed if this operation
543
+ completes successfully. Defaults to ``/tmp/mlflow``.
544
+ signature: :py:class:`ModelSignature <mlflow.models.ModelSignature>`
545
+ describes model input and output :py:class:`Schema <mlflow.types.Schema>`.
546
+ The model signature can be :py:func:`inferred <mlflow.models.infer_signature>`
547
+ from datasets with valid model input (e.g. the training dataset with target
548
+ column omitted) and valid model output (e.g. model predictions generated on
549
+ the training dataset), for example:
550
+
551
+ .. code-block:: python
552
+
553
+ from mlflow.models.signature import infer_signature
554
+
555
+ train = df.drop_column("target_label")
556
+ predictions = ... # compute model predictions
557
+ signature = infer_signature(train, predictions)
558
+ input_example: {{ input_example }}
559
+ pip_requirements: {{ pip_requirements }}
560
+ extra_pip_requirements: {{ extra_pip_requirements }}
561
+ metadata: {{ metadata }}
562
+ store_license: If True, the license will be stored with the model and used and
563
+ re-loading it.
564
+
565
+ .. code-block:: python
566
+ :caption: Example
567
+
568
+ from johnsnowlabs import nlp
569
+ import mlflow
570
+ import os
571
+
572
+ # Write your raw license.json string into the 'JOHNSNOWLABS_LICENSE_JSON' env variable
573
+ creds = {
574
+ "AWS_ACCESS_KEY_ID": "...",
575
+ "AWS_SECRET_ACCESS_KEY": "...",
576
+ "SPARK_NLP_LICENSE": "...",
577
+ "SECRET": "...",
578
+ }
579
+ os.environ["JOHNSNOWLABS_LICENSE_JSON"] = json.dumps(creds)
580
+
581
+ # Download & Install Jars/Wheels if missing and Start a spark Session
582
+ nlp.start()
583
+
584
+ # load a model
585
+ model = nlp.load("en.classify.bert_sequence.covid_sentiment")
586
+ model.predict(["I hate covid", "I love covid"])
587
+
588
+ # Save model as pyfunc and johnsnowlabs format
589
+ mlflow.johnsnowlabs.save_model(model, "saved_model")
590
+ model = mlflow.johnsnowlabs.load_model("saved_model")
591
+ # Predict with reloaded model,
592
+ # supports datatypes defined in https://nlp.johnsnowlabs.com/docs/en/jsl/predict_api#supported-data-types
593
+ model.predict(["I hate covid", "I love covid"])
594
+ """
595
+ _validate_env_vars()
596
+ if mlflow_model is None:
597
+ mlflow_model = Model()
598
+ if metadata is not None:
599
+ mlflow_model.metadata = metadata
600
+ # Spark ML stores the model on DFS if running on a cluster
601
+ # Save it to a DFS temp dir first and copy it to local path
602
+ if dfs_tmpdir is None:
603
+ dfs_tmpdir = MLFLOW_DFS_TMP.get()
604
+ tmp_path = generate_tmp_dfs_path(dfs_tmpdir)
605
+
606
+ _unpack_and_save_model(spark_model, tmp_path)
607
+ sparkml_data_path = os.path.abspath(str(Path(path) / _JOHNSNOWLABS_MODEL_PATH_SUB))
608
+ # We're copying the Spark model from DBFS to the local filesystem if (a) the temporary DFS URI
609
+ # we saved the Spark model to is a DBFS URI ("dbfs:/my-directory"), or (b) if we're running
610
+ # on a Databricks cluster and the URI is schemeless (e.g. looks like a filesystem absolute path
611
+ # like "/my-directory")
612
+ copying_from_dbfs = is_valid_dbfs_uri(tmp_path) or (
613
+ databricks_utils.is_in_cluster() and posixpath.abspath(tmp_path) == tmp_path
614
+ )
615
+ if copying_from_dbfs and databricks_utils.is_dbfs_fuse_available():
616
+ tmp_path_fuse = dbfs_hdfs_uri_to_fuse_path(tmp_path)
617
+ shutil.move(src=tmp_path_fuse, dst=sparkml_data_path)
618
+ else:
619
+ _HadoopFileSystem.copy_to_local_file(tmp_path, sparkml_data_path, remove_src=True)
620
+ _save_model_metadata(
621
+ dst_dir=path,
622
+ spark_model=spark_model,
623
+ mlflow_model=mlflow_model,
624
+ conda_env=conda_env,
625
+ code_paths=code_paths,
626
+ signature=signature,
627
+ input_example=input_example,
628
+ pip_requirements=pip_requirements,
629
+ extra_pip_requirements=extra_pip_requirements,
630
+ store_license=store_license,
631
+ )
632
+
633
+
634
+ def _load_model_databricks(dfs_tmpdir, local_model_path):
635
+ from johnsnowlabs import nlp
636
+
637
+ # Spark ML expects the model to be stored on DFS
638
+ # Copy the model to a temp DFS location first. We cannot delete this file, as
639
+ # Spark may read from it at any point.
640
+ fuse_dfs_tmpdir = dbfs_hdfs_uri_to_fuse_path(dfs_tmpdir)
641
+ os.makedirs(fuse_dfs_tmpdir)
642
+ # Workaround for inability to use shutil.copytree with DBFS FUSE due to permission-denied
643
+ # errors on passthrough-enabled clusters when attempting to copy permission bits for directories
644
+ shutil_copytree_without_file_permissions(src_dir=local_model_path, dst_dir=fuse_dfs_tmpdir)
645
+ return nlp.load(path=dfs_tmpdir)
646
+
647
+
648
+ def _load_model(model_uri, dfs_tmpdir_base=None, local_model_path=None):
649
+ from johnsnowlabs import nlp
650
+
651
+ dfs_tmpdir = generate_tmp_dfs_path(dfs_tmpdir_base or MLFLOW_DFS_TMP.get())
652
+ if databricks_utils.is_in_cluster() and databricks_utils.is_dbfs_fuse_available():
653
+ return _load_model_databricks(
654
+ dfs_tmpdir, local_model_path or _download_artifact_from_uri(model_uri)
655
+ )
656
+ # model_uri = _HadoopFileSystem.maybe_copy_from_uri(model_uri, dfs_tmpdir, local_model_path)
657
+ if model_uri and not local_model_path:
658
+ local_model_path = _download_artifact_from_uri(model_uri)
659
+ _get_or_create_sparksession(local_model_path)
660
+
661
+ if _JOHNSNOWLABS_MODEL_PATH_SUB not in local_model_path:
662
+ local_model_path = str(Path(local_model_path) / _JOHNSNOWLABS_MODEL_PATH_SUB)
663
+
664
+ return nlp.load(path=local_model_path)
665
+
666
+
667
+ def load_model(model_uri, dfs_tmpdir=None, dst_path=None):
668
+ """
669
+ Load the Johnsnowlabs MLflow model from the path.
670
+
671
+ Args:
672
+ model_uri: The location, in URI format, of the MLflow model. For example:
673
+
674
+ - ``/Users/me/path/to/local/model``
675
+ - ``relative/path/to/local/model``
676
+ - ``s3://my_bucket/path/to/model``
677
+ - ``runs:/<mlflow_run_id>/run-relative/path/to/model``
678
+ - ``models:/<model_name>/<model_version>``
679
+ - ``models:/<model_name>/<stage>``
680
+
681
+ For more information about supported URI schemes, see
682
+ `Referencing Artifacts <https://www.mlflow.org/docs/latest/concepts.html#
683
+ artifact-locations>`_.
684
+ dfs_tmpdir: Temporary directory path on Distributed (Hadoop) File System (DFS) or local
685
+ filesystem if running in local mode. The model is loaded from this
686
+ destination. Defaults to ``/tmp/mlflow``.
687
+ dst_path: The local filesystem path to which to download the model artifact.
688
+ This directory must already exist. If unspecified, a local output
689
+ path will be created.
690
+
691
+ Returns:
692
+ A
693
+ `nlu.NLUPipeline <https://nlp.johnsnowlabs.com/docs/en/jsl/predict_api>`_.
694
+
695
+ .. code-block:: python
696
+ :caption: Example
697
+
698
+ import mlflow
699
+ from johnsnowlabs import nlp
700
+ import os
701
+
702
+ # Write your raw license.json string into the 'JOHNSNOWLABS_LICENSE_JSON' env variable
703
+ creds = {
704
+ "AWS_ACCESS_KEY_ID": "...",
705
+ "AWS_SECRET_ACCESS_KEY": "...",
706
+ "SPARK_NLP_LICENSE": "...",
707
+ "SECRET": "...",
708
+ }
709
+ os.environ["JOHNSNOWLABS_LICENSE_JSON"] = json.dumps(creds)
710
+
711
+ # start a spark session
712
+ nlp.start()
713
+ # Load you MLflow Model
714
+ model = mlflow.johnsnowlabs.load_model("johnsnowlabs_model")
715
+
716
+ # Make predictions on test documents
717
+ # supports datatypes defined in https://nlp.johnsnowlabs.com/docs/en/jsl/predict_api#supported-data-types
718
+ prediction = model.transform(["I love Covid", "I hate Covid"])
719
+ """
720
+ # This MUST be called prior to appending the model flavor to `model_uri` in order
721
+ # for `artifact_path` to take on the correct value for model loading via mlflowdbfs.
722
+ _validate_env_vars()
723
+ root_uri, artifact_path = _get_root_uri_and_artifact_path(model_uri)
724
+
725
+ flavor_conf = _get_flavor_configuration_from_uri(model_uri, FLAVOR_NAME, _logger)
726
+ local_mlflow_model_path = _download_artifact_from_uri(
727
+ artifact_uri=model_uri, output_path=dst_path
728
+ )
729
+ _add_code_from_conf_to_system_path(local_mlflow_model_path, flavor_conf)
730
+
731
+ if _should_use_mlflowdbfs(model_uri):
732
+ from pyspark.ml.pipeline import PipelineModel
733
+
734
+ mlflowdbfs_path = _mlflowdbfs_path(
735
+ DatabricksArtifactRepository._extract_run_id(model_uri), artifact_path
736
+ )
737
+ with databricks_utils.MlflowCredentialContext(
738
+ get_databricks_profile_uri_from_artifact_uri(root_uri)
739
+ ):
740
+ return PipelineModel.load(mlflowdbfs_path)
741
+
742
+ sparkml_model_uri = append_to_uri_path(model_uri, flavor_conf["model_data"])
743
+ local_sparkml_model_path = str(Path(local_mlflow_model_path) / flavor_conf["model_data"])
744
+ return _load_model(
745
+ model_uri=sparkml_model_uri,
746
+ dfs_tmpdir_base=dfs_tmpdir,
747
+ local_model_path=local_sparkml_model_path,
748
+ )
749
+
750
+
751
+ def _load_pyfunc(path, spark=None):
752
+ """Load PyFunc implementation. Called by ``pyfunc.load_model``.
753
+
754
+ Args:
755
+ path: Local filesystem path to the MLflow Model with the ``johnsnowlabs`` flavor.
756
+ spark: Optionally pass spark context when using pyfunc as UDF. required, because
757
+ we cannot fetch the Sparkcontext inside of the Workernode which executes the UDF.
758
+
759
+ Returns:
760
+ None.
761
+
762
+ """
763
+ return _PyFuncModelWrapper(
764
+ _load_model(model_uri=path),
765
+ spark or _get_or_create_sparksession(path),
766
+ )
767
+
768
+
769
+ def _get_or_create_sparksession(model_path=None): # noqa: D417
770
+ """Check if SparkSession running and get it.
771
+
772
+ If none exists, create a new one using jars in model_path. If model_path not defined, rely on
773
+ nlp.start() to create a new one using johnsnowlabs Jar resolution method. See
774
+ https://nlp.johnsnowlabs.com/docs/en/jsl/start-a-sparksession and
775
+ https://nlp.johnsnowlabs.com/docs/en/jsl/install_advanced.
776
+
777
+ Args:
778
+ model_path:
779
+
780
+ Returns:
781
+
782
+ """
783
+ from johnsnowlabs import nlp
784
+
785
+ from mlflow.utils._spark_utils import _get_active_spark_session
786
+
787
+ _validate_env_vars()
788
+
789
+ spark = _get_active_spark_session()
790
+ if spark is None:
791
+ spark_conf = {}
792
+ spark_conf["spark.python.worker.reuse"] = "true"
793
+ os.environ["PYSPARK_PYTHON"] = sys.executable
794
+ os.environ["PYSPARK_DRIVER_PYTHON"] = sys.executable
795
+ if model_path:
796
+ jar_paths, license_path = _fetch_deps_from_path(model_path)
797
+ if license_path:
798
+ with open(license_path) as f:
799
+ loaded_license = json.load(f)
800
+ os.environ.update(
801
+ {k: str(v) for k, v in loaded_license.items() if v is not None}
802
+ )
803
+ os.environ["JSL_NLP_LICENSE"] = loaded_license["HC_LICENSE"]
804
+ _logger.info("Starting a new Session with Jars: %s", jar_paths)
805
+ spark = nlp.start(
806
+ nlp=False,
807
+ spark_nlp=False,
808
+ jar_paths=jar_paths,
809
+ json_license_path=license_path,
810
+ create_jsl_home_if_missing=False,
811
+ spark_conf=spark_conf,
812
+ )
813
+ else:
814
+ spark = nlp.start()
815
+ return spark
816
+
817
+
818
+ def _fetch_deps_from_path(local_model_path):
819
+ if _JOHNSNOWLABS_MODEL_PATH_SUB not in local_model_path:
820
+ local_model_path = Path(local_model_path) / _JOHNSNOWLABS_MODEL_PATH_SUB / "jars.jsl"
821
+ else:
822
+ local_model_path = Path(local_model_path) / "jars.jsl"
823
+
824
+ jar_paths = [
825
+ str(local_model_path / file) for file in local_model_path.iterdir() if file.suffix == ".jar"
826
+ ]
827
+ license_path = [
828
+ str(local_model_path / file)
829
+ for file in local_model_path.iterdir()
830
+ if file.name == "license.json"
831
+ ]
832
+
833
+ license_path = license_path[0] if license_path else None
834
+ return jar_paths, license_path
835
+
836
+
837
+ def _unpack_and_save_model(spark_model, dst):
838
+ from pyspark.ml import PipelineModel
839
+
840
+ if isinstance(spark_model, _PyFuncModelWrapper):
841
+ spark_model = spark_model.spark_model
842
+ if isinstance(spark_model, PipelineModel):
843
+ spark_model.write().overwrite().save(dst)
844
+ else:
845
+ # nlu pipe
846
+ spark_model.predict("Init")
847
+ try:
848
+ spark_model.vanilla_transformer_pipe.write().overwrite().save(dst)
849
+ except Exception:
850
+ # for mlflowdbfs_path we cannot use overwrite, gives
851
+ # org.apache.hadoop.fs.UnsupportedFileSystemException: No FileSystem for scheme
852
+ # "mlflowdbfs"
853
+ spark_model.save(dst)
854
+
855
+
856
+ class _PyFuncModelWrapper:
857
+ """
858
+ Wrapper around NLUPipeline providing interface for scoring pandas DataFrame.
859
+ """
860
+
861
+ def __init__(
862
+ self,
863
+ spark_model,
864
+ spark=None,
865
+ ):
866
+ # we have this `or`, so we support _PyFuncModelWrapper(nlu_ref)
867
+ self.spark = spark or _get_or_create_sparksession()
868
+ self.spark_model = spark_model
869
+
870
+ def get_raw_model(self):
871
+ """
872
+ Returns the underlying model.
873
+ """
874
+ return self.spark_model
875
+
876
+ def predict(self, text, params: Optional[dict[str, Any]] = None):
877
+ """Generate predictions given input data in a pandas DataFrame.
878
+
879
+ Args:
880
+ text: pandas DataFrame containing input data.
881
+ params: Additional parameters to pass to the model for inference.
882
+
883
+ Returns:
884
+ List with model predictions.
885
+
886
+ """
887
+ output_level = params.get("output_level", "") if params else ""
888
+ return self.spark_model.predict(text, output_level=output_level).reset_index().to_json()