letta-nightly 0.5.1.dev20241025104130__py3-none-any.whl → 0.5.1.dev20241026104101__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of letta-nightly might be problematic. Click here for more details.
- letta/agent.py +1 -6
- letta/cli/cli.py +6 -4
- letta/client/client.py +75 -92
- letta/config.py +0 -6
- letta/constants.py +0 -9
- letta/functions/functions.py +24 -0
- letta/functions/helpers.py +4 -3
- letta/functions/schema_generator.py +10 -2
- letta/metadata.py +2 -99
- letta/o1_agent.py +2 -2
- letta/orm/__all__.py +15 -0
- letta/orm/mixins.py +16 -1
- letta/orm/organization.py +2 -0
- letta/orm/sqlalchemy_base.py +17 -18
- letta/orm/tool.py +54 -0
- letta/orm/user.py +7 -1
- letta/schemas/block.py +6 -9
- letta/schemas/memory.py +27 -29
- letta/schemas/tool.py +62 -61
- letta/schemas/user.py +2 -2
- letta/server/rest_api/admin/users.py +1 -1
- letta/server/rest_api/routers/v1/tools.py +19 -23
- letta/server/server.py +42 -327
- letta/services/organization_manager.py +19 -12
- letta/services/tool_manager.py +193 -0
- letta/services/user_manager.py +16 -11
- {letta_nightly-0.5.1.dev20241025104130.dist-info → letta_nightly-0.5.1.dev20241026104101.dist-info}/METADATA +1 -1
- {letta_nightly-0.5.1.dev20241025104130.dist-info → letta_nightly-0.5.1.dev20241026104101.dist-info}/RECORD +31 -29
- {letta_nightly-0.5.1.dev20241025104130.dist-info → letta_nightly-0.5.1.dev20241026104101.dist-info}/LICENSE +0 -0
- {letta_nightly-0.5.1.dev20241025104130.dist-info → letta_nightly-0.5.1.dev20241026104101.dist-info}/WHEEL +0 -0
- {letta_nightly-0.5.1.dev20241025104130.dist-info → letta_nightly-0.5.1.dev20241026104101.dist-info}/entry_points.txt +0 -0
letta/metadata.py
CHANGED
|
@@ -14,8 +14,6 @@ from sqlalchemy import (
|
|
|
14
14
|
Integer,
|
|
15
15
|
String,
|
|
16
16
|
TypeDecorator,
|
|
17
|
-
asc,
|
|
18
|
-
or_,
|
|
19
17
|
)
|
|
20
18
|
from sqlalchemy.sql import func
|
|
21
19
|
|
|
@@ -32,7 +30,6 @@ from letta.schemas.llm_config import LLMConfig
|
|
|
32
30
|
from letta.schemas.memory import Memory
|
|
33
31
|
from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction
|
|
34
32
|
from letta.schemas.source import Source
|
|
35
|
-
from letta.schemas.tool import Tool
|
|
36
33
|
from letta.schemas.user import User
|
|
37
34
|
from letta.settings import settings
|
|
38
35
|
from letta.utils import enforce_types, get_utc_time, printd
|
|
@@ -309,9 +306,9 @@ class BlockModel(Base):
|
|
|
309
306
|
id = Column(String, primary_key=True, nullable=False)
|
|
310
307
|
value = Column(String, nullable=False)
|
|
311
308
|
limit = Column(BIGINT)
|
|
312
|
-
name = Column(String
|
|
309
|
+
name = Column(String)
|
|
313
310
|
template = Column(Boolean, default=False) # True: listed as possible human/persona
|
|
314
|
-
label = Column(String)
|
|
311
|
+
label = Column(String, nullable=False)
|
|
315
312
|
metadata_ = Column(JSON)
|
|
316
313
|
description = Column(String)
|
|
317
314
|
user_id = Column(String)
|
|
@@ -359,37 +356,6 @@ class BlockModel(Base):
|
|
|
359
356
|
)
|
|
360
357
|
|
|
361
358
|
|
|
362
|
-
class ToolModel(Base):
|
|
363
|
-
__tablename__ = "tools"
|
|
364
|
-
__table_args__ = {"extend_existing": True}
|
|
365
|
-
|
|
366
|
-
id = Column(String, primary_key=True)
|
|
367
|
-
name = Column(String, nullable=False)
|
|
368
|
-
user_id = Column(String)
|
|
369
|
-
description = Column(String)
|
|
370
|
-
source_type = Column(String)
|
|
371
|
-
source_code = Column(String)
|
|
372
|
-
json_schema = Column(JSON)
|
|
373
|
-
module = Column(String)
|
|
374
|
-
tags = Column(JSON)
|
|
375
|
-
|
|
376
|
-
def __repr__(self) -> str:
|
|
377
|
-
return f"<Tool(id='{self.id}', name='{self.name}')>"
|
|
378
|
-
|
|
379
|
-
def to_record(self) -> Tool:
|
|
380
|
-
return Tool(
|
|
381
|
-
id=self.id,
|
|
382
|
-
name=self.name,
|
|
383
|
-
user_id=self.user_id,
|
|
384
|
-
description=self.description,
|
|
385
|
-
source_type=self.source_type,
|
|
386
|
-
source_code=self.source_code,
|
|
387
|
-
json_schema=self.json_schema,
|
|
388
|
-
module=self.module,
|
|
389
|
-
tags=self.tags,
|
|
390
|
-
)
|
|
391
|
-
|
|
392
|
-
|
|
393
359
|
class JobModel(Base):
|
|
394
360
|
__tablename__ = "jobs"
|
|
395
361
|
__table_args__ = {"extend_existing": True}
|
|
@@ -516,14 +482,6 @@ class MetadataStore:
|
|
|
516
482
|
session.add(BlockModel(**vars(block)))
|
|
517
483
|
session.commit()
|
|
518
484
|
|
|
519
|
-
@enforce_types
|
|
520
|
-
def create_tool(self, tool: Tool):
|
|
521
|
-
with self.session_maker() as session:
|
|
522
|
-
if self.get_tool(tool_id=tool.id, tool_name=tool.name, user_id=tool.user_id) is not None:
|
|
523
|
-
raise ValueError(f"Tool with name {tool.name} already exists")
|
|
524
|
-
session.add(ToolModel(**vars(tool)))
|
|
525
|
-
session.commit()
|
|
526
|
-
|
|
527
485
|
@enforce_types
|
|
528
486
|
def update_agent(self, agent: AgentState):
|
|
529
487
|
with self.session_maker() as session:
|
|
@@ -556,18 +514,6 @@ class MetadataStore:
|
|
|
556
514
|
session.add(BlockModel(**vars(block)))
|
|
557
515
|
session.commit()
|
|
558
516
|
|
|
559
|
-
@enforce_types
|
|
560
|
-
def update_tool(self, tool_id: str, tool: Tool):
|
|
561
|
-
with self.session_maker() as session:
|
|
562
|
-
session.query(ToolModel).filter(ToolModel.id == tool_id).update(vars(tool))
|
|
563
|
-
session.commit()
|
|
564
|
-
|
|
565
|
-
@enforce_types
|
|
566
|
-
def delete_tool(self, tool_id: str):
|
|
567
|
-
with self.session_maker() as session:
|
|
568
|
-
session.query(ToolModel).filter(ToolModel.id == tool_id).delete()
|
|
569
|
-
session.commit()
|
|
570
|
-
|
|
571
517
|
@enforce_types
|
|
572
518
|
def delete_file_from_source(self, source_id: str, file_id: str, user_id: Optional[str]):
|
|
573
519
|
with self.session_maker() as session:
|
|
@@ -612,23 +558,6 @@ class MetadataStore:
|
|
|
612
558
|
|
|
613
559
|
session.commit()
|
|
614
560
|
|
|
615
|
-
@enforce_types
|
|
616
|
-
def list_tools(self, cursor: Optional[str] = None, limit: Optional[int] = 50, user_id: Optional[str] = None) -> List[ToolModel]:
|
|
617
|
-
with self.session_maker() as session:
|
|
618
|
-
# Query for public tools or user-specific tools
|
|
619
|
-
query = session.query(ToolModel).filter(or_(ToolModel.user_id == None, ToolModel.user_id == user_id))
|
|
620
|
-
|
|
621
|
-
# Apply cursor if provided (assuming cursor is an ID)
|
|
622
|
-
if cursor:
|
|
623
|
-
query = query.filter(ToolModel.id > cursor)
|
|
624
|
-
|
|
625
|
-
# Order by ID and apply limit
|
|
626
|
-
results = query.order_by(asc(ToolModel.id)).limit(limit).all()
|
|
627
|
-
|
|
628
|
-
# Convert to records
|
|
629
|
-
res = [r.to_record() for r in results]
|
|
630
|
-
return res
|
|
631
|
-
|
|
632
561
|
@enforce_types
|
|
633
562
|
def list_agents(self, user_id: str) -> List[AgentState]:
|
|
634
563
|
with self.session_maker() as session:
|
|
@@ -672,32 +601,6 @@ class MetadataStore:
|
|
|
672
601
|
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
|
|
673
602
|
return results[0].to_record()
|
|
674
603
|
|
|
675
|
-
@enforce_types
|
|
676
|
-
def get_tool(
|
|
677
|
-
self, tool_name: Optional[str] = None, tool_id: Optional[str] = None, user_id: Optional[str] = None
|
|
678
|
-
) -> Optional[ToolModel]:
|
|
679
|
-
with self.session_maker() as session:
|
|
680
|
-
if tool_id:
|
|
681
|
-
results = session.query(ToolModel).filter(ToolModel.id == tool_id).all()
|
|
682
|
-
else:
|
|
683
|
-
assert tool_name is not None
|
|
684
|
-
results = session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == None).all()
|
|
685
|
-
if user_id:
|
|
686
|
-
results += session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == user_id).all()
|
|
687
|
-
if len(results) == 0:
|
|
688
|
-
return None
|
|
689
|
-
# assert len(results) == 1, f"Expected 1 result, got {len(results)}"
|
|
690
|
-
return results[0].to_record()
|
|
691
|
-
|
|
692
|
-
@enforce_types
|
|
693
|
-
def get_tool_with_name_and_user_id(self, tool_name: Optional[str] = None, user_id: Optional[str] = None) -> Optional[ToolModel]:
|
|
694
|
-
with self.session_maker() as session:
|
|
695
|
-
results = session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == user_id).all()
|
|
696
|
-
if len(results) == 0:
|
|
697
|
-
return None
|
|
698
|
-
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
|
|
699
|
-
return results[0].to_record()
|
|
700
|
-
|
|
701
604
|
@enforce_types
|
|
702
605
|
def get_block(self, block_id: str) -> Optional[Block]:
|
|
703
606
|
with self.session_maker() as session:
|
letta/o1_agent.py
CHANGED
|
@@ -10,7 +10,7 @@ from letta.schemas.tool import Tool
|
|
|
10
10
|
from letta.schemas.usage import LettaUsageStatistics
|
|
11
11
|
|
|
12
12
|
|
|
13
|
-
def send_thinking_message(self: Agent, message: str) -> Optional[str]:
|
|
13
|
+
def send_thinking_message(self: "Agent", message: str) -> Optional[str]:
|
|
14
14
|
"""
|
|
15
15
|
Sends a thinking message so that the model can reason out loud before responding.
|
|
16
16
|
|
|
@@ -24,7 +24,7 @@ def send_thinking_message(self: Agent, message: str) -> Optional[str]:
|
|
|
24
24
|
return None
|
|
25
25
|
|
|
26
26
|
|
|
27
|
-
def send_final_message(self: Agent, message: str) -> Optional[str]:
|
|
27
|
+
def send_final_message(self: "Agent", message: str) -> Optional[str]:
|
|
28
28
|
"""
|
|
29
29
|
Sends a final message to the human user after thinking for a while.
|
|
30
30
|
|
letta/orm/__all__.py
CHANGED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""__all__ acts as manual import management to avoid collisions and circular imports."""
|
|
2
|
+
|
|
3
|
+
# from letta.orm.agent import Agent
|
|
4
|
+
# from letta.orm.users_agents import UsersAgents
|
|
5
|
+
# from letta.orm.blocks_agents import BlocksAgents
|
|
6
|
+
# from letta.orm.token import Token
|
|
7
|
+
# from letta.orm.source import Source
|
|
8
|
+
# from letta.orm.document import Document
|
|
9
|
+
# from letta.orm.passage import Passage
|
|
10
|
+
# from letta.orm.memory_templates import MemoryTemplate, HumanMemoryTemplate, PersonaMemoryTemplate
|
|
11
|
+
# from letta.orm.sources_agents import SourcesAgents
|
|
12
|
+
# from letta.orm.tools_agents import ToolsAgents
|
|
13
|
+
# from letta.orm.job import Job
|
|
14
|
+
# from letta.orm.block import Block
|
|
15
|
+
# from letta.orm.message import Message
|
letta/orm/mixins.py
CHANGED
|
@@ -55,7 +55,6 @@ class OrganizationMixin(Base):
|
|
|
55
55
|
|
|
56
56
|
__abstract__ = True
|
|
57
57
|
|
|
58
|
-
# Changed _organization_id to store string (still a valid UUID4 string)
|
|
59
58
|
_organization_id: Mapped[str] = mapped_column(String, ForeignKey("organization._id"))
|
|
60
59
|
|
|
61
60
|
@property
|
|
@@ -65,3 +64,19 @@ class OrganizationMixin(Base):
|
|
|
65
64
|
@organization_id.setter
|
|
66
65
|
def organization_id(self, value: str) -> None:
|
|
67
66
|
_relation_setter(self, "organization", value)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class UserMixin(Base):
|
|
70
|
+
"""Mixin for models that belong to a user."""
|
|
71
|
+
|
|
72
|
+
__abstract__ = True
|
|
73
|
+
|
|
74
|
+
_user_id: Mapped[str] = mapped_column(String, ForeignKey("user._id"))
|
|
75
|
+
|
|
76
|
+
@property
|
|
77
|
+
def user_id(self) -> str:
|
|
78
|
+
return _relation_getter(self, "user")
|
|
79
|
+
|
|
80
|
+
@user_id.setter
|
|
81
|
+
def user_id(self, value: str) -> None:
|
|
82
|
+
_relation_setter(self, "user", value)
|
letta/orm/organization.py
CHANGED
|
@@ -7,6 +7,7 @@ from letta.schemas.organization import Organization as PydanticOrganization
|
|
|
7
7
|
|
|
8
8
|
if TYPE_CHECKING:
|
|
9
9
|
|
|
10
|
+
from letta.orm.tool import Tool
|
|
10
11
|
from letta.orm.user import User
|
|
11
12
|
|
|
12
13
|
|
|
@@ -19,6 +20,7 @@ class Organization(SqlalchemyBase):
|
|
|
19
20
|
name: Mapped[str] = mapped_column(doc="The display name of the organization.")
|
|
20
21
|
|
|
21
22
|
users: Mapped[List["User"]] = relationship("User", back_populates="organization", cascade="all, delete-orphan")
|
|
23
|
+
tools: Mapped[List["Tool"]] = relationship("Tool", back_populates="organization", cascade="all, delete-orphan")
|
|
22
24
|
|
|
23
25
|
# TODO: Map these relationships later when we actually make these models
|
|
24
26
|
# below is just a suggestion
|
letta/orm/sqlalchemy_base.py
CHANGED
|
@@ -184,21 +184,20 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
|
184
184
|
logger.warning("to_record is deprecated, use to_pydantic instead.")
|
|
185
185
|
return self.to_pydantic()
|
|
186
186
|
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
# self._organization_id = created_by._organization_id
|
|
187
|
+
def _infer_organization(self, db_session: "Session") -> None:
|
|
188
|
+
"""🪄 MAGIC ALERT! 🪄
|
|
189
|
+
Because so much of the original API is centered around user scopes,
|
|
190
|
+
this allows us to continue with that scope and then infer the org from the creating user.
|
|
191
|
+
|
|
192
|
+
IF a created_by_id is set, we will use that to infer the organization and magic set it at create time!
|
|
193
|
+
If not do nothing to the object. Mutates in place.
|
|
194
|
+
"""
|
|
195
|
+
if self.created_by_id and hasattr(self, "_organization_id"):
|
|
196
|
+
try:
|
|
197
|
+
from letta.orm.user import User # to avoid circular import
|
|
198
|
+
|
|
199
|
+
created_by = User.read(db_session, self.created_by_id)
|
|
200
|
+
except NoResultFound:
|
|
201
|
+
logger.warning(f"User {self.created_by_id} not found, unable to infer organization.")
|
|
202
|
+
return
|
|
203
|
+
self._organization_id = created_by._organization_id
|
letta/orm/tool.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, List, Optional
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import JSON, String, UniqueConstraint
|
|
4
|
+
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
|
5
|
+
|
|
6
|
+
# TODO everything in functions should live in this model
|
|
7
|
+
from letta.orm.enums import ToolSourceType
|
|
8
|
+
from letta.orm.mixins import OrganizationMixin, UserMixin
|
|
9
|
+
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
|
10
|
+
from letta.schemas.tool import Tool as PydanticTool
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
pass
|
|
14
|
+
|
|
15
|
+
from letta.orm.organization import Organization
|
|
16
|
+
from letta.orm.user import User
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Tool(SqlalchemyBase, OrganizationMixin, UserMixin):
|
|
20
|
+
"""Represents an available tool that the LLM can invoke.
|
|
21
|
+
|
|
22
|
+
NOTE: polymorphic inheritance makes more sense here as a TODO. We want a superset of tools
|
|
23
|
+
that are always available, and a subset scoped to the organization. Alternatively, we could use the apply_access_predicate to build
|
|
24
|
+
more granular permissions.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
__tablename__ = "tool"
|
|
28
|
+
__pydantic_model__ = PydanticTool
|
|
29
|
+
|
|
30
|
+
# Add unique constraint on (name, _organization_id)
|
|
31
|
+
# An organization should not have multiple tools with the same name
|
|
32
|
+
__table_args__ = (
|
|
33
|
+
UniqueConstraint("name", "_organization_id", name="uix_name_organization"),
|
|
34
|
+
UniqueConstraint("name", "_user_id", name="uix_name_user"),
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
name: Mapped[str] = mapped_column(doc="The display name of the tool.")
|
|
38
|
+
description: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The description of the tool.")
|
|
39
|
+
tags: Mapped[List] = mapped_column(JSON, doc="Metadata tags used to filter tools.")
|
|
40
|
+
source_type: Mapped[ToolSourceType] = mapped_column(String, doc="The type of the source code.", default=ToolSourceType.json)
|
|
41
|
+
source_code: Mapped[Optional[str]] = mapped_column(String, doc="The source code of the function.")
|
|
42
|
+
json_schema: Mapped[dict] = mapped_column(JSON, default=lambda: {}, doc="The OAI compatable JSON schema of the function.")
|
|
43
|
+
module: Mapped[Optional[str]] = mapped_column(
|
|
44
|
+
String, nullable=True, doc="the module path from which this tool was derived in the codebase."
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
# TODO: add terminal here eventually
|
|
48
|
+
# This was an intentional decision by Sarah
|
|
49
|
+
|
|
50
|
+
# relationships
|
|
51
|
+
# TODO: Possibly add in user in the future
|
|
52
|
+
# This will require some more thought and justification to add this in.
|
|
53
|
+
user: Mapped["User"] = relationship("User", back_populates="tools", lazy="selectin")
|
|
54
|
+
organization: Mapped["Organization"] = relationship("Organization", back_populates="tools", lazy="selectin")
|
letta/orm/user.py
CHANGED
|
@@ -1,10 +1,15 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, List
|
|
2
|
+
|
|
1
3
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
|
2
4
|
|
|
3
5
|
from letta.orm.mixins import OrganizationMixin
|
|
4
|
-
from letta.orm.organization import Organization
|
|
5
6
|
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
|
6
7
|
from letta.schemas.user import User as PydanticUser
|
|
7
8
|
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from letta.orm.organization import Organization
|
|
11
|
+
from letta.orm.tool import Tool
|
|
12
|
+
|
|
8
13
|
|
|
9
14
|
class User(SqlalchemyBase, OrganizationMixin):
|
|
10
15
|
"""User ORM class"""
|
|
@@ -16,6 +21,7 @@ class User(SqlalchemyBase, OrganizationMixin):
|
|
|
16
21
|
|
|
17
22
|
# relationships
|
|
18
23
|
organization: Mapped["Organization"] = relationship("Organization", back_populates="users")
|
|
24
|
+
tools: Mapped[List["Tool"]] = relationship("Tool", back_populates="user", cascade="all, delete-orphan")
|
|
19
25
|
|
|
20
26
|
# TODO: Add this back later potentially
|
|
21
27
|
# agents: Mapped[List["Agent"]] = relationship(
|
letta/schemas/block.py
CHANGED
|
@@ -17,11 +17,14 @@ class BaseBlock(LettaBase, validate_assignment=True):
|
|
|
17
17
|
value: Optional[str] = Field(None, description="Value of the block.")
|
|
18
18
|
limit: int = Field(2000, description="Character limit of the block.")
|
|
19
19
|
|
|
20
|
-
|
|
20
|
+
# template data (optional)
|
|
21
|
+
name: Optional[str] = Field(None, description="Name of the block if it is a template.")
|
|
21
22
|
template: bool = Field(False, description="Whether the block is a template (e.g. saved human/persona options).")
|
|
22
|
-
label: Optional[str] = Field(None, description="Label of the block (e.g. 'human', 'persona').")
|
|
23
23
|
|
|
24
|
-
#
|
|
24
|
+
# context window label
|
|
25
|
+
label: str = Field(None, description="Label of the block (e.g. 'human', 'persona') in the context window.")
|
|
26
|
+
|
|
27
|
+
# metadata
|
|
25
28
|
description: Optional[str] = Field(None, description="Description of the block.")
|
|
26
29
|
metadata_: Optional[dict] = Field({}, description="Metadata of the block.")
|
|
27
30
|
|
|
@@ -39,12 +42,6 @@ class BaseBlock(LettaBase, validate_assignment=True):
|
|
|
39
42
|
raise e
|
|
40
43
|
return self
|
|
41
44
|
|
|
42
|
-
@model_validator(mode="after")
|
|
43
|
-
def ensure_label(self) -> Self:
|
|
44
|
-
if not self.label:
|
|
45
|
-
self.label = self.name
|
|
46
|
-
return self
|
|
47
|
-
|
|
48
45
|
def __len__(self):
|
|
49
46
|
return len(self.value)
|
|
50
47
|
|
letta/schemas/memory.py
CHANGED
|
@@ -61,7 +61,7 @@ class Memory(BaseModel, validate_assignment=True):
|
|
|
61
61
|
|
|
62
62
|
"""
|
|
63
63
|
|
|
64
|
-
# Memory.memory is a dict mapping from memory block
|
|
64
|
+
# Memory.memory is a dict mapping from memory block label to memory block.
|
|
65
65
|
memory: Dict[str, Block] = Field(default_factory=dict, description="Mapping from memory block section to memory block.")
|
|
66
66
|
|
|
67
67
|
# Memory.template is a Jinja2 template for compiling memory module into a prompt string.
|
|
@@ -126,44 +126,42 @@ class Memory(BaseModel, validate_assignment=True):
|
|
|
126
126
|
}
|
|
127
127
|
|
|
128
128
|
def to_flat_dict(self):
|
|
129
|
-
"""Convert to a dictionary that maps directly from block
|
|
129
|
+
"""Convert to a dictionary that maps directly from block label to values"""
|
|
130
130
|
return {k: v.value for k, v in self.memory.items() if v is not None}
|
|
131
131
|
|
|
132
|
-
def
|
|
132
|
+
def list_block_labels(self) -> List[str]:
|
|
133
133
|
"""Return a list of the block names held inside the memory object"""
|
|
134
134
|
return list(self.memory.keys())
|
|
135
135
|
|
|
136
136
|
# TODO: these should actually be label, not name
|
|
137
|
-
def get_block(self,
|
|
137
|
+
def get_block(self, label: str) -> Block:
|
|
138
138
|
"""Correct way to index into the memory.memory field, returns a Block"""
|
|
139
|
-
if
|
|
140
|
-
raise KeyError(f"Block field {
|
|
139
|
+
if label not in self.memory:
|
|
140
|
+
raise KeyError(f"Block field {label} does not exist (available sections = {', '.join(list(self.memory.keys()))})")
|
|
141
141
|
else:
|
|
142
|
-
return self.memory[
|
|
142
|
+
return self.memory[label]
|
|
143
143
|
|
|
144
144
|
def get_blocks(self) -> List[Block]:
|
|
145
145
|
"""Return a list of the blocks held inside the memory object"""
|
|
146
146
|
return list(self.memory.values())
|
|
147
147
|
|
|
148
|
-
def link_block(self,
|
|
148
|
+
def link_block(self, block: Block, override: Optional[bool] = False):
|
|
149
149
|
"""Link a new block to the memory object"""
|
|
150
150
|
if not isinstance(block, Block):
|
|
151
151
|
raise ValueError(f"Param block must be type Block (not {type(block)})")
|
|
152
|
-
if not
|
|
153
|
-
raise ValueError(f"
|
|
154
|
-
if not override and name in self.memory:
|
|
155
|
-
raise ValueError(f"Block with name {name} already exists")
|
|
152
|
+
if not override and block.label in self.memory:
|
|
153
|
+
raise ValueError(f"Block with label {block.label} already exists")
|
|
156
154
|
|
|
157
|
-
self.memory[
|
|
155
|
+
self.memory[block.label] = block
|
|
158
156
|
|
|
159
|
-
def update_block_value(self,
|
|
157
|
+
def update_block_value(self, label: str, value: str):
|
|
160
158
|
"""Update the value of a block"""
|
|
161
|
-
if
|
|
162
|
-
raise ValueError(f"Block with
|
|
159
|
+
if label not in self.memory:
|
|
160
|
+
raise ValueError(f"Block with label {label} does not exist")
|
|
163
161
|
if not isinstance(value, str):
|
|
164
162
|
raise ValueError(f"Provided value must be a string")
|
|
165
163
|
|
|
166
|
-
self.memory[
|
|
164
|
+
self.memory[label].value = value
|
|
167
165
|
|
|
168
166
|
|
|
169
167
|
# TODO: ideally this is refactored into ChatMemory and the subclasses are given more specific names.
|
|
@@ -192,41 +190,41 @@ class BasicBlockMemory(Memory):
|
|
|
192
190
|
# assert block.name is not None and block.name != "", "each existing chat block must have a name"
|
|
193
191
|
# self.link_block(name=block.name, block=block)
|
|
194
192
|
assert block.label is not None and block.label != "", "each existing chat block must have a name"
|
|
195
|
-
self.link_block(
|
|
193
|
+
self.link_block(block=block)
|
|
196
194
|
|
|
197
|
-
def core_memory_append(self: "Agent",
|
|
195
|
+
def core_memory_append(self: "Agent", label: str, content: str) -> Optional[str]: # type: ignore
|
|
198
196
|
"""
|
|
199
197
|
Append to the contents of core memory.
|
|
200
198
|
|
|
201
199
|
Args:
|
|
202
|
-
|
|
200
|
+
label (str): Section of the memory to be edited (persona or human).
|
|
203
201
|
content (str): Content to write to the memory. All unicode (including emojis) are supported.
|
|
204
202
|
|
|
205
203
|
Returns:
|
|
206
204
|
Optional[str]: None is always returned as this function does not produce a response.
|
|
207
205
|
"""
|
|
208
|
-
current_value = str(self.memory.get_block(
|
|
206
|
+
current_value = str(self.memory.get_block(label).value)
|
|
209
207
|
new_value = current_value + "\n" + str(content)
|
|
210
|
-
self.memory.update_block_value(
|
|
208
|
+
self.memory.update_block_value(label=label, value=new_value)
|
|
211
209
|
return None
|
|
212
210
|
|
|
213
|
-
def core_memory_replace(self: "Agent",
|
|
211
|
+
def core_memory_replace(self: "Agent", label: str, old_content: str, new_content: str) -> Optional[str]: # type: ignore
|
|
214
212
|
"""
|
|
215
213
|
Replace the contents of core memory. To delete memories, use an empty string for new_content.
|
|
216
214
|
|
|
217
215
|
Args:
|
|
218
|
-
|
|
216
|
+
label (str): Section of the memory to be edited (persona or human).
|
|
219
217
|
old_content (str): String to replace. Must be an exact match.
|
|
220
218
|
new_content (str): Content to write to the memory. All unicode (including emojis) are supported.
|
|
221
219
|
|
|
222
220
|
Returns:
|
|
223
221
|
Optional[str]: None is always returned as this function does not produce a response.
|
|
224
222
|
"""
|
|
225
|
-
current_value = str(self.memory.get_block(
|
|
223
|
+
current_value = str(self.memory.get_block(label).value)
|
|
226
224
|
if old_content not in current_value:
|
|
227
|
-
raise ValueError(f"Old content '{old_content}' not found in memory block '{
|
|
225
|
+
raise ValueError(f"Old content '{old_content}' not found in memory block '{label}'")
|
|
228
226
|
new_value = current_value.replace(str(old_content), str(new_content))
|
|
229
|
-
self.memory.update_block_value(
|
|
227
|
+
self.memory.update_block_value(label=label, value=new_value)
|
|
230
228
|
return None
|
|
231
229
|
|
|
232
230
|
|
|
@@ -245,8 +243,8 @@ class ChatMemory(BasicBlockMemory):
|
|
|
245
243
|
limit (int): The character limit for each block.
|
|
246
244
|
"""
|
|
247
245
|
super().__init__()
|
|
248
|
-
self.link_block(
|
|
249
|
-
self.link_block(
|
|
246
|
+
self.link_block(block=Block(value=persona, limit=limit, label="persona"))
|
|
247
|
+
self.link_block(block=Block(value=human, limit=limit, label="human"))
|
|
250
248
|
|
|
251
249
|
|
|
252
250
|
class UpdateMemory(BaseModel):
|