Skip to content

Commit

Permalink
Fix NaN bug in normalization for fp16
Browse files Browse the repository at this point in the history
  • Loading branch information
akhilg-nv committed Sep 12, 2024
1 parent 9602126 commit 5142a42
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 77 deletions.
17 changes: 9 additions & 8 deletions tripy/examples/diffusion/clip_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class CLIPConfig:
num_heads: int = 12
max_seq_len: int = 77
num_hidden_layers: int = 12
dtype: tp.dtype = tp.float16
dtype: tp.dtype = tp.float32

class CLIPMLP(tp.Module):
def __init__(self, config: CLIPConfig):
Expand All @@ -52,6 +52,7 @@ def __init__(self, config: CLIPConfig):
self.v_proj = tp.Linear(self.embed_dim, self.embed_dim, dtype=config.dtype)
self.q_proj = tp.Linear(self.embed_dim, self.embed_dim, dtype=config.dtype)
self.out_proj = tp.Linear(self.embed_dim, self.embed_dim, dtype=config.dtype)
self.dtype = config.dtype

def __call__(self, hidden_states, causal_attention_mask):
bsz, tgt_len, embed_dim = hidden_states.shape[0], hidden_states.shape[1], hidden_states.shape[2]
Expand All @@ -65,7 +66,7 @@ def __call__(self, hidden_states, causal_attention_mask):
for x in (q, k, v)
]
attn_output = scaled_dot_product_attention(
q, k, v, embedding_dim=self.head_dim, attn_mask=causal_attention_mask
q, k, v, embedding_dim=self.head_dim, attn_mask=causal_attention_mask, dtype=self.dtype,
)
out = self.out_proj(tp.reshape(tp.transpose(attn_output, 1, 2), (bsz, tgt_len, embed_dim)))
return out
Expand All @@ -74,18 +75,18 @@ def __call__(self, hidden_states, causal_attention_mask):
class CLIPEncoderLayer(tp.Module):
def __init__(self, config: CLIPConfig):
self.self_attn = CLIPAttention(config)
self.layer_norm1 = tp.LayerNorm(config.embedding_size, dtype=config.dtype)
self.layer_norm1 = tp.LayerNorm(config.embedding_size, dtype=tp.float32)
self.mlp = CLIPMLP(config)
self.layer_norm2 = tp.LayerNorm(config.embedding_size, dtype=config.dtype)
self.layer_norm2 = tp.LayerNorm(config.embedding_size, dtype=tp.float32)

def __call__(self, hidden_states, causal_attention_mask):
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states = tp.cast(self.layer_norm1(tp.cast(hidden_states, self.layer_norm1.dtype)), hidden_states.dtype)
hidden_states = self.self_attn(hidden_states, causal_attention_mask)
hidden_states = residual + hidden_states

residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = tp.cast(self.layer_norm2(tp.cast(hidden_states, self.layer_norm2.dtype)), hidden_states.dtype)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states

Expand Down Expand Up @@ -115,10 +116,10 @@ class CLIPTextTransformer(tp.Module):
def __init__(self, config: CLIPConfig):
self.embeddings = CLIPTextEmbeddings(config)
self.encoder = CLIPEncoder(config)
self.final_layer_norm = tp.LayerNorm(config.embedding_size, dtype=config.dtype)
self.final_layer_norm = tp.LayerNorm(config.embedding_size, dtype=tp.float32)
self.max_seq_len = config.max_seq_len

def __call__(self, input_ids):
x = self.embeddings(input_ids, tp.reshape(tp.iota((input_ids.shape[1],), dtype=tp.int32), (1, -1)))
x = self.encoder(x, tp.triu(tp.full((1, 1, self.max_seq_len, self.max_seq_len), float("-inf")), 1))
return self.final_layer_norm(x)
return tp.cast(self.final_layer_norm(tp.cast(x, self.final_layer_norm.dtype)), x.dtype)
58 changes: 32 additions & 26 deletions tripy/examples/diffusion/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def compile_clip(model, dtype=tp.int32, verbose=False):
return compile_model(model, inputs, verbose=verbose)


