From 4bec27023950c078fe0d559e1720d757d0024c93 Mon Sep 17 00:00:00 2001 From: Jerry Liu Date: Mon, 4 Dec 2023 17:01:24 -0800 Subject: [PATCH] upgrade v3 (#49) support multi-modality allow loading an entire directory (and all the files in the directory) add sources in the response --- "1_\360\237\217\240_Home.py" | 43 +-- README.md | 2 + core/agent_builder/__init__.py | 0 .../base.py} | 182 ++----------- core/agent_builder/loader.py | 115 ++++++++ core/agent_builder/multimodal.py | 256 ++++++++++++++++++ core/agent_builder/registry.py | 78 ++++++ core/param_cache.py | 62 ++++- core/utils.py | 166 +++++++++++- .../2_\342\232\231\357\270\217_RAG_Config.py" | 28 +- ...3_\360\237\244\226_Generated_RAG_Agent.py" | 51 +++- pyproject.toml | 11 +- st_utils.py | 53 +++- 13 files changed, 819 insertions(+), 228 deletions(-) create mode 100644 core/agent_builder/__init__.py rename core/{agent_builder.py => agent_builder/base.py} (56%) create mode 100644 core/agent_builder/loader.py create mode 100644 core/agent_builder/multimodal.py create mode 100644 core/agent_builder/registry.py diff --git "a/1_\360\237\217\240_Home.py" "b/1_\360\237\217\240_Home.py" index 9a84660..ecb2d2f 100644 --- "a/1_\360\237\217\240_Home.py" +++ "b/1_\360\237\217\240_Home.py" @@ -1,8 +1,13 @@ import streamlit as st from streamlit_pills import pills -from st_utils import add_sidebar, get_current_state +from st_utils import ( + add_builder_config, + add_sidebar, + get_current_state, +) +current_state = get_current_state() #################### #### STREAMLIT ##### @@ -28,7 +33,7 @@ st.info("**NOTE**: The ability to add web search is enabled.") -current_state = get_current_state() +add_builder_config() add_sidebar() @@ -60,40 +65,6 @@ def add_to_message_history(role: str, content: str) -> None: with st.chat_message(message["role"]): st.write(message["content"]) - -# def handle_user_input() -> None: -# """Handle user input.""" -# prompt = st.session_state.user_question_st -# print(f"USER PROMPT: {prompt}") -# add_to_message_history("user", prompt) -# with st.chat_message("user"): -# st.write(prompt) -# # If last message is not from assistant, generate a new response -# if st.session_state.messages[-1]["role"] != "assistant": -# with st.chat_message("assistant"): -# with st.spinner("Thinking..."): -# response = current_state.builder_agent.chat(prompt) -# st.write(str(response)) -# add_to_message_history("assistant", str(response)) - -# else: -# pass - -# # check agent_ids again -# # if it doesn't match, add to directory and refresh -# agent_ids = current_state.agent_registry.get_agent_ids() -# # check diff between agent_ids and cur agent ids -# diff_ids = list(set(agent_ids) - set(st.session_state.cur_agent_ids)) -# if len(diff_ids) > 0: -# # # clear streamlit cache, to allow you to generate a new agent -# # st.cache_resource.clear() -# st.rerun() - -# handle user input -# st.chat_input( -# "Your question", key="user_question_st", on_submit=handle_user_input -# ) # Prompt for user input and save to chat history - # TODO: this is really hacky, only because st.rerun is jank if prompt := st.chat_input( "Your question", diff --git a/README.md b/README.md index 6b22ee3..23a7ba2 100644 --- a/README.md +++ b/README.md @@ -40,6 +40,8 @@ streamlit run 1_🏠_Home.py ``` +**NOTE**: If you've upgraded the version of RAGs, and you're running into issues on launch, you may need to delete the `cache` folder in your home directory (we may have introduced breaking changes in the stored data structure between versions). + ## Detailed Overview The app contains the following sections, corresponding to the steps listed above. diff --git a/core/agent_builder/__init__.py b/core/agent_builder/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/core/agent_builder.py b/core/agent_builder/base.py similarity index 56% rename from core/agent_builder.py rename to core/agent_builder/base.py index c35e9a9..f6c2d4b 100644 --- a/core/agent_builder.py +++ b/core/agent_builder/base.py @@ -3,93 +3,19 @@ from llama_index.llms import ChatMessage from llama_index.prompts import ChatPromptTemplate from typing import List, cast, Optional -from llama_index.tools import FunctionTool -from llama_index.agent.types import BaseAgent from core.builder_config import BUILDER_LLM -from typing import Dict, Tuple, Any, Callable, Union -import streamlit as st -from pathlib import Path -import json +from typing import Dict, Any import uuid from core.constants import AGENT_CACHE_DIR -import shutil +from abc import ABC, abstractmethod from core.param_cache import ParamCache, RAGParams from core.utils import ( load_data, get_tool_objects, construct_agent, - load_meta_agent, ) - - -class AgentCacheRegistry: - """Registry for agent caches, in disk. - - Can register new agent caches, load agent caches, delete agent caches, etc. - - """ - - def __init__(self, dir: Union[str, Path]) -> None: - """Init params.""" - self._dir = dir - - def _add_agent_id_to_directory(self, agent_id: str) -> None: - """Save agent id to directory.""" - full_path = Path(self._dir) / "agent_ids.json" - if not full_path.exists(): - with open(full_path, "w") as f: - json.dump({"agent_ids": [agent_id]}, f) - else: - with open(full_path, "r") as f: - agent_ids = json.load(f)["agent_ids"] - if agent_id in agent_ids: - raise ValueError(f"Agent id {agent_id} already exists.") - agent_ids_set = set(agent_ids) - agent_ids_set.add(agent_id) - with open(full_path, "w") as f: - json.dump({"agent_ids": list(agent_ids_set)}, f) - - def add_new_agent_cache(self, agent_id: str, cache: ParamCache) -> None: - """Register agent.""" - # save the cache to disk - agent_cache_path = f"{self._dir}/{agent_id}" - cache.save_to_disk(agent_cache_path) - # save to agent ids - self._add_agent_id_to_directory(agent_id) - - def get_agent_ids(self) -> List[str]: - """Get agent ids.""" - full_path = Path(self._dir) / "agent_ids.json" - if not full_path.exists(): - return [] - with open(full_path, "r") as f: - agent_ids = json.load(f)["agent_ids"] - - return agent_ids - - def get_agent_cache(self, agent_id: str) -> ParamCache: - """Get agent cache.""" - full_path = Path(self._dir) / f"{agent_id}" - if not full_path.exists(): - raise ValueError(f"Cache for agent {agent_id} does not exist.") - cache = ParamCache.load_from_disk(str(full_path)) - return cache - - def delete_agent_cache(self, agent_id: str) -> None: - """Delete agent cache.""" - # modify / resave agent_ids - agent_ids = self.get_agent_ids() - new_agent_ids = [id for id in agent_ids if id != agent_id] - full_path = Path(self._dir) / "agent_ids.json" - with open(full_path, "w") as f: - json.dump({"agent_ids": new_agent_ids}, f) - - # remove agent cache - full_path = Path(self._dir) / f"{agent_id}" - if full_path.exists(): - # recursive delete - shutil.rmtree(full_path) +from core.agent_builder.registry import AgentCacheRegistry # System prompt tool @@ -121,7 +47,21 @@ def delete_agent_cache(self, agent_id: str) -> None: GEN_SYS_PROMPT_TMPL = ChatPromptTemplate(gen_sys_prompt_messages) -class RAGAgentBuilder: +class BaseRAGAgentBuilder(ABC): + """Base RAG Agent builder class.""" + + @property + @abstractmethod + def cache(self) -> ParamCache: + """Cache.""" + + @property + @abstractmethod + def agent_registry(self) -> AgentCacheRegistry: + """Agent registry.""" + + +class RAGAgentBuilder(BaseRAGAgentBuilder): """RAG Agent builder. Contains a set of functions to construct a RAG agent, including: @@ -165,25 +105,31 @@ def create_system_prompt(self, task: str) -> str: return f"System prompt created: {response.message.content}" def load_data( - self, file_names: Optional[List[str]] = None, urls: Optional[List[str]] = None + self, + file_names: Optional[List[str]] = None, + directory: Optional[str] = None, + urls: Optional[List[str]] = None, ) -> str: """Load data for a given task. - Only ONE of file_names or urls should be specified. + Only ONE of file_names or directory or urls should be specified. Args: file_names (Optional[List[str]]): List of file names to load. Defaults to None. + directory (Optional[str]): Directory to load files from. urls (Optional[List[str]]): List of urls to load. Defaults to None. """ file_names = file_names or [] urls = urls or [] - docs = load_data(file_names=file_names, urls=urls) + directory = directory or "" + docs = load_data(file_names=file_names, directory=directory, urls=urls) self._cache.docs = docs self._cache.file_names = file_names self._cache.urls = urls + self._cache.directory = directory return "Data loaded successfully." def add_web_tool(self) -> str: @@ -302,77 +248,3 @@ def update_agent( # this will update the agent in the cache self.create_agent() - - -#################### -#### META Agent #### -#################### - -RAG_BUILDER_SYS_STR = """\ -You are helping to construct an agent given a user-specified task. -You should generally use the tools in this rough order to build the agent. - -1) Create system prompt tool: to create the system prompt for the agent. -2) Load in user-specified data (based on file paths they specify). -3) Decide whether or not to add additional tools. -4) Set parameters for the RAG pipeline. -5) Build the agent - -This will be a back and forth conversation with the user. You should -continue asking users if there's anything else they want to do until -they say they're done. To help guide them on the process, -you can give suggestions on parameters they can set based on the tools they -have available (e.g. "Do you want to set the number of documents to retrieve?") - -""" - - -### DEFINE Agent #### -# NOTE: here we define a function that is dependent on the LLM, -# please make sure to update the LLM above if you change the function below - - -def _get_builder_agent_tools(agent_builder: RAGAgentBuilder) -> List[FunctionTool]: - """Get list of builder agent tools to pass to the builder agent.""" - # see if metaphor api key is set, otherwise don't add web tool - # TODO: refactor this later - - if "metaphor_key" in st.secrets: - fns: List[Callable] = [ - agent_builder.create_system_prompt, - agent_builder.load_data, - agent_builder.add_web_tool, - agent_builder.get_rag_params, - agent_builder.set_rag_params, - agent_builder.create_agent, - ] - else: - fns = [ - agent_builder.create_system_prompt, - agent_builder.load_data, - agent_builder.get_rag_params, - agent_builder.set_rag_params, - agent_builder.create_agent, - ] - - fn_tools: List[FunctionTool] = [FunctionTool.from_defaults(fn=fn) for fn in fns] - return fn_tools - - -# define agent -# @st.cache_resource -def load_meta_agent_and_tools( - cache: Optional[ParamCache] = None, - agent_registry: Optional[AgentCacheRegistry] = None, -) -> Tuple[BaseAgent, RAGAgentBuilder]: - - # think of this as tools for the agent to use - agent_builder = RAGAgentBuilder(cache, agent_registry=agent_registry) - - fn_tools = _get_builder_agent_tools(agent_builder) - - builder_agent = load_meta_agent( - fn_tools, llm=BUILDER_LLM, system_prompt=RAG_BUILDER_SYS_STR, verbose=True - ) - - return builder_agent, agent_builder diff --git a/core/agent_builder/loader.py b/core/agent_builder/loader.py new file mode 100644 index 0000000..ba32c93 --- /dev/null +++ b/core/agent_builder/loader.py @@ -0,0 +1,115 @@ +"""Loader agent.""" + +from typing import List, cast, Optional +from llama_index.tools import FunctionTool +from llama_index.agent.types import BaseAgent +from core.builder_config import BUILDER_LLM +from typing import Tuple, Callable +import streamlit as st + +from core.param_cache import ParamCache +from core.utils import ( + load_meta_agent, +) +from core.agent_builder.registry import AgentCacheRegistry +from core.agent_builder.base import RAGAgentBuilder, BaseRAGAgentBuilder +from core.agent_builder.multimodal import MultimodalRAGAgentBuilder + +#################### +#### META Agent #### +#################### + +RAG_BUILDER_SYS_STR = """\ +You are helping to construct an agent given a user-specified task. +You should generally use the tools in this rough order to build the agent. + +1) Create system prompt tool: to create the system prompt for the agent. +2) Load in user-specified data (based on file paths they specify). +3) Decide whether or not to add additional tools. +4) Set parameters for the RAG pipeline. +5) Build the agent + +This will be a back and forth conversation with the user. You should +continue asking users if there's anything else they want to do until +they say they're done. To help guide them on the process, +you can give suggestions on parameters they can set based on the tools they +have available (e.g. "Do you want to set the number of documents to retrieve?") + +""" + + +### DEFINE Agent #### +# NOTE: here we define a function that is dependent on the LLM, +# please make sure to update the LLM above if you change the function below + + +def _get_builder_agent_tools(agent_builder: RAGAgentBuilder) -> List[FunctionTool]: + """Get list of builder agent tools to pass to the builder agent.""" + # see if metaphor api key is set, otherwise don't add web tool + # TODO: refactor this later + + if "metaphor_key" in st.secrets: + fns: List[Callable] = [ + agent_builder.create_system_prompt, + agent_builder.load_data, + agent_builder.add_web_tool, + agent_builder.get_rag_params, + agent_builder.set_rag_params, + agent_builder.create_agent, + ] + else: + fns = [ + agent_builder.create_system_prompt, + agent_builder.load_data, + agent_builder.get_rag_params, + agent_builder.set_rag_params, + agent_builder.create_agent, + ] + + fn_tools: List[FunctionTool] = [FunctionTool.from_defaults(fn=fn) for fn in fns] + return fn_tools + + +def _get_mm_builder_agent_tools( + agent_builder: MultimodalRAGAgentBuilder, +) -> List[FunctionTool]: + """Get list of builder agent tools to pass to the builder agent.""" + fns: List[Callable] = [ + agent_builder.create_system_prompt, + agent_builder.load_data, + agent_builder.get_rag_params, + agent_builder.set_rag_params, + agent_builder.create_agent, + ] + + fn_tools: List[FunctionTool] = [FunctionTool.from_defaults(fn=fn) for fn in fns] + return fn_tools + + +# define agent +def load_meta_agent_and_tools( + cache: Optional[ParamCache] = None, + agent_registry: Optional[AgentCacheRegistry] = None, + is_multimodal: bool = False, +) -> Tuple[BaseAgent, BaseRAGAgentBuilder]: + """Load meta agent and tools.""" + + if is_multimodal: + agent_builder: BaseRAGAgentBuilder = MultimodalRAGAgentBuilder( + cache, agent_registry=agent_registry + ) + fn_tools = _get_mm_builder_agent_tools( + cast(MultimodalRAGAgentBuilder, agent_builder) + ) + builder_agent = load_meta_agent( + fn_tools, llm=BUILDER_LLM, system_prompt=RAG_BUILDER_SYS_STR, verbose=True + ) + else: + # think of this as tools for the agent to use + agent_builder = RAGAgentBuilder(cache, agent_registry=agent_registry) + fn_tools = _get_builder_agent_tools(agent_builder) + builder_agent = load_meta_agent( + fn_tools, llm=BUILDER_LLM, system_prompt=RAG_BUILDER_SYS_STR, verbose=True + ) + + return builder_agent, agent_builder diff --git a/core/agent_builder/multimodal.py b/core/agent_builder/multimodal.py new file mode 100644 index 0000000..b7972e4 --- /dev/null +++ b/core/agent_builder/multimodal.py @@ -0,0 +1,256 @@ +"""Multimodal agent builder.""" + +from llama_index.llms import ChatMessage +from typing import List, cast, Optional +from core.builder_config import BUILDER_LLM +from typing import Dict, Any +import uuid +from core.constants import AGENT_CACHE_DIR + +from core.param_cache import ParamCache, RAGParams +from core.utils import ( + load_data, + construct_mm_agent, +) +from core.agent_builder.registry import AgentCacheRegistry +from core.agent_builder.base import GEN_SYS_PROMPT_TMPL, BaseRAGAgentBuilder + +from llama_index.chat_engine.types import BaseChatEngine + +from llama_index.callbacks import trace_method +from llama_index.query_engine.multi_modal import SimpleMultiModalQueryEngine +from llama_index.chat_engine.types import ( + AGENT_CHAT_RESPONSE_TYPE, + StreamingAgentChatResponse, + AgentChatResponse, +) +from llama_index.llms.base import ChatResponse +from typing import Generator + + +class MultimodalChatEngine(BaseChatEngine): + """Multimodal chat engine. + + This chat engine is a light wrapper around a query engine. + Offers no real 'chat' functionality, is a beta feature. + + """ + + def __init__(self, mm_query_engine: SimpleMultiModalQueryEngine) -> None: + """Init params.""" + self._mm_query_engine = mm_query_engine + + def reset(self) -> None: + """Reset conversation state.""" + pass + + @trace_method("chat") + def chat( + self, message: str, chat_history: Optional[List[ChatMessage]] = None + ) -> AGENT_CHAT_RESPONSE_TYPE: + """Main chat interface.""" + # just return the top-k results + response = self._mm_query_engine.query(message) + return AgentChatResponse(response=str(response)) + + @trace_method("chat") + def stream_chat( + self, message: str, chat_history: Optional[List[ChatMessage]] = None + ) -> StreamingAgentChatResponse: + """Stream chat interface.""" + response = self._mm_query_engine.query(message) + + def _chat_stream(response: str) -> Generator[ChatResponse, None, None]: + yield ChatResponse(message=ChatMessage(role="assistant", content=response)) + + chat_stream = _chat_stream(str(response)) + return StreamingAgentChatResponse(chat_stream=chat_stream) + + @trace_method("chat") + async def achat( + self, message: str, chat_history: Optional[List[ChatMessage]] = None + ) -> AGENT_CHAT_RESPONSE_TYPE: + """Async version of main chat interface.""" + response = await self._mm_query_engine.aquery(message) + return AgentChatResponse(response=str(response)) + + @trace_method("chat") + async def astream_chat( + self, message: str, chat_history: Optional[List[ChatMessage]] = None + ) -> StreamingAgentChatResponse: + """Async version of main chat interface.""" + return self.stream_chat(message, chat_history) + + +class MultimodalRAGAgentBuilder(BaseRAGAgentBuilder): + """Multimodal RAG Agent builder. + + Contains a set of functions to construct a RAG agent, including: + - setting system prompts + - loading data + - adding web search + - setting parameters (e.g. top-k) + + Must pass in a cache. This cache will be modified as the agent is built. + + """ + + def __init__( + self, + cache: Optional[ParamCache] = None, + agent_registry: Optional[AgentCacheRegistry] = None, + ) -> None: + """Init params.""" + self._cache = cache or ParamCache() + self._agent_registry = agent_registry or AgentCacheRegistry( + str(AGENT_CACHE_DIR) + ) + + @property + def cache(self) -> ParamCache: + """Cache.""" + return self._cache + + @property + def agent_registry(self) -> AgentCacheRegistry: + """Agent registry.""" + return self._agent_registry + + def create_system_prompt(self, task: str) -> str: + """Create system prompt for another agent given an input task.""" + llm = BUILDER_LLM + fmt_messages = GEN_SYS_PROMPT_TMPL.format_messages(task=task) + response = llm.chat(fmt_messages) + self._cache.system_prompt = response.message.content + + return f"System prompt created: {response.message.content}" + + def load_data( + self, + file_names: Optional[List[str]] = None, + directory: Optional[str] = None, + ) -> str: + """Load data for a given task. + + Only ONE of file_names or directory should be specified. + **NOTE**: urls not supported in multi-modal setting. + + Args: + file_names (Optional[List[str]]): List of file names to load. + Defaults to None. + directory (Optional[str]): Directory to load files from. + + """ + file_names = file_names or [] + directory = directory or "" + docs = load_data(file_names=file_names, directory=directory) + self._cache.docs = docs + self._cache.file_names = file_names + self._cache.directory = directory + return "Data loaded successfully." + + def get_rag_params(self) -> Dict: + """Get parameters used to configure the RAG pipeline. + + Should be called before `set_rag_params` so that the agent is aware of the + schema. + + """ + rag_params = self._cache.rag_params + return rag_params.dict() + + def set_rag_params(self, **rag_params: Dict) -> str: + """Set RAG parameters. + + These parameters will then be used to actually initialize the agent. + Should call `get_rag_params` first to get the schema of the input dictionary. + + Args: + **rag_params (Dict): dictionary of RAG parameters. + + """ + new_dict = self._cache.rag_params.dict() + new_dict.update(rag_params) + rag_params_obj = RAGParams(**new_dict) + self._cache.rag_params = rag_params_obj + return "RAG parameters set successfully." + + def create_agent(self, agent_id: Optional[str] = None) -> str: + """Create an agent. + + There are no parameters for this function because all the + functions should have already been called to set up the agent. + + """ + if self._cache.system_prompt is None: + raise ValueError("Must set system prompt before creating agent.") + + # construct additional tools + agent, extra_info = construct_mm_agent( + cast(str, self._cache.system_prompt), + cast(RAGParams, self._cache.rag_params), + self._cache.docs, + ) + + # if agent_id not specified, randomly generate one + agent_id = agent_id or self._cache.agent_id or f"Agent_{str(uuid.uuid4())}" + self._cache.builder_type = "multimodal" + self._cache.vector_index = extra_info["vector_index"] + self._cache.agent_id = agent_id + self._cache.agent = agent + + # save the cache to disk + self._agent_registry.add_new_agent_cache(agent_id, self._cache) + return "Agent created successfully." + + def update_agent( + self, + agent_id: str, + system_prompt: Optional[str] = None, + include_summarization: Optional[bool] = None, + top_k: Optional[int] = None, + chunk_size: Optional[int] = None, + embed_model: Optional[str] = None, + llm: Optional[str] = None, + additional_tools: Optional[List] = None, + ) -> None: + """Update agent. + + Delete old agent by ID and create a new one. + Optionally update the system prompt and RAG parameters. + + NOTE: Currently is manually called, not meant for agent use. + + """ + self._agent_registry.delete_agent_cache(self.cache.agent_id) + + # set agent id + self.cache.agent_id = agent_id + + # set system prompt + if system_prompt is not None: + self.cache.system_prompt = system_prompt + # get agent_builder + # We call set_rag_params and create_agent, which will + # update the cache + # TODO: decouple functions from tool functions exposed to the agent + rag_params_dict: Dict[str, Any] = {} + if include_summarization is not None: + rag_params_dict["include_summarization"] = include_summarization + if top_k is not None: + rag_params_dict["top_k"] = top_k + if chunk_size is not None: + rag_params_dict["chunk_size"] = chunk_size + if embed_model is not None: + rag_params_dict["embed_model"] = embed_model + if llm is not None: + rag_params_dict["llm"] = llm + + self.set_rag_params(**rag_params_dict) + + # update tools + if additional_tools is not None: + self.cache.tools = additional_tools + + # this will update the agent in the cache + self.create_agent() diff --git a/core/agent_builder/registry.py b/core/agent_builder/registry.py new file mode 100644 index 0000000..c35da98 --- /dev/null +++ b/core/agent_builder/registry.py @@ -0,0 +1,78 @@ +"""Agent builder registry.""" + +from typing import List +from typing import Union +from pathlib import Path +import json +import shutil + +from core.param_cache import ParamCache + + +class AgentCacheRegistry: + """Registry for agent caches, in disk. + + Can register new agent caches, load agent caches, delete agent caches, etc. + + """ + + def __init__(self, dir: Union[str, Path]) -> None: + """Init params.""" + self._dir = dir + + def _add_agent_id_to_directory(self, agent_id: str) -> None: + """Save agent id to directory.""" + full_path = Path(self._dir) / "agent_ids.json" + if not full_path.exists(): + with open(full_path, "w") as f: + json.dump({"agent_ids": [agent_id]}, f) + else: + with open(full_path, "r") as f: + agent_ids = json.load(f)["agent_ids"] + if agent_id in agent_ids: + raise ValueError(f"Agent id {agent_id} already exists.") + agent_ids_set = set(agent_ids) + agent_ids_set.add(agent_id) + with open(full_path, "w") as f: + json.dump({"agent_ids": list(agent_ids_set)}, f) + + def add_new_agent_cache(self, agent_id: str, cache: ParamCache) -> None: + """Register agent.""" + # save the cache to disk + agent_cache_path = f"{self._dir}/{agent_id}" + cache.save_to_disk(agent_cache_path) + # save to agent ids + self._add_agent_id_to_directory(agent_id) + + def get_agent_ids(self) -> List[str]: + """Get agent ids.""" + full_path = Path(self._dir) / "agent_ids.json" + if not full_path.exists(): + return [] + with open(full_path, "r") as f: + agent_ids = json.load(f)["agent_ids"] + + return agent_ids + + def get_agent_cache(self, agent_id: str) -> ParamCache: + """Get agent cache.""" + full_path = Path(self._dir) / f"{agent_id}" + if not full_path.exists(): + raise ValueError(f"Cache for agent {agent_id} does not exist.") + cache = ParamCache.load_from_disk(str(full_path)) + return cache + + def delete_agent_cache(self, agent_id: str) -> None: + """Delete agent cache.""" + # modify / resave agent_ids + agent_ids = self.get_agent_ids() + new_agent_ids = [id for id in agent_ids if id != agent_id] + full_path = Path(self._dir) / "agent_ids.json" + with open(full_path, "w") as f: + json.dump({"agent_ids": new_agent_ids}, f) + + # remove agent cache + full_path = Path(self._dir) / f"{agent_id}" + if full_path.exists(): + # recursive delete + shutil.rmtree(full_path) diff --git a/core/param_cache.py b/core/param_cache.py index a7cd6e8..aab81ae 100644 --- a/core/param_cache.py +++ b/core/param_cache.py @@ -11,7 +11,13 @@ from pathlib import Path import json import uuid -from core.utils import load_data, get_tool_objects, construct_agent, RAGParams +from core.utils import ( + load_data, + get_tool_objects, + construct_agent, + RAGParams, + construct_mm_agent, +) class ParamCache(BaseModel): @@ -37,6 +43,10 @@ class Config: urls: List[str] = Field( default_factory=list, description="URLs as data source (if specified)" ) + directory: Optional[str] = Field( + default=None, description="Directory as data source (if specified)" + ) + docs: List = Field(default_factory=list, description="Documents for RAG agent.") # tools tools: List = Field( @@ -48,6 +58,9 @@ class Config: ) # agent params + builder_type: str = Field( + default="default", description="Builder type (default, multimodal)." + ) vector_index: Optional[VectorStoreIndex] = Field( default=None, description="Vector index for RAG agent." ) @@ -66,9 +79,11 @@ def save_to_disk(self, save_dir: str) -> None: "system_prompt": self.system_prompt, "file_names": self.file_names, "urls": self.urls, + "directory": self.directory, # TODO: figure out tools "tools": self.tools, "rag_params": self.rag_params.dict(), + "builder_type": self.builder_type, "agent_id": self.agent_id, } # store the vector store within the agent @@ -88,13 +103,22 @@ def load_from_disk( save_dir: str, ) -> "ParamCache": """Load cache from disk.""" + with open(Path(save_dir) / "cache.json", "r") as f: + cache_dict = json.load(f) + storage_context = StorageContext.from_defaults( persist_dir=str(Path(save_dir) / "storage") ) - vector_index = cast(VectorStoreIndex, load_index_from_storage(storage_context)) + if cache_dict["builder_type"] == "multimodal": + from llama_index.indices.multi_modal.base import MultiModalVectorStoreIndex - with open(Path(save_dir) / "cache.json", "r") as f: - cache_dict = json.load(f) + vector_index: VectorStoreIndex = cast( + MultiModalVectorStoreIndex, load_index_from_storage(storage_context) + ) + else: + vector_index = cast( + VectorStoreIndex, load_index_from_storage(storage_context) + ) # replace rag params with RAGParams object cache_dict["rag_params"] = RAGParams(**cache_dict["rag_params"]) @@ -102,18 +126,30 @@ def load_from_disk( # add in the missing fields # load docs cache_dict["docs"] = load_data( - file_names=cache_dict["file_names"], urls=cache_dict["urls"] + file_names=cache_dict["file_names"], + urls=cache_dict["urls"], + directory=cache_dict["directory"], ) # load agent from index additional_tools = get_tool_objects(cache_dict["tools"]) - agent, _ = construct_agent( - cache_dict["system_prompt"], - cache_dict["rag_params"], - cache_dict["docs"], - vector_index=vector_index, - additional_tools=additional_tools, - # TODO: figure out tools - ) + + if cache_dict["builder_type"] == "multimodal": + vector_index = cast(MultiModalVectorStoreIndex, vector_index) + agent, _ = construct_mm_agent( + cache_dict["system_prompt"], + cache_dict["rag_params"], + cache_dict["docs"], + mm_vector_index=vector_index, + ) + else: + agent, _ = construct_agent( + cache_dict["system_prompt"], + cache_dict["rag_params"], + cache_dict["docs"], + vector_index=vector_index, + additional_tools=additional_tools, + # TODO: figure out tools + ) cache_dict["vector_index"] = vector_index cache_dict["agent"] = agent diff --git a/core/utils.py b/core/utils.py index cb50c6a..c39a02c 100644 --- a/core/utils.py +++ b/core/utils.py @@ -26,8 +26,25 @@ from typing import Dict, Tuple, Any import streamlit as st -from llama_index.callbacks import CallbackManager +from llama_index.callbacks import CallbackManager, trace_method from core.callback_manager import StreamlitFunctionsCallbackHandler +from llama_index.schema import ImageNode, NodeWithScore + +### BETA: Multi-modal +from llama_index.indices.multi_modal.base import MultiModalVectorStoreIndex +from llama_index.multi_modal_llms.openai import OpenAIMultiModal +from llama_index.indices.multi_modal.retriever import ( + MultiModalVectorIndexRetriever, +) +from llama_index.llms import ChatMessage +from llama_index.query_engine.multi_modal import SimpleMultiModalQueryEngine +from llama_index.chat_engine.types import ( + AGENT_CHAT_RESPONSE_TYPE, + StreamingAgentChatResponse, + AgentChatResponse, +) +from llama_index.llms.base import ChatResponse +from typing import Generator class RAGParams(BaseModel): @@ -82,18 +99,28 @@ def _resolve_llm(llm_str: str) -> LLM: def load_data( - file_names: Optional[List[str]] = None, urls: Optional[List[str]] = None + file_names: Optional[List[str]] = None, + directory: Optional[str] = None, + urls: Optional[List[str]] = None, ) -> List[Document]: """Load data.""" file_names = file_names or [] + directory = directory or "" urls = urls or [] - if not file_names and not urls: - raise ValueError("Must specify either file_names or urls.") - elif file_names and urls: - raise ValueError("Must specify only one of file_names or urls.") + + # get number depending on whether specified + num_specified = sum(1 for v in [file_names, urls, directory] if v) + + if num_specified == 0: + raise ValueError("Must specify either file_names or urls or directory.") + elif num_specified > 1: + raise ValueError("Must specify only one of file_names or urls or directory.") elif file_names: reader = SimpleDirectoryReader(input_files=file_names) docs = reader.load_data() + elif directory: + reader = SimpleDirectoryReader(input_dir=directory) + docs = reader.load_data() elif urls: from llama_hub.web.simple_web.base import SimpleWebPageReader @@ -101,7 +128,7 @@ def load_data( loader = SimpleWebPageReader() docs = loader.load_data(urls=urls) else: - raise ValueError("Must specify either file_names or urls.") + raise ValueError("Must specify either file_names or urls or directory.") return docs @@ -326,3 +353,128 @@ def get_tool_objects(tool_names: List[str]) -> List: raise ValueError(f"Tool {tool_name} not recognized.") return tool_objs + + +class MultimodalChatEngine(BaseChatEngine): + """Multimodal chat engine. + + This chat engine is a light wrapper around a query engine. + Offers no real 'chat' functionality, is a beta feature. + + """ + + def __init__(self, mm_query_engine: SimpleMultiModalQueryEngine) -> None: + """Init params.""" + self._mm_query_engine = mm_query_engine + + def reset(self) -> None: + """Reset conversation state.""" + pass + + @property + def chat_history(self) -> List[ChatMessage]: + return [] + + @trace_method("chat") + def chat( + self, message: str, chat_history: Optional[List[ChatMessage]] = None + ) -> AGENT_CHAT_RESPONSE_TYPE: + """Main chat interface.""" + # just return the top-k results + response = self._mm_query_engine.query(message) + return AgentChatResponse( + response=str(response), source_nodes=response.source_nodes + ) + + @trace_method("chat") + def stream_chat( + self, message: str, chat_history: Optional[List[ChatMessage]] = None + ) -> StreamingAgentChatResponse: + """Stream chat interface.""" + response = self._mm_query_engine.query(message) + + def _chat_stream(response: str) -> Generator[ChatResponse, None, None]: + yield ChatResponse(message=ChatMessage(role="assistant", content=response)) + + chat_stream = _chat_stream(str(response)) + return StreamingAgentChatResponse( + chat_stream=chat_stream, source_nodes=response.source_nodes + ) + + @trace_method("chat") + async def achat( + self, message: str, chat_history: Optional[List[ChatMessage]] = None + ) -> AGENT_CHAT_RESPONSE_TYPE: + """Async version of main chat interface.""" + response = await self._mm_query_engine.aquery(message) + return AgentChatResponse( + response=str(response), source_nodes=response.source_nodes + ) + + @trace_method("chat") + async def astream_chat( + self, message: str, chat_history: Optional[List[ChatMessage]] = None + ) -> StreamingAgentChatResponse: + """Async version of main chat interface.""" + return self.stream_chat(message, chat_history) + + +def construct_mm_agent( + system_prompt: str, + rag_params: RAGParams, + docs: List[Document], + mm_vector_index: Optional[VectorStoreIndex] = None, + additional_tools: Optional[List] = None, +) -> Tuple[BaseChatEngine, Dict]: + """Construct agent from docs / parameters / indices. + + NOTE: system prompt isn't used right now + + """ + extra_info = {} + additional_tools = additional_tools or [] + + # first resolve llm and embedding model + embed_model = resolve_embed_model(rag_params.embed_model) + # TODO: use OpenAI for now + os.environ["OPENAI_API_KEY"] = st.secrets.openai_key + openai_mm_llm = OpenAIMultiModal(model="gpt-4-vision-preview", max_new_tokens=1500) + + # first let's index the data with the right parameters + service_context = ServiceContext.from_defaults( + chunk_size=rag_params.chunk_size, + embed_model=embed_model, + ) + + if mm_vector_index is None: + mm_vector_index = MultiModalVectorStoreIndex.from_documents( + docs, service_context=service_context + ) + else: + pass + + mm_retriever = mm_vector_index.as_retriever(similarity_top_k=rag_params.top_k) + mm_query_engine = SimpleMultiModalQueryEngine( + cast(MultiModalVectorIndexRetriever, mm_retriever), + multi_modal_llm=openai_mm_llm, + ) + + extra_info["vector_index"] = mm_vector_index + + # use condense + context chat engine + agent = MultimodalChatEngine(mm_query_engine) + + return agent, extra_info + + +def get_image_and_text_nodes( + nodes: List[NodeWithScore], +) -> Tuple[List[NodeWithScore], List[NodeWithScore]]: + image_nodes = [] + text_nodes = [] + for res_node in nodes: + if isinstance(res_node.node, ImageNode): + image_nodes.append(res_node) + else: + text_nodes.append(res_node) + return image_nodes, text_nodes diff --git "a/pages/2_\342\232\231\357\270\217_RAG_Config.py" "b/pages/2_\342\232\231\357\270\217_RAG_Config.py" index 61f7186..30fe5c0 100644 --- "a/pages/2_\342\232\231\357\270\217_RAG_Config.py" +++ "b/pages/2_\342\232\231\357\270\217_RAG_Config.py" @@ -4,7 +4,7 @@ from core.param_cache import ( RAGParams, ) -from core.agent_builder import ( +from core.agent_builder.loader import ( RAGAgentBuilder, AgentCacheRegistry, ) @@ -93,14 +93,24 @@ def delete_agent() -> None: ) rag_params = cast(RAGParams, current_state.cache.rag_params) - file_names = st.text_input( - "File names (not editable)", - value=",".join(current_state.cache.file_names), - disabled=True, - ) - urls = st.text_input( - "URLs (not editable)", value=",".join(current_state.cache.urls), disabled=True - ) + + with st.expander("Loaded Data (Expand to view)"): + file_names = st.text_input( + "File names (not editable)", + value=",".join(current_state.cache.file_names), + disabled=True, + ) + directory = st.text_input( + "Directory (not editable)", + value=current_state.cache.directory, + disabled=True, + ) + urls = st.text_input( + "URLs (not editable)", + value=",".join(current_state.cache.urls), + disabled=True, + ) + include_summarization_st = st.checkbox( "Include Summarization (only works for GPT-4)", value=rag_params.include_summarization, diff --git "a/pages/3_\360\237\244\226_Generated_RAG_Agent.py" "b/pages/3_\360\237\244\226_Generated_RAG_Agent.py" index ed95796..90cb910 100644 --- "a/pages/3_\360\237\244\226_Generated_RAG_Agent.py" +++ "b/pages/3_\360\237\244\226_Generated_RAG_Agent.py" @@ -1,6 +1,11 @@ """Streamlit page showing builder config.""" import streamlit as st from st_utils import add_sidebar, get_current_state +from core.utils import get_image_and_text_nodes +from llama_index.schema import MetadataMode +from llama_index.chat_engine.types import AGENT_CHAT_RESPONSE_TYPE +from typing import Dict, Optional +import pandas as pd #################### @@ -28,8 +33,36 @@ ] -def add_to_message_history(role: str, content: str) -> None: - message = {"role": role, "content": str(content)} +def display_sources(response: AGENT_CHAT_RESPONSE_TYPE) -> None: + image_nodes, text_nodes = get_image_and_text_nodes(response.source_nodes) + if len(image_nodes) > 0 or len(text_nodes) > 0: + with st.expander("Sources"): + # get image nodes + if len(image_nodes) > 0: + st.subheader("Images") + for image_node in image_nodes: + st.image(image_node.metadata["file_path"]) + + if len(text_nodes) > 0: + st.subheader("Text") + sources_df_list = [] + for text_node in text_nodes: + sources_df_list.append( + { + "ID": text_node.id_, + "Text": text_node.node.get_content( + metadata_mode=MetadataMode.ALL + ), + } + ) + sources_df = pd.DataFrame(sources_df_list) + st.dataframe(sources_df) + + +def add_to_message_history( + role: str, content: str, extra: Optional[Dict] = None +) -> None: + message = {"role": role, "content": str(content), "extra": extra} st.session_state.agent_messages.append(message) # Add response to message history @@ -45,6 +78,11 @@ def display_messages() -> None: else: raise ValueError(f"Unknown message type: {msg_type}") + # display sources + if "extra" in message and isinstance(message["extra"], dict): + if "response" in message["extra"].keys(): + display_sources(message["extra"]["response"]) + # if agent is created, then we can chat with it if current_state.cache is not None and current_state.cache.agent is not None: @@ -68,6 +106,13 @@ def display_messages() -> None: with st.spinner("Thinking..."): response = agent.chat(str(prompt)) st.write(str(response)) - add_to_message_history("assistant", str(response)) + + # display sources + # Multi-modal: check if image nodes are present + display_sources(response) + + add_to_message_history( + "assistant", str(response), extra={"response": response} + ) else: st.info("Agent not created. Please create an agent in the above section.") diff --git a/pyproject.toml b/pyproject.toml index ca8e30a..5e477a4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "rags" -version = "0.0.4" +version = "0.0.5" description = "Build RAG with natural language." authors = ["Jerry Liu"] # New attributes @@ -22,6 +22,7 @@ llama-hub = "0.0.44" # NOTE: this is due to a trivial dependency in the web tool, will refactor langchain = "0.0.305" pypdf = "3.17.1" +clip = { git = "https://github.com/openai/CLIP.git" } [tool.poetry.dev-dependencies] # pytest = "7.2.1" @@ -65,3 +66,11 @@ exclude = [ [tool.ruff.per-file-ignores] "base.py" = ["E402", "F811", "E501"] + + +[tool.poetry.extras] +multimodal = [ + "torch", + "torchvision", + "clip", +] \ No newline at end of file diff --git a/st_utils.py b/st_utils.py index bf1fa58..5f7a5a0 100644 --- a/st_utils.py +++ b/st_utils.py @@ -1,9 +1,9 @@ """Streamlit utils.""" -from core.agent_builder import ( +from core.agent_builder.loader import ( load_meta_agent_and_tools, AgentCacheRegistry, - RAGAgentBuilder, ) +from core.agent_builder.base import BaseRAGAgentBuilder from core.param_cache import ParamCache from core.constants import ( AGENT_CACHE_DIR, @@ -38,6 +38,47 @@ def update_selected_agent() -> None: update_selected_agent_with_id(selected_id) +def get_cached_is_multimodal() -> bool: + """Get default multimodal st.""" + if ( + "selected_cache" not in st.session_state.keys() + or st.session_state.selected_cache is None + ): + default_val = False + else: + selected_cache = cast(ParamCache, st.session_state.selected_cache) + default_val = True if selected_cache.builder_type == "multimodal" else False + return default_val + + +def get_is_multimodal() -> bool: + """Get is multimodal.""" + if "is_multimodal_st" not in st.session_state.keys(): + st.session_state.is_multimodal_st = False + return st.session_state.is_multimodal_st + + +def add_builder_config() -> None: + """Add builder config.""" + with st.expander("Builder Config (Advanced)"): + # add a few options - openai api key, and + if ( + "selected_cache" not in st.session_state.keys() + or st.session_state.selected_cache is None + ): + is_locked = False + else: + is_locked = True + + st.checkbox( + "Enable multimodal search (beta)", + key="is_multimodal_st", + on_change=update_selected_agent, + value=get_cached_is_multimodal(), + disabled=is_locked, + ) + + def add_sidebar() -> None: """Add sidebar.""" with st.sidebar: @@ -70,7 +111,7 @@ class Config: agent_registry: AgentCacheRegistry selected_id: Optional[str] selected_cache: Optional[ParamCache] - agent_builder: RAGAgentBuilder + agent_builder: BaseRAGAgentBuilder cache: ParamCache builder_agent: BaseAgent @@ -126,11 +167,15 @@ def get_current_state() -> CurrentSessionState: builder_agent, agent_builder = load_meta_agent_and_tools( cache=st.session_state.selected_cache, agent_registry=st.session_state.agent_registry, + # NOTE: we will probably generalize this later into different + # builder configs + is_multimodal=get_cached_is_multimodal(), ) else: # create builder agent / tools from new cache builder_agent, agent_builder = load_meta_agent_and_tools( - agent_registry=st.session_state.agent_registry + agent_registry=st.session_state.agent_registry, + is_multimodal=get_is_multimodal(), ) st.session_state.builder_agent = builder_agent