atlan-application-sdk 2.3.1__py3-none-any.whl → 2.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,6 +13,7 @@ from application_sdk.activities.common.utils import (
13
13
  get_workflow_id,
14
14
  )
15
15
  from application_sdk.clients.sql import BaseSQLClient
16
+ from application_sdk.common.file_ops import SafeFileOps
16
17
  from application_sdk.constants import UPSTREAM_OBJECT_STORE_NAME
17
18
  from application_sdk.handlers import HandlerInterface
18
19
  from application_sdk.handlers.sql import BaseSQLHandler
@@ -412,7 +413,7 @@ class SQLQueryExtractionActivities(ActivitiesInterface):
412
413
 
413
414
  # find the last marker from the parallel_markers
414
415
  last_marker = parallel_markers[-1]["end"]
415
- with open(marker_file_path, "w") as f:
416
+ with SafeFileOps.open(marker_file_path, "w") as f:
416
417
  f.write(last_marker)
417
418
 
418
419
  logger.info(f"Last marker: {last_marker}")
@@ -453,10 +454,10 @@ class SQLQueryExtractionActivities(ActivitiesInterface):
453
454
  )
454
455
 
455
456
  logger.info(f"Marker file downloaded to {marker_file_path}")
456
- if not os.path.exists(marker_file_path):
457
+ if not SafeFileOps.exists(marker_file_path):
457
458
  logger.warning(f"Marker file does not exist at {marker_file_path}")
458
459
  return None
459
- with open(marker_file_path, "r") as f:
460
+ with SafeFileOps.open(marker_file_path, "r") as f:
460
461
  current_marker = f.read()
461
462
  logger.info(f"Current marker: {current_marker}")
462
463
  return int(current_marker)
@@ -519,8 +520,8 @@ class SQLQueryExtractionActivities(ActivitiesInterface):
519
520
  # Write the results to a metadata file
520
521
  output_path = os.path.join(workflow_args["output_path"], "raw", "query")
521
522
  metadata_file_path = os.path.join(output_path, "metadata.json.ignore")
522
- os.makedirs(os.path.dirname(metadata_file_path), exist_ok=True)
523
- with open(metadata_file_path, "w") as f:
523
+ SafeFileOps.makedirs(os.path.dirname(metadata_file_path), exist_ok=True)
524
+ with SafeFileOps.open(metadata_file_path, "w") as f:
524
525
  f.write(json.dumps(parallel_markers))
525
526
 
