diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 84742fb9..c43758df 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -89,6 +89,8 @@ struct Args { // Default for default_include_stop_seqs is true for now, for backwards compatibility #[clap(default_value = "true", long, env, action = clap::ArgAction::Set)] default_include_stop_seqs: bool, + #[clap(default_value = "true", long, env, action = clap::ArgAction::Set)] + add_special_tokens: bool, } fn main() -> ExitCode { @@ -237,6 +239,7 @@ fn main() -> ExitCode { args.max_new_tokens, args.max_batch_size, args.batch_safety_margin, + args.add_special_tokens, args.shard_uds_path, args.cuda_process_memory_fraction, cuda_alloc_conf, @@ -307,6 +310,8 @@ fn main() -> ExitCode { format!("{}-0", args.shard_uds_path), "--tokenizer-path".to_string(), tokenizer_path, + "--add-special-tokens".to_string(), + args.add_special_tokens.to_string(), ]; if let Some(path) = args.tls_key_path { @@ -541,6 +546,7 @@ fn shard_manager( max_new_tokens: usize, max_batch_size: usize, batch_safety_margin: usize, + add_special_tokens: bool, uds_path: String, cuda_process_memory_fraction: f32, cuda_alloc_conf: Option<&str>, @@ -627,6 +633,9 @@ fn shard_manager( } } + // Add special tokens when tokenizing (e.g. leading with llama tokenizer) + env.push(("ADD_SPECIAL_TOKENS".into(), add_special_tokens.to_string().into())); + // Torch Distributed / DeepSpeed Env vars env.push(("RANK".into(), rank.to_string().into())); env.push(("LOCAL_RANK".into(), rank.to_string().into())); diff --git a/router/src/main.rs b/router/src/main.rs index cd97d8c8..108c0def 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -50,6 +50,8 @@ struct Args { output_special_tokens: bool, #[clap(long, env)] default_include_stop_seqs: bool, + #[clap(default_value = "true", long, env, action = clap::ArgAction::Set)] + add_special_tokens: bool, } fn main() -> Result<(), std::io::Error> { @@ -149,6 +151,7 @@ fn main() -> Result<(), std::io::Error> { tls_client_ca_cert: args.tls_client_ca_cert_path, output_special_tokens: args.output_special_tokens, default_include_stop_seqs: args.default_include_stop_seqs, + add_special_tokens: args.add_special_tokens, }) .await; Ok(()) diff --git a/router/src/server.rs b/router/src/server.rs index 20cd5250..de5d198a 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -247,6 +247,7 @@ pub struct ServerRunArgs { pub tls_client_ca_cert: Option, pub output_special_tokens: bool, pub default_include_stop_seqs: bool, + pub add_special_tokens: bool, } async fn metrics(prom_handle: Extension) -> String { @@ -337,7 +338,9 @@ async fn do_run( args.max_sequence_length - 1 }; - let tokenizers = AsyncTokenizer::new(&args.tokenizer, args.tokenization_workers); + let tokenizers = AsyncTokenizer::new( + &args.tokenizer, args.add_special_tokens, args.tokenization_workers + ); // Create state let generation_health = Arc::new(AtomicBool::new(false)); diff --git a/router/src/tokenizer.rs b/router/src/tokenizer.rs index 57f8d115..26c7f90e 100644 --- a/router/src/tokenizer.rs +++ b/router/src/tokenizer.rs @@ -23,12 +23,14 @@ impl Debug for AsyncTokenizer { /// Uses pool of tokenizer threads to provide async tokenization methods impl AsyncTokenizer { - pub(crate) fn new(tokenizer: &Tokenizer, workers: usize) -> Self { + pub(crate) fn new(tokenizer: &Tokenizer, add_special_tokens: bool, workers: usize) -> Self { let (sender, receiver) = flume::unbounded(); for _ in 0..workers { let tokenizer = tokenizer.clone(); let receiver = receiver.clone(); - tokio::task::spawn_blocking(move || tokenization_worker(tokenizer, receiver)); + tokio::task::spawn_blocking( + move || tokenization_worker(tokenizer, receiver, add_special_tokens) + ); } Self { sender } } @@ -50,10 +52,12 @@ impl AsyncTokenizer { } } -fn tokenization_worker(tokenizer: Tokenizer, receiver: Receiver) { +fn tokenization_worker( + tokenizer: Tokenizer, receiver: Receiver, add_special_tokens: bool +) { while let Ok((input, with_encoding, sender)) = receiver.recv() { let result = tokenizer - .encode(&input[..], true) + .encode(&input[..], add_special_tokens) .map(|encoding| (input, encoding.len(), with_encoding.then_some(encoding))); sender.send(result).unwrap_or_default(); } diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index fb5f0dfd..c139eaa3 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -8,7 +8,7 @@ from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase from typing import Optional, Tuple, List, Type, Union, Any -from text_generation_server.models.model import Model, CUDA_PAD_TO_MULT_OF_8 +from text_generation_server.models.model import Model, ADD_SPECIAL_TOKENS, CUDA_PAD_TO_MULT_OF_8 from text_generation_server.models.types import Batch, GenerateError from text_generation_server.pb import generate_pb2 from text_generation_server.prompt_cache import PrefixCache @@ -143,6 +143,7 @@ def from_pb( truncation=True, max_length=tokenize_length, return_token_type_ids=False, + add_special_tokens=ADD_SPECIAL_TOKENS, ).to(device) all_input_ids = tokenized_inputs["input_ids"] diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 02ac18eb..9a5698e1 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -5,14 +5,13 @@ import torch import torch.distributed -from torch.nn import functional as F - from dataclasses import dataclass from transformers import PreTrainedTokenizerBase from typing import Optional, Tuple, List, Type, Union, Any from text_generation_server.inference_engine import get_inference_engine_class from text_generation_server.models import Model +from text_generation_server.models.model import ADD_SPECIAL_TOKENS from text_generation_server.models.types import Batch, GenerateError from text_generation_server.pb import generate_pb2 @@ -123,7 +122,11 @@ def from_pb( # return as lists to avoid unnecessary padding; # sequences will be concatenated across the batch batch_tokenized_inputs = tokenizer( - batch_inputs, truncation=True, max_length=max_seqlen, return_token_type_ids=False + batch_inputs, + truncation=True, + max_length=max_seqlen, + return_token_type_ids=False, + add_special_tokens=ADD_SPECIAL_TOKENS, )["input_ids"] # Process inputs to generate the needed tensors diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index e27559ef..b7a52ab8 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -23,6 +23,7 @@ CUDA_PAD_TO_MULT_OF_8 = os.getenv("CUDA_PAD_TO_MULT_OF_8", "true").lower() != "false" PT2_COMPILE = os.getenv("PT2_COMPILE", "false").lower() != "false" +ADD_SPECIAL_TOKENS = os.getenv("ADD_SPECIAL_TOKENS", "true").lower() != "false" # defaults to true if PT2_COMPILE: import torch._dynamo diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 6a17eff5..a43ff972 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -10,7 +10,7 @@ from transformers.modeling_outputs import BaseModelOutput -from text_generation_server.models.model import Model, CUDA_PAD_TO_MULT_OF_8, PT2_COMPILE +from text_generation_server.models.model import Model, ADD_SPECIAL_TOKENS, CUDA_PAD_TO_MULT_OF_8, PT2_COMPILE from text_generation_server.models.types import Batch, GenerateError from text_generation_server.pb import generate_pb2 from text_generation_server.prompt_cache import PrefixCache @@ -148,6 +148,7 @@ def from_pb( truncation=True, max_length=tokenize_length, return_token_type_ids=False, + add_special_tokens=ADD_SPECIAL_TOKENS, ).to(device) input_ids = tokenized_inputs["input_ids"] attention_mask = tokenized_inputs["attention_mask"]