def compile_unet(model, dtype=tp.float16, verbose=False):
def compile_unet(model, dtype, verbose=False):
unconditional_context_shape = (1, 77, 768)
conditional_context_shape = (1, 77, 768)
latent_shape = (1, 4, 64, 64)
Expand All @@ -68,16 +68,16 @@ def compile_unet(model, dtype=tp.float16, verbose=False):
return compile_model(model, inputs, verbose=verbose)


def compile_vae(model, dtype=tp.float16, verbose=False):
def compile_vae(model, dtype, verbose=False):
inputs = (tp.InputInfo((1, 4, 64, 64), dtype=dtype),)
return compile_model(model, inputs, verbose=verbose)


def run_diffusion_loop(model, unconditional_context, context, latent, steps, guidance):
def run_diffusion_loop(model, unconditional_context, context, latent, steps, guidance, dtype):
timesteps = list(range(1, 1000, 1000 // steps))
print(f"[I] Running diffusion for {timesteps} timesteps...")
alphas = get_alphas_cumprod()[tp.Tensor(timesteps)]
alphas_prev = tp.concatenate([tp.Tensor([1.0]), alphas[:-1]], dim=0)
print(f"[I] Running diffusion for {steps} timesteps...")
alphas = get_alphas_cumprod(dtype=dtype)[tp.Tensor(timesteps)]
alphas_prev = tp.concatenate([tp.Tensor([1.0], dtype=dtype), alphas[:-1]], dim=0)

for index, timestep in (t := tqdm(list(enumerate(timesteps))[::-1])):
t.set_description("idx: %1d, timestep: %3d" % (index, timestep))
Expand All @@ -86,32 +86,34 @@ def run_diffusion_loop(model, unconditional_context, context, latent, steps, gui
unconditional_context,
context,
latent,
tp.cast(tp.Tensor([timestep]), tp.float32),
tp.Tensor([timestep], dtype=dtype),
alphas[tid],
alphas_prev[tid],
tp.Tensor([guidance]),
tp.Tensor([guidance], dtype=dtype),
)
return latent


def tripy_diffusion(args):
run_start_time = time.perf_counter()

if os.path.isdir("engines"):
dtype, torch_dtype = (tp.float16, torch.float16) if args.fp16 else (tp.float32, torch.float32)

if os.path.isdir(args.engine_dir):
print("[I] Loading cached engines from disk...")
clip_compiled = tp.Executable.load(os.path.join("engines", "clip_executable.json"))
unet_compiled = tp.Executable.load(os.path.join("engines", "unet_executable.json"))
vae_compiled = tp.Executable.load(os.path.join("engines", "vae_executable.json"))
else:
model = StableDiffusion(StableDiffusionConfig(dtype=tp.float16))
model = StableDiffusion(StableDiffusionConfig(dtype=dtype))
print("[I] Loading model weights...", flush=True)
load_from_diffusers(model, tp.float16, debug=True)
load_from_diffusers(model, dtype, args.hf_token, debug=True)
clip_compiled = compile_clip(model.cond_stage_model.transformer.text_model, verbose=True)
unet_compiled = compile_unet(model, verbose=True)
vae_compiled = compile_vae(model.decode, verbose=True)
unet_compiled = compile_unet(model, dtype, verbose=True)
vae_compiled = compile_vae(model.decode, dtype, verbose=True)

os.mkdir("engines")
print("[I] Saving engines to disk...")
os.mkdir(args.engine_dir)
print(f"[I] Saving engines to {args.engine_dir}...")
clip_compiled.save(os.path.join("engines", "clip_executable.json"))
unet_compiled.save(os.path.join("engines", "unet_executable.json"))
vae_compiled.save(os.path.join("engines", "vae_executable.json"))
Expand All @@ -135,11 +137,11 @@ def tripy_diffusion(args):
# Backbone of diffusion - the UNet
if args.seed is not None:
torch.manual_seed(args.seed)
torch_latent = torch.randn((1, 4, 64, 64)).to("cuda")
torch_latent = torch.randn((1, 4, 64, 64), dtype=torch_dtype).to("cuda")
latent = tp.Tensor(torch_latent)

diffusion_run_start = time.perf_counter()
latent = run_diffusion_loop(unet_compiled, unconditional_context, context, latent, args.steps, args.guidance)
latent = run_diffusion_loop(unet_compiled, unconditional_context, context, latent, args.steps, args.guidance, dtype)
diffusion_run_end = time.perf_counter()
print(f"[I] Finished diffusion denoising. Inference took {diffusion_run_end - diffusion_run_start} seconds.")

Expand Down Expand Up @@ -173,15 +175,17 @@ def hf_diffusion(args):

run_start_time = time.perf_counter()

dtype = torch.float16 if args.fp16 else torch.float32
model_opts = {'variant': 'fp16', 'torch_dtype': torch.float16} if args.fp16 else {}

# Initialize models
model_id = "CompVis/stable-diffusion-v1-4" #"benjamin-paine/stable-diffusion-v1-5" #"runwayml/stable-diffusion-v1-5"
clip_id = "openai/clip-vit-large-patch14"
model_id = "KiwiXR/stable-diffusion-v1-5"

print("[I] Loading models...")
hf_tokenizer = CLIPTokenizer.from_pretrained(clip_id)
hf_encoder = CLIPTextModel.from_pretrained(clip_id).to("cuda")
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet").to("cuda")
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to("cuda")
hf_tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
hf_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder").to("cuda")
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", use_auth_token=args.hf_token, **model_opts).to("cuda")
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", use_auth_token=args.hf_token, **model_opts).to("cuda")
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)

# Run through CLIP to get context from prompt
Expand All @@ -192,19 +196,20 @@ def hf_diffusion(args):
uncond_input = hf_tokenizer([""], padding="max_length", max_length=max_length, return_tensors="pt").to("cuda")
text_embeddings = hf_encoder(text_input.input_ids, output_hidden_states=True)[0]
uncond_embeddings = hf_encoder(uncond_input.input_ids)[0]
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype)
clip_run_end = time.perf_counter()
print(f"took {clip_run_end - clip_run_start} seconds.")

