genesis-flow 1.0.1__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {genesis_flow-1.0.1.dist-info → genesis_flow-1.0.4.dist-info}/METADATA +32 -28
  2. {genesis_flow-1.0.1.dist-info → genesis_flow-1.0.4.dist-info}/RECORD +29 -26
  3. mlflow/azure/config.py +13 -7
  4. mlflow/azure/connection_factory.py +15 -2
  5. mlflow/data/dataset_source_registry.py +8 -0
  6. mlflow/gateway/providers/bedrock.py +298 -0
  7. mlflow/genai/datasets/databricks_evaluation_dataset_source.py +77 -0
  8. mlflow/genai/datasets/evaluation_dataset.py +8 -5
  9. mlflow/genai/scorers/base.py +22 -14
  10. mlflow/langchain/utils/chat.py +10 -0
  11. mlflow/models/container/__init__.py +2 -2
  12. mlflow/spark/__init__.py +1286 -0
  13. mlflow/store/artifact/azure_blob_artifact_repo.py +1 -1
  14. mlflow/store/artifact/azure_data_lake_artifact_repo.py +1 -1
  15. mlflow/store/artifact/gcs_artifact_repo.py +1 -1
  16. mlflow/store/artifact/local_artifact_repo.py +2 -1
  17. mlflow/store/artifact/s3_artifact_repo.py +173 -3
  18. mlflow/tracing/client.py +139 -49
  19. mlflow/tracing/export/mlflow_v3.py +8 -11
  20. mlflow/tracing/provider.py +5 -1
  21. mlflow/tracking/_model_registry/client.py +5 -1
  22. mlflow/tracking/_tracking_service/utils.py +17 -5
  23. mlflow/utils/file_utils.py +2 -1
  24. mlflow/utils/rest_utils.py +4 -0
  25. mlflow/version.py +2 -2
  26. {genesis_flow-1.0.1.dist-info → genesis_flow-1.0.4.dist-info}/WHEEL +0 -0
  27. {genesis_flow-1.0.1.dist-info → genesis_flow-1.0.4.dist-info}/entry_points.txt +0 -0
  28. {genesis_flow-1.0.1.dist-info → genesis_flow-1.0.4.dist-info}/licenses/LICENSE.txt +0 -0
  29. {genesis_flow-1.0.1.dist-info → genesis_flow-1.0.4.dist-info}/top_level.txt +0 -0
@@ -41,7 +41,7 @@ class AzureBlobArtifactRepository(ArtifactRepository, MultipartUploadMixin):
41
41
  - DefaultAzureCredential is configured
42
42
  """
43
43
 
44
- def __init__(self, artifact_uri: str, tracking_uri: Optional[str] = None, client=None) -> None:
44
+ def __init__(self, artifact_uri: str, client=None, tracking_uri: Optional[str] = None) -> None:
45
45
  super().__init__(artifact_uri, tracking_uri)
46
46
 
47
47
  _DEFAULT_TIMEOUT = 600 # 10 minutes
@@ -82,9 +82,9 @@ class AzureDataLakeArtifactRepository(CloudArtifactRepository):
82
82
  def __init__(
83
83
  self,
84
84
  artifact_uri: str,
85
- tracking_uri: Optional[str] = None,
86
85
  credential=None,
87
86
  credential_refresh_def=None,
87
+ tracking_uri: Optional[str] = None,
88
88
  ) -> None:
89
89
  super().__init__(artifact_uri, tracking_uri)
90
90
  _DEFAULT_TIMEOUT = 600 # 10 minutes
@@ -43,9 +43,9 @@ class GCSArtifactRepository(ArtifactRepository, MultipartUploadMixin):
43
43
  def __init__(
44
44
  self,
45
45
  artifact_uri: str,
46
- tracking_uri: Optional[str] = None,
47
46
  client=None,
48
47
  credential_refresh_def=None,
48
+ tracking_uri: Optional[str] = None,
49
49
  ) -> None:
50
50
  super().__init__(artifact_uri, tracking_uri)
51
51
  from google.auth.exceptions import DefaultCredentialsError
@@ -14,6 +14,7 @@ from mlflow.utils.file_utils import (
14
14
  local_file_uri_to_path,
15
15
  mkdir,
16
16
  relative_path_to_artifact_path,
17
+ shutil_copytree_without_file_permissions,
17
18
  )
18
19
  from mlflow.utils.uri import validate_path_is_safe
19
20
 
@@ -64,7 +65,7 @@ class LocalArtifactRepository(ArtifactRepository):
64
65
  )
65
66
  if not os.path.exists(artifact_dir):
66
67
  mkdir(artifact_dir)
67
- shutil.copytree(src=local_dir, dst=artifact_dir, dirs_exist_ok=True)
68
+ shutil_copytree_without_file_permissions(local_dir, artifact_dir)
68
69
 
69
70
  def download_artifacts(self, artifact_path, dst_path=None):
70
71
  """
