intentkit 0.7.4__py3-none-any.whl → 0.7.4rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
intentkit/models/agent.py CHANGED
@@ -1,3 +1,4 @@
1
+ import hashlib
1
2
  import json
2
3
  import logging
3
4
  import re
@@ -14,10 +15,11 @@ from epyxid import XID
14
15
  from fastapi import HTTPException
15
16
  from intentkit.models.agent_data import AgentData
16
17
  from intentkit.models.base import Base
18
+ from intentkit.models.credit import CreditAccount
17
19
  from intentkit.models.db import get_session
18
20
  from intentkit.models.llm import LLMModelInfo, LLMModelInfoTable, LLMProvider
19
21
  from intentkit.models.skill import SkillTable
20
- from pydantic import BaseModel, ConfigDict, field_validator, model_validator
22
+ from pydantic import BaseModel, ConfigDict, field_validator
21
23
  from pydantic import Field as PydanticField
22
24
  from pydantic.json_schema import SkipJsonSchema
23
25
  from sqlalchemy import (
@@ -128,33 +130,6 @@ class AgentAutonomous(BaseModel):
128
130
  )
129
131
  return v
130
132
 
131
- @field_validator("name")
132
- @classmethod
133
- def validate_name(cls, v: Optional[str]) -> Optional[str]:
134
- if v is not None and len(v.encode()) > 50:
135
- raise ValueError("name must be at most 50 bytes")
136
- return v
137
-
138
- @field_validator("description")
139
- @classmethod
140
- def validate_description(cls, v: Optional[str]) -> Optional[str]:
141
- if v is not None and len(v.encode()) > 200:
142
- raise ValueError("description must be at most 200 bytes")
143
- return v
144
-
145
- @field_validator("prompt")
146
- @classmethod
147
- def validate_prompt(cls, v: Optional[str]) -> Optional[str]:
148
- if v is not None and len(v.encode()) > 20000:
149
- raise ValueError("prompt must be at most 20000 bytes")
150
- return v
151
-
152
- @model_validator(mode="after")
153
- def validate_schedule(self) -> "AgentAutonomous":
154
- # This validator is kept for backward compatibility
155
- # The actual validation now happens in AgentUpdate.validate_autonomous_schedule
156
- return self
157
-
158
133
 
159
134
  class AgentExample(BaseModel):
160
135
  """Agent example configuration."""
@@ -241,11 +216,6 @@ class AgentTable(Base):
241
216
  nullable=True,
242
217
  comment="Pool of the agent token",
243
218
  )
244
- mode = Column(
245
- String,
246
- nullable=True,
247
- comment="Mode of the agent, public or private",
248
- )
249
219
  fee_percentage = Column(
250
220
  Numeric(22, 4),
251
221
  nullable=True,
@@ -362,13 +332,6 @@ class AgentTable(Base):
362
332
  nullable=True,
363
333
  comment="Dict of skills and their corresponding configurations",
364
334
  )
365
-
366
- cdp_network_id = Column(
367
- String,
368
- nullable=True,
369
- default="base-mainnet",
370
- comment="Network identifier for CDP integration",
371
- )
372
335
  # if telegram_entrypoint_enabled, the telegram_entrypoint_enabled will be enabled, telegram_config will be checked
373
336
  telegram_entrypoint_enabled = Column(
374
337
  Boolean,
@@ -391,6 +354,31 @@ class AgentTable(Base):
391
354
  nullable=True,
392
355
  comment="Extra prompt for xmtp entrypoint",
393
356
  )
357
+ version = Column(
358
+ String,
359
+ nullable=True,
360
+ comment="Version hash of the agent",
361
+ )
362
+ statistics = Column(
363
+ JSON().with_variant(JSONB(), "postgresql"),
364
+ nullable=True,
365
+ comment="Statistics of the agent, update every 1 hour for query",
366
+ )
367
+ assets = Column(
368
+ JSON().with_variant(JSONB(), "postgresql"),
369
+ nullable=True,
370
+ comment="Assets of the agent, update every 1 hour for query",
371
+ )
372
+ account_snapshot = Column(
373
+ JSON().with_variant(JSONB(), "postgresql"),
374
+ nullable=True,
375
+ comment="Account snapshot of the agent, update every 1 hour for query",
376
+ )
377
+ extra = Column(
378
+ JSON().with_variant(JSONB(), "postgresql"),
379
+ nullable=True,
380
+ comment="Other helper data fields for query, come from agent and agent data",
381
+ )
394
382
  # auto timestamp
