ygg 0.1.34__py3-none-any.whl → 0.1.35__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ygg
3
- Version: 0.1.34
3
+ Version: 0.1.35
4
4
  Summary: Type-friendly utilities for moving data between Python objects, Arrow, Polars, Pandas, Spark, and Databricks
5
5
  Author: Yggdrasil contributors
6
6
  License: Apache License
@@ -1,15 +1,15 @@
1
- ygg-0.1.34.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
1
+ ygg-0.1.35.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
2
2
  yggdrasil/__init__.py,sha256=PfH7Xwt6uue6oqe6S5V8NhDJcVQClkKrBE1KXhdelZc,117
3
- yggdrasil/version.py,sha256=cIz48TZT2Xc-LLdWHdfAlxnIA0OSZqt42ZJcukkGo6s,22
3
+ yggdrasil/version.py,sha256=dvUFqQgabIeithBNW7swqwxg6R59T0-i289dJYgomYQ,22
4
4
  yggdrasil/databricks/__init__.py,sha256=skctY2c8W-hI81upx9F_PWRe5ishL3hrdiTuizgDjdw,152
5
5
  yggdrasil/databricks/compute/__init__.py,sha256=NvdzmaJSNYY1uJthv1hHdBuNu3bD_-Z65DWnaJt9yXg,289
6
- yggdrasil/databricks/compute/cluster.py,sha256=KUyGcpEKiA5XgAbeX1iHzuhJ4pucFqch_galZwYJlnc,39599
7
- yggdrasil/databricks/compute/execution_context.py,sha256=Z0EvkhdR803Kh1UOh4wR0oyyLXzAJo4Lj5CRNmxW4q4,22287
8
- yggdrasil/databricks/compute/remote.py,sha256=rrqLMnzI0KvhXghtOrve3W-rudi-cTjS-8dJXKjHM3A,2266
6
+ yggdrasil/databricks/compute/cluster.py,sha256=yx3xEJ5Vgg-IPeyYxxdf74_9DSBPHvg2-FYs8ptJrl0,40404
7
+ yggdrasil/databricks/compute/execution_context.py,sha256=nxrNXoarq_JAB-Cpj0udHhq2jx-DmMbRWJdAezLrPis,22347
8
+ yggdrasil/databricks/compute/remote.py,sha256=nEN_Fr1Ouul_iKOf4B5QjEGscYAcl7nHjGsl2toRzrU,2874
9
9
  yggdrasil/databricks/jobs/__init__.py,sha256=snxGSJb0M5I39v0y3IR-uEeSlZR248cQ_4DJ1sYs-h8,154
10
10
  yggdrasil/databricks/jobs/config.py,sha256=9LGeHD04hbfy0xt8_6oobC4moKJh4_DTjZiK4Q2Tqjk,11557
11
11
  yggdrasil/databricks/sql/__init__.py,sha256=y1n5yg-drZ8QVZbEgznsRG24kdJSnFis9l2YfYCsaCM,234
12
- yggdrasil/databricks/sql/engine.py,sha256=weYHosCVc9CZYaVooexEphNw6W_Ex0dphuGbfA48mEI,41104
12
+ yggdrasil/databricks/sql/engine.py,sha256=kUFBddJJQC0AgDqH0l7GFs7d_Ony5rc8fOv4inLU6Vw,41051
13
13
  yggdrasil/databricks/sql/exceptions.py,sha256=Jqd_gT_VyPL8klJEHYEzpv5eHtmdY43WiQ7HZBaEqSk,53
14
14
  yggdrasil/databricks/sql/statement_result.py,sha256=VlHXhTcvTVya_2aJ-uUfUooZF_MqQuOZ8k7g6PBDhOM,17227
15
15
  yggdrasil/databricks/sql/types.py,sha256=5G-BM9_eOsRKEMzeDTWUsWW5g4Idvs-czVCpOCrMhdA,6412
@@ -31,6 +31,7 @@ yggdrasil/libs/extensions/polars_extensions.py,sha256=RTkGi8llhPJjX7x9egix7-yXWo
31
31
  yggdrasil/libs/extensions/spark_extensions.py,sha256=E64n-3SFTDgMuXwWitX6vOYP9ln2lpGKb0htoBLEZgc,16745
