Skip to content

Commit

Permalink
refactor: chat memory & chat history into core module (#1201)
Browse files Browse the repository at this point in the history
  • Loading branch information
himself65 committed Sep 16, 2024
1 parent b42adeb commit 423d66b
Show file tree
Hide file tree
Showing 21 changed files with 311 additions and 354 deletions.
6 changes: 3 additions & 3 deletions examples/anthropic/chat_interactive.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { Anthropic, SimpleChatEngine, SimpleChatHistory } from "llamaindex";
import { Anthropic, ChatMemoryBuffer, SimpleChatEngine } from "llamaindex";
import { stdin as input, stdout as output } from "node:process";
import readline from "node:readline/promises";

Expand All @@ -8,8 +8,8 @@ import readline from "node:readline/promises";
model: "claude-3-opus",
});
// chatHistory will store all the messages in the conversation
const chatHistory = new SimpleChatHistory({
messages: [
const chatHistory = new ChatMemoryBuffer({
chatHistory: [
{
content: "You want to talk in rhymes.",
role: "system",
Expand Down
4 changes: 2 additions & 2 deletions examples/chatHistory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ import { stdin as input, stdout as output } from "node:process";
import readline from "node:readline/promises";

import {
ChatSummaryMemoryBuffer,
OpenAI,
Settings,
SimpleChatEngine,
SummaryChatHistory,
} from "llamaindex";

if (process.env.NODE_ENV === "development") {
Expand All @@ -18,7 +18,7 @@ async function main() {
// Set maxTokens to 75% of the context window size of 4096
// This will trigger the summarizer once the chat history reaches 25% of the context window size (1024 tokens)
const llm = new OpenAI({ model: "gpt-3.5-turbo", maxTokens: 4096 * 0.75 });
const chatHistory = new SummaryChatHistory({ llm });
const chatHistory = new ChatSummaryMemoryBuffer({ llm });
const chatEngine = new SimpleChatEngine({ llm });
const rl = readline.createInterface({ input, output });

Expand Down
28 changes: 28 additions & 0 deletions packages/core/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,34 @@
"types": "./dist/workflow/index.d.ts",
"default": "./dist/workflow/index.js"
}
},
"./memory": {
"require": {
"types": "./dist/memory/index.d.cts",
"default": "./dist/memory/index.cjs"
},
"import": {
"types": "./dist/memory/index.d.ts",
"default": "./dist/memory/index.js"
},
"default": {
"types": "./dist/memory/index.d.ts",
"default": "./dist/memory/index.js"
}
},
"./storage/chat-store": {
"require": {
"types": "./dist/storage/chat-store/index.d.cts",
"default": "./dist/storage/chat-store/index.cjs"
},
"import": {
"types": "./dist/storage/chat-store/index.d.ts",
"default": "./dist/storage/chat-store/index.js"
},
"default": {
"types": "./dist/storage/chat-store/index.d.ts",
"default": "./dist/storage/chat-store/index.js"
}
}
},
"files": [
Expand Down
62 changes: 62 additions & 0 deletions packages/core/src/memory/base.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import { Settings } from "../global";
import type { ChatMessage, MessageContent } from "../llms";
import { type BaseChatStore, SimpleChatStore } from "../storage/chat-store";
import { extractText } from "../utils";

export const DEFAULT_TOKEN_LIMIT_RATIO = 0.75;
export const DEFAULT_CHAT_STORE_KEY = "chat_history";

/**
* A ChatMemory is used to keep the state of back and forth chat messages
*/
export abstract class BaseMemory<
AdditionalMessageOptions extends object = object,
> {
abstract getMessages(
input?: MessageContent | undefined,
):
| ChatMessage<AdditionalMessageOptions>[]
| Promise<ChatMessage<AdditionalMessageOptions>[]>;
abstract getAllMessages():
| ChatMessage<AdditionalMessageOptions>[]
| Promise<ChatMessage<AdditionalMessageOptions>[]>;
abstract put(messages: ChatMessage<AdditionalMessageOptions>): void;
abstract reset(): void;

protected _tokenCountForMessages(messages: ChatMessage[]): number {
if (messages.length === 0) {
return 0;
}

const tokenizer = Settings.tokenizer;
const str = messages.map((m) => extractText(m.content)).join(" ");
return tokenizer.encode(str).length;
}
}

export abstract class BaseChatStoreMemory<
AdditionalMessageOptions extends object = object,
> extends BaseMemory<AdditionalMessageOptions> {
protected constructor(
public chatStore: BaseChatStore<AdditionalMessageOptions> = new SimpleChatStore<AdditionalMessageOptions>(),
public chatStoreKey: string = DEFAULT_CHAT_STORE_KEY,
) {
super();
}

getAllMessages(): ChatMessage<AdditionalMessageOptions>[] {
return this.chatStore.getMessages(this.chatStoreKey);
}

put(messages: ChatMessage<AdditionalMessageOptions>) {
this.chatStore.addMessage(this.chatStoreKey, messages);
}

set(messages: ChatMessage<AdditionalMessageOptions>[]) {
this.chatStore.setMessages(this.chatStoreKey, messages);
}

reset() {
this.chatStore.deleteMessages(this.chatStoreKey);
}
}
65 changes: 65 additions & 0 deletions packages/core/src/memory/chat-memory-buffer.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import { Settings } from "../global";
import type { ChatMessage, LLM, MessageContent } from "../llms";
import { type BaseChatStore } from "../storage/chat-store";
import { BaseChatStoreMemory, DEFAULT_TOKEN_LIMIT_RATIO } from "./base";

type ChatMemoryBufferOptions<AdditionalMessageOptions extends object = object> =
{
tokenLimit?: number | undefined;
chatStore?: BaseChatStore<AdditionalMessageOptions> | undefined;
chatStoreKey?: string | undefined;
chatHistory?: ChatMessage<AdditionalMessageOptions>[] | undefined;
llm?: LLM<object, AdditionalMessageOptions> | undefined;
};

export class ChatMemoryBuffer<
AdditionalMessageOptions extends object = object,
> extends BaseChatStoreMemory<AdditionalMessageOptions> {
tokenLimit: number;

constructor(
options?: Partial<ChatMemoryBufferOptions<AdditionalMessageOptions>>,
) {
super(options?.chatStore, options?.chatStoreKey);

const llm = options?.llm ?? Settings.llm;
const contextWindow = llm.metadata.contextWindow;
this.tokenLimit =
options?.tokenLimit ??
Math.ceil(contextWindow * DEFAULT_TOKEN_LIMIT_RATIO);

if (options?.chatHistory) {
this.chatStore.setMessages(this.chatStoreKey, options.chatHistory);
}
}

getMessages(
input?: MessageContent | undefined,
initialTokenCount: number = 0,
) {
const messages = this.getAllMessages();

if (initialTokenCount > this.tokenLimit) {
throw new Error("Initial token count exceeds token limit");
}

let messageCount = messages.length;
let currentMessages = messages.slice(-messageCount);
let tokenCount = this._tokenCountForMessages(messages) + initialTokenCount;

while (tokenCount > this.tokenLimit && messageCount > 1) {
messageCount -= 1;
if (messages.at(-messageCount)!.role === "assistant") {
messageCount -= 1;
}
currentMessages = messages.slice(-messageCount);
tokenCount =
this._tokenCountForMessages(currentMessages) + initialTokenCount;
}

if (tokenCount > this.tokenLimit && messageCount <= 0) {
return [];
}
return messages.slice(-messageCount);
}
}
3 changes: 3 additions & 0 deletions packages/core/src/memory/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
export { BaseMemory } from "./base";
export { ChatMemoryBuffer } from "./chat-memory-buffer";
export { ChatSummaryMemoryBuffer } from "./summary-memory";
Loading

0 comments on commit 423d66b

Please sign in to comment.