@@ -122,16 +122,54 @@ def _get_s3_client(
122
122
 
123
123
 
124
124
  class S3ArtifactRepository(ArtifactRepository, MultipartUploadMixin):
125
- """Stores artifacts on Amazon S3."""
125
+ """
126
+ Stores artifacts on Amazon S3.
127
+
128
+ This repository provides MLflow artifact storage using Amazon S3 as the backend.
129
+ It supports both single-file uploads and multipart uploads for large files,
130
+ with automatic content type detection and configurable upload parameters.
131
+
132
+ The repository uses boto3 for S3 operations and supports various authentication
133
+ methods including AWS credentials, IAM roles, and environment variables.
134
+
135
+ Environment Variables:
136
+ AWS_ACCESS_KEY_ID: AWS access key ID for authentication
137
+ AWS_SECRET_ACCESS_KEY: AWS secret access key for authentication
138
+ AWS_SESSION_TOKEN: AWS session token for temporary credentials
139
+ AWS_DEFAULT_REGION: Default AWS region for S3 operations
140
+ MLFLOW_S3_ENDPOINT_URL: Custom S3 endpoint URL (for S3-compatible storage)
141
+ MLFLOW_S3_IGNORE_TLS: Set to 'true' to disable TLS verification
142
+ MLFLOW_S3_UPLOAD_EXTRA_ARGS: JSON string of extra arguments for S3 uploads
143
+ MLFLOW_BOTO_CLIENT_ADDRESSING_STYLE: S3 addressing style ('path' or 'virtual')
144
+
145
+ Note:
146
+ This class inherits from both ArtifactRepository and MultipartUploadMixin,
147
+ providing full artifact management capabilities including efficient large file uploads.
148
+ """
126
149
 
127
150
  def __init__(
128
151
  self,
129
152
  artifact_uri: str,
130
- tracking_uri: Optional[str] = None,
131
153
  access_key_id=None,
132
154
  secret_access_key=None,
133
155
  session_token=None,
156
+ tracking_uri: Optional[str] = None,
134
157
  ) -> None:
158
+ """
159
+ Initialize an S3 artifact repository.
160
+
161
+ Args:
162
+ artifact_uri: S3 URI in the format 's3://bucket-name/path/to/artifacts'.
163
+ The URI must be a valid S3 URI with a bucket that exists and is accessible.
164
+ access_key_id: Optional AWS access key ID. If None, uses default AWS credential
165
+ resolution (environment variables, IAM roles, etc.).
166
+ secret_access_key: Optional AWS secret access key. Must be provided if
167
+ access_key_id is provided.
168
+ session_token: Optional AWS session token for temporary credentials.
169
+ Used with STS tokens or IAM roles.
170
+ tracking_uri: Optional URI for the MLflow tracking server.
171
+ If None, uses the current tracking URI context.
172
+ """
135
173
  super().__init__(artifact_uri, tracking_uri)
136
174
  self._access_key_id = access_key_id
137
175
  self._secret_access_key = secret_access_key
@@ -145,7 +183,17 @@ class S3ArtifactRepository(ArtifactRepository, MultipartUploadMixin):
145
183
  )
146
184
 
147
185
  def parse_s3_compliant_uri(self, uri):
148
- """Parse an S3 URI, returning (bucket, path)"""
186
+ """
187
+ Parse an S3 URI into bucket and path components.
188
+
189
+ Args:
190
+ uri: S3 URI in the format 's3://bucket-name/path/to/object'
191
+
192
+ Returns:
193
+ A tuple containing (bucket_name, object_path) where:
194
+ - bucket_name: The S3 bucket name
195
+ - object_path: The path within the bucket (without leading slash)
196
+ """
149
197
  parsed = urllib.parse.urlparse(uri)
150
198
  if parsed.scheme != "s3":