32
32
  yggdrasil/pyutils/__init__.py,sha256=tl-LapAc71TV7RMgf2ftKwrzr8iiLOGHeJgA3RvO93w,293
33
33
  yggdrasil/pyutils/callable_serde.py,sha256=euY7Kiy04i1tpWKuB0b2qQ1FokLC3nq0cv7PObWYUBE,21809
34
+ yggdrasil/pyutils/equality.py,sha256=Xyf8D1dLUCm3spDEir8Zyj7O4US_fBJwEylJCfJ9slI,3080
34
35
  yggdrasil/pyutils/exceptions.py,sha256=ssKNm-rjhavHUOZmGA7_1Gq9tSHDrb2EFI-cnBuWgng,3388
35
36
  yggdrasil/pyutils/expiring_dict.py,sha256=q9gb09-2EUN-jQZumUw5BXOQGYcj1wb85qKtQlciSxg,5825
36
37
  yggdrasil/pyutils/modules.py,sha256=B7IP99YqUMW6-DIESFzBx8-09V1d0a8qrIJUDFhhL2g,11424
@@ -54,8 +55,8 @@ yggdrasil/types/cast/registry.py,sha256=_zdFGmUBB7P-e_LIcJlOxMcxAkXoA-UXB6HqLMgT
54
55
  yggdrasil/types/cast/spark_cast.py,sha256=_KAsl1DqmKMSfWxqhVE7gosjYdgiL1C5bDQv6eP3HtA,24926
55
56
  yggdrasil/types/cast/spark_pandas_cast.py,sha256=BuTiWrdCANZCdD_p2MAytqm74eq-rdRXd-LGojBRrfU,5023
56
57
  yggdrasil/types/cast/spark_polars_cast.py,sha256=btmZNHXn2NSt3fUuB4xg7coaE0RezIBdZD92H8NK0Jw,9073
57
- ygg-0.1.34.dist-info/METADATA,sha256=iGQcUq6tGnBBLiVo9jPak9PE-Ma8wWPxY2BsWKLGC2w,19204
58
- ygg-0.1.34.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
59
- ygg-0.1.34.dist-info/entry_points.txt,sha256=6q-vpWG3kvw2dhctQ0LALdatoeefkN855Ev02I1dKGY,70
60
- ygg-0.1.34.dist-info/top_level.txt,sha256=iBe9Kk4VIVbLpgv_p8OZUIfxgj4dgJ5wBg6vO3rigso,10
61
- ygg-0.1.34.dist-info/RECORD,,
58
+ ygg-0.1.35.dist-info/METADATA,sha256=qSh-Cd0LtjU_CV9Je0E1ns1AnuRbIUPMIpvVJoJtIUA,19204
59
+ ygg-0.1.35.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
60
+ ygg-0.1.35.dist-info/entry_points.txt,sha256=6q-vpWG3kvw2dhctQ0LALdatoeefkN855Ev02I1dKGY,70
61
+ ygg-0.1.35.dist-info/top_level.txt,sha256=iBe9Kk4VIVbLpgv_p8OZUIfxgj4dgJ5wBg6vO3rigso,10
62
+ ygg-0.1.35.dist-info/RECORD,,
@@ -24,6 +24,7 @@ from .execution_context import ExecutionContext
24
24
  from ..workspaces.workspace import WorkspaceService, Workspace
25
25
  from ... import retry, CallableSerde
26
26
  from ...libs.databrickslib import databricks_sdk
27
+ from ...pyutils.equality import dicts_equal, dict_diff
27
28
  from ...pyutils.expiring_dict import ExpiringDict
28
29
  from ...pyutils.modules import PipIndexSettings
29
30
  from ...pyutils.python_env import PythonEnv
@@ -110,7 +111,7 @@ class Cluster(WorkspaceService):
110
111
 
111
112
  _details: Optional["ClusterDetails"] = dataclasses.field(default=None, repr=False)