395
383
  created_at = Column(
396
384
  DateTime(timezone=True),
@@ -407,16 +395,8 @@ class AgentTable(Base):
407
395
  )
408
396
 
409
397
 
410
- class AgentUpdate(BaseModel):
411
- """Agent update model."""
412
-
413
- model_config = ConfigDict(
414
- title="Agent",
415
- from_attributes=True,
416
- json_schema_extra={
417
- "required": ["name", "purpose", "personality", "principles"],
418
- },
419
- )
398
+ class AgentCore(BaseModel):
399
+ """Agent core model."""
420
400
 
421
401
  name: Annotated[
422
402
  Optional[str],
@@ -431,19 +411,6 @@ class AgentUpdate(BaseModel):
431
411
  },
432
412
  ),
433
413
  ]
434
- slug: Annotated[
435
- Optional[str],
436
- PydanticField(
437
- default=None,
438
- description="Slug of the agent, used for URL generation",
439
- max_length=30,
440
- min_length=2,
441
- json_schema_extra={
442
- "x-group": "internal",
443
- "readOnly": True,
444
- },
445
- ),
446
- ]
447
414
  description: Annotated[
448
415
  Optional[str],
449
416
  PydanticField(
@@ -515,16 +482,6 @@ class AgentUpdate(BaseModel):
515
482
  },
516
483
  ),
517
484
  ]
518
- mode: Annotated[
519
- Optional[Literal["public", "private"]],
520
- PydanticField(
521
- default=None,
522
- description="Mode of the agent, public or private",
523
- json_schema_extra={
524
- "x-group": "basic",
525
- },
526
- ),
527
- ]
528
485
  fee_percentage: Annotated[
529
486
  Optional[Decimal],
530
487
  PydanticField(
@@ -584,38 +541,6 @@ class AgentUpdate(BaseModel):
584
541
  },
585
542
  ),
586
543
  ]
587
- owner: Annotated[
588
- Optional[str],
589
- PydanticField(
590
- default=None,
591
- description="Owner identifier of the agent, used for access control",
592
- max_length=50,
593
- json_schema_extra={
594
- "x-group": "internal",
595
- },
596
- ),
597
- ]
598
- upstream_id: Annotated[
599
- Optional[str],
600
- PydanticField(
601
- default=None,
602
- description="External reference ID for idempotent operations",
603
- max_length=100,
604
- json_schema_extra={
605
- "x-group": "internal",
606
- },
607
- ),
608
- ]
609
- upstream_extra: Annotated[
610
- Optional[Dict[str, Any]],
611
- PydanticField(
612
- default=None,
613
- description="Additional data store for upstream use",
614
- json_schema_extra={
615
- "x-group": "internal",
616
- },
617
- ),
618
- ]
619
544
  # AI part
620
545
  model: Annotated[
621
546
  str,
@@ -693,6 +618,89 @@ class AgentUpdate(BaseModel):
693
618
  },
694
619
  ),
695
620
  ]
