langmigrate 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,62 @@
1
+ """LangMigrate — declarative schema migrations for LangGraph state persistence.
2
+
3
+ See ``CLAUDE.md`` for architecture and conventions.
4
+ """
5
+
6
+ from .core.engine import HEAD, MigrationEngine
7
+ from .core.exceptions import (
8
+ ChannelRemovalUnsupportedError,
9
+ CyclicHistoryError,
10
+ DuplicateRevisionError,
11
+ IrreversibleMigrationError,
12
+ LangMigrateError,
13
+ MissingRequiredFieldError,
14
+ MultipleHeadsError,
15
+ RevisionNotAncestorError,
16
+ RevisionNotFoundError,
17
+ TopologyMismatchError,
18
+ UnsafeMigrationError,
19
+ )
20
+ from .core.migration import BaseMigration, FunctionMigration, migration
21
+ from .core.registry import MigrationRegistry, new_revision_id
22
+ from .core.topology import NodeRemap
23
+ from .core.types import REVISION_METADATA_KEY, RevisionMeta, StateEnvelope
24
+ from .integrations.state import migrate_state_update
25
+ from .runtime.batch import BatchResult, run_batch_downgrade, run_batch_upgrade
26
+ from .runtime.factory import setup_langmigrate
27
+ from .runtime.interceptor import MigrationInterceptor
28
+
29
+ __version__ = "1.0.0"
30
+
31
+ __all__ = [
32
+ "__version__",
33
+ "HEAD",
34
+ "BaseMigration",
35
+ "FunctionMigration",
36
+ "migration",
37
+ "MigrationEngine",
38
+ "MigrationRegistry",
39
+ "MigrationInterceptor",
40
+ "setup_langmigrate",
41
+ "new_revision_id",
42
+ "NodeRemap",
43
+ "StateEnvelope",
44
+ "RevisionMeta",
45
+ "REVISION_METADATA_KEY",
46
+ "BatchResult",
47
+ "run_batch_upgrade",
48
+ "run_batch_downgrade",
49
+ "migrate_state_update",
50
+ # exceptions
51
+ "LangMigrateError",
52
+ "UnsafeMigrationError",
53
+ "MissingRequiredFieldError",
54
+ "RevisionNotFoundError",
55
+ "RevisionNotAncestorError",
56
+ "DuplicateRevisionError",
57
+ "CyclicHistoryError",
58
+ "MultipleHeadsError",
59
+ "IrreversibleMigrationError",
60
+ "TopologyMismatchError",
61
+ "ChannelRemovalUnsupportedError",
62
+ ]
@@ -0,0 +1,5 @@
1
+ """Database-specific bulk adapters for the proactive batch CLI.
2
+
3
+ Database client imports are confined to this package and loaded lazily so that
4
+ importing ``langmigrate`` never requires an optional backend extra to be installed.
5
+ """
@@ -0,0 +1,48 @@
1
+ """The adapter contract for proactive (batch) migration.
2
+
3
+ An adapter exposes a database's checkpoints to the batch CLI: it enumerates the
4
+ checkpoints whose stored revision is behind the target (ideally via an indexed
5
+ metadata query) and provides the underlying saver used to read/write them.
6
+
7
+ This module is pure — it declares a :class:`Protocol` only. Concrete adapters
8
+ (``postgres``, ``redis``) live alongside and import their DB client lazily.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from collections.abc import Iterator
14
+ from typing import Protocol, runtime_checkable
15
+
16
+ from langchain_core.runnables import RunnableConfig
17
+ from langgraph.checkpoint.base import BaseCheckpointSaver
18
+
19
+
20
+ @runtime_checkable
21
+ class CheckpointAdapter(Protocol):
22
+ """Backend-specific access used by the batch migration runner."""
23
+
24
+ @property
25
+ def saver(self) -> BaseCheckpointSaver:
26
+ """The underlying LangGraph checkpointer for reads/writes."""
27
+ ...
28
+
29
+ def count_stale(self, head: str) -> int:
30
+ """Number of checkpoints whose stored revision differs from ``head``."""
31
+ ...
32
+
33
+ def iter_stale_configs(self, head: str) -> Iterator[RunnableConfig]:
34
+ """Yield a full ``RunnableConfig`` (incl. ``checkpoint_id``) per stale checkpoint."""
35
+ ...
36
+
37
+
38
+ @runtime_checkable
39
+ class BatchCheckpointAdapter(CheckpointAdapter, Protocol):
40
+ """A :class:`CheckpointAdapter` that can also enumerate *all* checkpoints.
41
+
42
+ Needed for downgrades, whose target sits below the current head (so the
43
+ stale-only enumeration is insufficient).
44
+ """
45
+
46
+ def iter_all_configs(self) -> Iterator[RunnableConfig]:
47
+ """Yield a full ``RunnableConfig`` for every checkpoint in the store."""
48
+ ...
@@ -0,0 +1,126 @@
1
+ """PostgreSQL adapter for proactive batch migration.
2
+
3
+ Finds stale checkpoints with a single indexed query against the ``metadata``
4
+ JSONB column — no need to deserialize every row to discover its revision, which
5
+ is what makes the batch path scale to large databases.
6
+
7
+ The ``psycopg`` / ``langgraph-checkpoint-postgres`` imports are done lazily so the
8
+ rest of LangMigrate stays importable without the ``[postgres]`` extra installed.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from collections.abc import Iterator
14
+ from typing import TYPE_CHECKING, Any
15
+
16
+ from langchain_core.runnables import RunnableConfig
17
+
18
+ from ..core.types import REVISION_METADATA_KEY
19
+
20
+ if TYPE_CHECKING: # pragma: no cover
21
+ from langgraph.checkpoint.postgres import PostgresSaver
22
+
23
+ # Stale = the stored revision tag differs from (or is missing relative to) the head.
24
+ _STALE_WHERE = f"metadata->>'{REVISION_METADATA_KEY}' IS DISTINCT FROM %s"
25
+
26
+
27
+ class PostgresAdapter:
28
+ """Adapter over a ``PostgresSaver`` connection for batch enumeration."""
29
+
30
+ def __init__(self, conn: Any, saver: PostgresSaver) -> None:
31
+ self._conn = conn
32
+ self._saver = saver
33
+
34
+ @classmethod
35
+ def from_conn_string(cls, conn_string: str) -> PostgresAdapter:
36
+ """Open a connection and build the adapter (and its ``PostgresSaver``)."""
37
+ import psycopg
38
+ from langgraph.checkpoint.postgres import PostgresSaver
39
+ from psycopg.rows import dict_row
40
+
41
+ conn = psycopg.connect(conn_string, autocommit=True, row_factory=dict_row)
42
+ return cls(conn, PostgresSaver(conn))
43
+
44
+ @property
45
+ def saver(self) -> PostgresSaver:
46
+ return self._saver
47
+
48
+ def setup(self) -> None:
49
+ """Create the checkpoint tables if they do not yet exist."""
50
+ self._saver.setup()
51
+
52
+ def count_stale(self, head: str) -> int:
53
+ with self._conn.cursor() as cur:
54
+ cur.execute(f"SELECT count(*) AS c FROM checkpoints WHERE {_STALE_WHERE}", (head,))
55
+ return int(cur.fetchone()["c"])
56
+
57
+ def iter_stale_configs(self, head: str) -> Iterator[RunnableConfig]:
58
+ # Materialize first so the cursor is closed before the saver reuses the
59
+ # connection during migration of each checkpoint.
60
+ with self._conn.cursor() as cur:
61
+ cur.execute(
62
+ "SELECT thread_id, checkpoint_ns, checkpoint_id "
63
+ f"FROM checkpoints WHERE {_STALE_WHERE} "
64
+ "ORDER BY thread_id, checkpoint_ns, checkpoint_id",
65
+ (head,),
66
+ )
67
+ rows = cur.fetchall()
68
+ for row in rows:
69
+ yield {
70
+ "configurable": {
71
+ "thread_id": row["thread_id"],
72
+ "checkpoint_ns": row["checkpoint_ns"],
73
+ "checkpoint_id": row["checkpoint_id"],
74
+ }
75
+ }
76
+
77
+ def iter_all_configs(self) -> Iterator[RunnableConfig]:
78
+ with self._conn.cursor() as cur:
79
+ cur.execute(
80
+ "SELECT thread_id, checkpoint_ns, checkpoint_id FROM checkpoints "
81
+ "ORDER BY thread_id, checkpoint_ns, checkpoint_id"
82
+ )
83
+ rows = cur.fetchall()
84
+ for row in rows:
85
+ yield {
86
+ "configurable": {
87
+ "thread_id": row["thread_id"],
88
+ "checkpoint_ns": row["checkpoint_ns"],
89
+ "checkpoint_id": row["checkpoint_id"],
90
+ }
91
+ }
92
+
93
+ def stamp_all(self, revision: str) -> int:
94
+ """Set the revision tag on every checkpoint without running migrations.
95
+
96
+ Returns the number of rows updated. Use when adopting LangMigrate on a
97
+ database whose state already matches a known revision. ``COALESCE`` guards
98
+ against a row whose ``metadata`` is SQL/JSON ``null`` (``jsonb_set`` of a
99
+ null base returns null and would silently drop the tag).
100
+ """
101
+ with self._conn.cursor() as cur:
102
+ cur.execute(
103
+ "UPDATE checkpoints SET metadata = "
104
+ "jsonb_set(COALESCE(NULLIF(metadata, 'null'::jsonb), '{}'::jsonb), "
105
+ f"'{{{REVISION_METADATA_KEY}}}', to_jsonb(%s::text))",
106
+ (revision,),
107
+ )
108
+ return int(cur.rowcount or 0)
109
+
110
+ def revision_counts(self) -> dict[str, int]:
111
+ """Distribution of stored revision tags across all checkpoints (for ``current --db``)."""
112
+ with self._conn.cursor() as cur:
113
+ cur.execute(
114
+ f"SELECT metadata->>'{REVISION_METADATA_KEY}' AS rev, count(*) AS c "
115
+ "FROM checkpoints GROUP BY rev"
116
+ )
117
+ return {(row["rev"] or "<untagged>"): int(row["c"]) for row in cur.fetchall()}
118
+
119
+ def close(self) -> None:
120
+ self._conn.close()
121
+
122
+ def __enter__(self) -> PostgresAdapter:
123
+ return self
124
+
125
+ def __exit__(self, *exc: object) -> None:
126
+ self.close()
@@ -0,0 +1,152 @@
1
+ """Redis adapter for proactive batch migration.
2
+
3
+ ``langgraph-checkpoint-redis`` stores each checkpoint as a RedisJSON document at
4
+ ``checkpoint:<thread>:<ns>:<id>`` with the LangGraph metadata kept as a serialized
5
+ JSON string under ``$.metadata``. Our revision tag is *not* part of the RediSearch
6
+ index, so batch enumeration scans the checkpoint keys and reads each document's
7
+ metadata (an O(n) sweep — inherent to Redis without a custom index).
8
+
9
+ Lazy *online* migration needs none of this: wrap a ``RedisSaver`` with
10
+ :class:`~langmigrate.runtime.interceptor.MigrationInterceptor` and it works today.
11
+ The ``redis`` client imports are done lazily so the rest of LangMigrate stays
12
+ importable without the ``[redis]`` extra.
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import contextlib
18
+ import json
19
+ from collections.abc import Iterator
20
+ from typing import TYPE_CHECKING, Any, cast
21
+
22
+ from langchain_core.runnables import RunnableConfig
23
+ from langgraph.checkpoint.base import CheckpointMetadata
24
+
25
+ from ..core.version import read_revision, stamp_metadata
26
+
27
+ if TYPE_CHECKING: # pragma: no cover
28
+ from langgraph.checkpoint.redis import RedisSaver
29
+
30
+ # Matches checkpoint docs only ("checkpoint:..."), not "checkpoint_write:...".
31
+ _CHECKPOINT_MATCH = "checkpoint:*"
32
+
33
+
34
+ def _first(fields: dict[str, Any], path: str) -> Any:
35
+ """First value at a RedisJSON ``$.path`` result (paths return a list)."""
36
+ value = fields.get(path)
37
+ return value[0] if value else None
38
+
39
+
40
+ class RedisAdapter:
41
+ """Adapter over a ``RedisSaver`` for batch enumeration and stamping."""
42
+
43
+ def __init__(self, saver: RedisSaver) -> None:
44
+ self._saver = saver
45
+
46
+ @classmethod
47
+ def from_conn_string(cls, conn_string: str) -> RedisAdapter:
48
+ """Open a connection and build the adapter (and its ``RedisSaver``)."""
49
+ from langgraph.checkpoint.redis import RedisSaver
50
+
51
+ saver = RedisSaver(conn_string)
52
+ saver.setup()
53
+ return cls(saver)
54
+
55
+ @property
56
+ def saver(self) -> RedisSaver:
57
+ return self._saver
58
+
59
+ def setup(self) -> None:
60
+ self._saver.setup()
61
+
62
+ # -- enumeration --------------------------------------------------------
63
+
64
+ def _client(self) -> Any:
65
+ return self._saver._redis
66
+
67
+ def _iter_docs(self) -> Iterator[tuple[RunnableConfig, str | None]]:
68
+ """Yield ``(config, stored_revision)`` for every checkpoint document."""
69
+ from langgraph.checkpoint.redis.util import (
70
+ from_storage_safe_id,
71
+ from_storage_safe_str,
72
+ safely_decode,
73
+ )
74
+
75
+ client = self._client()
76
+ for raw_key in client.scan_iter(match=_CHECKPOINT_MATCH, count=200):
77
+ key = safely_decode(raw_key)
78
+ fields = client.json().get(
79
+ key, "$.thread_id", "$.checkpoint_ns", "$.checkpoint_id", "$.metadata"
80
+ )
81
+ if not fields:
82
+ continue
83
+
84
+ thread_id = from_storage_safe_id(_first(fields, "$.thread_id") or "")
85
+ checkpoint_ns = from_storage_safe_str(_first(fields, "$.checkpoint_ns") or "")
86
+ checkpoint_id = from_storage_safe_id(_first(fields, "$.checkpoint_id") or "")
87
+ revision = self._revision_from_metadata(_first(fields, "$.metadata"))
88
+ config: RunnableConfig = {
89
+ "configurable": {
90
+ "thread_id": thread_id,
91
+ "checkpoint_ns": checkpoint_ns,
92
+ "checkpoint_id": checkpoint_id,
93
+ }
94
+ }
95
+ yield config, revision
96
+
97
+ @staticmethod
98
+ def _revision_from_metadata(metadata: Any) -> str | None:
99
+ # Stored as a serialized JSON string (occasionally already a dict).
100
+ if isinstance(metadata, str):
101
+ try:
102
+ metadata = json.loads(metadata)
103
+ except json.JSONDecodeError:
104
+ return None
105
+ return read_revision(metadata if isinstance(metadata, dict) else None)
106
+
107
+ def count_stale(self, head: str) -> int:
108
+ return sum(1 for _, rev in self._iter_docs() if rev != head)
109
+
110
+ def iter_stale_configs(self, head: str) -> Iterator[RunnableConfig]:
111
+ for config, rev in self._iter_docs():
112
+ if rev != head:
113
+ yield config
114
+
115
+ def iter_all_configs(self) -> Iterator[RunnableConfig]:
116
+ for config, _ in self._iter_docs():
117
+ yield config
118
+
119
+ def revision_counts(self) -> dict[str, int]:
120
+ counts: dict[str, int] = {}
121
+ for _, rev in self._iter_docs():
122
+ key = rev or "<untagged>"
123
+ counts[key] = counts.get(key, 0) + 1
124
+ return counts
125
+
126
+ def stamp_all(self, revision: str) -> int:
127
+ """Set the revision tag on every checkpoint without running migrations."""
128
+ updated = 0
129
+ for config, _ in self._iter_docs():
130
+ tup = self._saver.get_tuple(config)
131
+ if tup is None:
132
+ continue
133
+ metadata = stamp_metadata(dict(tup.metadata or {}), revision)
134
+ put_config = tup.parent_config or {
135
+ "configurable": {
136
+ "thread_id": config["configurable"]["thread_id"],
137
+ "checkpoint_ns": config["configurable"]["checkpoint_ns"],
138
+ }
139
+ }
140
+ self._saver.put(put_config, tup.checkpoint, cast(CheckpointMetadata, metadata), {})
141
+ updated += 1
142
+ return updated
143
+
144
+ def close(self) -> None:
145
+ with contextlib.suppress(Exception): # best effort
146
+ self._client().close()
147
+
148
+ def __enter__(self) -> RedisAdapter:
149
+ return self
150
+
151
+ def __exit__(self, *exc: object) -> None:
152
+ self.close()
@@ -0,0 +1 @@
1
+ """Typer-based command line interface."""