112
113
  _details_refresh_time: float = dataclasses.field(default=0, repr=False)
113
- _system_context: Optional[ExecutionContext] = None
114
+ _system_context: Optional[ExecutionContext] = dataclasses.field(default=None, repr=False)
114
115
 
115
116
  # host → Cluster instance
116
117
  _env_clusters: ClassVar[Dict[str, "Cluster"]] = {}
@@ -658,8 +659,6 @@ class Cluster(WorkspaceService):
658
659
  Returns:
659
660
  The updated Cluster instance.
660
661
  """
661
- self.install_libraries(libraries=libraries, wait_timeout=None, raise_error=False)
662
-
663
662
  existing_details = {
664
663
  k: v
665
664
  for k, v in self.details.as_shallow_dict().items()
@@ -672,22 +671,35 @@ class Cluster(WorkspaceService):
672
671
  if k in _EDIT_ARG_NAMES
673
672
  }
674
673
 
675
- if update_details != existing_details:
674
+ same = dicts_equal(
675
+ existing_details,
676
+ update_details,
677
+ keys=_EDIT_ARG_NAMES,
678
+ treat_missing_as_none=True,
679
+ float_tol=0.0, # set e.g. 1e-6 if you have float-y stuff
680
+ )
681
+
682
+ if not same:
683
+ diff = {
684
+ k: v[1]
685
+ for k, v in dict_diff(existing_details, update_details, keys=_EDIT_ARG_NAMES).items()
686
+ }
687
+
676
688
  logger.debug(
677
689
  "Updating %s with %s",
678
- self, update_details
690
+ self, diff
679
691
  )
680
692
 
681
693
  self.wait_for_status()
682
- self.details = retry(tries=4, delay=0.5, max_delay=2)(
683
- self.clusters_client().edit_and_wait
684
- )(**update_details)
694
+ self.details = self.clusters_client().edit_and_wait(**update_details)
685
695
 
686
696
  logger.info(
687
697
  "Updated %s",
688
698
  self
689
699
  )
690
700
 
701
+ self.install_libraries(libraries=libraries, wait_timeout=None, raise_error=False)
702
+
691
703
  return self
692
704
 
693
705
  def list_clusters(self) -> Iterator["Cluster"]:
@@ -760,16 +772,18 @@ class Cluster(WorkspaceService):
760
772
 
761
773
  def ensure_running(
762
774
  self,
775
+ wait_timeout: Optional[dt.timedelta] = dt.timedelta(minutes=20)
763
776
  ) -> "Cluster":
764
777
  """Ensure the cluster is running.
765
778
 
766
779
  Returns:
767
780
  The current Cluster instance.
768
781
  """
769
- return self.start()
782
+ return self.start(wait_timeout=wait_timeout)
770
783
 
771
784
  def start(
772
785
  self,
786
+ wait_timeout: Optional[dt.timedelta] = dt.timedelta(minutes=20)
773
787
  ) -> "Cluster":
774
788
  """Start the cluster if it is not already running.
775
789
 
@@ -780,8 +794,15 @@ class Cluster(WorkspaceService):
780
794
 
781
795
  if not self.is_running:
782
796
  logger.info("Starting %s", self)
783
- self.details = self.clusters_client().start_and_wait(cluster_id=self.cluster_id)
784
- return self.wait_installed_libraries()
797
+
798
+ if wait_timeout:
799
+ self.details = (
800
+ self.clusters_client()
801
+ .start_and_wait(cluster_id=self.cluster_id, timeout=wait_timeout)
802
+ )
803
+ self.wait_installed_libraries(timeout=wait_timeout)
804
+ else:
805
+ self.clusters_client().start(cluster_id=self.cluster_id)
785
806
 
786
807
  return self
787
808
 
@@ -1124,7 +1145,7 @@ class Cluster(WorkspaceService):
1124
1145
  "Waiting %s to install libraries timed out" % self
1125
1146
  )
1126
1147
 
1127
- time.sleep(10)
1148
+ time.sleep(5)
1128
1149
  statuses = list(self.installed_library_statuses())
1129
1150
 
1130
1151
  return self
