Skip to content

Commit

Permalink
upgrade v3 (#49)
Browse files Browse the repository at this point in the history
support multi-modality
allow loading an entire directory (and all the files in the directory)
add sources in the response
  • Loading branch information
jerryjliu committed Dec 5, 2023
1 parent 76bf409 commit 4bec270
Show file tree
Hide file tree
Showing 13 changed files with 819 additions and 228 deletions.
43 changes: 7 additions & 36 deletions 1_🏠_Home.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import streamlit as st
from streamlit_pills import pills

from st_utils import add_sidebar, get_current_state
from st_utils import (
add_builder_config,
add_sidebar,
get_current_state,
)

current_state = get_current_state()

####################
#### STREAMLIT #####
Expand All @@ -28,7 +33,7 @@
st.info("**NOTE**: The ability to add web search is enabled.")


current_state = get_current_state()
add_builder_config()
add_sidebar()


Expand Down Expand Up @@ -60,40 +65,6 @@ def add_to_message_history(role: str, content: str) -> None:
with st.chat_message(message["role"]):
st.write(message["content"])


# def handle_user_input() -> None:
# """Handle user input."""
# prompt = st.session_state.user_question_st
# print(f"USER PROMPT: {prompt}")
# add_to_message_history("user", prompt)
# with st.chat_message("user"):
# st.write(prompt)
# # If last message is not from assistant, generate a new response
# if st.session_state.messages[-1]["role"] != "assistant":
# with st.chat_message("assistant"):
# with st.spinner("Thinking..."):
# response = current_state.builder_agent.chat(prompt)
# st.write(str(response))
# add_to_message_history("assistant", str(response))

# else:
# pass

# # check agent_ids again
# # if it doesn't match, add to directory and refresh
# agent_ids = current_state.agent_registry.get_agent_ids()
# # check diff between agent_ids and cur agent ids
# diff_ids = list(set(agent_ids) - set(st.session_state.cur_agent_ids))
# if len(diff_ids) > 0:
# # # clear streamlit cache, to allow you to generate a new agent
# # st.cache_resource.clear()
# st.rerun()

# handle user input
# st.chat_input(
# "Your question", key="user_question_st", on_submit=handle_user_input
# ) # Prompt for user input and save to chat history

# TODO: this is really hacky, only because st.rerun is jank
if prompt := st.chat_input(
"Your question",
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ streamlit run 1_🏠_Home.py
```

**NOTE**: If you've upgraded the version of RAGs, and you're running into issues on launch, you may need to delete the `cache` folder in your home directory (we may have introduced breaking changes in the stored data structure between versions).

## Detailed Overview

The app contains the following sections, corresponding to the steps listed above.
Expand Down
Empty file added core/agent_builder/__init__.py
Empty file.
182 changes: 27 additions & 155 deletions core/agent_builder.py → core/agent_builder/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,93 +3,19 @@
from llama_index.llms import ChatMessage
from llama_index.prompts import ChatPromptTemplate
from typing import List, cast, Optional
from llama_index.tools import FunctionTool
from llama_index.agent.types import BaseAgent
from core.builder_config import BUILDER_LLM
from typing import Dict, Tuple, Any, Callable, Union
import streamlit as st
from pathlib import Path
import json
from typing import Dict, Any
import uuid
from core.constants import AGENT_CACHE_DIR
import shutil
from abc import ABC, abstractmethod

from core.param_cache import ParamCache, RAGParams
from core.utils import (
load_data,
get_tool_objects,
construct_agent,
load_meta_agent,
)


class AgentCacheRegistry:
"""Registry for agent caches, in disk.
Can register new agent caches, load agent caches, delete agent caches, etc.
"""

def __init__(self, dir: Union[str, Path]) -> None:
"""Init params."""
self._dir = dir

def _add_agent_id_to_directory(self, agent_id: str) -> None:
"""Save agent id to directory."""
full_path = Path(self._dir) / "agent_ids.json"
if not full_path.exists():
with open(full_path, "w") as f:
json.dump({"agent_ids": [agent_id]}, f)
else:
with open(full_path, "r") as f:
agent_ids = json.load(f)["agent_ids"]
if agent_id in agent_ids:
raise ValueError(f"Agent id {agent_id} already exists.")
agent_ids_set = set(agent_ids)
agent_ids_set.add(agent_id)
with open(full_path, "w") as f:
json.dump({"agent_ids": list(agent_ids_set)}, f)

def add_new_agent_cache(self, agent_id: str, cache: ParamCache) -> None:
"""Register agent."""
# save the cache to disk
agent_cache_path = f"{self._dir}/{agent_id}"
cache.save_to_disk(agent_cache_path)
# save to agent ids
self._add_agent_id_to_directory(agent_id)

def get_agent_ids(self) -> List[str]:
"""Get agent ids."""
full_path = Path(self._dir) / "agent_ids.json"
if not full_path.exists():
return []
with open(full_path, "r") as f:
agent_ids = json.load(f)["agent_ids"]

return agent_ids

def get_agent_cache(self, agent_id: str) -> ParamCache:
"""Get agent cache."""
full_path = Path(self._dir) / f"{agent_id}"
if not full_path.exists():
raise ValueError(f"Cache for agent {agent_id} does not exist.")
cache = ParamCache.load_from_disk(str(full_path))
return cache

def delete_agent_cache(self, agent_id: str) -> None:
"""Delete agent cache."""
# modify / resave agent_ids
agent_ids = self.get_agent_ids()
new_agent_ids = [id for id in agent_ids if id != agent_id]
full_path = Path(self._dir) / "agent_ids.json"
with open(full_path, "w") as f:
json.dump({"agent_ids": new_agent_ids}, f)

# remove agent cache
full_path = Path(self._dir) / f"{agent_id}"
if full_path.exists():
# recursive delete
shutil.rmtree(full_path)
from core.agent_builder.registry import AgentCacheRegistry


# System prompt tool
Expand Down Expand Up @@ -121,7 +47,21 @@ def delete_agent_cache(self, agent_id: str) -> None:
GEN_SYS_PROMPT_TMPL = ChatPromptTemplate(gen_sys_prompt_messages)


class RAGAgentBuilder:
class BaseRAGAgentBuilder(ABC):
"""Base RAG Agent builder class."""

@property
@abstractmethod
def cache(self) -> ParamCache:
"""Cache."""

@property
@abstractmethod
def agent_registry(self) -> AgentCacheRegistry:
"""Agent registry."""


class RAGAgentBuilder(BaseRAGAgentBuilder):
"""RAG Agent builder.
Contains a set of functions to construct a RAG agent, including:
Expand Down Expand Up @@ -165,25 +105,31 @@ def create_system_prompt(self, task: str) -> str:
return f"System prompt created: {response.message.content}"

def load_data(
self, file_names: Optional[List[str]] = None, urls: Optional[List[str]] = None
self,
file_names: Optional[List[str]] = None,
directory: Optional[str] = None,
urls: Optional[List[str]] = None,
) -> str:
"""Load data for a given task.
Only ONE of file_names or urls should be specified.
Only ONE of file_names or directory or urls should be specified.
Args:
file_names (Optional[List[str]]): List of file names to load.
Defaults to None.
directory (Optional[str]): Directory to load files from.
urls (Optional[List[str]]): List of urls to load.
Defaults to None.
"""
file_names = file_names or []
urls = urls or []
docs = load_data(file_names=file_names, urls=urls)
directory = directory or ""
docs = load_data(file_names=file_names, directory=directory, urls=urls)
self._cache.docs = docs
self._cache.file_names = file_names
self._cache.urls = urls
self._cache.directory = directory
return "Data loaded successfully."

def add_web_tool(self) -> str:
Expand Down Expand Up @@ -302,77 +248,3 @@ def update_agent(

# this will update the agent in the cache
self.create_agent()


####################
#### META Agent ####
####################

RAG_BUILDER_SYS_STR = """\
You are helping to construct an agent given a user-specified task.
You should generally use the tools in this rough order to build the agent.
1) Create system prompt tool: to create the system prompt for the agent.
2) Load in user-specified data (based on file paths they specify).
3) Decide whether or not to add additional tools.
4) Set parameters for the RAG pipeline.
5) Build the agent
This will be a back and forth conversation with the user. You should
continue asking users if there's anything else they want to do until
they say they're done. To help guide them on the process,
you can give suggestions on parameters they can set based on the tools they
have available (e.g. "Do you want to set the number of documents to retrieve?")
"""


### DEFINE Agent ####
# NOTE: here we define a function that is dependent on the LLM,
# please make sure to update the LLM above if you change the function below


def _get_builder_agent_tools(agent_builder: RAGAgentBuilder) -> List[FunctionTool]:
"""Get list of builder agent tools to pass to the builder agent."""
# see if metaphor api key is set, otherwise don't add web tool
# TODO: refactor this later

if "metaphor_key" in st.secrets:
fns: List[Callable] = [
agent_builder.create_system_prompt,
agent_builder.load_data,
agent_builder.add_web_tool,
agent_builder.get_rag_params,
agent_builder.set_rag_params,
agent_builder.create_agent,
]
else:
fns = [
agent_builder.create_system_prompt,
agent_builder.load_data,
agent_builder.get_rag_params,
agent_builder.set_rag_params,
agent_builder.create_agent,
]

fn_tools: List[FunctionTool] = [FunctionTool.from_defaults(fn=fn) for fn in fns]
return fn_tools


# define agent
# @st.cache_resource
def load_meta_agent_and_tools(
cache: Optional[ParamCache] = None,
agent_registry: Optional[AgentCacheRegistry] = None,
) -> Tuple[BaseAgent, RAGAgentBuilder]:

# think of this as tools for the agent to use
agent_builder = RAGAgentBuilder(cache, agent_registry=agent_registry)

fn_tools = _get_builder_agent_tools(agent_builder)

builder_agent = load_meta_agent(
fn_tools, llm=BUILDER_LLM, system_prompt=RAG_BUILDER_SYS_STR, verbose=True
)

return builder_agent, agent_builder
Loading

0 comments on commit 4bec270

Please sign in to comment.