Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PromptGuard to safety_utils #608

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def main(
enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs
enable_salesforce_content_safety: bool=True, # Enable safety check with Salesforce safety flan t5
enable_llamaguard_content_safety: bool=False, # Enable safety check with Llama-Guard
enable_promptguard_safety: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the small size of the mode, can we leave it as default true?

use_fast_kernels: bool = True, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
**kwargs
):
Expand Down Expand Up @@ -62,6 +63,7 @@ def main(
enable_sensitive_topics,
enable_salesforce_content_safety,
enable_llamaguard_content_safety,
enable_promptguard_safety,
)

# Safety check of the user prompt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def main(
enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs
enable_salesforce_content_safety: bool=True, # Enable safety check with Salesforce safety flan t5
enable_llamaguard_content_safety: bool=False, # Enable safety check with Llama-Guard
enable_promptguard_safety: bool = False,
use_fast_kernels: bool = True, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
**kwargs
):
Expand Down Expand Up @@ -62,6 +63,7 @@ def main(
enable_sensitive_topics,
enable_salesforce_content_safety,
enable_llamaguard_content_safety,
enable_promptguard_safety
)

# Safety check of the user prompt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def main(
enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs
enable_salesforce_content_safety: bool=True, # Enable safety check with Salesforce safety flan t5
enable_llamaguard_content_safety: bool=False, # Enable safety check with Llama-Guard
enable_promptguard_safety: bool = False,
use_fast_kernels: bool = True, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
**kwargs
):
Expand All @@ -95,6 +96,7 @@ def main(
enable_sensitive_topics,
enable_salesforce_content_safety,
enable_llamaguard_content_safety,
enable_promptguard_safety
)

# Safety check of the user prompt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def main(
enable_saleforce_content_safety: bool=True, # Enable safety check woth Saleforce safety flan t5
use_fast_kernels: bool = False, # Enable using SDPA from PyTorch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
enable_llamaguard_content_safety: bool = False,
enable_promptguard_safety: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

specially for this call, can we set it to true?

**kwargs
):
if prompt_file is not None:
Expand Down Expand Up @@ -81,6 +82,7 @@ def main(
enable_sensitive_topics,
enable_saleforce_content_safety,
enable_llamaguard_content_safety,
enable_promptguard_safety
)
# Safety check of the user prompt
safety_results = [check(dialogs[idx][0]["content"]) for check in safety_checker]
Expand Down
7 changes: 5 additions & 2 deletions recipes/quickstart/inference/local_inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def main(
enable_sensitive_topics: bool = False, # Enable check for sensitive topics using AuditNLG APIs
enable_salesforce_content_safety: bool = True, # Enable safety check with Salesforce safety flan t5
enable_llamaguard_content_safety: bool = False,
enable_promptguard_safety: bool = False,
max_padding_length: int = None, # the max padding length to be used with tokenizer padding the prompts.
use_fast_kernels: bool = False, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
share_gradio: bool = False, # Enable endpoint creation for gradio.live
Expand Down Expand Up @@ -70,8 +71,10 @@ def inference(
enable_sensitive_topics,
enable_salesforce_content_safety,
enable_llamaguard_content_safety,
enable_promptguard_safety,
)


# Safety check of the user prompt
safety_results = [check(user_prompt) for check in safety_checker]
are_safe = all([r[1] for r in safety_results])
Expand Down Expand Up @@ -119,9 +122,9 @@ def inference(

# Safety check of the model output
safety_results = [
check(output_text, agent_type=AgentType.AGENT, user_prompt=user_prompt)
for check in safety_checker
check(output_text, agent_type=AgentType.AGENT, user_prompt=user_prompt) for check in safety_checker
]
print(safety_results)
are_safe = all([r[1] for r in safety_results])
if are_safe:
print("User input and model output deemed safe.")
Expand Down
49 changes: 46 additions & 3 deletions src/llama_recipes/inference/safety_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,14 +201,55 @@ def __call__(self, output_text, **kwargs):
report = result

return "Llama Guard", is_safe, report


# Function to load the PeftModel for performance optimization
class PromptGuardSafetyChecker(object):

def __init__(self):
from transformers import AutoModelForSequenceClassification, AutoTokenizer, BitsAndBytesConfig
model_id = "meta-llama/Prompt-Guard-86M"

self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model = AutoModelForSequenceClassification.from_pretrained(model_id)

def get_scores(self, text, temperature=1.0, device='cpu'):
from torch.nn.functional import softmax
inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
inputs = inputs.to(device)
if len(inputs[0]) > 512:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As max_length is 512 this condition should never be true. Can we instead follow the PromptGuard recommendation and split the text into multiple segments which we apply in parallel (batched)? Especially because of the much bigger context length of Llama 3.1.

warnings.warn(
"Input length is > 512 token. PromptGuard check result could be incorrect."
)
with torch.no_grad():
logits = self.model(**inputs).logits
scaled_logits = logits / temperature
probabilities = softmax(scaled_logits, dim=-1)

return {
'jailbreak': probabilities[0, 2].item(),
'indirect_injection': (probabilities[0, 1] + probabilities[0, 2]).item()
}

def __call__(self, text_for_check, **kwargs):
agent_type = kwargs.get('agent_type', AgentType.USER)
if agent_type == AgentType.AGENT:
return "PromptGuard", True, "PromptGuard is not used for model output so checking not carried out"
sentences = text_for_check.split(".")
running_scores = {'jailbreak':0, 'indirect_injection' :0}
for sentence in sentences:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its probably more efficient to do this batched as commented above. Lets split the prompt in blocks of 512 (with some overlap) and then feed them batched into the model which will be way more efficient than feeding the sentences one by one.

scores = self.get_scores(sentence)
running_scores['jailbreak'] = max([running_scores['jailbreak'],scores['jailbreak']])
running_scores['indirect_injection'] = max([running_scores['indirect_injection'],scores['indirect_injection']])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we comment that this is not being used for the user dialog?

is_safe = True if running_scores['jailbreak'] < 0.5 else False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we set the bar at 0.5? I think 0.8 or 0.9 would be better based on talks with the team.

report = str(running_scores)
return "PromptGuard", is_safe, report


# Function to determine which safety checker to use based on the options selected
def get_safety_checker(enable_azure_content_safety,
enable_sensitive_topics,
enable_salesforce_content_safety,
enable_llamaguard_content_safety):
enable_llamaguard_content_safety,
enable_promptguard_safety):
safety_checker = []
if enable_azure_content_safety:
safety_checker.append(AzureSaftyChecker())
Expand All @@ -218,5 +259,7 @@ def get_safety_checker(enable_azure_content_safety,
safety_checker.append(SalesforceSafetyChecker())
if enable_llamaguard_content_safety:
safety_checker.append(LlamaGuardSafetyChecker())
if enable_promptguard_safety:
safety_checker.append(PromptGuardSafetyChecker())
return safety_checker

Loading