151
199
  raise Exception(f"Not an S3 URI: {uri}")
@@ -156,6 +204,17 @@ class S3ArtifactRepository(ArtifactRepository, MultipartUploadMixin):
156
204
 
157
205
  @staticmethod
158
206
  def get_s3_file_upload_extra_args():
207
+ """
208
+ Get additional S3 upload arguments from environment variables.
209
+
210
+ Returns:
211
+ Dictionary of extra arguments for S3 uploads, or None if not configured.
212
+ These arguments are passed to boto3's upload_file method.
213
+
214
+ Environment Variables:
215
+ MLFLOW_S3_UPLOAD_EXTRA_ARGS: JSON string containing extra arguments
216
+ for S3 uploads (e.g., '{"ServerSideEncryption": "AES256"}')
217
+ """
159
218
  s3_file_upload_extra_args = MLFLOW_S3_UPLOAD_EXTRA_ARGS.get()
160
219
  if s3_file_upload_extra_args:
161
220
  return json.loads(s3_file_upload_extra_args)
@@ -175,6 +234,19 @@ class S3ArtifactRepository(ArtifactRepository, MultipartUploadMixin):
175
234
  s3_client.upload_file(Filename=local_file, Bucket=bucket, Key=key, ExtraArgs=extra_args)
176
235
 
177
236
  def log_artifact(self, local_file, artifact_path=None):
237
+ """
238
+ Log a local file as an artifact to S3.
239
+
240
+ This method uploads a single file to S3 with automatic content type detection
241
+ and optional extra upload arguments from environment variables.
242
+
243
+ Args:
244
+ local_file: Absolute path to the local file to upload. The file must
245
+ exist and be readable.
246
+ artifact_path: Optional relative path within the S3 bucket where the
247
+ artifact should be stored. If None, the file is stored in the root
248
+ of the configured S3 path. Use forward slashes (/) for path separators.
249
+ """
178
250
  (bucket, dest_path) = self.parse_s3_compliant_uri(self.artifact_uri)
179
251
  if artifact_path:
180
252
  dest_path = posixpath.join(dest_path, artifact_path)
@@ -184,6 +256,20 @@ class S3ArtifactRepository(ArtifactRepository, MultipartUploadMixin):
184
256
  )
185
257
 
186
258
  def log_artifacts(self, local_dir, artifact_path=None):
259
+ """
260
+ Log all files in a local directory as artifacts to S3.
261
+
262
+ This method recursively uploads all files in the specified directory,
263
+ preserving the directory structure in S3. Each file is uploaded with
264
+ automatic content type detection.
265
+
266
+ Args:
267
+ local_dir: Absolute path to the local directory containing files to upload.
268
+ The directory must exist and be readable.
269
+ artifact_path: Optional relative path within the S3 bucket where the
270
+ artifacts should be stored. If None, files are stored in the root
271
+ of the configured S3 path. Use forward slashes (/) for path separators.
272
+ """
187
273
  (bucket, dest_path) = self.parse_s3_compliant_uri(self.artifact_uri)
188
274
  if artifact_path:
189
275
  dest_path = posixpath.join(dest_path, artifact_path)
@@ -205,6 +291,25 @@ class S3ArtifactRepository(ArtifactRepository, MultipartUploadMixin):
205
291
  )
206
292
 
207
293
  def list_artifacts(self, path=None):
294
+ """
295
+ List all artifacts directly under the specified S3 path.
296
+
297
+ This method uses S3's list_objects_v2 API with pagination to efficiently
298
+ list artifacts. It treats S3 prefixes as directories and returns both
299
+ files and directories as FileInfo objects.
300
+
301
+ Args:
302
+ path: Optional relative path within the S3 bucket to list. If None,
303
+ lists artifacts in the root of the configured S3 path. If the path
304
+ refers to a single file, returns an empty list per MLflow convention.
305
+
306
+ Returns:
307
+ A list of FileInfo objects representing artifacts directly under the
308
+ specified path. Each FileInfo contains:
309
+ - path: Relative path of the artifact from the repository root
310
+ - is_dir: True if the artifact represents a directory (S3 prefix)
311
+ - file_size: Size in bytes for files, None for directories
312
+ """
208
313
  (bucket, artifact_path) = self.parse_s3_compliant_uri(self.artifact_uri)
209
314
  dest_path = artifact_path
210
315
  if path:
@@ -247,6 +352,18 @@ class S3ArtifactRepository(ArtifactRepository, MultipartUploadMixin):
247
352
  )
248
353
 
249
354
  def _download_file(self, remote_file_path, local_path):
355
+ """
356
+ Download a file from S3 to the local filesystem.
357
+
358
+ This method downloads a single file from S3 to the specified local path.
359
+ It's used internally by the download_artifacts method.
360
+
361
+ Args:
362
+ remote_file_path: Relative path of the file within the S3 bucket,
363
+ relative to the repository's root path.
364
+ local_path: Absolute path where the file should be saved locally.
365
+ The parent directory must exist.
366
+ """
250
367
  (bucket, s3_root_path) = self.parse_s3_compliant_uri(self.artifact_uri)
251
368
  s3_full_path = posixpath.join(s3_root_path, remote_file_path)
252
369
  s3_client = self._get_s3_client()
@@ -273,6 +390,28 @@ class S3ArtifactRepository(ArtifactRepository, MultipartUploadMixin):
273
390
  s3_client.delete_objects(Bucket=bucket, Delete={"Objects": keys})
274
391
 
275
392
  def create_multipart_upload(self, local_file, num_parts=1, artifact_path=None):
393
+ """
394
+ Initiate a multipart upload for efficient large file uploads to S3.
395
+
396
+ This method creates a multipart upload session in S3 and generates
397
+ presigned URLs for uploading each part. This is more efficient than
398
+ single-part uploads for large files and provides better error recovery.
399
+
400
+ Args:
401
+ local_file: Absolute path to the local file to upload. The file must
402
+ exist and be readable.
403
+ num_parts: Number of parts to split the upload into. Must be between
404
+ 1 and 10,000 (S3 limit). More parts allow greater parallelism
405
+ but increase overhead.
406
+ artifact_path: Optional relative path within the S3 bucket where the
407
+ artifact should be stored. If None, the file is stored in the root
408
+ of the configured S3 path.
409
+
410
+ Returns:
411
+ CreateMultipartUploadResponse containing:
412
+ - credentials: List of MultipartUploadCredential objects with presigned URLs
413
+ - upload_id: S3 upload ID for tracking this multipart upload
414
+ """
276
415
  (bucket, dest_path) = self.parse_s3_compliant_uri(self.artifact_uri)
277
416
  if artifact_path:
278
417
  dest_path = posixpath.join(dest_path, artifact_path)
@@ -307,6 +446,23 @@ class S3ArtifactRepository(ArtifactRepository, MultipartUploadMixin):
307
446
  )
308
447
 
309
448
  def complete_multipart_upload(self, local_file, upload_id, parts=None, artifact_path=None):
449
+ """
450
+ Complete a multipart upload by combining all parts into a single S3 object.
451
+
452
+ This method should be called after all parts have been successfully uploaded
453
+ using the presigned URLs from create_multipart_upload. It tells S3 to combine
454
+ all the parts into the final object.
455
+
456
+ Args:
457
+ local_file: Absolute path to the local file that was uploaded. Must match
458
+ the local_file used in create_multipart_upload.
459
+ upload_id: The S3 upload ID returned by create_multipart_upload.
460
+ parts: List of MultipartUploadPart objects containing metadata for each
461
+ successfully uploaded part. Must include part_number and etag for each part.
462
+ Parts must be provided in order (part 1, part 2, etc.).
463
+ artifact_path: Optional relative path where the artifact should be stored.
464
+ Must match the artifact_path used in create_multipart_upload.
465
+ """
310
466
  (bucket, dest_path) = self.parse_s3_compliant_uri(self.artifact_uri)
311
467
  if artifact_path:
312
468
  dest_path = posixpath.join(dest_path, artifact_path)
@@ -318,6 +474,20 @@ class S3ArtifactRepository(ArtifactRepository, MultipartUploadMixin):
318
474
  )
319
475
 
320
476
  def abort_multipart_upload(self, local_file, upload_id, artifact_path=None):
477
+ """
478
+ Abort a multipart upload and clean up any uploaded parts.
479
+
480
+ This method should be called if a multipart upload fails or is cancelled.
481
+ It cleans up any parts that were successfully uploaded and cancels the
482
+ multipart upload session in S3.
483
+
484
+ Args:
485
+ local_file: Absolute path to the local file that was being uploaded.
486
+ Must match the local_file used in create_multipart_upload.
487
+ upload_id: The S3 upload ID returned by create_multipart_upload.
488
+ artifact_path: Optional relative path where the artifact would have been stored.
489
+ Must match the artifact_path used in create_multipart_upload.
490
+ """
321
491
  (bucket, dest_path) = self.parse_s3_compliant_uri(self.artifact_uri)