621
+ wallet_provider: Annotated[
622
+ Optional[Literal["cdp", "readonly"]],
623
+ PydanticField(
624
+ default="cdp",
625
+ description="Provider of the agent's wallet",
626
+ json_schema_extra={
627
+ "x-group": "onchain",
628
+ },
629
+ ),
630
+ ]
631
+ readonly_wallet_address: Annotated[
632
+ Optional[str],
633
+ PydanticField(
634
+ default=None,
635
+ description="Address of the agent's wallet, only used when wallet_provider is readonly. Agent will not be able to sign transactions.",
636
+ ),
637
+ ]
638
+ network_id: Annotated[
639
+ Optional[
640
+ Literal[
641
+ "ethereum-mainnet",
642
+ "ethereum-sepolia",
643
+ "polygon-mainnet",
644
+ "polygon-mumbai",
645
+ "base-mainnet",
646
+ "base-sepolia",
647
+ "arbitrum-mainnet",
648
+ "arbitrum-sepolia",
649
+ "optimism-mainnet",
650
+ "optimism-sepolia",
651
+ "solana",
652
+ ]
653
+ ],
654
+ PydanticField(
655
+ default="base-mainnet",
656
+ description="Network identifier",
657
+ json_schema_extra={
658
+ "x-group": "onchain",
659
+ },
660
+ ),
661
+ ]
662
+ skills: Annotated[
663
+ Optional[Dict[str, Any]],
664
+ PydanticField(
665
+ default=None,
666
+ description="Dict of skills and their corresponding configurations",
667
+ json_schema_extra={
668
+ "x-group": "skills",
669
+ "x-inline": True,
670
+ },
671
+ ),
672
+ ]
673
+
674
+ def hash(self) -> str:
675
+ """
676
+ Generate a fixed-length hash based on the agent's content.
677
+
678
+ The hash remains unchanged if the content is the same and changes if the content changes.
679
+ This method serializes only AgentCore fields to JSON and generates a SHA-256 hash.
680
+ When called from subclasses, it will only use AgentCore fields, not subclass fields.
681
+
682
+ Returns:
683
+ str: A 64-character hexadecimal hash string
684
+ """
685
+ # Create a dictionary with only AgentCore fields for hashing
686
+ hash_data = {}
687
+
688
+ # Get only AgentCore field values, excluding None values for consistency
689
+ for field_name in AgentCore.model_fields:
690
+ value = getattr(self, field_name)
691
+ if value is not None:
692
+ hash_data[field_name] = value
693
+
694
+ # Convert to JSON string with sorted keys for consistent ordering
695
+ json_str = json.dumps(hash_data, sort_keys=True, default=str, ensure_ascii=True)
696
+
697
+ # Generate SHA-256 hash
698
+ return hashlib.sha256(json_str.encode("utf-8")).hexdigest()
699
+
700
+
701
+ class AgentUserInput(AgentCore):
702
+ """Agent update model."""
703
+
696
704
  short_term_memory_strategy: Annotated[
697
705
  Optional[Literal["trim", "summarize"]],
698
706
  PydanticField(
@@ -751,82 +759,6 @@ class AgentUpdate(BaseModel):
751
759
  },
752
760
  ),
753
761
  ]
754
- # skills
755
- skills: Annotated[
756
- Optional[Dict[str, Any]],
757
- PydanticField(
758
- default=None,
759
- description="Dict of skills and their corresponding configurations",
760
- json_schema_extra={
761
- "x-group": "skills",
762
- "x-inline": True,
763
- },
764
- ),
765
- ]
766
- wallet_provider: Annotated[
767
- Optional[Literal["cdp", "readonly"]],
768
- PydanticField(
769
- default="cdp",
770
- description="Provider of the agent's wallet",
771
- json_schema_extra={
772
- "x-group": "onchain",
773
- },
774
- ),
775
- ]
776
- readonly_wallet_address: Annotated[
777
- Optional[str],
778
- PydanticField(
779
- default=None,
780
- description="Address of the agent's wallet, only used when wallet_provider is readonly. Agent will not be able to sign transactions.",
781
- ),
782
- ]
783
- network_id: Annotated[
784
- Optional[
785
- Literal[
786
- "ethereum-mainnet",
787
- "ethereum-sepolia",
788
- "polygon-mainnet",
789
- "polygon-mumbai",
790
- "base-mainnet",
791
- "base-sepolia",
792
- "arbitrum-mainnet",
793
- "arbitrum-sepolia",
794
- "optimism-mainnet",
795
- "optimism-sepolia",
796
- "solana",
797
- ]
798
- ],
799
- PydanticField(
800
- default="base-mainnet",
801
- description="Network identifier",
802
- json_schema_extra={
803
- "x-group": "onchain",
804
- },
805
- ),
806
- ]
807
- cdp_network_id: Annotated[
808
- Optional[
809
- Literal[
810
- "ethereum-mainnet",
811
- "ethereum-sepolia",
812
- "polygon-mainnet",
813
- "polygon-mumbai",
814
- "base-mainnet",
815
- "base-sepolia",
816
- "arbitrum-mainnet",
817
- "arbitrum-sepolia",
818
- "optimism-mainnet",
819
- "optimism-sepolia",
820
- ]
821
- ],
822
- PydanticField(
823
- default="base-mainnet",
824
- description="Network identifier for CDP integration",
825
- json_schema_extra={
826
- "x-group": "deprecated",
827
- },
828
- ),
829
- ]
830
762
  # if telegram_entrypoint_enabled, the telegram_entrypoint_enabled will be enabled, telegram_config will be checked
831
763
  telegram_entrypoint_enabled: Annotated[
832
764
  Optional[bool],
@@ -871,6 +803,37 @@ class AgentUpdate(BaseModel):
871
803
  ),
