From 423d66b07aba293f8d4582d970bcffe4d18ba361 Mon Sep 17 00:00:00 2001 From: Alex Yang Date: Mon, 16 Sep 2024 16:09:17 -0700 Subject: [PATCH] refactor: chat memory & chat history into core module (#1201) --- examples/anthropic/chat_interactive.ts | 6 +- examples/chatHistory.ts | 4 +- packages/core/package.json | 28 ++++ packages/core/src/memory/base.ts | 62 ++++++++ .../core/src/memory/chat-memory-buffer.ts | 65 +++++++++ packages/core/src/memory/index.ts | 3 + .../src/memory/summary-memory.ts} | 137 ++++-------------- .../src/storage/chat-store/base-chat-store.ts | 19 +++ packages/core/src/storage/chat-store/index.ts | 2 + .../storage/chat-store/simple-chat-store.ts | 43 ++++++ packages/llamaindex/src/agent/base.ts | 11 +- .../chat/CondenseQuestionChatEngine.ts | 26 ++-- .../src/engines/chat/ContextChatEngine.ts | 24 +-- .../src/engines/chat/SimpleChatEngine.ts | 25 ++-- packages/llamaindex/src/engines/chat/types.ts | 4 +- packages/llamaindex/src/index.edge.ts | 2 +- .../llamaindex/src/memory/ChatMemoryBuffer.ts | 107 -------------- packages/llamaindex/src/memory/types.ts | 10 -- .../src/storage/chatStore/SimpleChatStore.ts | 65 --------- .../llamaindex/src/storage/chatStore/types.ts | 19 --- packages/llamaindex/src/storage/index.ts | 3 +- 21 files changed, 311 insertions(+), 354 deletions(-) create mode 100644 packages/core/src/memory/base.ts create mode 100644 packages/core/src/memory/chat-memory-buffer.ts create mode 100644 packages/core/src/memory/index.ts rename packages/{llamaindex/src/ChatHistory.ts => core/src/memory/summary-memory.ts} (58%) create mode 100644 packages/core/src/storage/chat-store/base-chat-store.ts create mode 100644 packages/core/src/storage/chat-store/index.ts create mode 100644 packages/core/src/storage/chat-store/simple-chat-store.ts delete mode 100644 packages/llamaindex/src/memory/ChatMemoryBuffer.ts delete mode 100644 packages/llamaindex/src/memory/types.ts delete mode 100644 packages/llamaindex/src/storage/chatStore/SimpleChatStore.ts delete mode 100644 packages/llamaindex/src/storage/chatStore/types.ts diff --git a/examples/anthropic/chat_interactive.ts b/examples/anthropic/chat_interactive.ts index 88bac3544e..4565c70e41 100644 --- a/examples/anthropic/chat_interactive.ts +++ b/examples/anthropic/chat_interactive.ts @@ -1,4 +1,4 @@ -import { Anthropic, SimpleChatEngine, SimpleChatHistory } from "llamaindex"; +import { Anthropic, ChatMemoryBuffer, SimpleChatEngine } from "llamaindex"; import { stdin as input, stdout as output } from "node:process"; import readline from "node:readline/promises"; @@ -8,8 +8,8 @@ import readline from "node:readline/promises"; model: "claude-3-opus", }); // chatHistory will store all the messages in the conversation - const chatHistory = new SimpleChatHistory({ - messages: [ + const chatHistory = new ChatMemoryBuffer({ + chatHistory: [ { content: "You want to talk in rhymes.", role: "system", diff --git a/examples/chatHistory.ts b/examples/chatHistory.ts index 388fd428cf..c55c618d69 100644 --- a/examples/chatHistory.ts +++ b/examples/chatHistory.ts @@ -2,10 +2,10 @@ import { stdin as input, stdout as output } from "node:process"; import readline from "node:readline/promises"; import { + ChatSummaryMemoryBuffer, OpenAI, Settings, SimpleChatEngine, - SummaryChatHistory, } from "llamaindex"; if (process.env.NODE_ENV === "development") { @@ -18,7 +18,7 @@ async function main() { // Set maxTokens to 75% of the context window size of 4096 // This will trigger the summarizer once the chat history reaches 25% of the context window size (1024 tokens) const llm = new OpenAI({ model: "gpt-3.5-turbo", maxTokens: 4096 * 0.75 }); - const chatHistory = new SummaryChatHistory({ llm }); + const chatHistory = new ChatSummaryMemoryBuffer({ llm }); const chatEngine = new SimpleChatEngine({ llm }); const rl = readline.createInterface({ input, output }); diff --git a/packages/core/package.json b/packages/core/package.json index fad7a8d06d..6493a568f4 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -157,6 +157,34 @@ "types": "./dist/workflow/index.d.ts", "default": "./dist/workflow/index.js" } + }, + "./memory": { + "require": { + "types": "./dist/memory/index.d.cts", + "default": "./dist/memory/index.cjs" + }, + "import": { + "types": "./dist/memory/index.d.ts", + "default": "./dist/memory/index.js" + }, + "default": { + "types": "./dist/memory/index.d.ts", + "default": "./dist/memory/index.js" + } + }, + "./storage/chat-store": { + "require": { + "types": "./dist/storage/chat-store/index.d.cts", + "default": "./dist/storage/chat-store/index.cjs" + }, + "import": { + "types": "./dist/storage/chat-store/index.d.ts", + "default": "./dist/storage/chat-store/index.js" + }, + "default": { + "types": "./dist/storage/chat-store/index.d.ts", + "default": "./dist/storage/chat-store/index.js" + } } }, "files": [ diff --git a/packages/core/src/memory/base.ts b/packages/core/src/memory/base.ts new file mode 100644 index 0000000000..0302d7d75f --- /dev/null +++ b/packages/core/src/memory/base.ts @@ -0,0 +1,62 @@ +import { Settings } from "../global"; +import type { ChatMessage, MessageContent } from "../llms"; +import { type BaseChatStore, SimpleChatStore } from "../storage/chat-store"; +import { extractText } from "../utils"; + +export const DEFAULT_TOKEN_LIMIT_RATIO = 0.75; +export const DEFAULT_CHAT_STORE_KEY = "chat_history"; + +/** + * A ChatMemory is used to keep the state of back and forth chat messages + */ +export abstract class BaseMemory< + AdditionalMessageOptions extends object = object, +> { + abstract getMessages( + input?: MessageContent | undefined, + ): + | ChatMessage[] + | Promise[]>; + abstract getAllMessages(): + | ChatMessage[] + | Promise[]>; + abstract put(messages: ChatMessage): void; + abstract reset(): void; + + protected _tokenCountForMessages(messages: ChatMessage[]): number { + if (messages.length === 0) { + return 0; + } + + const tokenizer = Settings.tokenizer; + const str = messages.map((m) => extractText(m.content)).join(" "); + return tokenizer.encode(str).length; + } +} + +export abstract class BaseChatStoreMemory< + AdditionalMessageOptions extends object = object, +> extends BaseMemory { + protected constructor( + public chatStore: BaseChatStore = new SimpleChatStore(), + public chatStoreKey: string = DEFAULT_CHAT_STORE_KEY, + ) { + super(); + } + + getAllMessages(): ChatMessage[] { + return this.chatStore.getMessages(this.chatStoreKey); + } + + put(messages: ChatMessage) { + this.chatStore.addMessage(this.chatStoreKey, messages); + } + + set(messages: ChatMessage[]) { + this.chatStore.setMessages(this.chatStoreKey, messages); + } + + reset() { + this.chatStore.deleteMessages(this.chatStoreKey); + } +} diff --git a/packages/core/src/memory/chat-memory-buffer.ts b/packages/core/src/memory/chat-memory-buffer.ts new file mode 100644 index 0000000000..c84a837c77 --- /dev/null +++ b/packages/core/src/memory/chat-memory-buffer.ts @@ -0,0 +1,65 @@ +import { Settings } from "../global"; +import type { ChatMessage, LLM, MessageContent } from "../llms"; +import { type BaseChatStore } from "../storage/chat-store"; +import { BaseChatStoreMemory, DEFAULT_TOKEN_LIMIT_RATIO } from "./base"; + +type ChatMemoryBufferOptions = + { + tokenLimit?: number | undefined; + chatStore?: BaseChatStore | undefined; + chatStoreKey?: string | undefined; + chatHistory?: ChatMessage[] | undefined; + llm?: LLM | undefined; + }; + +export class ChatMemoryBuffer< + AdditionalMessageOptions extends object = object, +> extends BaseChatStoreMemory { + tokenLimit: number; + + constructor( + options?: Partial>, + ) { + super(options?.chatStore, options?.chatStoreKey); + + const llm = options?.llm ?? Settings.llm; + const contextWindow = llm.metadata.contextWindow; + this.tokenLimit = + options?.tokenLimit ?? + Math.ceil(contextWindow * DEFAULT_TOKEN_LIMIT_RATIO); + + if (options?.chatHistory) { + this.chatStore.setMessages(this.chatStoreKey, options.chatHistory); + } + } + + getMessages( + input?: MessageContent | undefined, + initialTokenCount: number = 0, + ) { + const messages = this.getAllMessages(); + + if (initialTokenCount > this.tokenLimit) { + throw new Error("Initial token count exceeds token limit"); + } + + let messageCount = messages.length; + let currentMessages = messages.slice(-messageCount); + let tokenCount = this._tokenCountForMessages(messages) + initialTokenCount; + + while (tokenCount > this.tokenLimit && messageCount > 1) { + messageCount -= 1; + if (messages.at(-messageCount)!.role === "assistant") { + messageCount -= 1; + } + currentMessages = messages.slice(-messageCount); + tokenCount = + this._tokenCountForMessages(currentMessages) + initialTokenCount; + } + + if (tokenCount > this.tokenLimit && messageCount <= 0) { + return []; + } + return messages.slice(-messageCount); + } +} diff --git a/packages/core/src/memory/index.ts b/packages/core/src/memory/index.ts new file mode 100644 index 0000000000..bdc356e48d --- /dev/null +++ b/packages/core/src/memory/index.ts @@ -0,0 +1,3 @@ +export { BaseMemory } from "./base"; +export { ChatMemoryBuffer } from "./chat-memory-buffer"; +export { ChatSummaryMemoryBuffer } from "./summary-memory"; diff --git a/packages/llamaindex/src/ChatHistory.ts b/packages/core/src/memory/summary-memory.ts similarity index 58% rename from packages/llamaindex/src/ChatHistory.ts rename to packages/core/src/memory/summary-memory.ts index 4bbd44e4b4..87b5c11108 100644 --- a/packages/llamaindex/src/ChatHistory.ts +++ b/packages/core/src/memory/summary-memory.ts @@ -1,73 +1,11 @@ -import type { ChatMessage, LLM, MessageType } from "@llamaindex/core/llms"; -import { - defaultSummaryPrompt, - type SummaryPrompt, -} from "@llamaindex/core/prompts"; -import { extractText, messagesToHistory } from "@llamaindex/core/utils"; -import { tokenizers, type Tokenizer } from "@llamaindex/env"; -import { OpenAI } from "@llamaindex/openai"; - -/** - * A ChatHistory is used to keep the state of back and forth chat messages - */ -export abstract class ChatHistory< - AdditionalMessageOptions extends object = object, -> { - abstract get messages(): ChatMessage[]; - /** - * Adds a message to the chat history. - * @param message - */ - abstract addMessage(message: ChatMessage): void; - - /** - * Returns the messages that should be used as input to the LLM. - */ - abstract requestMessages( - transientMessages?: ChatMessage[], - ): Promise[]>; - - /** - * Resets the chat history so that it's empty. - */ - abstract reset(): void; - - /** - * Returns the new messages since the last call to this function (or since calling the constructor) - */ - abstract newMessages(): ChatMessage[]; -} - -export class SimpleChatHistory extends ChatHistory { - messages: ChatMessage[]; - private messagesBefore: number; - - constructor(init?: { messages?: ChatMessage[] | undefined }) { - super(); - this.messages = init?.messages ?? []; - this.messagesBefore = this.messages.length; - } - - addMessage(message: ChatMessage) { - this.messages.push(message); - } - - async requestMessages(transientMessages?: ChatMessage[]) { - return [...(transientMessages ?? []), ...this.messages]; - } - - reset() { - this.messages = []; - } - - newMessages() { - const newMessages = this.messages.slice(this.messagesBefore); - this.messagesBefore = this.messages.length; - return newMessages; - } -} - -export class SummaryChatHistory extends ChatHistory { +import { type Tokenizer, tokenizers } from "@llamaindex/env"; +import { Settings } from "../global"; +import type { ChatMessage, LLM, MessageType } from "../llms"; +import { defaultSummaryPrompt, type SummaryPrompt } from "../prompts"; +import { extractText, messagesToHistory } from "../utils"; +import { BaseMemory } from "./base"; + +export class ChatSummaryMemoryBuffer extends BaseMemory { /** * Tokenizer function that converts text to tokens, * this is used to calculate the number of tokens in a message. @@ -77,20 +15,18 @@ export class SummaryChatHistory extends ChatHistory { messages: ChatMessage[]; summaryPrompt: SummaryPrompt; llm: LLM; - private messagesBefore: number; - constructor(init?: Partial) { + constructor(options?: Partial) { super(); - this.messages = init?.messages ?? []; - this.messagesBefore = this.messages.length; - this.summaryPrompt = init?.summaryPrompt ?? defaultSummaryPrompt; - this.llm = init?.llm ?? new OpenAI(); + this.messages = options?.messages ?? []; + this.summaryPrompt = options?.summaryPrompt ?? defaultSummaryPrompt; + this.llm = options?.llm ?? Settings.llm; if (!this.llm.metadata.maxTokens) { throw new Error( "LLM maxTokens is not set. Needed so the summarizer ensures the context window size of the LLM.", ); } - this.tokenizer = init?.tokenizer ?? tokenizers.tokenizer(); + this.tokenizer = options?.tokenizer ?? tokenizers.tokenizer(); this.tokensToSummarize = this.llm.metadata.contextWindow - this.llm.metadata.maxTokens; if (this.tokensToSummarize < this.llm.metadata.contextWindow * 0.25) { @@ -128,12 +64,8 @@ export class SummaryChatHistory extends ChatHistory { return { content: response.message.content, role: "memory" }; } - addMessage(message: ChatMessage) { - this.messages.push(message); - } - // Find last summary message - private getLastSummaryIndex(): number | null { + private get lastSummaryIndex(): number | null { const reversedMessages = this.messages.slice().reverse(); const index = reversedMessages.findIndex( (message) => message.role === "memory", @@ -145,7 +77,7 @@ export class SummaryChatHistory extends ChatHistory { } public getLastSummary(): ChatMessage | null { - const lastSummaryIndex = this.getLastSummaryIndex(); + const lastSummaryIndex = this.lastSummaryIndex; return lastSummaryIndex ? this.messages[lastSummaryIndex]! : null; } @@ -165,7 +97,7 @@ export class SummaryChatHistory extends ChatHistory { * If there's a memory, uses all messages after the last summary message. */ private calcConversationMessages(transformSummary?: boolean): ChatMessage[] { - const lastSummaryIndex = this.getLastSummaryIndex(); + const lastSummaryIndex = this.lastSummaryIndex; if (!lastSummaryIndex) { // there's no memory, so just use all non-system messages return this.nonSystemMessages; @@ -182,18 +114,18 @@ export class SummaryChatHistory extends ChatHistory { } } - private calcCurrentRequestMessages(transientMessages?: ChatMessage[]) { + private calcCurrentRequestMessages() { // TODO: check order: currently, we're sending: // system messages first, then transient messages and then the messages that describe the conversation so far - return [ - ...this.systemMessages, - ...(transientMessages ? transientMessages : []), - ...this.calcConversationMessages(true), - ]; + return [...this.systemMessages, ...this.calcConversationMessages(true)]; } - async requestMessages(transientMessages?: ChatMessage[]) { - const requestMessages = this.calcCurrentRequestMessages(transientMessages); + reset() { + this.messages = []; + } + + async getMessages(): Promise { + const requestMessages = this.calcCurrentRequestMessages(); // get tokens of current request messages and the transient messages const tokens = requestMessages.reduce( @@ -217,27 +149,16 @@ export class SummaryChatHistory extends ChatHistory { // TODO: we still might have too many tokens // e.g. too large system messages or transient messages // how should we deal with that? - return this.calcCurrentRequestMessages(transientMessages); + return this.calcCurrentRequestMessages(); } return requestMessages; } - reset() { - this.messages = []; - } - - newMessages() { - const newMessages = this.messages.slice(this.messagesBefore); - this.messagesBefore = this.messages.length; - return newMessages; + async getAllMessages(): Promise { + return this.getMessages(); } -} -export function getHistory( - chatHistory?: ChatMessage[] | ChatHistory, -): ChatHistory { - if (chatHistory instanceof ChatHistory) { - return chatHistory; + put(message: ChatMessage) { + this.messages.push(message); } - return new SimpleChatHistory({ messages: chatHistory }); } diff --git a/packages/core/src/storage/chat-store/base-chat-store.ts b/packages/core/src/storage/chat-store/base-chat-store.ts new file mode 100644 index 0000000000..be19928f26 --- /dev/null +++ b/packages/core/src/storage/chat-store/base-chat-store.ts @@ -0,0 +1,19 @@ +import type { ChatMessage } from "../../llms"; + +export abstract class BaseChatStore< + AdditionalMessageOptions extends object = object, +> { + abstract setMessages( + key: string, + messages: ChatMessage[], + ): void; + abstract getMessages(key: string): ChatMessage[]; + abstract addMessage( + key: string, + message: ChatMessage, + idx?: number, + ): void; + abstract deleteMessages(key: string): void; + abstract deleteMessage(key: string, idx: number): void; + abstract getKeys(): IterableIterator; +} diff --git a/packages/core/src/storage/chat-store/index.ts b/packages/core/src/storage/chat-store/index.ts new file mode 100644 index 0000000000..922555ba22 --- /dev/null +++ b/packages/core/src/storage/chat-store/index.ts @@ -0,0 +1,2 @@ +export { BaseChatStore } from "./base-chat-store"; +export { SimpleChatStore } from "./simple-chat-store"; diff --git a/packages/core/src/storage/chat-store/simple-chat-store.ts b/packages/core/src/storage/chat-store/simple-chat-store.ts new file mode 100644 index 0000000000..365be20c4d --- /dev/null +++ b/packages/core/src/storage/chat-store/simple-chat-store.ts @@ -0,0 +1,43 @@ +import type { ChatMessage } from "../../llms"; +import { BaseChatStore } from "./base-chat-store"; + +export class SimpleChatStore< + AdditionalMessageOptions extends object = object, +> extends BaseChatStore { + #store = new Map[]>(); + setMessages(key: string, messages: ChatMessage[]) { + this.#store.set(key, messages); + } + + getMessages(key: string) { + return this.#store.get(key) ?? []; + } + + addMessage( + key: string, + message: ChatMessage, + idx?: number, + ) { + const messages = this.#store.get(key) ?? []; + if (idx === undefined) { + messages.push(message); + } else { + messages.splice(idx, 0, message); + } + this.#store.set(key, messages); + } + + deleteMessages(key: string) { + this.#store.delete(key); + } + + deleteMessage(key: string, idx: number) { + const messages = this.#store.get(key) ?? []; + messages.splice(idx, 1); + this.#store.set(key, messages); + } + + getKeys() { + return this.#store.keys(); + } +} diff --git a/packages/llamaindex/src/agent/base.ts b/packages/llamaindex/src/agent/base.ts index b320fa9287..fb2d20f6bc 100644 --- a/packages/llamaindex/src/agent/base.ts +++ b/packages/llamaindex/src/agent/base.ts @@ -5,10 +5,10 @@ import type { MessageContent, ToolOutput, } from "@llamaindex/core/llms"; +import { BaseMemory } from "@llamaindex/core/memory"; import { EngineResponse } from "@llamaindex/core/schema"; import { wrapEventCaller } from "@llamaindex/core/utils"; import { randomUUID } from "@llamaindex/env"; -import { ChatHistory } from "../ChatHistory.js"; import { Settings } from "../Settings.js"; import { type ChatEngine, @@ -353,11 +353,12 @@ export abstract class AgentRunner< async chat( params: ChatEngineParamsNonStreaming | ChatEngineParamsStreaming, ): Promise> { - let chatHistory: ChatMessage[] | undefined = []; + let chatHistory: ChatMessage[] = []; - if (params.chatHistory instanceof ChatHistory) { - chatHistory = params.chatHistory - .messages as ChatMessage[]; + if (params.chatHistory instanceof BaseMemory) { + chatHistory = (await params.chatHistory.getMessages( + params.message, + )) as ChatMessage[]; } else { chatHistory = params.chatHistory as ChatMessage[]; diff --git a/packages/llamaindex/src/engines/chat/CondenseQuestionChatEngine.ts b/packages/llamaindex/src/engines/chat/CondenseQuestionChatEngine.ts index f097a1985a..a20cf92f49 100644 --- a/packages/llamaindex/src/engines/chat/CondenseQuestionChatEngine.ts +++ b/packages/llamaindex/src/engines/chat/CondenseQuestionChatEngine.ts @@ -1,4 +1,5 @@ import type { ChatMessage, LLM } from "@llamaindex/core/llms"; +import { BaseMemory, ChatMemoryBuffer } from "@llamaindex/core/memory"; import { type CondenseQuestionPrompt, defaultCondenseQuestionPrompt, @@ -12,8 +13,6 @@ import { streamReducer, wrapEventCaller, } from "@llamaindex/core/utils"; -import type { ChatHistory } from "../../ChatHistory.js"; -import { getHistory } from "../../ChatHistory.js"; import type { ServiceContext } from "../../ServiceContext.js"; import { llmFromSettingsOrContext } from "../../Settings.js"; import type { QueryEngine } from "../../types.js"; @@ -39,7 +38,7 @@ export class CondenseQuestionChatEngine implements ChatEngine { queryEngine: QueryEngine; - chatHistory: ChatHistory; + chatHistory: BaseMemory; llm: LLM; condenseMessagePrompt: CondenseQuestionPrompt; @@ -52,7 +51,9 @@ export class CondenseQuestionChatEngine super(); this.queryEngine = init.queryEngine; - this.chatHistory = getHistory(init?.chatHistory); + this.chatHistory = new ChatMemoryBuffer({ + chatHistory: init?.chatHistory, + }); this.llm = llmFromSettingsOrContext(init?.serviceContext); this.condenseMessagePrompt = init?.condenseMessagePrompt ?? defaultCondenseQuestionPrompt; @@ -76,9 +77,9 @@ export class CondenseQuestionChatEngine } } - private async condenseQuestion(chatHistory: ChatHistory, question: string) { + private async condenseQuestion(chatHistory: BaseMemory, question: string) { const chatHistoryStr = messagesToHistory( - await chatHistory.requestMessages(), + await chatHistory.getMessages(question), ); return this.llm.complete({ @@ -99,13 +100,18 @@ export class CondenseQuestionChatEngine ): Promise> { const { message, stream } = params; const chatHistory = params.chatHistory - ? getHistory(params.chatHistory) + ? new ChatMemoryBuffer({ + chatHistory: + params.chatHistory instanceof BaseMemory + ? await params.chatHistory.getMessages(message) + : params.chatHistory, + }) : this.chatHistory; const condensedQuestion = ( await this.condenseQuestion(chatHistory, extractText(message)) ).text; - chatHistory.addMessage({ content: message, role: "user" }); + chatHistory.put({ content: message, role: "user" }); if (stream) { const stream = await this.queryEngine.query({ @@ -118,14 +124,14 @@ export class CondenseQuestionChatEngine reducer: (accumulator, part) => (accumulator += extractText(part.message.content)), finished: (accumulator) => { - chatHistory.addMessage({ content: accumulator, role: "assistant" }); + chatHistory.put({ content: accumulator, role: "assistant" }); }, }); } const response = await this.queryEngine.query({ query: condensedQuestion, }); - chatHistory.addMessage({ + chatHistory.put({ content: response.message.content, role: "assistant", }); diff --git a/packages/llamaindex/src/engines/chat/ContextChatEngine.ts b/packages/llamaindex/src/engines/chat/ContextChatEngine.ts index 41c603f1eb..3ea12161ab 100644 --- a/packages/llamaindex/src/engines/chat/ContextChatEngine.ts +++ b/packages/llamaindex/src/engines/chat/ContextChatEngine.ts @@ -4,6 +4,7 @@ import type { MessageContent, MessageType, } from "@llamaindex/core/llms"; +import { BaseMemory, ChatMemoryBuffer } from "@llamaindex/core/memory"; import { type ContextSystemPrompt, type ModuleRecord, @@ -17,8 +18,6 @@ import { streamReducer, wrapEventCaller, } from "@llamaindex/core/utils"; -import type { ChatHistory } from "../../ChatHistory.js"; -import { getHistory } from "../../ChatHistory.js"; import type { BaseRetriever } from "../../Retriever.js"; import { Settings } from "../../Settings.js"; import type { BaseNodePostprocessor } from "../../postprocessors/index.js"; @@ -36,7 +35,7 @@ import type { */ export class ContextChatEngine extends PromptMixin implements ChatEngine { chatModel: LLM; - chatHistory: ChatHistory; + chatHistory: BaseMemory; contextGenerator: ContextGenerator & PromptMixin; systemPrompt?: string | undefined; @@ -51,7 +50,7 @@ export class ContextChatEngine extends PromptMixin implements ChatEngine { }) { super(); this.chatModel = init.chatModel ?? Settings.llm; - this.chatHistory = getHistory(init?.chatHistory); + this.chatHistory = new ChatMemoryBuffer({ chatHistory: init?.chatHistory }); this.contextGenerator = new DefaultContextGenerator({ retriever: init.retriever, contextSystemPrompt: init?.contextSystemPrompt, @@ -90,7 +89,12 @@ export class ContextChatEngine extends PromptMixin implements ChatEngine { ): Promise> { const { message, stream } = params; const chatHistory = params.chatHistory - ? getHistory(params.chatHistory) + ? new ChatMemoryBuffer({ + chatHistory: + params.chatHistory instanceof BaseMemory + ? await params.chatHistory.getMessages(message) + : params.chatHistory, + }) : this.chatHistory; const requestMessages = await this.prepareRequestMessages( message, @@ -107,7 +111,7 @@ export class ContextChatEngine extends PromptMixin implements ChatEngine { initialValue: "", reducer: (accumulator, part) => (accumulator += part.delta), finished: (accumulator) => { - chatHistory.addMessage({ content: accumulator, role: "assistant" }); + chatHistory.put({ content: accumulator, role: "assistant" }); }, }), (r) => EngineResponse.fromChatResponseChunk(r, requestMessages.nodes), @@ -116,7 +120,7 @@ export class ContextChatEngine extends PromptMixin implements ChatEngine { const response = await this.chatModel.chat({ messages: requestMessages.messages, }); - chatHistory.addMessage(response.message); + chatHistory.put(response.message); return EngineResponse.fromChatResponse(response, requestMessages.nodes); } @@ -126,16 +130,16 @@ export class ContextChatEngine extends PromptMixin implements ChatEngine { private async prepareRequestMessages( message: MessageContent, - chatHistory: ChatHistory, + chatHistory: BaseMemory, ) { - chatHistory.addMessage({ + chatHistory.put({ content: message, role: "user", }); const textOnly = extractText(message); const context = await this.contextGenerator.generate(textOnly); const systemMessage = this.prependSystemPrompt(context.message); - const messages = await chatHistory.requestMessages([systemMessage]); + const messages = await chatHistory.getMessages(systemMessage.content); return { nodes: context.nodes, messages }; } diff --git a/packages/llamaindex/src/engines/chat/SimpleChatEngine.ts b/packages/llamaindex/src/engines/chat/SimpleChatEngine.ts index bef71cc66d..a123881b8f 100644 --- a/packages/llamaindex/src/engines/chat/SimpleChatEngine.ts +++ b/packages/llamaindex/src/engines/chat/SimpleChatEngine.ts @@ -1,12 +1,11 @@ import type { LLM } from "@llamaindex/core/llms"; +import { BaseMemory, ChatMemoryBuffer } from "@llamaindex/core/memory"; import { EngineResponse } from "@llamaindex/core/schema"; import { streamConverter, streamReducer, wrapEventCaller, } from "@llamaindex/core/utils"; -import type { ChatHistory } from "../../ChatHistory.js"; -import { getHistory } from "../../ChatHistory.js"; import { Settings } from "../../Settings.js"; import type { ChatEngine, @@ -19,11 +18,11 @@ import type { */ export class SimpleChatEngine implements ChatEngine { - chatHistory: ChatHistory; + chatHistory: BaseMemory; llm: LLM; constructor(init?: Partial) { - this.chatHistory = getHistory(init?.chatHistory); + this.chatHistory = init?.chatHistory ?? new ChatMemoryBuffer(); this.llm = init?.llm ?? Settings.llm; } @@ -38,13 +37,18 @@ export class SimpleChatEngine implements ChatEngine { const { message, stream } = params; const chatHistory = params.chatHistory - ? getHistory(params.chatHistory) + ? new ChatMemoryBuffer({ + chatHistory: + params.chatHistory instanceof BaseMemory + ? await params.chatHistory.getMessages(message) + : params.chatHistory, + }) : this.chatHistory; - chatHistory.addMessage({ content: message, role: "user" }); + chatHistory.put({ content: message, role: "user" }); if (stream) { const stream = await this.llm.chat({ - messages: await chatHistory.requestMessages(), + messages: await chatHistory.getMessages(params.message), stream: true, }); return streamConverter( @@ -53,7 +57,7 @@ export class SimpleChatEngine implements ChatEngine { initialValue: "", reducer: (accumulator, part) => accumulator + part.delta, finished: (accumulator) => { - chatHistory.addMessage({ content: accumulator, role: "assistant" }); + chatHistory.put({ content: accumulator, role: "assistant" }); }, }), EngineResponse.fromChatResponseChunk, @@ -61,9 +65,10 @@ export class SimpleChatEngine implements ChatEngine { } const response = await this.llm.chat({ - messages: await chatHistory.requestMessages(), + stream: false, + messages: await chatHistory.getMessages(params.message), }); - chatHistory.addMessage(response.message); + chatHistory.put(response.message); return EngineResponse.fromChatResponse(response); } diff --git a/packages/llamaindex/src/engines/chat/types.ts b/packages/llamaindex/src/engines/chat/types.ts index 4f1f9802be..9b3b18c0bf 100644 --- a/packages/llamaindex/src/engines/chat/types.ts +++ b/packages/llamaindex/src/engines/chat/types.ts @@ -1,6 +1,6 @@ import type { ChatMessage, MessageContent } from "@llamaindex/core/llms"; +import type { BaseMemory } from "@llamaindex/core/memory"; import { EngineResponse, type NodeWithScore } from "@llamaindex/core/schema"; -import type { ChatHistory } from "../../ChatHistory.js"; /** * Represents the base parameters for ChatEngine. @@ -10,7 +10,7 @@ export interface ChatEngineParamsBase { /** * Optional chat history if you want to customize the chat history. */ - chatHistory?: ChatMessage[] | ChatHistory; + chatHistory?: ChatMessage[] | BaseMemory; /** * Optional flag to enable verbose mode. * @default false diff --git a/packages/llamaindex/src/index.edge.ts b/packages/llamaindex/src/index.edge.ts index b23086cd55..541d9f0248 100644 --- a/packages/llamaindex/src/index.edge.ts +++ b/packages/llamaindex/src/index.edge.ts @@ -46,9 +46,9 @@ declare module "@llamaindex/core/global" { } export * from "@llamaindex/core/llms"; +export * from "@llamaindex/core/memory"; export * from "@llamaindex/core/schema"; export * from "./agent/index.js"; -export * from "./ChatHistory.js"; export * from "./cloud/index.js"; export * from "./embeddings/index.js"; export * from "./engines/chat/index.js"; diff --git a/packages/llamaindex/src/memory/ChatMemoryBuffer.ts b/packages/llamaindex/src/memory/ChatMemoryBuffer.ts deleted file mode 100644 index ed4b387608..0000000000 --- a/packages/llamaindex/src/memory/ChatMemoryBuffer.ts +++ /dev/null @@ -1,107 +0,0 @@ -import type { ChatMessage, LLM } from "@llamaindex/core/llms"; -import type { ChatHistory } from "../ChatHistory.js"; -import { SimpleChatStore } from "../storage/chatStore/SimpleChatStore.js"; -import type { BaseChatStore } from "../storage/chatStore/types.js"; -import type { BaseMemory } from "./types.js"; - -const DEFAULT_TOKEN_LIMIT_RATIO = 0.75; -const DEFAULT_TOKEN_LIMIT = 3000; - -type ChatMemoryBufferParams = - { - tokenLimit?: number; - chatStore?: BaseChatStore; - chatStoreKey?: string; - chatHistory?: ChatHistory; - llm?: LLM; - }; - -export class ChatMemoryBuffer - implements BaseMemory -{ - tokenLimit: number; - - chatStore: BaseChatStore; - chatStoreKey: string; - - constructor( - init?: Partial>, - ) { - this.chatStore = - init?.chatStore ?? new SimpleChatStore(); - this.chatStoreKey = init?.chatStoreKey ?? "chat_history"; - if (init?.llm) { - const contextWindow = init.llm.metadata.contextWindow; - this.tokenLimit = - init?.tokenLimit ?? - Math.ceil(contextWindow * DEFAULT_TOKEN_LIMIT_RATIO); - } else { - this.tokenLimit = init?.tokenLimit ?? DEFAULT_TOKEN_LIMIT; - } - - if (init?.chatHistory) { - this.chatStore.setMessages(this.chatStoreKey, init.chatHistory.messages); - } - } - - get(initialTokenCount: number = 0) { - const chatHistory = this.getAll(); - - if (initialTokenCount > this.tokenLimit) { - throw new Error("Initial token count exceeds token limit"); - } - - let messageCount = chatHistory.length; - let tokenCount = - this._tokenCountForMessageCount(messageCount) + initialTokenCount; - - while (tokenCount > this.tokenLimit && messageCount > 1) { - messageCount -= 1; - if (chatHistory.at(-messageCount)?.role === "assistant") { - // we cannot have an assistant message at the start of the chat history - // if after removal of the first, we have an assistant message, - // we need to remove the assistant message too - messageCount -= 1; - } - - tokenCount = - this._tokenCountForMessageCount(messageCount) + initialTokenCount; - } - - // catch one message longer than token limit - if (tokenCount > this.tokenLimit || messageCount <= 0) { - return []; - } - - return chatHistory.slice(-messageCount); - } - - getAll() { - return this.chatStore.getMessages(this.chatStoreKey); - } - - put(message: ChatMessage) { - this.chatStore.addMessage(this.chatStoreKey, message); - } - - set(messages: ChatMessage[]) { - this.chatStore.setMessages(this.chatStoreKey, messages); - } - - reset() { - this.chatStore.deleteMessages(this.chatStoreKey); - } - - private _tokenCountForMessageCount(messageCount: number): number { - if (messageCount <= 0) { - return 0; - } - - const chatHistory = this.getAll(); - const msgStr = chatHistory - .slice(-messageCount) - .map((m) => m.content) - .join(" "); - return msgStr.split(" ").length; - } -} diff --git a/packages/llamaindex/src/memory/types.ts b/packages/llamaindex/src/memory/types.ts deleted file mode 100644 index a95e2efea9..0000000000 --- a/packages/llamaindex/src/memory/types.ts +++ /dev/null @@ -1,10 +0,0 @@ -import type { ChatMessage } from "@llamaindex/core/llms"; - -export interface BaseMemory { - tokenLimit: number; - get(...args: unknown[]): ChatMessage[]; - getAll(): ChatMessage[]; - put(message: ChatMessage): void; - set(messages: ChatMessage[]): void; - reset(): void; -} diff --git a/packages/llamaindex/src/storage/chatStore/SimpleChatStore.ts b/packages/llamaindex/src/storage/chatStore/SimpleChatStore.ts deleted file mode 100644 index 0f13ed6fbf..0000000000 --- a/packages/llamaindex/src/storage/chatStore/SimpleChatStore.ts +++ /dev/null @@ -1,65 +0,0 @@ -import type { ChatMessage } from "@llamaindex/core/llms"; -import type { BaseChatStore } from "./types.js"; - -/** - * fixme: User could carry object references in the messages. - * This could lead to memory leaks if the messages are not properly cleaned up. - */ -export class SimpleChatStore< - AdditionalMessageOptions extends object = Record, -> implements BaseChatStore -{ - store: { [key: string]: ChatMessage[] } = {}; - - public setMessages( - key: string, - messages: ChatMessage[], - ) { - this.store[key] = messages; - } - - public getMessages(key: string): ChatMessage[] { - return this.store[key] || []; - } - - public addMessage( - key: string, - message: ChatMessage, - ) { - this.store[key] = this.store[key] || []; - this.store[key].push(message); - } - - public deleteMessages(key: string) { - if (!(key in this.store)) { - return null; - } - const messages = this.store[key]!; - delete this.store[key]; - return messages; - } - - public deleteMessage(key: string, idx: number) { - if (!(key in this.store)) { - return null; - } - if (idx >= this.store[key]!.length) { - return null; - } - return this.store[key]!.splice(idx, 1)[0]!; - } - - public deleteLastMessage(key: string) { - if (!(key in this.store)) { - return null; - } - - const lastMessage = this.store[key]!.pop(); - - return lastMessage || null; - } - - public getKeys(): string[] { - return Object.keys(this.store); - } -} diff --git a/packages/llamaindex/src/storage/chatStore/types.ts b/packages/llamaindex/src/storage/chatStore/types.ts deleted file mode 100644 index 2e467276c6..0000000000 --- a/packages/llamaindex/src/storage/chatStore/types.ts +++ /dev/null @@ -1,19 +0,0 @@ -import type { ChatMessage } from "@llamaindex/core/llms"; - -export interface BaseChatStore< - AdditionalMessageOptions extends object = object, -> { - setMessages( - key: string, - messages: ChatMessage[], - ): void; - getMessages(key: string): ChatMessage[]; - addMessage(key: string, message: ChatMessage): void; - deleteMessages(key: string): ChatMessage[] | null; - deleteMessage( - key: string, - idx: number, - ): ChatMessage | null; - deleteLastMessage(key: string): ChatMessage | null; - getKeys(): string[]; -} diff --git a/packages/llamaindex/src/storage/index.ts b/packages/llamaindex/src/storage/index.ts index ecb182ebc7..bb67a8589a 100644 --- a/packages/llamaindex/src/storage/index.ts +++ b/packages/llamaindex/src/storage/index.ts @@ -1,5 +1,4 @@ -export { SimpleChatStore } from "./chatStore/SimpleChatStore.js"; -export * from "./chatStore/types.js"; +export * from "@llamaindex/core/storage/chat-store"; export { PostgresDocumentStore } from "./docStore/PostgresDocumentStore.js"; export { SimpleDocumentStore } from "./docStore/SimpleDocumentStore.js"; export * from "./docStore/types.js";