-
Notifications
You must be signed in to change notification settings - Fork 339
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: chat memory & chat history into core module (#1201)
- Loading branch information
Showing
21 changed files
with
311 additions
and
354 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import { Settings } from "../global"; | ||
import type { ChatMessage, MessageContent } from "../llms"; | ||
import { type BaseChatStore, SimpleChatStore } from "../storage/chat-store"; | ||
import { extractText } from "../utils"; | ||
|
||
export const DEFAULT_TOKEN_LIMIT_RATIO = 0.75; | ||
export const DEFAULT_CHAT_STORE_KEY = "chat_history"; | ||
|
||
/** | ||
* A ChatMemory is used to keep the state of back and forth chat messages | ||
*/ | ||
export abstract class BaseMemory< | ||
AdditionalMessageOptions extends object = object, | ||
> { | ||
abstract getMessages( | ||
input?: MessageContent | undefined, | ||
): | ||
| ChatMessage<AdditionalMessageOptions>[] | ||
| Promise<ChatMessage<AdditionalMessageOptions>[]>; | ||
abstract getAllMessages(): | ||
| ChatMessage<AdditionalMessageOptions>[] | ||
| Promise<ChatMessage<AdditionalMessageOptions>[]>; | ||
abstract put(messages: ChatMessage<AdditionalMessageOptions>): void; | ||
abstract reset(): void; | ||
|
||
protected _tokenCountForMessages(messages: ChatMessage[]): number { | ||
if (messages.length === 0) { | ||
return 0; | ||
} | ||
|
||
const tokenizer = Settings.tokenizer; | ||
const str = messages.map((m) => extractText(m.content)).join(" "); | ||
return tokenizer.encode(str).length; | ||
} | ||
} | ||
|
||
export abstract class BaseChatStoreMemory< | ||
AdditionalMessageOptions extends object = object, | ||
> extends BaseMemory<AdditionalMessageOptions> { | ||
protected constructor( | ||
public chatStore: BaseChatStore<AdditionalMessageOptions> = new SimpleChatStore<AdditionalMessageOptions>(), | ||
public chatStoreKey: string = DEFAULT_CHAT_STORE_KEY, | ||
) { | ||
super(); | ||
} | ||
|
||
getAllMessages(): ChatMessage<AdditionalMessageOptions>[] { | ||
return this.chatStore.getMessages(this.chatStoreKey); | ||
} | ||
|
||
put(messages: ChatMessage<AdditionalMessageOptions>) { | ||
this.chatStore.addMessage(this.chatStoreKey, messages); | ||
} | ||
|
||
set(messages: ChatMessage<AdditionalMessageOptions>[]) { | ||
this.chatStore.setMessages(this.chatStoreKey, messages); | ||
} | ||
|
||
reset() { | ||
this.chatStore.deleteMessages(this.chatStoreKey); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import { Settings } from "../global"; | ||
import type { ChatMessage, LLM, MessageContent } from "../llms"; | ||
import { type BaseChatStore } from "../storage/chat-store"; | ||
import { BaseChatStoreMemory, DEFAULT_TOKEN_LIMIT_RATIO } from "./base"; | ||
|
||
type ChatMemoryBufferOptions<AdditionalMessageOptions extends object = object> = | ||
{ | ||
tokenLimit?: number | undefined; | ||
chatStore?: BaseChatStore<AdditionalMessageOptions> | undefined; | ||
chatStoreKey?: string | undefined; | ||
chatHistory?: ChatMessage<AdditionalMessageOptions>[] | undefined; | ||
llm?: LLM<object, AdditionalMessageOptions> | undefined; | ||
}; | ||
|
||
export class ChatMemoryBuffer< | ||
AdditionalMessageOptions extends object = object, | ||
> extends BaseChatStoreMemory<AdditionalMessageOptions> { | ||
tokenLimit: number; | ||
|
||
constructor( | ||
options?: Partial<ChatMemoryBufferOptions<AdditionalMessageOptions>>, | ||
) { | ||
super(options?.chatStore, options?.chatStoreKey); | ||
|
||
const llm = options?.llm ?? Settings.llm; | ||
const contextWindow = llm.metadata.contextWindow; | ||
this.tokenLimit = | ||
options?.tokenLimit ?? | ||
Math.ceil(contextWindow * DEFAULT_TOKEN_LIMIT_RATIO); | ||
|
||
if (options?.chatHistory) { | ||
this.chatStore.setMessages(this.chatStoreKey, options.chatHistory); | ||
} | ||
} | ||
|
||
getMessages( | ||
input?: MessageContent | undefined, | ||
initialTokenCount: number = 0, | ||
) { | ||
const messages = this.getAllMessages(); | ||
|
||
if (initialTokenCount > this.tokenLimit) { | ||
throw new Error("Initial token count exceeds token limit"); | ||
} | ||
|
||
let messageCount = messages.length; | ||
let currentMessages = messages.slice(-messageCount); | ||
let tokenCount = this._tokenCountForMessages(messages) + initialTokenCount; | ||
|
||
while (tokenCount > this.tokenLimit && messageCount > 1) { | ||
messageCount -= 1; | ||
if (messages.at(-messageCount)!.role === "assistant") { | ||
messageCount -= 1; | ||
} | ||
currentMessages = messages.slice(-messageCount); | ||
tokenCount = | ||
this._tokenCountForMessages(currentMessages) + initialTokenCount; | ||
} | ||
|
||
if (tokenCount > this.tokenLimit && messageCount <= 0) { | ||
return []; | ||
} | ||
return messages.slice(-messageCount); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
export { BaseMemory } from "./base"; | ||
export { ChatMemoryBuffer } from "./chat-memory-buffer"; | ||
export { ChatSummaryMemoryBuffer } from "./summary-memory"; |
Oops, something went wrong.