"""Contains long term memory management functionality."""
import asyncio
import os
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Optional
import nltk
from langchain.docstore.document import Document
from langchain_community.retrievers import PineconeHybridSearchRetriever
from pinecone import Pinecone, ServerlessSpec
from pinecone_text.hybrid import hybrid_convex_scale
from pinecone_text.sparse import BM25Encoder
from bobbot.agents.llms import openai_embeddings
from bobbot.utils import get_logger, is_playwright_browser_open, on_heroku
PARAMS_PATH = "local/bm25_params.json"
Path(PARAMS_PATH).parent.mkdir(parents=True, exist_ok=True)
logger = get_logger(__name__)
def create_bm25_encoder() -> "BM25Encoder":
"""Create a default BM25 model from the MS MARCO passages corpus, or restore from local cache."""
bm25 = BM25Encoder()
if not Path(PARAMS_PATH).exists():
bm25 = BM25Encoder().default() # Default tf-idf values
print()
bm25.dump(str(Path(PARAMS_PATH)))
else:
bm25.load(str(Path(PARAMS_PATH)))
return bm25
nltk.download("punkt_tab", quiet=True)
# Create index if it doesn't exist
pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
index_name = "bob-bot"
if index_name not in pc.list_indexes().names():
pc.create_index(
name=index_name,
dimension=1536, # dimensionality of dense model
metric="dotproduct", # sparse values supported only for dotproduct
spec=ServerlessSpec(cloud="aws", region="us-east-1"),
)
index = pc.Index(index_name)
sparse_encoder = create_bm25_encoder()
sparse_usage_count = 0
retriever = PineconeHybridSearchRetriever(embeddings=openai_embeddings, sparse_encoder=sparse_encoder, index=index)
if on_heroku():
del sparse_encoder # Free up memory
sparse_encoder = None
retriever.sparse_encoder = None
async def add_tool_memories(texts: list[str], chain_id: str, response: str) -> None:
"""Add tool usage results to Bob's long term memory.
Args:
texts: The texts representing inputs/outputs of each individual tool.
chain_id: The ID of the chain in which the tools were used.
response: Bob's final response after using the tools.
"""
if on_heroku() and is_playwright_browser_open():
logger.warning("(Heroku) Skipping tool memory saving due to open Playwright browser.")
return
global sparse_usage_count
curr_time = datetime.now(timezone.utc).timestamp()
metadatas = []
for _ in texts:
metadatas.append(
{
"creation_time": curr_time,
"last_retrieval_time": curr_time,
"type": "tool",
"chain_id": chain_id,
"response": response,
"version": 1,
}
)
sparse_usage_count += 1
if retriever.sparse_encoder is None:
retriever.sparse_encoder = create_bm25_encoder()
await asyncio.to_thread(retriever.add_texts, texts, metadatas=metadatas)
sparse_usage_count -= 1
logger.info("Saved tool memories for this chain.")
async def add_chat_memory(text: str, message_ids: list[int]) -> None:
"""Add a given message history to Bob's long term memory.
Args:
text: The message history.
message_ids: The message IDs of the messages in the history.
"""
if on_heroku() and is_playwright_browser_open():
logger.warning("(Heroku) Skipping chat memory saving due to open Playwright browser.")
return
global sparse_usage_count
curr_time = datetime.now(timezone.utc).timestamp()
metadatas = [
{
"creation_time": curr_time,
"last_retrieval_time": curr_time,
"type": "chat",
"message_ids": [str(id) for id in message_ids],
"version": 1,
}
]
sparse_usage_count += 1
if retriever.sparse_encoder is None:
retriever.sparse_encoder = create_bm25_encoder()
await asyncio.to_thread(retriever.add_texts, [text], metadatas=metadatas)
sparse_usage_count -= 1
if on_heroku() and sparse_usage_count == 0:
del retriever.sparse_encoder # Free up memory
retriever.sparse_encoder = None
logger.info("Saved chat memory.")
async def delete_memory(id: str) -> bool:
"""Delete a memory from Bob's long term memory.
Args:
id: The ID of the memory to delete.
Returns:
True if the memory was deleted, False otherwise.
"""
try:
result = index.fetch(ids=[id])
if len(result["vectors"]) == 0:
logger.info(f"Memory with ID {id} does not exist.")
return False
index.delete(ids=[id])
except Exception:
logger.exception(f"Error deleting memory with ID {id}.")
return False
logger.info(f"Deleted memory with ID {id}.")
return True
async def get_relevant_documents(query: str, query_filter: Optional[dict] = None) -> list[Document]:
"""Get relevant documents from Bob's long term memory using hybrid search.
Adapted from PineconeHybridSearchRetriever's _get_relevant_documents method.
"""
# Convert the question into a dense vector
dense_vec = retriever.embeddings.embed_query(query)
if on_heroku():
# Not enough memory to use sparse
result = await asyncio.to_thread(
retriever.index.query,
vector=dense_vec,
top_k=retriever.top_k,
include_metadata=True,
namespace=retriever.namespace,
filter=query_filter,
)
else:
sparse_vec = retriever.sparse_encoder.encode_queries(query)
# Scale alpha with hybrid_scale
dense_vec, sparse_vec = hybrid_convex_scale(dense_vec, sparse_vec, retriever.alpha)
sparse_vec["values"] = [float(s1) for s1 in sparse_vec["values"]]
# Query Pinecone index
result = await asyncio.to_thread(
retriever.index.query,
vector=dense_vec,
sparse_vector=sparse_vec if sparse_vec["indices"] else None,
top_k=retriever.top_k,
include_metadata=True,
namespace=retriever.namespace,
filter=query_filter,
)
# print(result["usage"])
final_result = []
for res in result["matches"]:
context = res["metadata"].pop("context")
metadata = res["metadata"]
metadata["id"] = res["id"]
metadata["score"] = res["score"]
final_result.append(Document(page_content=context, metadata=res["metadata"]))
return final_result
[docs]
async def query_memories(
query: str,
limit: int = 4,
age_limit: Optional[timedelta] = None,
ignore_recent: bool = True,
only_tools: bool = False,
) -> list[Document]:
"""Search Bob's long term memory for relevant, recent memories.
Most memories should not exceed ~5000 characters in length, but this is still a lot, and is also not guaranteed.
Truncate all memories before giving them to an LLM!
Args:
query: The query to search for.
limit: The maximum number of memories to retrieve.
age_limit: The maximum age of the memories to retrieve. If None, all memories are considered.
ignore_recent: Whether to ignore recent memories (within 1 minute old) when retrieving.
only_tools: Whether to only retrieve tool memories.
Returns:
The list of retrieved memories, sorted by relevance.
"""
query_filter = {}
# Threshold datetimes
if ignore_recent:
recent_time = (datetime.now(timezone.utc) - timedelta(minutes=1)).timestamp()
query_filter.setdefault("creation_time", {})
query_filter["creation_time"]["$lt"] = recent_time
if age_limit is not None:
old_time = (datetime.now(timezone.utc) - age_limit).timestamp()
query_filter.setdefault("creation_time", {})
query_filter["creation_time"]["$gt"] = old_time
# Tool memories only
if only_tools:
query_filter["type"] = {"$eq": "tool"}
# Empty filter
if not query_filter:
query_filter = None
# Retrieve relevant documents
retriever.top_k = limit
try:
results = await get_relevant_documents(query, query_filter=query_filter)
except Exception:
logger.exception("Error querying long term memory")
return []
logger.debug(
f"Long term memory query with query={query}, limit={limit}, age_limit={age_limit}, ignore_recent={ignore_recent}, only_tools={only_tools} -> {[f'{doc.metadata["id"][:16]}...' for doc in results]}" # noqa: E501
)
# Update last retrieval times concurrently
curr_time = datetime.now(timezone.utc).timestamp()
tasks = [
asyncio.to_thread(
retriever.index.update, id=result.metadata["id"], set_metadata={"last_retrieval_time": curr_time}
)
for result in results
]
await asyncio.gather(*tasks)
return results
def is_sparse_encoder_loaded() -> bool:
"""Check if the sparse encoder is loaded."""
return retriever.sparse_encoder is not None