526
527
  await ObjectStore.upload_file(
@@ -0,0 +1,122 @@
1
+ import os
2
+ import shutil
3
+ from contextlib import contextmanager
4
+ from pathlib import Path
5
+ from typing import Any, List, Optional, Union
6
+
7
+ from application_sdk.common.path import convert_to_extended_path
8
+
9
+
10
+ class SafeFileOps:
11
+ """Safe file operations with Windows extended-length path support."""
12
+
13
+ @staticmethod
14
+ def rename(src: Union[str, Path], dst: Union[str, Path]) -> None:
15
+ """Safely rename a file or directory, supporting long paths on Windows."""
16
+ os.rename(convert_to_extended_path(src), convert_to_extended_path(dst))
17
+
18
+ @staticmethod
19
+ def remove(path: Union[str, Path]) -> None:
20
+ """Safely remove a file, supporting long paths on Windows."""
21
+ os.remove(convert_to_extended_path(path))
22
+
23
+ @staticmethod
24
+ def unlink(path: Union[str, Path], missing_ok: bool = False) -> None:
25
+ """Safely unlink a file, supporting long paths on Windows."""
26
+ try:
27
+ os.unlink(convert_to_extended_path(path))
28
+ except FileNotFoundError:
29
+ if not missing_ok:
30
+ raise
31
+
32
+ @staticmethod
33
+ def makedirs(
34
+ name: Union[str, Path], mode: int = 0o777, exist_ok: bool = False
35
+ ) -> None:
36
+ """Safely create directories, supporting long paths on Windows."""
37
+ os.makedirs(convert_to_extended_path(name), mode=mode, exist_ok=exist_ok)
38
+
39
+ @staticmethod
40
+ def mkdir(path: Union[str, Path], mode: int = 0o777) -> None:
41
+ """Safely create a directory, supporting long paths on Windows."""
42
+ os.mkdir(convert_to_extended_path(path), mode=mode)
43
+
44
+ @staticmethod
45
+ def rmdir(path: Union[str, Path]) -> None:
46
+ """Safely remove a directory, supporting long paths on Windows."""
47
+ os.rmdir(convert_to_extended_path(path))
48
+
49
+ @staticmethod
50
+ def exists(path: Union[str, Path]) -> bool:
51
+ """Safely check if a path exists, supporting long paths on Windows."""
52
+ return os.path.exists(convert_to_extended_path(path))
53
+
54
+ @staticmethod
55
+ def isfile(path: Union[str, Path]) -> bool:
56
+ """Safely check if a path is a file, supporting long paths on Windows."""
57
+ return os.path.isfile(convert_to_extended_path(path))
58
+
59
+ @staticmethod
60
+ def isdir(path: Union[str, Path]) -> bool:
61
+ """Safely check if a path is a directory, supporting long paths on Windows."""
62
+ return os.path.isdir(convert_to_extended_path(path))
63
+
64
+ @staticmethod
65
+ def rmtree(
66
+ path: Union[str, Path],
67
+ ignore_errors: bool = False,
68
+ onerror: Optional[Any] = None,
69
+ ) -> None:
70
+ """Safely remove a directory tree, supporting long paths on Windows."""
71
+ shutil.rmtree(
72
+ convert_to_extended_path(path), ignore_errors=ignore_errors, onerror=onerror
73
+ )
74
+
75
+ @staticmethod
76
+ def copy(
77
+ src: Union[str, Path], dst: Union[str, Path], follow_symlinks: bool = True
78
+ ) -> Union[str, Path]:
79
+ """Safely copy a file, supporting long paths on Windows."""
80
+ return shutil.copy(
81
+ convert_to_extended_path(src),
82
+ convert_to_extended_path(dst),
83
+ follow_symlinks=follow_symlinks,
84
+ )
85
+
86
+ @staticmethod
87
+ def move(src: Union[str, Path], dst: Union[str, Path]) -> Union[str, Path]:
88
+ """Safely move a file or directory, supporting long paths on Windows."""
89
+ return shutil.move(convert_to_extended_path(src), convert_to_extended_path(dst))
90
+
91
+ @staticmethod
92
+ @contextmanager
93
+ def open(
94
+ file: Union[str, Path],
95
+ mode: str = "r",
96
+ buffering: int = -1,
97
+ encoding: Optional[str] = None,
98
+ errors: Optional[str] = None,
99
+ newline: Optional[str] = None,
100
+ closefd: bool = True,
101
+ opener: Optional[Any] = None,
102
+ ):
103
+ """Safely open a file, supporting long paths on Windows."""
104
+ f = open(
105
+ convert_to_extended_path(file),
106
+ mode=mode,
107
+ buffering=buffering,
108
+ encoding=encoding,
109
+ errors=errors,
110
+ newline=newline,
111
+ closefd=closefd,
112
+ opener=opener,
113
+ )
114
+ try:
115
+ yield f
116
+ finally:
117
+ f.close()
118
+
119
+ @staticmethod
120
+ def listdir(path: Union[str, Path]) -> List[str]:
121
+ """Safely list directory contents, supporting long paths on Windows."""
122
+ return os.listdir(convert_to_extended_path(path))
@@ -0,0 +1,38 @@
1
+ import os
2
+ import sys
3
+ from pathlib import Path
4
+ from typing import Union
5
+
6
+ from application_sdk.constants import WINDOWS_EXTENDED_PATH_PREFIX
7
+
8
+
9
+ def convert_to_extended_path(path: Union[str, Path]) -> str:
10
+ """
11
+ Robust conversion to Windows extended-length path ({WINDOWS_EXTENDED_PATH_PREFIX}).
12
+
13
+ On Windows, this prefixes the path with {WINDOWS_EXTENDED_PATH_PREFIX} to bypass the 260-character limit.
14
+ It ensures the path is absolute and uses backslashes.
15
+ On non-Windows platforms, it returns the path as a string.
16
+
17
+ Args:
18
+ path: The path to convert (str or Path object).
19
+
20
+ Returns:
21
+ Optional[str]: The converted path string, or None if input is empty.
22
+ """
23
+ if not path:
24
+ raise ValueError("Path cannot be empty")
25
+
26
+ path_str = str(path)
27
+
28
+ if sys.platform != "win32":
29
+ return path_str
30
+
31
+ if path_str.startswith(WINDOWS_EXTENDED_PATH_PREFIX):
32
+ return path_str
33
+
34
+ # Use os.path.abspath for better Windows reliability than Path.absolute()
35
+ # It also handles normalization of separators to backslashes
36
+ abs_path = os.path.abspath(path_str)
37
+
38
+ return f"{WINDOWS_EXTENDED_PATH_PREFIX}{abs_path}"
@@ -42,6 +42,8 @@ APP_HOST = str(os.getenv("ATLAN_APP_HTTP_HOST", "0.0.0.0"))
42
42
  APP_PORT = int(os.getenv("ATLAN_APP_HTTP_PORT", "8000"))
43
43
  #: Tenant ID for multi-tenant applications
44
44
  APP_TENANT_ID = os.getenv("ATLAN_TENANT_ID", "default")
45
+ # Domain Name of the tenant
46
+ DOMAIN_NAME = os.getenv("ATLAN_DOMAIN_NAME", "atlan.com")
45
47
  #: Host address for the application's dashboard
46
48
  APP_DASHBOARD_HOST = str(os.getenv("ATLAN_APP_DASHBOARD_HOST", "localhost"))
47
49
  #: Port number for the application's dashboard
@@ -239,6 +241,24 @@ METRICS_CLEANUP_ENABLED = (
239
241
  )
240
242
  METRICS_RETENTION_DAYS = int(os.getenv("ATLAN_METRICS_RETENTION_DAYS", "30"))
241
243
 
244
+ # Segment Configuration
245
+ #: Segment API URL for sending events. Defaults to https://api.segment.io/v1/batch
246
+ SEGMENT_API_URL = os.getenv("ATLAN_SEGMENT_API_URL", "https://api.segment.io/v1/batch")
247
+ #: Segment write key for authentication
248
+ SEGMENT_WRITE_KEY = os.getenv("ATLAN_SEGMENT_WRITE_KEY", "")
249
+ #: Whether to enable Segment metrics export
250
+ ENABLE_SEGMENT_METRICS = (
251
+ os.getenv("ATLAN_ENABLE_SEGMENT_METRICS", "false").lower() == "true"
252
+ )
253
+ #: Default user ID for Segment events
254
+ SEGMENT_DEFAULT_USER_ID = "atlan.automation"
255
+ #: Maximum batch size for Segment events
256
+ SEGMENT_BATCH_SIZE = int(os.getenv("ATLAN_SEGMENT_BATCH_SIZE", "100"))
257
+ #: Maximum time to wait before sending a batch (in seconds)
258
+ SEGMENT_BATCH_TIMEOUT_SECONDS = float(
259
+ os.getenv("ATLAN_SEGMENT_BATCH_TIMEOUT_SECONDS", "10.0")
260
+ )
261
+
242
262
  # Traces Configuration
243
263
  ENABLE_OTLP_TRACES = os.getenv("ATLAN_ENABLE_OTLP_TRACES", "false").lower() == "true"
244
264
  TRACES_BATCH_SIZE = int(os.getenv("ATLAN_TRACES_BATCH_SIZE", "100"))
@@ -287,6 +307,9 @@ LOCK_RETRY_INTERVAL_SECONDS = int(os.getenv("LOCK_RETRY_INTERVAL_SECONDS", "60")
287
307
  ENABLE_MCP = os.getenv("ENABLE_MCP", "false").lower() == "true"
288
308
  MCP_METADATA_KEY = "__atlan_application_sdk_mcp_metadata"
289
309
 
310
+ #: Windows extended-length path prefix
311
+ WINDOWS_EXTENDED_PATH_PREFIX = "\\\\?\\"
312
+
290
313
 
291
314
  class ApplicationMode(str, Enum):
292
315
  """Application execution mode.
@@ -66,6 +66,8 @@ class EventActivityInboundInterceptor(ActivityInboundInterceptor):
66
66
  Returns:
67
67
  Any: The result of the activity execution.
68
68
  """
69
+ import time
70
+
69
71
  start_event = Event(
70
72
  event_type=EventTypes.APPLICATION_EVENT.value,
71
73
  event_name=ApplicationEventNames.ACTIVITY_START.value,
@@ -73,16 +75,18 @@ class EventActivityInboundInterceptor(ActivityInboundInterceptor):
73
75
  )
74
76
  await EventStore.publish_event(start_event)
75
77
 
78
+ start_time = time.time()
76
79
  output = None
77
80
  try:
78
81
  output = await super().execute_activity(input)
79
82
  except Exception:
80
83
  raise
81
84
  finally:
85
+ duration_ms = (time.time() - start_time) * 1000
82
86
  end_event = Event(
83
87
  event_type=EventTypes.APPLICATION_EVENT.value,
84
88
  event_name=ApplicationEventNames.ACTIVITY_END.value,
85
- data={},
89
+ data={"duration_ms": round(duration_ms, 2)},
86
90
  )
87
91
  await EventStore.publish_event(end_event)
88
92
 
@@ -106,6 +110,8 @@ class EventWorkflowInboundInterceptor(WorkflowInboundInterceptor):
106
110
  Returns:
107
111
  Any: The result of the workflow execution.
108
112
  """
113
+ # Record start time (use workflow.time() for deterministic time in workflows)
114
+ start_time = workflow.time()
109
115
 
110
116
  # Publish workflow start event via activity
111
117
  try:
@@ -138,7 +144,10 @@ class EventWorkflowInboundInterceptor(WorkflowInboundInterceptor):
138
144
  workflow_state = WorkflowStates.FAILED.value # Keep as failed
139
145
  raise
140
146
  finally:
141
- # Always publish workflow end event
147
+ # Calculate duration in milliseconds
148
+ duration_ms = (workflow.time() - start_time) * 1000
149
+
150
+ # Always publish workflow end event with duration
142
151
  try:
143
152
  await workflow.execute_activity(
144
153
  publish_event,
@@ -146,7 +155,7 @@ class EventWorkflowInboundInterceptor(WorkflowInboundInterceptor):
146
155
  "metadata": EventMetadata(workflow_state=workflow_state),
147
156
  "event_type": EventTypes.APPLICATION_EVENT.value,
148
157
  "event_name": ApplicationEventNames.WORKFLOW_END.value,
149
- "data": {},
158
+ "data": {"duration_ms": round(duration_ms, 2)},
150
159
  },
151
160
  schedule_to_close_timeout=timedelta(seconds=30),
152
161
  retry_policy=RetryPolicy(maximum_attempts=3),
@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Optional, Unio
4
4
  import orjson
5
5
  from temporalio import activity
6
6
 
7
+ from application_sdk.common.file_ops import SafeFileOps
7
8
  from application_sdk.common.types import DataframeType
8
9
  from application_sdk.constants import DAPR_MAX_GRPC_MESSAGE_LENGTH
9
10
  from application_sdk.io.utils import (
@@ -322,7 +323,7 @@ class JsonFileWriter(Writer):
322
323
 
323
324
  if typename:
324
325
  self.path = os.path.join(self.path, typename)
325
- os.makedirs(self.path, exist_ok=True)
326
+ SafeFileOps.makedirs(self.path, exist_ok=True)
326
327
 
327
328
  if self.chunk_start:
328
329
  self.chunk_count = self.chunk_start + self.chunk_count
@@ -395,7 +396,7 @@ class JsonFileWriter(Writer):
395
396
  and self.total_record_count % self.chunk_size == 0
396
397
  ):
397
398
  output_file_name = f"{self.path}/{path_gen(self.chunk_count, self.chunk_part, self.start_marker, self.end_marker, extension=self.extension)}"
398
- if os.path.exists(output_file_name):
399
+ if SafeFileOps.exists(output_file_name):
399
400
  await self._upload_file(output_file_name)
400
401
  self.chunk_part += 1
401
402
 
@@ -403,6 +404,13 @@ class JsonFileWriter(Writer):
403
404
  if self.current_buffer_size > 0:
404
405
  await self._flush_daft_buffer(buffer, self.chunk_part)
405
406
 
407
+ # Upload the final file (matching pandas behavior)
408
+ if self.current_buffer_size_bytes > 0:
409
+ output_file_name = f"{self.path}/{path_gen(self.chunk_count, self.chunk_part, self.start_marker, self.end_marker, extension=self.extension)}"
410
+ if os.path.exists(output_file_name):
411
+ await self._upload_file(output_file_name)
412
+ self.chunk_part += 1
413
+
406
414
  # Record metrics for successful write
407
415
  self.metrics.record_metric(
408
416
  name="json_write_records",
@@ -411,6 +419,11 @@ class JsonFileWriter(Writer):
411
419
  labels={"type": "daft"},
412
420
  description="Number of records written to JSON files from daft DataFrame",
413
421
  )
422
+
423
+ # Increment chunk_count and record partitions (matching pandas behavior)
424
+ if self.chunk_start is None:
425
+ self.chunk_count += 1
426
+ self.partitions.append(self.chunk_part)
414
427
  except Exception as e:
415
428
  # Record metrics for failed write
416
429
  self.metrics.record_metric(
@@ -430,7 +443,7 @@ class JsonFileWriter(Writer):
430
443
  and uploads the file to the object store.
431
444
  """
432
445
  output_file_name = f"{self.path}/{path_gen(self.chunk_count, chunk_part, self.start_marker, self.end_marker, extension=self.extension)}"
433
- with open(output_file_name, "ab+") as f:
446
+ with SafeFileOps.open(output_file_name, "ab+") as f:
434
447
  f.writelines(buffer)
435
448
  buffer.clear() # Clear the buffer
436
449
 
@@ -450,8 +463,8 @@ class JsonFileWriter(Writer):
450
463
 
451
464
  This method writes a chunk to a JSON file and uploads the file to the object store.
452
465
  """
453
- mode = "w" if not os.path.exists(file_name) else "a"
454
- with open(file_name, mode=mode) as f:
466
+ mode = "w" if not SafeFileOps.exists(file_name) else "a"
467
+ with SafeFileOps.open(file_name, mode=mode) as f:
455
468
  chunk.to_json(f, orient="records", lines=True)
456
469
 
457
470
  async def _finalize(self) -> None:
@@ -462,7 +475,7 @@ class JsonFileWriter(Writer):
462
475
  # Upload the final file if there's remaining buffered data
463
476
  if self.current_buffer_size_bytes > 0:
464
477
  output_file_name = f"{self.path}/{path_gen(self.chunk_count, self.chunk_part, self.start_marker, self.end_marker, extension=self.extension)}"
465
- if os.path.exists(output_file_name):
478
+ if SafeFileOps.exists(output_file_name):
466
479
  await self._upload_file(output_file_name)
467
480
  self.chunk_part += 1
468
481
 
@@ -1,6 +1,5 @@
1
1
  import inspect
2
2
  import os
3
- import shutil
4
3
  from typing import (
5
4
  TYPE_CHECKING,
6
5
  AsyncGenerator,
@@ -15,6 +14,7 @@ from typing import (
15
14
  from temporalio import activity
16
15
 
17
16
  from application_sdk.activities.common.utils import get_object_store_prefix
17
+ from application_sdk.common.file_ops import SafeFileOps
18
18
  from application_sdk.constants import (
19
19
  DAPR_MAX_GRPC_MESSAGE_LENGTH,
20
20
  ENABLE_ATLAN_UPLOAD,
@@ -494,7 +494,7 @@ class ParquetFileWriter(Writer):
494
494
  # Create output directory
495
495
  if self.typename:
496
496
  self.path = os.path.join(self.path, self.typename)
497
- os.makedirs(self.path, exist_ok=True)
497
+ SafeFileOps.makedirs(self.path, exist_ok=True)
498
498
 
499
499
  async def _write_batched_dataframe(
500
500
  self,
@@ -729,7 +729,7 @@ class ParquetFileWriter(Writer):
729
729
  )
730
730
 
731
731
  # Create the directory
732
- os.makedirs(self.current_temp_folder_path, exist_ok=True)
732
+ SafeFileOps.makedirs(self.current_temp_folder_path, exist_ok=True)
733
733
 
734
734
  async def _write_chunk_to_temp_folder(self, chunk: "pd.DataFrame"):
735
735
  """Write a chunk to the current temp folder."""
@@ -740,7 +740,7 @@ class ParquetFileWriter(Writer):
740
740
  existing_files = len(
741
741
  [
742
742
  f
743
- for f in os.listdir(self.current_temp_folder_path)
743
+ for f in SafeFileOps.listdir(self.current_temp_folder_path)
744
744
  if f.endswith(self.extension)
745
745
  ]
746
746
  )
@@ -780,7 +780,7 @@ class ParquetFileWriter(Writer):
780
780
  folder_index=self.chunk_count,
781
781
  chunk_part=i,
782
782
  )
783
- os.rename(file_path, consolidated_file_path)
783
+ SafeFileOps.rename(file_path, consolidated_file_path)
784
784
 
785
785
  # Upload consolidated file to object store
786
786
  await ObjectStore.upload_file(
@@ -789,7 +789,7 @@ class ParquetFileWriter(Writer):
789
789
  )
790
790
 
791
791
  # Clean up temp consolidated dir
792
- shutil.rmtree(temp_consolidated_dir, ignore_errors=True)
792
+ SafeFileOps.rmtree(temp_consolidated_dir, ignore_errors=True)
793
793
 
794
794
  # Update statistics
795
795
  self.chunk_count += 1
@@ -825,13 +825,15 @@ class ParquetFileWriter(Writer):
825
825
  # Clean up all temp folders
826
826
  for folder_index in self.temp_folders_created:
827
827
  temp_folder = self._get_temp_folder_path(folder_index)
828
- if os.path.exists(temp_folder):
829
- shutil.rmtree(temp_folder, ignore_errors=True)
828
+ if SafeFileOps.exists(temp_folder):
829
+ SafeFileOps.rmtree(temp_folder, ignore_errors=True)
830
830
 
831
831
  # Clean up base temp directory if it exists and is empty
832
832
  temp_base_path = os.path.join(self.path, "temp_accumulation")
833
- if os.path.exists(temp_base_path) and not os.listdir(temp_base_path):
834
- os.rmdir(temp_base_path)
833
+ if SafeFileOps.exists(temp_base_path) and not SafeFileOps.listdir(
834
+ temp_base_path
835
+ ):
836
+ SafeFileOps.rmdir(temp_base_path)
835
837
 
836
838
  # Reset state
837
839
  self.temp_folders_created.clear()
@@ -168,9 +168,10 @@ logging.basicConfig(
168
168
  level=logging.getLevelNamesMapping()[LOG_LEVEL], handlers=[InterceptHandler()]
169
169
  )
170
170
 
171
- DEPENDENCY_LOGGERS = ["daft_io.stats", "tracing.span"]
171
+ DEPENDENCY_LOGGERS = ["daft_io.stats", "tracing.span", "httpx"]
172
172
 
173
173
  # Configure external dependency loggers to reduce noise
174
+ # Set httpx to WARNING to reduce verbose HTTP request logs (200 OK messages)
174
175
  for logger_name in DEPENDENCY_LOGGERS:
175
176
  logging.getLogger(logger_name).setLevel(logging.WARNING)
176
177
 
@@ -1,7 +1,7 @@
1
1
  import asyncio
2
+ import atexit
2
3
  import logging
3
4
  import threading
4
- from enum import Enum
5
5
  from time import time
6
6
  from typing import Any, Dict, Optional
7
7
 
@@ -10,10 +10,10 @@ from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExp
10
10
  from opentelemetry.sdk.metrics import MeterProvider
11
11
  from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
12
12
  from opentelemetry.sdk.resources import Resource
13
- from pydantic import BaseModel
14
13
 
15
14
  from application_sdk.constants import (
16
15
  ENABLE_OTLP_METRICS,
16
+ ENABLE_SEGMENT_METRICS,
17
17
  METRICS_BATCH_SIZE,
18
18
  METRICS_CLEANUP_ENABLED,
19
19
  METRICS_FILE_NAME,
@@ -24,128 +24,37 @@ from application_sdk.constants import (
24
24
  OTEL_EXPORTER_TIMEOUT_SECONDS,
25
25
  OTEL_RESOURCE_ATTRIBUTES,
26
26
  OTEL_WF_NODE_NAME,
27
+ SEGMENT_API_URL,
28
+ SEGMENT_BATCH_SIZE,
29
+ SEGMENT_BATCH_TIMEOUT_SECONDS,
30
+ SEGMENT_DEFAULT_USER_ID,
31
+ SEGMENT_WRITE_KEY,
27
32
  SERVICE_NAME,
28
33
  SERVICE_VERSION,
29
34
  )
30
35
  from application_sdk.observability.logger_adaptor import get_logger
36
+ from application_sdk.observability.models import MetricRecord, MetricType
31
37
  from application_sdk.observability.observability import AtlanObservability
38
+ from application_sdk.observability.segment_client import SegmentClient
32
39
  from application_sdk.observability.utils import (
33
40
  get_observability_dir,
34
41
  get_workflow_context,
35
42
  )
36
43
 
37
-
38
- class MetricType(Enum):
39
- """Enum for metric types."""
40
-
41
- COUNTER = "counter"
42
- GAUGE = "gauge"
43
- HISTOGRAM = "histogram"
44
-
45
-
46
- class MetricRecord(BaseModel):
47
- """A Pydantic model representing a metric record in the system.
48
-
49
- This model defines the structure for metric data with fields for timestamp,
50
- name, value, type, labels, and optional description and unit.
51
-
52
- Attributes:
53
- timestamp (float): Unix timestamp when the metric was recorded
54
- name (str): Name of the metric
55
- value (float): Numeric value of the metric
56
- type (str): Type of metric (counter, gauge, or histogram)
57
- labels (Dict[str, str]): Key-value pairs for metric dimensions
58
- description (Optional[str]): Optional description of the metric
59
- unit (Optional[str]): Optional unit of measurement
60
- """
61
-
62
- timestamp: float
63
- name: str
64
- value: float
65
- type: MetricType # counter, gauge, histogram
66
- labels: Dict[str, str]
67
- description: Optional[str] = None
68
- unit: Optional[str] = None
69
-
70
- class Config:
71
- """Configuration for the MetricRecord Pydantic model.
72
-
73
- Provides custom parsing logic to ensure consistent data types and structure
74
- for metric records, including validation and type conversion for all fields.
75
- """
76
-
77
- @classmethod
78
- def parse_obj(cls, obj):
79
- if isinstance(obj, dict):
80
- # Ensure labels is a dictionary with consistent structure
81
- if "labels" in obj:
82
- # Create a new labels dict with only the expected fields
83
- new_labels = {}
84
- expected_fields = [
85
- "database",
86
- "status",
87
- "type",
88
- "mode",
89
- "workflow_id",
90
- "workflow_type",
91
- ]
92
-
93
- # Copy only the expected fields if they exist
94
- for field in expected_fields:
95
- if field in obj["labels"]:
96
- new_labels[field] = str(obj["labels"][field])
97
-
98
- obj["labels"] = new_labels
99
-
100
- # Ensure value is float
101
- if "value" in obj:
102
- try:
103
- obj["value"] = float(obj["value"])
104
- except (ValueError, TypeError):
105
- obj["value"] = 0.0
106
-
107
- # Ensure timestamp is float
108
- if "timestamp" in obj:
109
- try:
110
- obj["timestamp"] = float(obj["timestamp"])
111
- except (ValueError, TypeError):
112
- obj["timestamp"] = time()
113
-
114
- # Ensure type is MetricType
115
- if "type" in obj:
116
- try:
117
- obj["type"] = MetricType(obj["type"])
118
- except ValueError:
119
- obj["type"] = MetricType.COUNTER
120
-
121
- # Ensure name is string
122
- if "name" in obj:
123
- obj["name"] = str(obj["name"])
124
-
125
- # Ensure description is string or None
126
- if "description" in obj:
127
- obj["description"] = (
128
- str(obj["description"])
129
- if obj["description"] is not None
130
- else None
131
- )
132
-
133
- # Ensure unit is string or None
134
- if "unit" in obj:
135
- obj["unit"] = str(obj["unit"]) if obj["unit"] is not None else None
136
-
137
- return super().parse_obj(obj)
44
+ # MetricRecord and MetricType are imported from models.py to avoid circular dependencies
45
+ logger = get_logger(__name__)
138
46
 
139
47
 
140
48
  class AtlanMetricsAdapter(AtlanObservability[MetricRecord]):
141
49
  """A metrics adapter for Atlan that extends AtlanObservability.
142
50
 
143
51
  This adapter provides functionality for recording, processing, and exporting
144
- metrics to various backends including OpenTelemetry and parquet files.
52
+ metrics to various backends including OpenTelemetry, Segment API, and parquet files.
145
53
 
146
54
  Features:
147
55
  - Metric recording with labels and units
148
56
  - OpenTelemetry integration
57
+ - Segment API integration
149
58
  - Periodic metric flushing
150
59
  - Console logging
151
60
  - Parquet file storage
@@ -160,6 +69,7 @@ class AtlanMetricsAdapter(AtlanObservability[MetricRecord]):
160
69
  - Sets up base observability configuration
161
70
  - Configures date-based file settings
162
71
  - Initializes OpenTelemetry metrics if enabled
72
+ - Initializes Segment API client if enabled
163
73
  - Starts periodic flush task for metric buffering
164
74
  """
165
75
  super().__init__(
@@ -175,6 +85,18 @@ class AtlanMetricsAdapter(AtlanObservability[MetricRecord]):
175
85
  if ENABLE_OTLP_METRICS:
176
86
  self._setup_otel_metrics()
177
87
 
88
+ # Initialize Segment client (it handles enable/disable internally)
89
+ self.segment_client = SegmentClient(
90
+ enabled=ENABLE_SEGMENT_METRICS,
91
+ write_key=SEGMENT_WRITE_KEY,
92
+ api_url=SEGMENT_API_URL,
93
+ default_user_id=SEGMENT_DEFAULT_USER_ID,
94
+ batch_size=SEGMENT_BATCH_SIZE,
95
+ batch_timeout_seconds=SEGMENT_BATCH_TIMEOUT_SECONDS,
96
+ )
97
+ # Register cleanup handler to close SegmentClient on shutdown
98
+ atexit.register(self.segment_client.close)
99
+
178
100
  # Start periodic flush task if not already started
179
101
  if not AtlanMetricsAdapter._flush_task_started:
180
102
  try:
@@ -319,6 +241,7 @@ class AtlanMetricsAdapter(AtlanObservability[MetricRecord]):
319
241
  This method:
320
242
  - Validates the record is a MetricRecord
321
243
  - Sends to OpenTelemetry if enabled
244
+ - Sends to Segment API if enabled
322
245
  - Logs to console
323
246
  """
324
247
  if not isinstance(record, MetricRecord):
@@ -328,6 +251,9 @@ class AtlanMetricsAdapter(AtlanObservability[MetricRecord]):
328
251
  if ENABLE_OTLP_METRICS:
329
252
  self._send_to_otel(record)
330
253
 
254
+ # Send to Segment (client handles enable/disable internally)
255
+ self.segment_client.send_metric(record)
256
+
331
257
  # Log to console
332
258
  self._log_to_console(record)
333
259