872
804
  ]
873
805
 
806
+
807
+ class AgentUpdate(AgentUserInput):
808
+ """Agent update model."""
809
+
810
+ model_config = ConfigDict(
811
+ title="Agent",
812
+ from_attributes=True,
813
+ json_schema_extra={
814
+ "required": ["name"],
815
+ },
816
+ )
817
+
818
+ upstream_id: Annotated[
819
+ Optional[str],
820
+ PydanticField(
821
+ default=None,
822
+ description="External reference ID for idempotent operations",
823
+ max_length=100,
824
+ ),
825
+ ]
826
+ upstream_extra: Annotated[
827
+ Optional[Dict[str, Any]],
828
+ PydanticField(
829
+ default=None,
830
+ description="Additional data store for upstream use",
831
+ json_schema_extra={
832
+ "x-group": "internal",
833
+ },
834
+ ),
835
+ ]
836
+
874
837
  @field_validator("purpose", "personality", "principles", "prompt", "prompt_append")
875
838
  @classmethod
876
839
  def validate_no_level1_level2_headings(cls, v: Optional[str]) -> Optional[str]:
@@ -960,6 +923,7 @@ class AgentUpdate(BaseModel):
960
923
  detail="The shortest execution interval is 5 minutes",
961
924
  )
962
925
 
926
+ # deprecated, use override instead
963
927
  async def update(self, id: str) -> "Agent":
964
928
  # Validate autonomous schedule settings if present
965
929
  if "autonomous" in self.model_dump(exclude_unset=True):
@@ -969,12 +933,6 @@ class AgentUpdate(BaseModel):
969
933
  db_agent = await db.get(AgentTable, id)
970
934
  if not db_agent:
971
935
  raise HTTPException(status_code=404, detail="Agent not found")
972
- # check owner
973
- if self.owner and db_agent.owner != self.owner:
974
- raise HTTPException(
975
- status_code=403,
976
- detail="You do not have permission to update this agent",
977
- )
978
936
  # update
979
937
  for key, value in self.model_dump(exclude_unset=True).items():
980
938
  setattr(db_agent, key, value)
@@ -991,15 +949,11 @@ class AgentUpdate(BaseModel):
991
949
  db_agent = await db.get(AgentTable, id)
992
950
  if not db_agent:
993
951
  raise HTTPException(status_code=404, detail="Agent not found")
994
- # check owner
995
- if db_agent.owner and db_agent.owner != self.owner:
996
- raise HTTPException(
997
- status_code=403,
998
- detail="You do not have permission to update this agent",
999
- )
1000
952
  # update
1001
953
  for key, value in self.model_dump().items():
1002
954
  setattr(db_agent, key, value)
955
+ # version
956
+ db_agent.version = self.hash()
1003
957
  await db.commit()
1004
958
  await db.refresh(db_agent)
1005
959
  return Agent.model_validate(db_agent)
@@ -1018,6 +972,14 @@ class AgentCreate(AgentUpdate):
1018
972
  max_length=67,
1019
973
  ),
1020
974
  ]
975
+ owner: Annotated[
976
+ Optional[str],
977
+ PydanticField(
978
+ default=None,
979
+ description="Owner identifier of the agent, used for access control",
980
+ max_length=50,
981
+ ),
982
+ ]
1021
983
 
1022
984
  async def check_upstream_id(self) -> None:
1023
985
  if not self.upstream_id:
@@ -1050,45 +1012,56 @@ class AgentCreate(AgentUpdate):
1050
1012
 
1051
1013
  async with get_session() as db:
1052
1014
  db_agent = AgentTable(**self.model_dump())
1015
+ db_agent.version = self.hash()
1053
1016
  db.add(db_agent)
1054
1017
  await db.commit()
1055
1018
  await db.refresh(db_agent)
1056
1019
  return Agent.model_validate(db_agent)
1057
1020
 
