diff --git a/recipes/quickstart/inference/code_llama/code_completion_example.py b/recipes/quickstart/inference/code_llama/code_completion_example.py index 201f8df8b..c7bae4bb4 100644 --- a/recipes/quickstart/inference/code_llama/code_completion_example.py +++ b/recipes/quickstart/inference/code_llama/code_completion_example.py @@ -34,6 +34,7 @@ def main( enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs enable_salesforce_content_safety: bool=True, # Enable safety check with Salesforce safety flan t5 enable_llamaguard_content_safety: bool=False, # Enable safety check with Llama-Guard + enable_promptguard_safety: bool = False, use_fast_kernels: bool = True, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels **kwargs ): @@ -62,6 +63,7 @@ def main( enable_sensitive_topics, enable_salesforce_content_safety, enable_llamaguard_content_safety, + enable_promptguard_safety, ) # Safety check of the user prompt diff --git a/recipes/quickstart/inference/code_llama/code_infilling_example.py b/recipes/quickstart/inference/code_llama/code_infilling_example.py index a955eb5ce..40d96018d 100644 --- a/recipes/quickstart/inference/code_llama/code_infilling_example.py +++ b/recipes/quickstart/inference/code_llama/code_infilling_example.py @@ -33,6 +33,7 @@ def main( enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs enable_salesforce_content_safety: bool=True, # Enable safety check with Salesforce safety flan t5 enable_llamaguard_content_safety: bool=False, # Enable safety check with Llama-Guard + enable_promptguard_safety: bool = False, use_fast_kernels: bool = True, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels **kwargs ): @@ -62,6 +63,7 @@ def main( enable_sensitive_topics, enable_salesforce_content_safety, enable_llamaguard_content_safety, + enable_promptguard_safety ) # Safety check of the user prompt diff --git a/recipes/quickstart/inference/code_llama/code_instruct_example.py b/recipes/quickstart/inference/code_llama/code_instruct_example.py index d7b98f088..ae3a02d45 100644 --- a/recipes/quickstart/inference/code_llama/code_instruct_example.py +++ b/recipes/quickstart/inference/code_llama/code_instruct_example.py @@ -71,6 +71,7 @@ def main( enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs enable_salesforce_content_safety: bool=True, # Enable safety check with Salesforce safety flan t5 enable_llamaguard_content_safety: bool=False, # Enable safety check with Llama-Guard + enable_promptguard_safety: bool = False, use_fast_kernels: bool = True, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels **kwargs ): @@ -95,6 +96,7 @@ def main( enable_sensitive_topics, enable_salesforce_content_safety, enable_llamaguard_content_safety, + enable_promptguard_safety ) # Safety check of the user prompt diff --git a/recipes/quickstart/inference/local_inference/chat_completion/chat_completion.py b/recipes/quickstart/inference/local_inference/chat_completion/chat_completion.py index ebfa2d663..60dbb3ee3 100644 --- a/recipes/quickstart/inference/local_inference/chat_completion/chat_completion.py +++ b/recipes/quickstart/inference/local_inference/chat_completion/chat_completion.py @@ -37,6 +37,7 @@ def main( enable_saleforce_content_safety: bool=True, # Enable safety check woth Saleforce safety flan t5 use_fast_kernels: bool = False, # Enable using SDPA from PyTorch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels enable_llamaguard_content_safety: bool = False, + enable_promptguard_safety: bool = False, **kwargs ): if prompt_file is not None: @@ -81,6 +82,7 @@ def main( enable_sensitive_topics, enable_saleforce_content_safety, enable_llamaguard_content_safety, + enable_promptguard_safety ) # Safety check of the user prompt safety_results = [check(dialogs[idx][0]["content"]) for check in safety_checker] diff --git a/recipes/quickstart/inference/local_inference/inference.py b/recipes/quickstart/inference/local_inference/inference.py index bf2f824a9..581899c12 100644 --- a/recipes/quickstart/inference/local_inference/inference.py +++ b/recipes/quickstart/inference/local_inference/inference.py @@ -36,6 +36,7 @@ def main( enable_sensitive_topics: bool = False, # Enable check for sensitive topics using AuditNLG APIs enable_salesforce_content_safety: bool = True, # Enable safety check with Salesforce safety flan t5 enable_llamaguard_content_safety: bool = False, + enable_promptguard_safety: bool = False, max_padding_length: int = None, # the max padding length to be used with tokenizer padding the prompts. use_fast_kernels: bool = False, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels share_gradio: bool = False, # Enable endpoint creation for gradio.live @@ -70,8 +71,10 @@ def inference( enable_sensitive_topics, enable_salesforce_content_safety, enable_llamaguard_content_safety, + enable_promptguard_safety, ) + # Safety check of the user prompt safety_results = [check(user_prompt) for check in safety_checker] are_safe = all([r[1] for r in safety_results]) @@ -119,9 +122,9 @@ def inference( # Safety check of the model output safety_results = [ - check(output_text, agent_type=AgentType.AGENT, user_prompt=user_prompt) - for check in safety_checker + check(output_text, agent_type=AgentType.AGENT, user_prompt=user_prompt) for check in safety_checker ] + print(safety_results) are_safe = all([r[1] for r in safety_results]) if are_safe: print("User input and model output deemed safe.") diff --git a/src/llama_recipes/inference/safety_utils.py b/src/llama_recipes/inference/safety_utils.py index f81a05a3a..bdebf9a1e 100644 --- a/src/llama_recipes/inference/safety_utils.py +++ b/src/llama_recipes/inference/safety_utils.py @@ -201,14 +201,55 @@ def __call__(self, output_text, **kwargs): report = result return "Llama Guard", is_safe, report - -# Function to load the PeftModel for performance optimization +class PromptGuardSafetyChecker(object): + + def __init__(self): + from transformers import AutoModelForSequenceClassification, AutoTokenizer, BitsAndBytesConfig + model_id = "meta-llama/Prompt-Guard-86M" + + self.tokenizer = AutoTokenizer.from_pretrained(model_id) + self.model = AutoModelForSequenceClassification.from_pretrained(model_id) + + def get_scores(self, text, temperature=1.0, device='cpu'): + from torch.nn.functional import softmax + inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) + inputs = inputs.to(device) + if len(inputs[0]) > 512: + warnings.warn( + "Input length is > 512 token. PromptGuard check result could be incorrect." + ) + with torch.no_grad(): + logits = self.model(**inputs).logits + scaled_logits = logits / temperature + probabilities = softmax(scaled_logits, dim=-1) + + return { + 'jailbreak': probabilities[0, 2].item(), + 'indirect_injection': (probabilities[0, 1] + probabilities[0, 2]).item() + } + + def __call__(self, text_for_check, **kwargs): + agent_type = kwargs.get('agent_type', AgentType.USER) + if agent_type == AgentType.AGENT: + return "PromptGuard", True, "PromptGuard is not used for model output so checking not carried out" + sentences = text_for_check.split(".") + running_scores = {'jailbreak':0, 'indirect_injection' :0} + for sentence in sentences: + scores = self.get_scores(sentence) + running_scores['jailbreak'] = max([running_scores['jailbreak'],scores['jailbreak']]) + running_scores['indirect_injection'] = max([running_scores['indirect_injection'],scores['indirect_injection']]) + is_safe = True if running_scores['jailbreak'] < 0.5 else False + report = str(running_scores) + return "PromptGuard", is_safe, report + + # Function to determine which safety checker to use based on the options selected def get_safety_checker(enable_azure_content_safety, enable_sensitive_topics, enable_salesforce_content_safety, - enable_llamaguard_content_safety): + enable_llamaguard_content_safety, + enable_promptguard_safety): safety_checker = [] if enable_azure_content_safety: safety_checker.append(AzureSaftyChecker()) @@ -218,5 +259,7 @@ def get_safety_checker(enable_azure_content_safety, safety_checker.append(SalesforceSafetyChecker()) if enable_llamaguard_content_safety: safety_checker.append(LlamaGuardSafetyChecker()) + if enable_promptguard_safety: + safety_checker.append(PromptGuardSafetyChecker()) return safety_checker