@@ -78,8 +78,8 @@ class ExecutionContext:
78
78
  language: Optional["Language"] = None
79
79
  context_id: Optional[str] = None
80
80
 
81
- _was_connected: Optional[bool] = None
82
- _remote_metadata: Optional[RemoteMetadata] = None
81
+ _was_connected: Optional[bool] = dc.field(default=None, repr=False)
82
+ _remote_metadata: Optional[RemoteMetadata] = dc.field(default=None, repr=False)
83
83
 
84
84
  _lock: threading.RLock = dc.field(default_factory=threading.RLock, init=False, repr=False)
85
85
 
@@ -2,11 +2,12 @@
2
2
 
3
3
  import datetime as dt
4
4
  import logging
5
+ import os
5
6
  from typing import (
6
7
  Callable,
7
8
  Optional,
8
9
  TypeVar,
9
- List, TYPE_CHECKING,
10
+ List, TYPE_CHECKING, Union,
10
11
  )
11
12
 
12
13
  if TYPE_CHECKING:
@@ -25,10 +26,15 @@ ReturnType = TypeVar("ReturnType")
25
26
  logger = logging.getLogger(__name__)
26
27
 
27
28
 
29
+ def identity(x):
30
+ return x
31
+
32
+
28
33
  def databricks_remote_compute(
34
+ _func: Optional[Callable] = None,
29
35
  cluster_id: Optional[str] = None,
30
36
  cluster_name: Optional[str] = None,
31
- workspace: Optional[Workspace] = None,
37
+ workspace: Optional[Union[Workspace, str]] = None,
32
38
  cluster: Optional["Cluster"] = None,
33
39
  timeout: Optional[dt.timedelta] = None,
34
40
  env_keys: Optional[List[str]] = None,
@@ -38,6 +44,7 @@ def databricks_remote_compute(
38
44
  """Return a decorator that executes functions on a remote cluster.
39
45
 
40
46
  Args:
47
+ _func: function to decorate
41
48
  cluster_id: Optional cluster id to target.
42
49
  cluster_name: Optional cluster name to target.
43
50
  workspace: Workspace instance or host string for lookup.
@@ -51,13 +58,19 @@ def databricks_remote_compute(
51
58
  A decorator that runs functions on the resolved Databricks cluster.
52
59
  """
53
60
  if force_local or Workspace.is_in_databricks_environment():
54
- def identity(x):
55
- return x
61
+ return identity if _func is None else _func
62
+
63
+ if workspace is None:
64
+ workspace = os.getenv("DATABRICKS_HOST")
56
65
 
57
- return identity
66
+ if workspace is None:
67
+ return identity if _func is None else _func
58
68
 
59
- if isinstance(workspace, str):
60
- workspace = Workspace(host=workspace)
69
+ if not isinstance(workspace, Workspace):
70
+ if isinstance(workspace, str):
71
+ workspace = Workspace(host=workspace).connect(clone=False)
72
+ else:
73
+ raise ValueError("Cannot initialize databricks workspace with %s" % type(workspace))
61
74
 
62
75
  if cluster is None:
63
76
  if cluster_id or cluster_name:
@@ -68,10 +81,14 @@ def databricks_remote_compute(
68
81
  else:
69
82
  cluster = workspace.clusters().replicated_current_environment(
70
83
  workspace=workspace,
71
- cluster_name=cluster_name
84
+ cluster_name=cluster_name,
85
+ single_user_name=workspace.current_user.user_name
72
86
  )
73
87
 
88
+ cluster.ensure_running(wait_timeout=None)
89
+
74
90
  return cluster.execution_decorator(
91
+ _func=_func,
75
92
  env_keys=env_keys,
76
93
  timeout=timeout,
77
94
  **options
@@ -198,8 +198,7 @@ class SQLEngine(WorkspaceService):
198
198
  """Short, single-line preview for logs (avoids spewing giant SQL)."""
199
199
  if not sql:
200
200
  return ""
201
- one_line = " ".join(sql.split())
202
- return one_line[:limit] + ("…" if len(one_line) > limit else "")
201
+ return sql[:limit] + ("…" if len(sql) > limit else "")
203
202
 
204
203
  def execute(
205
204
  self,
@@ -218,7 +217,6 @@ class SQLEngine(WorkspaceService):
218
217
  schema_name: Optional[str] = None,
219
218
  table_name: Optional[str] = None,
220
219
  wait_result: bool = True,
221
- **kwargs,
222
220
  ) -> "StatementResult":
223
221
  """Execute a SQL statement via Spark or Databricks SQL Statement Execution API.
224
222
 
@@ -245,7 +243,6 @@ class SQLEngine(WorkspaceService):
245
243
  schema_name: Optional schema override for API engine.
246
244
  table_name: Optional table override used when `statement` is None.
247
245
  wait_result: Whether to block until completion (API engine).
248
- **kwargs: Extra params forwarded to Databricks SDK execute_statement.
249
246
 
250
247
  Returns:
251
248
  StatementResult.
@@ -263,9 +260,12 @@ class SQLEngine(WorkspaceService):
263
260
  if spark_session is None:
264
261
  raise ValueError("No spark session found to run sql query")
265
262
 
266
- t0 = time.time()
267
- df = spark_session.sql(statement)
268
- logger.info("Spark SQL executed in %.3fs: %s", time.time() - t0, self._sql_preview(statement))
263
+ df: SparkDataFrame = spark_session.sql(statement)
264
+
265
+ if row_limit:
266
+ df = df.limit(row_limit)
267
+
268
+ logger.info("Spark SQL executed: %s", self._sql_preview(statement))
269
269
 
270
270
  # Avoid Disposition dependency if SDK imports are absent
271
271
  spark_disp = disposition if disposition is not None else getattr(globals().get("Disposition", object), "EXTERNAL_LINKS", None)
@@ -287,7 +287,6 @@ class SQLEngine(WorkspaceService):
287
287
  if not statement:
288
288
  full_name = self.table_full_name(catalog_name=catalog_name, schema_name=schema_name, table_name=table_name)
289
289
  statement = f"SELECT * FROM {full_name}"
290
- logger.debug("Autogenerated statement: %s", self._sql_preview(statement))
291
290
 
292
291
  if not warehouse_id:
293
292
  warehouse_id = self._get_or_default_warehouse_id()
@@ -314,7 +313,11 @@ class SQLEngine(WorkspaceService):
314
313
  disposition=disposition,
315
314
  )
316
315
 
317
- # BUGFIX: previously returned `wait_result` (a bool) on wait_result=False 🤦
316
+ logger.info(
317
+ "API SQL executed: %s",
318
+ self._sql_preview(statement)
319
+ )
320
+
318
321
  return execution.wait() if wait_result else execution
319
322
 
320
323
  def spark_table(
@@ -465,15 +468,7 @@ class SQLEngine(WorkspaceService):
465
468
  safe_chars=True,
466
469
  )
467
470
 
468
- logger.info(
469
- "Arrow insert into %s (mode=%s, match_by=%s, zorder_by=%s)",
470
- location,
471
- mode,
472
- match_by,
473
- zorder_by,
474
- )
475
-
476
- with self as connected:
471
+ with self.connect() as connected:
477
472
  if existing_schema is None:
478
473
  try:
479
474
  existing_schema = connected.get_table_schema(
@@ -482,7 +477,6 @@ class SQLEngine(WorkspaceService):
482
477
  table_name=table_name,
483
478
  to_arrow_schema=True,
484
479
  )
485
- logger.debug("Fetched existing schema for %s (columns=%d)", location, len(existing_schema.names))
486
480
  except ValueError as exc:
487
481
  data_tbl = convert(data, pa.Table)
488
482
  existing_schema = data_tbl.schema
@@ -527,7 +521,20 @@ class SQLEngine(WorkspaceService):
527
521
 
528
522
  transaction_id = self._random_suffix()
529
523
 
530
- data_tbl = convert(data, pa.Table, options=cast_options, target_field=existing_schema)
524
+ data_tbl = convert(
525
+ data, pa.Table,
526
+ options=cast_options, target_field=existing_schema
527
+ )
528
+ num_rows = data_tbl.num_rows
529
+
530
+ logger.debug(
531
+ "Arrow inserting %s rows into %s (mode=%s, match_by=%s, zorder_by=%s)",
532
+ num_rows,
533
+ location,
534
+ mode,
535
+ match_by,
536
+ zorder_by,
537
+ )
531
538
 
532
539
  # Write in temp volume
533
540
  temp_volume_path = connected.dbfs_path(
@@ -545,7 +552,6 @@ class SQLEngine(WorkspaceService):
545
552
  statements: list[str] = []
546
553
 
547
554
  if match_by:
548
- logger.info("Using MERGE INTO (match_by=%s)", match_by)
549
555
  on_condition = " AND ".join([f"T.`{k}` = S.`{k}`" for k in match_by])
550
556
 
551
557
  update_cols = [c for c in columns if c not in match_by]
@@ -588,6 +594,15 @@ FROM parquet.`{temp_volume_path}`"""
588
594
  except Exception:
589
595
  logger.exception("Failed cleaning temp volume: %s", temp_volume_path)
590
596
 
597
+ logger.info(
598
+ "Arrow inserted %s rows into %s (mode=%s, match_by=%s, zorder_by=%s)",
599
+ num_rows,
600
+ location,
601
+ mode,
602
+ match_by,
603
+ zorder_by,
604
+ )
605
+
591
606
  if zorder_by:
592
607
  zcols = ", ".join([f"`{c}`" for c in zorder_by])
593
608
  optimize_sql = f"OPTIMIZE {location} ZORDER BY ({zcols})"
@@ -675,7 +690,6 @@ FROM parquet.`{temp_volume_path}`"""
675
690
  table_name=table_name,
676
691
  to_arrow_schema=False,
677
692
  )
678
- logger.debug("Fetched destination Spark schema for %s", location)
679
693
  except ValueError:
680
694
  logger.warning("Destination table missing; creating table %s via overwrite write", location)
681
695
  data = convert(data, pyspark.sql.DataFrame)
@@ -704,10 +718,8 @@ FROM parquet.`{temp_volume_path}`"""
704
718
 
705
719
  if match_by:
706
720
  cond = " AND ".join([f"t.`{k}` <=> s.`{k}`" for k in match_by])
707
- logger.info("Running Delta MERGE (cond=%s)", cond)
708
721
 
709
722
  if mode.casefold() == "overwrite":
710
- logger.info("Overwrite-by-key mode: delete matching keys then append")
711
723
  data = data.cache()
712
724
  distinct_keys = data.select([f"`{k}`" for k in match_by]).distinct()
713
725
 
@@ -815,6 +827,7 @@ FROM parquet.`{temp_volume_path}`"""
815
827
  optimize_write: bool = True,
816
828
  auto_compact: bool = True,
817
829
  execute: bool = True,
830
+ wait_result: bool = True
818
831
  ) -> Union[str, "StatementResult"]:
819
832
  """Generate (and optionally execute) CREATE TABLE DDL from an Arrow schema/field.
820
833
 
@@ -832,6 +845,7 @@ FROM parquet.`{temp_volume_path}`"""
832
845
  optimize_write: Sets delta.autoOptimize.optimizeWrite table property.
833
846
  auto_compact: Sets delta.autoOptimize.autoCompact table property.
834
847
  execute: If True, executes DDL and returns StatementResult; otherwise returns SQL string.
848
+ wait_result: Waits execution to complete
835
849
 
836
850
  Returns:
837
851
  StatementResult if execute=True, else the DDL SQL string.
@@ -897,11 +911,13 @@ FROM parquet.`{temp_volume_path}`"""
897
911
 
898
912
  statement = "\n".join(sql)
899
913
 
900
- logger.info("Generated CREATE TABLE DDL for %s", location)
901
- logger.debug("DDL:\n%s", statement)
914
+ logger.debug(
915
+ "Generated CREATE TABLE DDL for %s:\n%s",
916
+ location, statement
917
+ )
902
918
 
903
919
  if execute:
904
- return self.execute(statement)
920
+ return self.execute(statement, wait_result=wait_result)
905
921
  return statement
906
922
 
907
923
  def _check_location_params(
@@ -0,0 +1,107 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from typing import Any, Dict, Iterable, Tuple
5
+
6
+ _MISSING = object()
7
+
8
+
9
+ __all__ = [
10
+ "dicts_equal",
11
+ "dict_diff"
12
+ ]
13
+
14
+
15
+ def _normalize(obj: Any) -> Any:
16
+ """
17
+ Normalize nested structures so equality is stable:
18
+ - dict: sort keys + normalize values
19
+ - list/tuple: normalize items (keeps order)
20
+ - set: sort normalized items (orderless)
21
+ - float: keep as float (handled separately for tolerance)
22
+ """
23
+ if isinstance(obj, dict):
24
+ return {k: _normalize(obj[k]) for k in sorted(obj.keys())}
25
+ if isinstance(obj, (list, tuple)):
26
+ return [_normalize(x) for x in obj]
27
+ if isinstance(obj, set):
28
+ return sorted(_normalize(x) for x in obj)
29
+ return obj
30
+
31
+ def _equal(a: Any, b: Any, float_tol: float = 0.0) -> bool:
32
+ # Float tolerance (optional)
33
+ if isinstance(a, float) or isinstance(b, float):
34
+ if a is None or b is None:
35
+ return a is b
36
+ try:
37
+ return math.isclose(float(a), float(b), rel_tol=float_tol, abs_tol=float_tol)
38
+ except Exception:
39
+ pass
40
+
41
+ # Deep normalize compare for dict/list/set
42
+ return _normalize(a) == _normalize(b)
43
+
44
+ def dicts_equal(
45
+ a: Dict[str, Any],
46
+ b: Dict[str, Any],
47
+ *,
48
+ keys: Iterable[str] | None = None,
49
+ treat_missing_as_none: bool = True,
50
+ float_tol: float = 0.0,
51
+ ) -> bool:
52
+ """
53
+ Equality check for two dicts with options:
54
+ - keys: only compare these keys
55
+ - treat_missing_as_none: missing key == None if other side is None
56
+ - float_tol: tolerance for float comparisons
57
+ """
58
+ if keys is None:
59
+ keys = set(a.keys()) | set(b.keys())
60
+
61
+ for k in keys:
62
+ av = a.get(k, _MISSING)
63
+ bv = b.get(k, _MISSING)
64
+
65
+ if treat_missing_as_none:
66
+ if av is _MISSING and bv is None:
67
+ continue
68
+ if bv is _MISSING and av is None:
69
+ continue
70
+ if av is _MISSING and bv is _MISSING:
71
+ continue
72
+
73
+ if not _equal(av, bv, float_tol=float_tol):
74
+ return False
75
+
76
+ return True
77
+
78
+ def dict_diff(
79
+ a: Dict[str, Any],
80
+ b: Dict[str, Any],
81
+ *,
82
+ keys: Iterable[str] | None = None,
83
+ treat_missing_as_none: bool = True,
84
+ float_tol: float = 0.0,
85
+ ) -> Dict[str, Tuple[Any, Any]]:
86
+ """
87
+ Returns {key: (a_val, b_val)} for all keys that differ.
88
+ """
89
+ if keys is None:
90
+ keys = set(a.keys()) | set(b.keys())
91
+
92
+ out: Dict[str, Tuple[Any, Any]] = {}
93
+ for k in keys:
94
+ av = a.get(k, _MISSING)
95
+ bv = b.get(k, _MISSING)
96
+
97
+ if treat_missing_as_none:
98
+ if av is _MISSING and bv is None:
99
+ continue
100
+ if bv is _MISSING and av is None:
101
+ continue
102
+ if av is _MISSING and bv is _MISSING:
103
+ continue
104
+
105
+ if not _equal(av, bv, float_tol=float_tol):
106
+ out[k] = (None if av is _MISSING else av, None if bv is _MISSING else bv)
107
+ return out
yggdrasil/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.1.34"
1
+ __version__ = "0.1.35"
File without changes