1058
- async def create_or_update(self) -> ("Agent", bool):
1059
- # Validation is now handled by field validators
1060
- await self.check_upstream_id()
1061
-
1062
- # Validate autonomous schedule settings if present
1063
- if self.autonomous:
1064
- self.validate_autonomous_schedule()
1065
-
1066
- is_new = False
1067
- async with get_session() as db:
1068
- db_agent = await db.get(AgentTable, self.id)
1069
- if not db_agent:
1070
- db_agent = AgentTable(**self.model_dump())
1071
- db.add(db_agent)
1072
- is_new = True
1073
- else:
1074
- # check owner
1075
- if self.owner and db_agent.owner != self.owner:
1076
- raise HTTPException(
1077
- status_code=403,
1078
- detail="You do not have permission to update this agent",
1079
- )
1080
- for key, value in self.model_dump(exclude_unset=True).items():
1081
- setattr(db_agent, key, value)
1082
- await db.commit()
1083
- await db.refresh(db_agent)
1084
- return Agent.model_validate(db_agent), is_new
1085
-
1086
1021
 
1087
1022
  class Agent(AgentCreate):
1088
1023
  """Agent model."""
1089
1024
 
1090
1025
  model_config = ConfigDict(from_attributes=True)
1091
1026
 
1027
+ slug: Annotated[
1028
+ Optional[str],
1029
+ PydanticField(
1030
+ default=None,
1031
+ description="Slug of the agent, used for URL generation",
1032
+ max_length=20,
1033
+ min_length=2,
1034
+ ),
1035
+ ]
1036
+ version: Annotated[
1037
+ Optional[str],
1038
+ PydanticField(
1039
+ default=None,
1040
+ description="Version hash of the agent",
1041
+ ),
1042
+ ]
1043
+ statistics: Annotated[
1044
+ Optional[Dict[str, Any]],
1045
+ PydanticField(
1046
+ description="Statistics of the agent, update every 1 hour for query"
1047
+ ),
1048
+ ]
1049
+ assets: Annotated[
1050
+ Optional[Dict[str, Any]],
1051
+ PydanticField(description="Assets of the agent, update every 1 hour for query"),
1052
+ ]
1053
+ account_snapshot: Annotated[
1054
+ Optional[CreditAccount],
1055
+ PydanticField(
1056
+ description="Account snapshot of the agent, update every 1 hour for query"
1057
+ ),
1058
+ ]
1059
+ extra: Annotated[
1060
+ Optional[Dict[str, Any]],
1061
+ PydanticField(
1062
+ description="Other helper data fields for query, come from agent and agent data"
1063
+ ),
1064
+ ]
1092
1065
  # auto timestamp
1093
1066
  created_at: Annotated[
1094
1067
  datetime,
@@ -1555,13 +1528,6 @@ class AgentResponse(BaseModel):
1555
1528
  description="Pool of the agent token",
1556
1529
  ),
1557
1530
  ]
1558
- mode: Annotated[
1559
- Optional[Literal["public", "private"]],
1560
- PydanticField(
1561
- default=None,
1562
- description="Mode of the agent, public or private",
1563
- ),
1564
- ]
1565
1531
  fee_percentage: Annotated[
1566
1532
  Optional[Decimal],
1567
1533
  PydanticField(
@@ -1650,13 +1616,6 @@ class AgentResponse(BaseModel):
1650
1616
  description="Network identifier",
1651
1617
  ),
1652
1618
  ]
1653
- cdp_network_id: Annotated[
1654
- Optional[str],
1655
- PydanticField(
1656
- default="base-mainnet",
1657
- description="Network identifier for CDP integration",
1658
- ),
1659
- ]
1660
1619
  # telegram entrypoint
1661
1620
  telegram_entrypoint_enabled: Annotated[
1662
1621
  Optional[bool],
@@ -67,20 +67,6 @@
67
67
  "x-group": "basic",
68
68
  "x-placeholder": "Enter agent name"
69
69
  },
70
- "mode": {
71
- "title": "Usage Type",
72
- "type": "string",
73
- "description": "Mode of the agent, Public App or Personal Assistant",
74
- "enum": [
75
- "public",
76
- "private"
77
- ],
78
- "x-enum-title": [
79
- "Public App",
80
- "Personal Assistant"
81
- ],
82
- "x-group": "deprecated"
83
- },
84
70
  "fee_percentage": {
85
71
  "title": "Service Fee",
86
72
  "type": "number",
@@ -116,31 +102,6 @@
116
102
  "x-group": "experimental",
117
103
  "x-placeholder": "Upload a picture of your agent"
118
104
  },
