diff --git a/docs/cli_reference.md b/docs/cli_reference.md index d063dc1..cd9684b 100644 --- a/docs/cli_reference.md +++ b/docs/cli_reference.md @@ -286,6 +286,13 @@ i+-------------------------------+---------------------------------------+------ | | "memory": "meta-reference-faiss" | | | | } | | +--------------------------------+---------------------------------------+----------------------------------------------------------------------+ +| local-plus-tgi-inference | { | Use TGI (local or with [Hugging Face Inference Endpoints](https:// | +| | "inference": "remote::tgi", | huggingface.co/inference-endpoints/dedicated)) for running LLM | +| | "safety": "meta-reference", | inference. When using HF Inference Endpoints, you must provide the | +| | "agentic_system": "meta-reference", | name of the endpoint. | +| | "memory": "meta-reference-faiss" | | +| | } | | ++--------------------------------+---------------------------------------+----------------------------------------------------------------------+ As you can see above, each “distribution” details the “providers” it is composed of. For example, `local` uses the “meta-reference” provider for inference while local-ollama relies on a different provider (Ollama) for inference. Similarly, you can use Fireworks or Together.AI for running inference as well. diff --git a/llama_toolchain/core/distribution_registry.py b/llama_toolchain/core/distribution_registry.py index 75c43ed..a69529b 100644 --- a/llama_toolchain/core/distribution_registry.py +++ b/llama_toolchain/core/distribution_registry.py @@ -65,11 +65,23 @@ def available_distribution_specs() -> List[DistributionSpec]: Api.telemetry: "console", }, ), + DistributionSpec( + distribution_type="local-plus-tgi-inference", + description="Use TGI for running LLM inference", + providers={ + Api.inference: remote_provider_type("tgi"), + Api.safety: "meta-reference", + Api.agentic_system: "meta-reference", + Api.memory: "meta-reference-faiss", + }, + ), ] @lru_cache() -def resolve_distribution_spec(distribution_type: str) -> Optional[DistributionSpec]: +def resolve_distribution_spec( + distribution_type: str, +) -> Optional[DistributionSpec]: for spec in available_distribution_specs(): if spec.distribution_type == distribution_type: return spec diff --git a/llama_toolchain/inference/adapters/tgi/__init__.py b/llama_toolchain/inference/adapters/tgi/__init__.py index 4940667..7438078 100644 --- a/llama_toolchain/inference/adapters/tgi/__init__.py +++ b/llama_toolchain/inference/adapters/tgi/__init__.py @@ -4,12 +4,21 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_toolchain.core.datatypes import RemoteProviderConfig +from .config import TGIImplConfig +from .tgi import InferenceEndpointAdapter, TGIAdapter -async def get_adapter_impl(config: RemoteProviderConfig, _deps): - from .tgi import TGIInferenceAdapter +async def get_adapter_impl(config: TGIImplConfig, _deps): + assert isinstance(config, TGIImplConfig), f"Unexpected config type: {type(config)}" + + if config.url is not None: + impl = TGIAdapter(config) + elif config.is_inference_endpoint(): + impl = InferenceEndpointAdapter(config) + else: + raise ValueError( + "Invalid configuration. Specify either an URL or HF Inference Endpoint details (namespace and endpoint name)." + ) - impl = TGIInferenceAdapter(config.url) await impl.initialize() return impl diff --git a/llama_toolchain/inference/adapters/tgi/config.py b/llama_toolchain/inference/adapters/tgi/config.py new file mode 100644 index 0000000..a0135df --- /dev/null +++ b/llama_toolchain/inference/adapters/tgi/config.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Optional + +from llama_models.schema_utils import json_schema_type +from pydantic import BaseModel, Field + + +@json_schema_type +class TGIImplConfig(BaseModel): + url: Optional[str] = Field( + default=None, + description="The URL for the local TGI endpoint (e.g., http://localhost:8080)", + ) + api_token: Optional[str] = Field( + default=None, + description="The HF token for Hugging Face Inference Endpoints (will default to locally saved token if not provided)", + ) + hf_endpoint_name: Optional[str] = Field( + default=None, + description="The name of the Hugging Face Inference Endpoint : can be either in the format of '{namespace}/{endpoint_name}' (namespace can be the username or organization name) or just '{endpoint_name}' if logged into the same account as the namespace", + ) + + def is_inference_endpoint(self) -> bool: + return self.hf_endpoint_name is not None diff --git a/llama_toolchain/inference/adapters/tgi/tgi.py b/llama_toolchain/inference/adapters/tgi/tgi.py index 7eb36ac..bb7b99d 100644 --- a/llama_toolchain/inference/adapters/tgi/tgi.py +++ b/llama_toolchain/inference/adapters/tgi/tgi.py @@ -4,63 +4,68 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import AsyncGenerator, List -import httpx +from typing import Any, AsyncGenerator, Dict -from llama_models.llama3.api.chat_format import ChatFormat +import requests -from llama_models.llama3.api.datatypes import Message, StopReason +from huggingface_hub import HfApi, InferenceClient +from llama_models.llama3.api.chat_format import ChatFormat +from llama_models.llama3.api.datatypes import StopReason from llama_models.llama3.api.tokenizer import Tokenizer - -from text_generation import Client - from llama_toolchain.inference.api import * # noqa: F403 from llama_toolchain.inference.prepare_messages import prepare_messages +from .config import TGIImplConfig -SUPPORTED_MODELS = { +HF_SUPPORTED_MODELS = { "Meta-Llama3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct", "Meta-Llama3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct", "Meta-Llama3.1-405B-Instruct": "meta-llama/Meta-Llama-3.1-405B-Instruct", } -class TGIInferenceAdapter(Inference): - def __init__(self, url: str) -> None: - self.url = url.rstrip("/") +class TGIAdapter(Inference): + def __init__(self, config: TGIImplConfig) -> None: + self.config = config self.tokenizer = Tokenizer.get_instance() self.formatter = ChatFormat(self.tokenizer) - self.model = None - self.max_tokens = None - async def initialize(self) -> None: - hf_models = {v: k for k, v in SUPPORTED_MODELS.items()} + @property + def client(self) -> InferenceClient: + return InferenceClient(model=self.config.url, token=self.config.api_token) - try: - print(f"Connecting to TGI server at: {self.url}") - async with httpx.AsyncClient() as client: - response = await client.get(f"{self.url}/info") - response.raise_for_status() - info = response.json() - if "model_id" not in info: - raise RuntimeError("Missing model_id in model info") - if "max_total_tokens" not in info: - raise RuntimeError("Missing max_total_tokens in model info") - self.max_tokens = info["max_total_tokens"] - - model_id = info["model_id"] - if model_id not in hf_models: - raise RuntimeError( - f"TGI is serving model: {model_id}, use one of the supported models: {','.join(hf_models.keys())}" - ) + def _get_endpoint_info(self) -> Dict[str, Any]: + return { + **self.client.get_endpoint_info(), + "inference_url": self.config.url, + } - self.model = hf_models[model_id] + async def initialize(self) -> None: + try: + info = self._get_endpoint_info() + if "model_id" not in info: + raise RuntimeError("Missing model_id in model info") + if "max_total_tokens" not in info: + raise RuntimeError("Missing max_total_tokens in model info") + self.max_tokens = info["max_total_tokens"] + + model_id = info["model_id"] + model_name = next( + (name for name, id in HF_SUPPORTED_MODELS.items() if id == model_id), + None, + ) + if model_name is None: + raise RuntimeError( + f"TGI is serving model: {model_id}, use one of the supported models: {', '.join(HF_SUPPORTED_MODELS.values())}" + ) + self.model_name = model_name + self.inference_url = info["inference_url"] except Exception as e: import traceback traceback.print_exc() - raise RuntimeError("Could not connect to TGI server") from e + raise RuntimeError(f"Error initializing TGIAdapter: {e}") from e async def shutdown(self) -> None: pass @@ -68,16 +73,6 @@ async def shutdown(self) -> None: async def completion(self, request: CompletionRequest) -> AsyncGenerator: raise NotImplementedError() - def _convert_messages(self, messages: List[Message]) -> List[Message]: - ret = [] - for message in messages: - if message.role == "ipython": - role = "tool" - else: - role = message.role - ret.append({"role": role, "content": message.content}) - return ret - def get_chat_options(self, request: ChatCompletionRequest) -> dict: options = {} if request.sampling_params is not None: @@ -89,47 +84,47 @@ def get_chat_options(self, request: ChatCompletionRequest) -> dict: async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: messages = prepare_messages(request) - model_input = self.formatter.encode_dialog_prompt(messages) prompt = self.tokenizer.decode(model_input.tokens) + + input_tokens = len(model_input.tokens) max_new_tokens = min( - request.sampling_params.max_tokens or self.max_tokens, - self.max_tokens - len(model_input.tokens) - 1, + request.sampling_params.max_tokens or (self.max_tokens - input_tokens), + self.max_tokens - input_tokens - 1, ) - if request.model != self.model: - raise ValueError( - f"Model mismatch, expected: {self.model}, got: {request.model}" - ) + print(f"Calculated max_new_tokens: {max_new_tokens}") - options = self.get_chat_options(request) + assert ( + request.model == self.model_name + ), f"Model mismatch, expected {self.model_name}, got {request.model}" - client = Client(base_url=self.url) + options = self.get_chat_options(request) if not request.stream: - r = client.generate( - prompt, + response = self.client.text_generation( + prompt=prompt, + stream=False, + details=True, max_new_tokens=max_new_tokens, stop_sequences=["<|eom_id|>", "<|eot_id|>"], **options, ) - - if r.details.finish_reason: - if r.details.finish_reason == "stop": + stop_reason = None + if response.details.finish_reason: + if response.details.finish_reason == "stop": stop_reason = StopReason.end_of_turn - elif r.details.finish_reason == "length": + elif response.details.finish_reason == "length": stop_reason = StopReason.out_of_tokens - else: - stop_reason = StopReason.end_of_message - else: - stop_reason = StopReason.out_of_tokens completion_message = self.formatter.decode_assistant_message_from_content( - r.generated_text, stop_reason + response.generated_text, + stop_reason, ) yield ChatCompletionResponse( completion_message=completion_message, logprobs=None, ) + else: yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( @@ -137,14 +132,15 @@ async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerato delta="", ) ) - buffer = "" ipython = False stop_reason = None tokens = [] - for response in client.generate_stream( - prompt, + for response in self.client.text_generation( + prompt=prompt, + stream=True, + details=True, max_new_tokens=max_new_tokens, stop_sequences=["<|eom_id|>", "<|eot_id|>"], **options, @@ -231,3 +227,48 @@ async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerato stop_reason=stop_reason, ) ) + + +class InferenceEndpointAdapter(TGIAdapter): + def __init__(self, config: TGIImplConfig) -> None: + super().__init__(config) + self.config.url = self._construct_endpoint_url() + + def _construct_endpoint_url(self) -> str: + hf_endpoint_name = self.config.hf_endpoint_name + assert hf_endpoint_name.count("/") <= 1, ( + "Endpoint name must be in the format of 'namespace/endpoint_name' " + "or 'endpoint_name'" + ) + if "/" not in hf_endpoint_name: + hf_namespace: str = self.get_namespace() + endpoint_path = f"{hf_namespace}/{hf_endpoint_name}" + else: + endpoint_path = hf_endpoint_name + return f"https://api.endpoints.huggingface.cloud/v2/endpoint/{endpoint_path}" + + def get_namespace(self) -> str: + return HfApi().whoami()["name"] + + @property + def client(self) -> InferenceClient: + return InferenceClient(model=self.inference_url, token=self.config.api_token) + + def _get_endpoint_info(self) -> Dict[str, Any]: + headers = { + "accept": "application/json", + "authorization": f"Bearer {self.config.api_token}", + } + response = requests.get(self.config.url, headers=headers) + response.raise_for_status() + endpoint_info = response.json() + return { + "inference_url": endpoint_info["status"]["url"], + "model_id": endpoint_info["model"]["repository"], + "max_total_tokens": int( + endpoint_info["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"] + ), + } + + async def initialize(self) -> None: + await super().initialize() diff --git a/llama_toolchain/inference/providers.py b/llama_toolchain/inference/providers.py index 6e80ecb..31c6b8d 100644 --- a/llama_toolchain/inference/providers.py +++ b/llama_toolchain/inference/providers.py @@ -39,8 +39,9 @@ def available_providers() -> List[ProviderSpec]: api=Api.inference, adapter=AdapterSpec( adapter_id="tgi", - pip_packages=["text-generation"], + pip_packages=["huggingface_hub"], module="llama_toolchain.inference.adapters.tgi", + config_class="llama_toolchain.inference.adapters.tgi.TGIImplConfig", ), ), remote_provider_spec(