Skip to content

Commit

Permalink
feat: Make tokenizer add_special_tokens option configurable
Browse files Browse the repository at this point in the history
In particular so that it can be disabled for chat/instruct models where an explicit template is used that already includes these tokens.

(for example the leading <s> token added by llama and mixtral tokenizers)

Signed-off-by: Nick Hill <[email protected]>
  • Loading branch information
njhill committed Mar 11, 2024
1 parent f7d3c5f commit 2afbf1c
Show file tree
Hide file tree
Showing 8 changed files with 35 additions and 10 deletions.
9 changes: 9 additions & 0 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ struct Args {
// Default for default_include_stop_seqs is true for now, for backwards compatibility
#[clap(default_value = "true", long, env, action = clap::ArgAction::Set)]
default_include_stop_seqs: bool,
#[clap(default_value = "true", long, env, action = clap::ArgAction::Set)]
add_special_tokens: bool,
}

fn main() -> ExitCode {
Expand Down Expand Up @@ -237,6 +239,7 @@ fn main() -> ExitCode {
args.max_new_tokens,
args.max_batch_size,
args.batch_safety_margin,
args.add_special_tokens,
args.shard_uds_path,
args.cuda_process_memory_fraction,
cuda_alloc_conf,
Expand Down Expand Up @@ -307,6 +310,8 @@ fn main() -> ExitCode {
format!("{}-0", args.shard_uds_path),
"--tokenizer-path".to_string(),
tokenizer_path,
"--add-special-tokens".to_string(),
args.add_special_tokens.to_string(),
];

if let Some(path) = args.tls_key_path {
Expand Down Expand Up @@ -541,6 +546,7 @@ fn shard_manager(
max_new_tokens: usize,
max_batch_size: usize,
batch_safety_margin: usize,
add_special_tokens: bool,
uds_path: String,
cuda_process_memory_fraction: f32,
cuda_alloc_conf: Option<&str>,
Expand Down Expand Up @@ -627,6 +633,9 @@ fn shard_manager(
}
}

// Add special tokens when tokenizing (e.g. leading <s> with llama tokenizer)
env.push(("ADD_SPECIAL_TOKENS".into(), add_special_tokens.to_string().into()));

// Torch Distributed / DeepSpeed Env vars
env.push(("RANK".into(), rank.to_string().into()));
env.push(("LOCAL_RANK".into(), rank.to_string().into()));
Expand Down
3 changes: 3 additions & 0 deletions router/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ struct Args {
output_special_tokens: bool,
#[clap(long, env)]
default_include_stop_seqs: bool,
#[clap(default_value = "true", long, env, action = clap::ArgAction::Set)]
add_special_tokens: bool,
}

fn main() -> Result<(), std::io::Error> {
Expand Down Expand Up @@ -149,6 +151,7 @@ fn main() -> Result<(), std::io::Error> {
tls_client_ca_cert: args.tls_client_ca_cert_path,
output_special_tokens: args.output_special_tokens,
default_include_stop_seqs: args.default_include_stop_seqs,
add_special_tokens: args.add_special_tokens,
})
.await;
Ok(())
Expand Down
5 changes: 4 additions & 1 deletion router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ pub struct ServerRunArgs {
pub tls_client_ca_cert: Option<String>,
pub output_special_tokens: bool,
pub default_include_stop_seqs: bool,
pub add_special_tokens: bool,
}

async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
Expand Down Expand Up @@ -337,7 +338,9 @@ async fn do_run<B: BatchType>(
args.max_sequence_length - 1
};

let tokenizers = AsyncTokenizer::new(&args.tokenizer, args.tokenization_workers);
let tokenizers = AsyncTokenizer::new(
&args.tokenizer, args.add_special_tokens, args.tokenization_workers
);

// Create state
let generation_health = Arc::new(AtomicBool::new(false));
Expand Down
12 changes: 8 additions & 4 deletions router/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@ impl Debug for AsyncTokenizer {

/// Uses pool of tokenizer threads to provide async tokenization methods
impl AsyncTokenizer {
pub(crate) fn new(tokenizer: &Tokenizer, workers: usize) -> Self {
pub(crate) fn new(tokenizer: &Tokenizer, add_special_tokens: bool, workers: usize) -> Self {
let (sender, receiver) = flume::unbounded();
for _ in 0..workers {
let tokenizer = tokenizer.clone();
let receiver = receiver.clone();
tokio::task::spawn_blocking(move || tokenization_worker(tokenizer, receiver));
tokio::task::spawn_blocking(
move || tokenization_worker(tokenizer, receiver, add_special_tokens)
);
}
Self { sender }
}
Expand All @@ -50,10 +52,12 @@ impl AsyncTokenizer {
}
}

fn tokenization_worker(tokenizer: Tokenizer, receiver: Receiver<TokenizationRequest>) {
fn tokenization_worker(
tokenizer: Tokenizer, receiver: Receiver<TokenizationRequest>, add_special_tokens: bool
) {
while let Ok((input, with_encoding, sender)) = receiver.recv() {
let result = tokenizer
.encode(&input[..], true)
.encode(&input[..], add_special_tokens)
.map(|encoding| (input, encoding.len(), with_encoding.then_some(encoding)));
sender.send(result).unwrap_or_default();
}
Expand Down
3 changes: 2 additions & 1 deletion server/text_generation_server/models/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Union, Any

from text_generation_server.models.model import Model, CUDA_PAD_TO_MULT_OF_8
from text_generation_server.models.model import Model, ADD_SPECIAL_TOKENS, CUDA_PAD_TO_MULT_OF_8
from text_generation_server.models.types import Batch, GenerateError
from text_generation_server.pb import generate_pb2
from text_generation_server.prompt_cache import PrefixCache
Expand Down Expand Up @@ -143,6 +143,7 @@ def from_pb(
truncation=True,
max_length=tokenize_length,
return_token_type_ids=False,
add_special_tokens=ADD_SPECIAL_TOKENS,
).to(device)
all_input_ids = tokenized_inputs["input_ids"]

Expand Down
9 changes: 6 additions & 3 deletions server/text_generation_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@
import torch
import torch.distributed

from torch.nn import functional as F

from dataclasses import dataclass
from transformers import PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Union, Any

from text_generation_server.inference_engine import get_inference_engine_class
from text_generation_server.models import Model
from text_generation_server.models.model import ADD_SPECIAL_TOKENS

from text_generation_server.models.types import Batch, GenerateError
from text_generation_server.pb import generate_pb2
Expand Down Expand Up @@ -123,7 +122,11 @@ def from_pb(
# return as lists to avoid unnecessary padding;
# sequences will be concatenated across the batch
batch_tokenized_inputs = tokenizer(
batch_inputs, truncation=True, max_length=max_seqlen, return_token_type_ids=False
batch_inputs,
truncation=True,
max_length=max_seqlen,
return_token_type_ids=False,
add_special_tokens=ADD_SPECIAL_TOKENS,
)["input_ids"]

# Process inputs to generate the needed tensors
Expand Down
1 change: 1 addition & 0 deletions server/text_generation_server/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

CUDA_PAD_TO_MULT_OF_8 = os.getenv("CUDA_PAD_TO_MULT_OF_8", "true").lower() != "false"
PT2_COMPILE = os.getenv("PT2_COMPILE", "false").lower() != "false"
ADD_SPECIAL_TOKENS = os.getenv("ADD_SPECIAL_TOKENS", "true").lower() != "false" # defaults to true

if PT2_COMPILE:
import torch._dynamo
Expand Down
3 changes: 2 additions & 1 deletion server/text_generation_server/models/seq2seq_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from transformers.modeling_outputs import BaseModelOutput

from text_generation_server.models.model import Model, CUDA_PAD_TO_MULT_OF_8, PT2_COMPILE
from text_generation_server.models.model import Model, ADD_SPECIAL_TOKENS, CUDA_PAD_TO_MULT_OF_8, PT2_COMPILE
from text_generation_server.models.types import Batch, GenerateError
from text_generation_server.pb import generate_pb2
from text_generation_server.prompt_cache import PrefixCache
Expand Down Expand Up @@ -148,6 +148,7 @@ def from_pb(
truncation=True,
max_length=tokenize_length,
return_token_type_ids=False,
add_special_tokens=ADD_SPECIAL_TOKENS,
).to(device)
input_ids = tokenized_inputs["input_ids"]
attention_mask = tokenized_inputs["attention_mask"]
Expand Down

0 comments on commit 2afbf1c

Please sign in to comment.