119
- "slug": {
120
- "title": "Slug",
121
- "type": "string",
122
- "description": "Slug of the agent, used for URL generation",
123
- "maxLength": 30,
124
- "minLength": 2,
125
- "readOnly": true,
126
- "x-group": "internal"
127
- },
128
- "owner": {
129
- "title": "Owner",
130
- "type": "string",
131
- "description": "Owner identifier of the agent, used for access control",
132
- "readOnly": true,
133
- "maxLength": 50,
134
- "x-group": "internal"
135
- },
136
- "upstream_id": {
137
- "title": "Upstream ID",
138
- "type": "string",
139
- "description": "External reference ID for idempotent operations",
140
- "readOnly": true,
141
- "maxLength": 100,
142
- "x-group": "internal"
143
- },
144
105
  "model": {
145
106
  "title": "AI Model",
146
107
  "type": "string",
intentkit/models/db.py CHANGED
@@ -102,7 +102,7 @@ async def get_db() -> AsyncGenerator[AsyncSession, None]:
102
102
 
103
103
 
104
104
  @asynccontextmanager
105
- async def get_session() -> AsyncSession:
105
+ async def get_session() -> AsyncGenerator[AsyncSession, None]:
106
106
  """Get a database session using an async context manager.
107
107
 
108
108
  This function is designed to be used with the 'async with' statement,
intentkit/models/user.py CHANGED
@@ -170,7 +170,7 @@ class UserUpdate(BaseModel):
170
170
  note=note,
171
171
  )
172
172
 
173
- async def patch(self, id: str) -> "User":
173
+ async def patch(self, id: str) -> UserModelType:
174
174
  """Update only the provided fields of a user in the database.
175
175
  If the user doesn't exist, create a new one with the provided ID and fields.
176
176
  If nft_count changes, update the daily quota accordingly.
@@ -182,7 +182,9 @@ class UserUpdate(BaseModel):
182
182
  Updated or newly created User model
183
183
  """
184
184
  user_model_class = user_model_registry.get_user_model_class()
185
+ assert issubclass(user_model_class, User)
185
186
  user_table_class = user_model_registry.get_user_table_class()
187
+ assert issubclass(user_table_class, UserTable)
186
188
  async with get_session() as db:
187
189
  db_user = await db.get(user_table_class, id)
188
190
  old_nft_count = 0 # Default for new users
@@ -208,7 +210,7 @@ class UserUpdate(BaseModel):
208
210
 
209
211
  return user_model_class.model_validate(db_user)
210
212
 
211
- async def put(self, id: str) -> "User":
213
+ async def put(self, id: str) -> UserModelType:
212
214
  """Replace all fields of a user in the database with the provided values.
213
215
  If the user doesn't exist, create a new one with the provided ID and fields.
214
216
  If nft_count changes, update the daily quota accordingly.
@@ -220,7 +222,9 @@ class UserUpdate(BaseModel):
220
222
  Updated or newly created User model
221
223
  """
222
224
  user_model_class = user_model_registry.get_user_model_class()
225
+ assert issubclass(user_model_class, User)
223
226
  user_table_class = user_model_registry.get_user_table_class()
227
+ assert issubclass(user_table_class, UserTable)
224
228
  async with get_session() as db:
225
229
  db_user = await db.get(user_table_class, id)
226
230
  old_nft_count = 0 # Default for new users
@@ -261,7 +265,7 @@ class User(UserUpdate):
261
265
  ]
262
266
 
263
267
  @classmethod
264
- async def get(cls, user_id: str) -> Optional["User"]:
268
+ async def get(cls, user_id: str) -> Optional[UserModelType]:
265
269
  """Get a user by ID.
266
270
 
267
271
  Args:
@@ -276,7 +280,7 @@ class User(UserUpdate):
276
280
  @classmethod
277
281
  async def get_in_session(
278
282
  cls, session: AsyncSession, user_id: str
279
- ) -> Optional["User"]:
283
+ ) -> Optional[UserModelType]:
280
284
  """Get a user by ID using the provided session.
281
285
 
282
286
  Args:
@@ -286,11 +290,14 @@ class User(UserUpdate):
286
290
  Returns:
287
291
  User model or None if not found
288
292
  """
293
+ user_model_class = user_model_registry.get_user_model_class()
294
+ assert issubclass(user_model_class, User)
289
295
  user_table_class = user_model_registry.get_user_table_class()
296
+ assert issubclass(user_table_class, UserTable)
290
297
  result = await session.execute(
291
298
  select(user_table_class).where(user_table_class.id == user_id)
292
299
  )
293
300
  user = result.scalars().first()
294
301
  if user is None:
295
302
  return None
296
- return cls.model_validate(user)
303
+ return user_model_class.model_validate(user)