# Backbone of diffusion - the UNet
if args.seed is not None:
torch.manual_seed(args.seed)
torch_latent = torch.randn((1, 4, 64, 64)).to("cuda")
torch_latent = torch.randn((1, 4, 64, 64), dtype=dtype).to("cuda")
torch_latent *= scheduler.init_noise_sigma

scheduler.set_timesteps(args.steps)

diffusion_run_start = time.perf_counter()
print(f"[I] Running diffusion for {args.steps} timesteps...")
for t in tqdm(scheduler.timesteps):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latent_model_input = torch.cat([torch_latent] * 2)
Expand Down Expand Up @@ -267,7 +272,6 @@ def print_summary(denoising_steps, times):


# TODO: Add torch compilation modes
# TODO: Add fp16 support
# TODO: Add Timing context
def main():
default_prompt = "a horse sized cat eating a bagel"
Expand All @@ -282,6 +286,8 @@ def main():
parser.add_argument("--seed", type=int, help="Set the random latent seed")
parser.add_argument("--guidance", type=float, default=7.5, help="Prompt strength")
parser.add_argument('--torch-inference', action='store_true', help="Run inference with PyTorch (eager mode) instead of TensorRT.")
parser.add_argument('--hf-token', type=str, default='', help="HuggingFace API access token for downloading model checkpoints")
parser.add_argument('--engine-dir', type=str, default='engines', help="Output directory for TensorRT engines")
args = parser.parse_args()

if args.torch_inference:
Expand Down
2 changes: 1 addition & 1 deletion tripy/examples/diffusion/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def scaled_dot_product_attention(
embedding_dim: Optional[int] = None,
attn_mask: Optional[tp.Tensor] = None,
is_causal: bool = False,
dtype: tp.dtype = tp.float16
dtype: tp.dtype = tp.float32
) -> tp.Tensor:
"""
Computes scaled dot-product attention.
Expand Down
6 changes: 3 additions & 3 deletions tripy/examples/diffusion/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

@dataclass
class StableDiffusionConfig:
dtype: tp.dtype = tp.float16
dtype: tp.dtype = tp.float32
clip_config: Optional[CLIPConfig] = field(default=None, init=False)
unet_config: Optional[UNetConfig] = field(default=None, init=False)
vae_config: Optional[VAEConfig] = field(default=None, init=False)
Expand All @@ -44,11 +44,11 @@ def __post_init__(self):
self.vae_config = VAEConfig(dtype=self.dtype)

# equivalent to LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
def get_alphas_cumprod(beta_start=0.00085, beta_end=0.0120, n_training_steps=1000):
def get_alphas_cumprod(beta_start=0.00085, beta_end=0.0120, n_training_steps=1000, dtype=tp.float32):
betas = np.linspace(beta_start**0.5, beta_end**0.5, n_training_steps, dtype=np.float32) ** 2
alphas = 1.0 - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
return tp.Tensor(alphas_cumprod)
return tp.cast(tp.Tensor(alphas_cumprod), dtype)


class StableDiffusion(tp.Module):
Expand Down
38 changes: 21 additions & 17 deletions tripy/examples/diffusion/unet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import math
from typing import List, Tuple

import torch
import tripy as tp
from dataclasses import dataclass

Expand All @@ -33,28 +34,30 @@ class UNetConfig:
num_heads: int = 8
context_dim: int = 768
emb_channels: int = 1280
dtype: tp.dtype = tp.float16
dtype: tp.dtype = tp.float32


# Used for UNet, not to be confused with ResnetBlock, called ResnetBlock2D in HF diffusers
class ResBlock(tp.Module):
def __init__(self, config: UNetConfig, channels, emb_channels, out_channels):
self.norm1 = tp.GroupNorm(32, channels, dtype=config.dtype)
self.norm1 = tp.GroupNorm(32, channels, dtype=tp.float32)
self.conv1 = tp.Conv(channels, out_channels, (3, 3), padding=((1, 1), (1, 1)), dtype=config.dtype)
self.time_emb_proj = tp.Linear(emb_channels, out_channels, dtype=config.dtype)
self.norm2 = tp.GroupNorm(32, out_channels, dtype=config.dtype)
self.norm2 = tp.GroupNorm(32, out_channels, dtype=tp.float32)
self.conv2 = tp.Conv(out_channels, out_channels, (3, 3), padding=((1, 1), (1, 1)), dtype=config.dtype)
self.nonlinearity = tp.silu
self.conv_shortcut = tp.Conv(channels, out_channels, (1, 1), dtype=config.dtype) if channels != out_channels else lambda x: x

def __call__(self, x, emb):
h = self.conv1(self.nonlinearity(self.norm1(x)))
h = tp.cast(self.norm1(tp.cast(x, self.norm1.dtype)), x.dtype)
h = self.conv1(self.nonlinearity(h))
emb_out = self.time_emb_proj(self.nonlinearity(emb))
target_shape = emb_out.shape + (1, 1)
# TODO: #228: WAR to prevent computing output rank in infer_rank for reshape
target_shape.trace_tensor.shape = (emb_out.rank + 2,)
h = h + tp.reshape(emb_out, target_shape)
h = self.conv2(self.nonlinearity(self.norm2(h)))
h = tp.cast(self.norm2(tp.cast(h, self.norm2.dtype)), h.dtype)
h = self.conv2(self.nonlinearity(h))
ret = self.conv_shortcut(x) + h
return ret

Expand All @@ -67,14 +70,15 @@ def __init__(self, config: UNetConfig, query_dim, context_dim, n_heads, d_head):
self.num_heads = n_heads
self.head_size = d_head
self.to_out = [tp.Linear(n_heads * d_head, query_dim, dtype=config.dtype)]
self.dtype = config.dtype

def __call__(self, x, context=None):
context = x if context is None else context
q, k, v = self.to_q(x), self.to_k(context), self.to_v(context)
q, k, v = [
tp.transpose(tp.reshape(y, (x.shape[0], -1, self.num_heads, self.head_size)), 1, 2) for y in (q, k, v)
]
attention = tp.transpose(scaled_dot_product_attention(q, k, v, embedding_dim=self.head_size), 1, 2)
attention = tp.transpose(scaled_dot_product_attention(q, k, v, embedding_dim=self.head_size, dtype=self.dtype), 1, 2)
h_ = tp.reshape(attention, (x.shape[0], -1, self.num_heads * self.head_size))
out = sequential(h_, self.to_out)
return out
Expand Down Expand Up @@ -116,20 +120,20 @@ def __init__(self, config, dim, context_dim, n_heads, d_head):
self.attn1 = CrossAttention(config, dim, dim, n_heads, d_head)
self.ff = FeedForward(config, dim)
self.attn2 = CrossAttention(config, dim, context_dim, n_heads, d_head)
self.norm1 = tp.LayerNorm(dim, dtype=config.dtype)
self.norm2 = tp.LayerNorm(dim, dtype=config.dtype)
self.norm3 = tp.LayerNorm(dim, dtype=config.dtype)
self.norm1 = tp.LayerNorm(dim, dtype=tp.float32)
self.norm2 = tp.LayerNorm(dim, dtype=tp.float32)
self.norm3 = tp.LayerNorm(dim, dtype=tp.float32)

def __call__(self, x, context=None):
x = self.attn1(self.norm1(x)) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
x = self.attn1(tp.cast(self.norm1(tp.cast(x, self.norm1.dtype)), x.dtype)) + x
x = self.attn2(tp.cast(self.norm2(tp.cast(x, self.norm2.dtype)), x.dtype), context=context) + x
x = self.ff(tp.cast(self.norm3(tp.cast(x, self.norm3.dtype)), x.dtype)) + x
return x


class SpatialTransformer(tp.Module): # Transformer2dModel in HF diffusers
def __init__(self, config: UNetConfig, channels, context_dim, n_heads, d_head):
self.norm = tp.GroupNorm(32, channels, dtype=config.dtype)
self.norm = tp.GroupNorm(32, channels, dtype=tp.float32)
assert channels == n_heads * d_head
self.proj_in = tp.Conv(channels, n_heads * d_head, (1, 1), dtype=config.dtype)
self.transformer_blocks = [BasicTransformerBlock(config, channels, context_dim, n_heads, d_head)]
Expand All @@ -138,7 +142,7 @@ def __init__(self, config: UNetConfig, channels, context_dim, n_heads, d_head):
def __call__(self, x, context=None):
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
x = tp.cast(self.norm(tp.cast(x, self.norm.dtype)), x.dtype)
x = self.proj_in(x)
x = tp.permute(tp.reshape(x, (b, c, h * w)), (0, 2, 1))
for block in self.transformer_blocks:
Expand Down Expand Up @@ -272,15 +276,14 @@ def __init__(self, config: UNetConfig):
CrossAttnUpBlock2D(config, up_channels[2:5], down_channels[2]),
CrossAttnUpBlock2D(config, up_channels[4:7], down_channels[1], use_upsampler=False),
]
self.conv_norm_out = tp.GroupNorm(32, config.model_channels, dtype=config.dtype)
self.conv_norm_out = tp.GroupNorm(32, config.model_channels, dtype=tp.float32)
self.conv_act = tp.silu
self.conv_out = tp.Conv(config.model_channels, config.io_channels, (3, 3), padding=((1, 1), (1, 1)), dtype=config.dtype)

def __call__(self, x, timesteps=None, context=None):
# TODO: real time embedding
t_emb = timestep_embedding(timesteps, self.config.model_channels, self.config.dtype)
emb = self.time_embedding(t_emb)

x = self.conv_in(x)
saved_inputs = [x]

Expand All @@ -301,6 +304,7 @@ def __call__(self, x, timesteps=None, context=None):
else:
x = block(x, emb, context, partial_inputs)

act = self.conv_out(self.conv_act(self.conv_norm_out(x)))
act = tp.cast(self.conv_norm_out(tp.cast(x, self.conv_norm_out.dtype)), x.dtype)
act = self.conv_out(self.conv_act(act))
return act

Loading

0 comments on commit 5142a42

Please sign in to comment.