322
492
  if artifact_path:
323
493
  dest_path = posixpath.join(dest_path, artifact_path)
mlflow/tracing/client.py CHANGED
@@ -2,15 +2,19 @@ import json
2
2
  import logging
3
3
  from concurrent.futures import ThreadPoolExecutor
4
4
  from contextlib import nullcontext
5
- from typing import Optional, Sequence
5
+ from typing import Optional, Sequence, Union
6
6
 
7
7
  import mlflow
8
- from mlflow.entities.assessment import Assessment
8
+ from mlflow.entities.assessment import (
9
+ Assessment,
10
+ )
9
11
  from mlflow.entities.model_registry import PromptVersion
10
12
  from mlflow.entities.span import NO_OP_SPAN_TRACE_ID
11
13
  from mlflow.entities.trace import Trace
12
14
  from mlflow.entities.trace_data import TraceData
13
15
  from mlflow.entities.trace_info import TraceInfo
16
+ from mlflow.entities.trace_info_v2 import TraceInfoV2
17
+ from mlflow.entities.trace_status import TraceStatus
14
18
  from mlflow.environment_variables import MLFLOW_SEARCH_TRACES_MAX_THREADS
15
19
  from mlflow.exceptions import (
16
20
  MlflowException,
@@ -59,65 +63,128 @@ class TracingClient:
59
63
  def store(self):
60
64
  return _get_store(self.tracking_uri)
61
65
 
62
- def start_trace(self, trace_info: TraceInfo) -> TraceInfo:
66
+ def start_trace(
67
+ self,
68
+ experiment_id: str,
69
+ timestamp_ms: int,
70
+ request_metadata: dict[str, str],
71
+ tags: dict[str, str],
72
+ ):
63
73
  """
64
- Create a new trace in the backend.
74
+ Start an initial TraceInfo object in the backend store.
65
75
 
66
76
  Args:
67
- trace_info: The TraceInfo object to record in the backend.
77
+ experiment_id: String id of the experiment for this run.
78
+ timestamp_ms: Start time of the trace, in milliseconds since the UNIX epoch.
79
+ request_metadata: Metadata of the trace.
80
+ tags: Tags of the trace.
81
+
82
+ Returns:
83
+ The created TraceInfo object.
84
+ """
85
+ tags = exclude_immutable_tags(tags or {})
86
+ return self.store.start_trace(
87
+ experiment_id=experiment_id,
88
+ timestamp_ms=timestamp_ms,
89
+ request_metadata=request_metadata,
90
+ tags=tags,
91
+ )
92
+
93
+ def start_trace_v3(self, trace: Trace) -> TraceInfo:
94
+ """
95
+ Start a trace using the V3 API format.
96
+ NB: This method is named "Start" for internal reason in the backend, but actually
97
+ should be called at the end of the trace. We will migrate this to "CreateTrace"
98
+ API in the future to avoid confusion.
99
+
100
+ Args:
101
+ trace: The Trace object to create.
68
102
 
69
103
  Returns:
70
104
  The returned TraceInfoV3 object from the backend.
71
105
  """
72
- return self.store.start_trace(trace_info=trace_info)
106
+ return self.store.start_trace_v3(trace=trace)
107
+
108
+ def end_trace(
109
+ self,
110
+ request_id: str,
111
+ timestamp_ms: int,
112
+ status: TraceStatus,
113
+ request_metadata: dict[str, str],
114
+ tags: dict[str, str],
115
+ ) -> TraceInfoV2:
116
+ """
117
+ Update the TraceInfo object in the backend store with the completed trace info.
118
+
119
+ Args:
120
+ request_id: Unique string identifier of the trace.
121
+ timestamp_ms: End time of the trace, in milliseconds. The execution time field
122
+ in the TraceInfo will be calculated by subtracting the start time from this.
123
+ status: Status of the trace.
124
+ request_metadata: Metadata of the trace. This will be merged with the existing
125
+ metadata logged during the start_trace call.
126
+ tags: Tags of the trace. This will be merged with the existing tags logged
127
+ during the start_trace or set_trace_tag calls.
128
+
129
+ Returns:
130
+ The updated TraceInfo object.
131
+ """
132
+ tags = exclude_immutable_tags(tags or {})
133
+ return self.store.end_trace(
134
+ request_id=request_id,
135
+ timestamp_ms=timestamp_ms,
136
+ status=status,
137
+ request_metadata=request_metadata,
138
+ tags=tags,
139
+ )
73
140
 
74
141
  def delete_traces(
75
142
  self,
76
143
  experiment_id: str,
77
144
  max_timestamp_millis: Optional[int] = None,
78
145
  max_traces: Optional[int] = None,
79
- trace_ids: Optional[list[str]] = None,
146
+ request_ids: Optional[list[str]] = None,
80
147
  ) -> int:
81
148
  return self.store.delete_traces(
82
149
  experiment_id=experiment_id,
83
150
  max_timestamp_millis=max_timestamp_millis,
84
151
  max_traces=max_traces,
85
- trace_ids=trace_ids,
152
+ request_ids=request_ids,
86
153
  )
87
154
 
88
- def get_trace_info(self, trace_id: str) -> TraceInfo:
155
+ def get_trace_info(self, request_id, should_query_v3: bool = False) -> TraceInfoV2:
89
156
  """
90
- Get the trace info matching the ``trace_id``.
157
+ Get the trace info matching the ``request_id``.
91
158
 
92
159
  Args:
93
- trace_id: String id of the trace to fetch.
160
+ request_id: String id of the trace to fetch.
161
+ should_query_v3: If True, the backend store will query the V3 API for the trace info.
162
+ TODO: Remove this flag once the V3 API is the default in OSS.
94
163
 
95
164
  Returns:
96
165
  TraceInfo object, of type ``mlflow.entities.trace_info.TraceInfo``.
97
166
  """
98
- with InMemoryTraceManager.get_instance().get_trace(trace_id) as trace:
99
- if trace is not None:
100
- return trace.info
167
+ return self.store.get_trace_info(request_id, should_query_v3=should_query_v3)
101
168
 
102
- return self.store.get_trace_info(trace_id)
103
-
104
- def get_trace(self, trace_id: str) -> Trace:
169
+ def get_trace(self, request_id) -> Trace:
105
170
  """
106
- Get the trace matching the ``trace_id``.
171
+ Get the trace matching the ``request_id``.
107
172
 
108
173
  Args:
109
- trace_id: String id of the trace to fetch.
174
+ request_id: String id of the trace to fetch.
110
175
 
111
176
  Returns:
112
177
  The fetched Trace object, of type ``mlflow.entities.Trace``.
113
178
  """
114
- trace_info = self.get_trace_info(trace_id)
179
+ trace_info = self.get_trace_info(
180
+ request_id=request_id, should_query_v3=is_databricks_uri(self.tracking_uri)
181
+ )
115
182
  try:
116
183
  trace_data = self._download_trace_data(trace_info)
117
184
  except MlflowTraceDataNotFound:
118
185
  raise MlflowException(
119
186
  message=(
120
- f"Trace with ID {trace_id} cannot be loaded because it is missing span data."
187
+ f"Trace with ID {request_id} cannot be loaded because it is missing span data."
121
188
  " Please try creating or loading another trace."
122
189
  ),
123
190
  error_code=BAD_REQUEST,
@@ -125,7 +192,7 @@ class TracingClient:
125
192
  except MlflowTraceDataCorrupted:
126
193
  raise MlflowException(
127
194
  message=(
128
- f"Trace with ID {trace_id} cannot be loaded because its span data"
195
+ f"Trace with ID {request_id} cannot be loaded because its span data"
129
196
  " is corrupted. Please try creating or loading another trace."
130
197
  ),
131
198
  error_code=BAD_REQUEST,
@@ -253,23 +320,29 @@ class TracingClient:
253
320
  else:
254
321
  filter_string = additional_filter
255
322
 
256
- is_databricks = is_databricks_uri(self.tracking_uri)
257
-
258
- def download_trace_extra_fields(trace_info: TraceInfo) -> Optional[Trace]:
323
+ def download_trace_extra_fields(
324
+ trace_info: Union[TraceInfoV2, TraceInfo],
325
+ ) -> Optional[Trace]:
259
326
  """
260
327
  Download trace data and assessments for the given trace_info and returns a Trace object.
261
328
  If the download fails (e.g., the trace data is missing or corrupted), returns None.
262
329
 
263
330
  The trace_info parameter can be either TraceInfo or TraceInfoV3 object.
264
331
  """
265
- is_online_trace = is_uuid(trace_info.trace_id)
332
+ from mlflow.entities.trace_info import TraceInfo
333
+
334
+ # Determine if this is TraceInfo or TraceInfoV3
335
+ # Helps while transitioning to V3 traces for offline & online
336
+ is_v3 = isinstance(trace_info, TraceInfo)
337
+ trace_id = trace_info.trace_id if is_v3 else trace_info.request_id
338
+ is_online_trace = is_uuid(trace_id)
266
339
 
267
340
  # For online traces in Databricks, we need to get trace data from a different endpoint
268
341
  try:
269
342
  if is_databricks and is_online_trace:
270
343
  # For online traces, get data from the online API
271
344
  trace_data = self.get_online_trace_details(
272
- trace_id=trace_info.trace_id,
345
+ trace_id=trace_id,
273
346
  sql_warehouse_id=sql_warehouse_id,
274
347
  source_inference_table=trace_info.request_metadata.get(
275
348
  "mlflow.sourceTable"
@@ -285,7 +358,7 @@ class TracingClient:
285
358
  except MlflowTraceDataException as e:
286
359
  _logger.warning(
287
360
  (
288
- f"Failed to download trace data for trace {trace_info.trace_id!r} "
361
+ f"Failed to download trace data for trace {trace_id!r} "
289
362
  f"with {e.ctx}. For full traceback, set logging level to DEBUG."
290
363
  ),
291
364
  exc_info=_logger.isEnabledFor(logging.DEBUG),
@@ -299,11 +372,7 @@ class TracingClient:
299
372
  next_token = page_token
300
373
 
301
374
  max_workers = MLFLOW_SEARCH_TRACES_MAX_THREADS.get()
302
- executor = (
303
- ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="MlflowTracingSearch")
304
- if include_spans
305
- else nullcontext()
306
- )
375
+ executor = ThreadPoolExecutor(max_workers=max_workers) if include_spans else nullcontext()
307
376
  with executor:
308
377
  while len(traces) < max_results:
309
378
  trace_infos, next_token = self._search_traces(
@@ -332,24 +401,24 @@ class TracingClient:
332
401
 
333
402
  return PagedList(traces, next_token)
334
403
 
335
- def set_trace_tags(self, trace_id: str, tags: dict[str, str]):
404
+ def set_trace_tags(self, request_id, tags):
336
405
  """
337
- Set tags on the trace with the given trace_id.
406
+ Set tags on the trace with the given request_id.
338
407
 
339
408
  Args:
340
- trace_id: The ID of the trace.
409
+ request_id: The ID of the trace.
341
410
  tags: A dictionary of key-value pairs.
342
411
  """
343
412
  tags = exclude_immutable_tags(tags)
344
413
  for k, v in tags.items():
345
- self.set_trace_tag(trace_id, k, v)
414
+ self.set_trace_tag(request_id, k, v)
346
415
 
347
- def set_trace_tag(self, trace_id: str, key: str, value: str):
416
+ def set_trace_tag(self, request_id, key, value):
348
417
  """
349
418
  Set a tag on the trace with the given trace ID.
350
419
 
351
420
  Args:
352
- trace_id: The ID of the trace to set the tag on.
421
+ request_id: The ID of the trace to set the tag on.
353
422
  key: The string key of the tag. Must be at most 250 characters long, otherwise
354
423
  it will be truncated when stored.
355
424
  value: The string value of the tag. Must be at most 250 characters long, otherwise
@@ -362,7 +431,7 @@ class TracingClient:
362
431
  )
363
432
 
364
433
  # Trying to set the tag on the active trace first
365
- with InMemoryTraceManager.get_instance().get_trace(trace_id) as trace:
434
+ with InMemoryTraceManager.get_instance().get_trace(request_id) as trace:
366
435
  if trace:
367
436
  trace.info.tags[key] = str(value)
368
437
  return
@@ -370,33 +439,33 @@ class TracingClient:
370
439
  if key in IMMUTABLE_TAGS:
371
440
  _logger.warning(f"Tag '{key}' is immutable and cannot be set on a trace.")
372
441
  else:
373
- self.store.set_trace_tag(trace_id, key, str(value))
442
+ self.store.set_trace_tag(request_id, key, str(value))
374
443
 
375
- def delete_trace_tag(self, trace_id: str, key: str):
444
+ def delete_trace_tag(self, request_id, key):
376
445
  """
377
446
  Delete a tag on the trace with the given trace ID.
378
447
 
379
448
  Args:
380
- trace_id: The ID of the trace to delete the tag from.
449
+ request_id: The ID of the trace to delete the tag from.
381
450
  key: The string key of the tag. Must be at most 250 characters long, otherwise
382
451
  it will be truncated when stored.
383
452
  """
384
453
  # Trying to delete the tag on the active trace first
385
- with InMemoryTraceManager.get_instance().get_trace(trace_id) as trace:
454
+ with InMemoryTraceManager.get_instance().get_trace(request_id) as trace:
386
455
  if trace:
387
456
  if key in trace.info.tags:
388
457
  trace.info.tags.pop(key)
389
458
  return
390
459
  else:
391
460
  raise MlflowException(
392
- f"Tag with key {key} not found in trace with ID {trace_id}.",
461
+ f"Tag with key {key} not found in trace with ID {request_id}.",
393
462
  error_code=RESOURCE_DOES_NOT_EXIST,
394
463
  )
395
464
 
396
465
  if key in IMMUTABLE_TAGS:
397
466
  _logger.warning(f"Tag '{key}' is immutable and cannot be deleted on a trace.")
398
467
  else:
399
- self.store.delete_trace_tag(trace_id, key)
468
+ self.store.delete_trace_tag(request_id, key)
400
469
 
401
470
  def get_assessment(self, trace_id: str, assessment_id: str) -> Assessment:
402
471
  """
@@ -492,12 +561,12 @@ class TracingClient:
492
561
 
493
562
  self.store.delete_assessment(trace_id=trace_id, assessment_id=assessment_id)
494
563
 
495
- def _get_artifact_repo_for_trace(self, trace_info: TraceInfo):
564
+ def _get_artifact_repo_for_trace(self, trace_info: TraceInfoV2):
496
565
  artifact_uri = get_artifact_uri_for_trace(trace_info)
497
566
  artifact_uri = add_databricks_profile_info_to_artifact_uri(artifact_uri, self.tracking_uri)
498
567
  return get_artifact_repository(artifact_uri)
499
568
 
500
- def _download_trace_data(self, trace_info: TraceInfo) -> TraceData:
569
+ def _download_trace_data(self, trace_info: Union[TraceInfoV2, TraceInfo]) -> TraceData:
501
570
  """
502
571
  Download trace data from artifact repository.
503
572
 
@@ -510,11 +579,32 @@ class TracingClient:
510
579
  artifact_repo = self._get_artifact_repo_for_trace(trace_info)
511
580
  return TraceData.from_dict(artifact_repo.download_trace_data())
512
581
 
513
- def _upload_trace_data(self, trace_info: TraceInfo, trace_data: TraceData) -> None:
582
+ def _upload_trace_data(self, trace_info: TraceInfoV2, trace_data: TraceData) -> None:
514
583
  artifact_repo = self._get_artifact_repo_for_trace(trace_info)
515
584
  trace_data_json = json.dumps(trace_data.to_dict(), cls=TraceJSONEncoder, ensure_ascii=False)
516
585
  return artifact_repo.upload_trace_data(trace_data_json)
517
586
 
587
+ def _upload_ended_trace_info(
588
+ self,
589
+ trace_info: TraceInfoV2,
590
+ ) -> TraceInfoV2:
591
+ """
592
+ Update the TraceInfo object in the backend store with the completed trace info.
593
+
594
+ Args:
595
+ trace_info: Updated TraceInfo object to be stored in the backend store.
596
+
597
+ Returns:
598
+ The updated TraceInfo object.
599
+ """
600
+ return self.end_trace(
601
+ request_id=trace_info.request_id,
602
+ timestamp_ms=trace_info.timestamp_ms + trace_info.execution_time_ms,
603
+ status=trace_info.status,
604
+ request_metadata=trace_info.request_metadata,
605
+ tags=trace_info.tags or {},
606
+ )
607
+
518
608
  def link_prompt_versions_to_trace(
519
609
  self, trace_id: str, prompts: Sequence[PromptVersion]
520
610
  ) -> None: