"""Contains the TextChannelHistory class for tracking the history of a text channel."""
import pprint
import re
from dataclasses import dataclass
from datetime import datetime, timezone
from logging import Logger
from typing import Callable
import discord
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from bobbot.utils import get_images_in, get_logger, time_elapsed_str, truncate_length
logger: Logger = get_logger(__name__)
[docs]
def get_users_in_channel(channel: discord.DMChannel | discord.TextChannel) -> list[discord.User]:
"""Get a list of all users in a Discord channel."""
if isinstance(channel, discord.DMChannel):
users = channel.recipients + [channel.me]
else:
users = channel.members
return users
def pings_to_usernames(content: str, channel: discord.TextChannel) -> str:
"""Replace raw ID mentions with username mentions in the given content."""
user_mention_pattern = re.compile(r"<@!?(\d+)>")
# Find all matches of user mentions in the content
matches = user_mention_pattern.findall(content)
all_users: list[discord.User] = get_users_in_channel(channel)
for user_id in matches:
# Check if the user is in the channel
if (user := discord.utils.get(all_users, id=int(user_id))) is not None:
# Replace the mention with the member's display name
content = re.sub(f"<@!?{user_id}>", f"@{user.display_name}", content)
return content
def get_full_content(message: discord.Message) -> str:
"""Get the full content of a message, including attachment URLs and sticker names."""
# Replace user mentions with display names
content: str = pings_to_usernames(message.content, message.channel)
# Add stickers and attachments
stickers: str = ", ".join([sticker.name for sticker in message.stickers])
if stickers:
content = f"{content} {stickers}"
attachments: str = ", ".join([f"{attachment.url}" for attachment in message.attachments])
if attachments:
content = f"{content} {attachments}"
# Check for embed
if message.embeds and not content:
pass
return content
[docs]
@dataclass
class ParsedMessage:
"""Represents a parsed Discord message."""
message: discord.Message
"""The raw message."""
is_deleted: bool = False
"""Whether the message was deleted."""
@property
def id(self) -> int:
"""The message ID."""
return self.message.id
@property
def is_edited(self) -> bool:
"""Whether the message was edited."""
return self.message.edited_at is not None
@property
def author(self) -> str:
"""As a string."""
return self.message.author.display_name
@property
def content(self) -> str:
"""Text of the message."""
return self.message.content
@property
def full_content(self) -> str:
"""Includes content, attachment URLs, and stickers."""
return get_full_content(self.message)
@property
def reactions(self) -> str:
"""As a string."""
return ", ".join([f"{r.count} {r.emoji}" for r in self.message.reactions])
@property
def timestamp(self) -> str:
"""Relative to now, as a string."""
return time_elapsed_str(self.message.created_at)
@property
def context(self) -> str:
"""Whether the message is edited/deleted, plus if it's a reply/command response."""
context = "Deleted" if self.is_deleted else "Edited" if self.is_edited else ""
replying = ""
if self.message.reference and isinstance(self.message.reference.resolved, discord.Message):
# Replying to a message
old_msg: discord.Message = self.message.reference.resolved
replying = f"replying to {old_msg.author.display_name}"
elif self.message.interaction_metadata:
# Slash command response
replying = "triggered by a command"
context = f"{context}, {replying}" if context and replying else replying if replying else context
return context
def __init__(self, message: discord.Message, is_deleted: bool = False):
"""Create a message entry.
Args:
message: The message to parse.
is_deleted: Whether the message was deleted.
"""
self.message = message
self.is_deleted = is_deleted
[docs]
def as_string(
self,
with_author: bool = True,
with_context: bool = True,
with_reactions: bool = True,
with_timestamp: bool = False,
) -> str:
"""Format the message as a string.
Args:
with_author: Whether to include the author.
with_context: Whether to include context info.
with_reactions: Whether to include reactions.
with_timestamp: Whether to include a timestamp.
Returns:
The formatted message.
"""
result = ""
if with_timestamp:
result += f"[{self.timestamp}] "
if with_author:
if with_context and self.context:
result += f"{self.author} ({self.context}): "
else:
result += f"{self.author}: "
elif with_context and self.context:
result += f"({self.context}) "
result += self.full_content
if with_reactions and self.reactions:
result += f" | Reactions: {self.reactions}"
return result
def __str__(self) -> str:
"""Format the message as a full string, containing all info."""
return self.as_string(with_timestamp=True)
[docs]
class TextChannelHistory:
"""Tracks the history of a text channel, providing concise representations of messages and other events."""
MAX_CONTENT_LENS: list[int] = [4096] * 2 + [256] * 8 + [64] * 10
# MAX_CONTENT_LENS: list[int] = [128] * 2 + [64] * 2
"""The maximum lengths for message content, starting with the most recent."""
MAX_MSGS: int = len(MAX_CONTENT_LENS)
"""The maximum number of messages to track."""
channel: discord.TextChannel
"""The channel being tracked."""
is_typing: dict[discord.User, datetime]
"""Users currently typing in the channel."""
history: list[ParsedMessage]
"""Recent messages in the channel."""
message_count: int
"""Counter for the total number of messages sent in the channel."""
def __init__(self, channel: discord.TextChannel):
"""Create a tracker for the specified channel.
Args:
channel: The channel to track.
"""
self.channel = channel
self.is_typing = {}
self.history = []
self.message_count = 0
[docs]
def on_typing(self, user: discord.User, when: datetime) -> None:
"""Handle a typing event."""
self.is_typing[user] = when
[docs]
def clear_users_typing(self) -> None:
"""Clear all currently typing users."""
self.is_typing = {}
[docs]
def get_users_typing(self) -> list[discord.User]:
"""Returns a list of users currently typing."""
# Update typing status, removing entries that are >= 10 seconds old
now: datetime = datetime.now(timezone.utc)
self.is_typing = {user: when for user, when in self.is_typing.items() if (now - when).total_seconds() < 10}
return list(self.is_typing.keys())
[docs]
async def aupdate(self) -> None:
"""Update the history with the latest messages and events.
This method should be called before querying the history after every idle period.
"""
history: list[ParsedMessage] = []
old_history: list[ParsedMessage] = self.history
old_history_index: int = len(old_history) - 1
async for prev_msg in self.channel.history(limit=TextChannelHistory.MAX_MSGS): # From most to least recent
# Get allowed content length
if len(history) >= TextChannelHistory.MAX_MSGS:
break
# Check for messages that are already in history
is_old: bool = False
while old_history_index >= 0 and old_history[old_history_index].message.created_at >= prev_msg.created_at:
old_entry: ParsedMessage = old_history[old_history_index]
old_history_index -= 1
if prev_msg.id == old_entry.message.id:
is_old = True
history.append(ParsedMessage(prev_msg, is_deleted=old_entry.is_deleted))
else:
# Message was deleted
history.append(ParsedMessage(old_entry.message, is_deleted=True))
if not is_old:
self.message_count += 1
if prev_msg.content.startswith(("! reset", "! mode")):
break # Stop tracking history at the first command
elif prev_msg.content.startswith("!"):
continue # Skip other commands
history.append(ParsedMessage(prev_msg))
# Truncate history and reverse it
history = history[: TextChannelHistory.MAX_MSGS][::-1]
self.history = history
logger.info("Updated text channel history.")
[docs]
def history_to_strings(self, transform: Callable[[ParsedMessage], str], limit: int = MAX_MSGS) -> list[str]:
"""Get a list of truncated strings representing the history, up to the last limit messages.
Messages are truncated based on how recent they are, with older messages being truncated more.
Args:
transform: A function to convert a ParsedMessage to a string.
limit: The maximum number of messages to include.
Returns:
A list of strings representing the history.
"""
history_strs: list[str] = []
for i, entry in enumerate(reversed(self.history[-limit:])):
curr_str = truncate_length(transform(entry), TextChannelHistory.MAX_CONTENT_LENS[i])
history_strs.append(curr_str)
return history_strs[::-1]
[docs]
def as_parsed_messages(self, limit: int = MAX_MSGS) -> list[ParsedMessage]:
"""Get the channel's message history, up to the last limit messages."""
return self.history[-limit:]
[docs]
def as_string(
self,
limit: int = MAX_MSGS,
with_author: bool = True,
with_context: bool = True,
with_reactions: bool = True,
with_timestamp: bool = False,
) -> str:
"""Get a string representation of the history, up to the last limit messages.
Args:
limit: The maximum number of messages to include.
with_author: Whether to include the author.
with_context: Whether to include context info.
with_reactions: Whether to include reactions.
with_timestamp: Whether to include a timestamp.
Returns:
A string representing the history.
"""
result = "\n".join(
self.history_to_strings(
lambda e: e.as_string(
with_author=with_author,
with_context=with_context,
with_reactions=with_reactions,
with_timestamp=with_timestamp,
),
limit,
)
)
logger.debug(f"Text channel history as string:\n{result}")
return result
[docs]
def as_langchain_msgs(
self, bot_user: discord.User, limit: int = MAX_MSGS, get_image: bool = True
) -> list[BaseMessage]:
"""Get a list of LangChain messages representing the history, up to the last limit messages.
Args:
bot_user: The bot user. Used to distinguish AIMessages from HumanMessages.
limit: The maximum number of messages to include.
get_image: Whether to include a single image URL in the most recent message (if present).
Returns:
A list of LangChain message objects, with only HumanMessages containing author/context info.
"""
def transform(entry: ParsedMessage) -> str:
"""Bot messages should not contain author or context info."""
if entry.message.author == bot_user:
return entry.as_string(with_author=False, with_context=False, with_reactions=True, with_timestamp=False)
else:
return entry.as_string(with_author=True, with_context=True, with_reactions=True, with_timestamp=False)
history_strs: list[str] = self.history_to_strings(transform, limit)
result: list[BaseMessage] = []
for entry, text in zip(self.history[-limit:], history_strs):
if entry.message.author == bot_user:
result.append(AIMessage(content=text))
else:
# Include image in most recent message
if entry == self.history[-1]:
image_urls: list[str] = get_images_in(text)
if image_urls:
result.append(
HumanMessage(
content=[
{"type": "text", "text": text},
{"type": "image_url", "image_url": {"url": image_urls[-1]}},
]
)
)
continue
result.append(HumanMessage(content=text))
logger.debug(f"Text channel history as langchain messages:\n{pprint.pformat(result)}")
return result
[docs]
class ManualHistory:
"""Generate message histories manually. Messages sent by bob should always start with "bob: "."""
def __init__(self, history: list[str] = None) -> None:
"""Initialize the message history."""
self._history = history.copy() if history else []
[docs]
def add_message(self, message: str) -> None:
"""Add a message (with any desired context) to the message history."""
self._history.append(message)
[docs]
def limit_messages(self, limit: int) -> None:
"""Limit the number of messages in the history."""
self._history = self._history[-limit:]
[docs]
def as_string(self) -> str:
"""Return the full message history."""
return "\n".join(self._history)
[docs]
def as_langchain_msgs(self) -> list[str]:
"""Return the message history as Langchain messages."""
msgs = []
for msg in self._history:
if msg.startswith("bob: "):
msgs.append(AIMessage(msg[5:]))
else:
msgs.append(HumanMessage(msg))
return msgs
channel_history: dict[int, TextChannelHistory] = {}
[docs]
def get_channel_history(channel: discord.TextChannel) -> TextChannelHistory:
"""Get the history for a channel, creating it if it doesn't exist."""
if channel.id not in channel_history:
channel_history[channel.id] = TextChannelHistory(channel)
return channel_history[channel.id]