Skip to content

Commit

Permalink
Add fp16 support
Browse files Browse the repository at this point in the history
Signed-off-by: Akhil Goel <[email protected]>
  • Loading branch information
akhilg-nv committed Sep 12, 2024
1 parent 56884c9 commit 9602126
Show file tree
Hide file tree
Showing 7 changed files with 183 additions and 159 deletions.
23 changes: 12 additions & 11 deletions tripy/examples/diffusion/clip_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,12 @@ class CLIPConfig:
num_heads: int = 12
max_seq_len: int = 77
num_hidden_layers: int = 12
dtype: tp.dtype = tp.float16

class CLIPMLP(tp.Module):
def __init__(self, config: CLIPConfig):
self.fc1 = tp.Linear(config.embedding_size, config.embedding_size * 4)
self.fc2 = tp.Linear(config.embedding_size * 4, config.embedding_size)
self.fc1 = tp.Linear(config.embedding_size, config.embedding_size * 4, dtype=config.dtype)
self.fc2 = tp.Linear(config.embedding_size * 4, config.embedding_size, dtype=config.dtype)

def __call__(self, hidden_states):
hidden_states = self.fc1(hidden_states)
Expand All @@ -47,10 +48,10 @@ def __init__(self, config: CLIPConfig):
self.embed_dim = config.embedding_size
self.num_heads = config.num_heads
self.head_dim = self.embed_dim // self.num_heads
self.k_proj = tp.Linear(self.embed_dim, self.embed_dim)
self.v_proj = tp.Linear(self.embed_dim, self.embed_dim)
self.q_proj = tp.Linear(self.embed_dim, self.embed_dim)
self.out_proj = tp.Linear(self.embed_dim, self.embed_dim)
self.k_proj = tp.Linear(self.embed_dim, self.embed_dim, dtype=config.dtype)
self.v_proj = tp.Linear(self.embed_dim, self.embed_dim, dtype=config.dtype)
self.q_proj = tp.Linear(self.embed_dim, self.embed_dim, dtype=config.dtype)
self.out_proj = tp.Linear(self.embed_dim, self.embed_dim, dtype=config.dtype)

def __call__(self, hidden_states, causal_attention_mask):
bsz, tgt_len, embed_dim = hidden_states.shape[0], hidden_states.shape[1], hidden_states.shape[2]
Expand All @@ -73,9 +74,9 @@ def __call__(self, hidden_states, causal_attention_mask):
class CLIPEncoderLayer(tp.Module):
def __init__(self, config: CLIPConfig):
self.self_attn = CLIPAttention(config)
self.layer_norm1 = tp.LayerNorm(config.embedding_size)
self.layer_norm1 = tp.LayerNorm(config.embedding_size, dtype=config.dtype)
self.mlp = CLIPMLP(config)
self.layer_norm2 = tp.LayerNorm(config.embedding_size)
self.layer_norm2 = tp.LayerNorm(config.embedding_size, dtype=config.dtype)

def __call__(self, hidden_states, causal_attention_mask):
residual = hidden_states
Expand Down Expand Up @@ -103,8 +104,8 @@ def __call__(self, hidden_states, causal_attention_mask):

class CLIPTextEmbeddings(tp.Module):
def __init__(self, config: CLIPConfig):
self.token_embedding = tp.Embedding(config.vocab_size, config.embedding_size)
self.position_embedding = tp.Embedding(config.max_seq_len, config.embedding_size)
self.token_embedding = tp.Embedding(config.vocab_size, config.embedding_size, dtype=config.dtype)
self.position_embedding = tp.Embedding(config.max_seq_len, config.embedding_size, dtype=config.dtype)

def __call__(self, input_ids, position_ids):
return self.token_embedding(input_ids) + self.position_embedding(position_ids)
Expand All @@ -114,7 +115,7 @@ class CLIPTextTransformer(tp.Module):
def __init__(self, config: CLIPConfig):
self.embeddings = CLIPTextEmbeddings(config)
self.encoder = CLIPEncoder(config)
self.final_layer_norm = tp.LayerNorm(config.embedding_size)
self.final_layer_norm = tp.LayerNorm(config.embedding_size, dtype=config.dtype)
self.max_seq_len = config.max_seq_len

def __call__(self, input_ids):
Expand Down
30 changes: 15 additions & 15 deletions tripy/examples/diffusion/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,29 +47,29 @@ def compile_model(model, inputs, verbose=False):
return compiled_model


def compile_clip(model, verbose=False):
inputs = (tp.InputInfo((1, 77), dtype=tp.int32),)
def compile_clip(model, dtype=tp.int32, verbose=False):
inputs = (tp.InputInfo((1, 77), dtype=dtype),)
return compile_model(model, inputs, verbose=verbose)


def compile_unet(model, verbose=False):
def compile_unet(model, dtype=tp.float16, verbose=False):
unconditional_context_shape = (1, 77, 768)
conditional_context_shape = (1, 77, 768)
latent_shape = (1, 4, 64, 64)
inputs = (
tp.InputInfo(unconditional_context_shape, dtype=tp.float32),
tp.InputInfo(conditional_context_shape, dtype=tp.float32),
tp.InputInfo(latent_shape, dtype=tp.float32),
tp.InputInfo((1,), dtype=tp.float32),
tp.InputInfo((1,), dtype=tp.float32),
tp.InputInfo((1,), dtype=tp.float32),
tp.InputInfo((1,), dtype=tp.float32),
tp.InputInfo(unconditional_context_shape, dtype=dtype),
tp.InputInfo(conditional_context_shape, dtype=dtype),
tp.InputInfo(latent_shape, dtype=dtype),
tp.InputInfo((1,), dtype=dtype),
tp.InputInfo((1,), dtype=dtype),
tp.InputInfo((1,), dtype=dtype),
tp.InputInfo((1,), dtype=dtype),
)
return compile_model(model, inputs, verbose=verbose)


def compile_vae(model, verbose=False):
inputs = (tp.InputInfo((1, 4, 64, 64), dtype=tp.float32),)
def compile_vae(model, dtype=tp.float16, verbose=False):
inputs = (tp.InputInfo((1, 4, 64, 64), dtype=dtype),)
return compile_model(model, inputs, verbose=verbose)


Expand Down Expand Up @@ -103,9 +103,9 @@ def tripy_diffusion(args):
unet_compiled = tp.Executable.load(os.path.join("engines", "unet_executable.json"))
vae_compiled = tp.Executable.load(os.path.join("engines", "vae_executable.json"))
else:
model = StableDiffusion(StableDiffusionConfig)
model = StableDiffusion(StableDiffusionConfig(dtype=tp.float16))
print("[I] Loading model weights...", flush=True)
load_from_diffusers(model, tp.float32, debug=True)
load_from_diffusers(model, tp.float16, debug=True)
clip_compiled = compile_clip(model.cond_stage_model.transformer.text_model, verbose=True)
unet_compiled = compile_unet(model, verbose=True)
vae_compiled = compile_vae(model.decode, verbose=True)
Expand Down Expand Up @@ -174,7 +174,7 @@ def hf_diffusion(args):
run_start_time = time.perf_counter()

# Initialize models
model_id = "runwayml/stable-diffusion-v1-5"
model_id = "CompVis/stable-diffusion-v1-4" #"benjamin-paine/stable-diffusion-v1-5" #"runwayml/stable-diffusion-v1-5"
clip_id = "openai/clip-vit-large-patch14"

print("[I] Loading models...")
Expand Down
6 changes: 4 additions & 2 deletions tripy/examples/diffusion/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def scaled_dot_product_attention(
embedding_dim: Optional[int] = None,
attn_mask: Optional[tp.Tensor] = None,
is_causal: bool = False,
dtype: tp.dtype = tp.float16
) -> tp.Tensor:
"""
Computes scaled dot-product attention.
Expand All @@ -20,14 +21,15 @@ def scaled_dot_product_attention(
- Described: https://paperswithcode.com/method/scaled
- Paper: https://arxiv.org/abs/1706.03762v7
"""

if is_causal: # this path is not called in demoDiffusion
target_shape = query.shape[-2:-1] + key.shape[-2:-1]
# TODO: #228: WAR to prevent computing output rank in infer_rank for reshape
target_shape.trace_tensor.shape = (2,)
attn_mask = tp.cast(tp.tril(tp.ones(target_shape)), tp.bool)
if attn_mask is not None and attn_mask.dtype == tp.bool:
attn_mask = tp.where((attn_mask == 0), tp.ones_like(attn_mask) * -float("inf"), tp.zeros_like(attn_mask))
attn_mask = tp.where((attn_mask == 0), tp.ones_like(attn_mask, dtype=dtype) * -float("inf"), tp.zeros_like(attn_mask, dtype=dtype))
if attn_mask is not None:
attn_mask = tp.cast(attn_mask, dtype)
qk = query @ tp.transpose(key, -2, -1) / math.sqrt(embedding_dim)
return tp.cast(tp.softmax((qk + attn_mask) if attn_mask is not None else qk, -1), query.dtype) @ value

Expand Down
17 changes: 12 additions & 5 deletions tripy/examples/diffusion/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,25 @@

import numpy as np
import tripy as tp
from dataclasses import dataclass
from typing import Optional
from dataclasses import dataclass, field

from examples.diffusion.clip_model import CLIPTextTransformer, CLIPConfig
from examples.diffusion.unet_model import UNetModel, UNet15Config
from examples.diffusion.unet_model import UNetModel, UNetConfig
from examples.diffusion.vae_model import AutoencoderKL, VAEConfig
from examples.diffusion.helper import clamp

@dataclass
class StableDiffusionConfig:
clip_config: CLIPConfig = CLIPConfig()
unet_config: UNet15Config = UNet15Config()
vae_config: VAEConfig = VAEConfig()
dtype: tp.dtype = tp.float16
clip_config: Optional[CLIPConfig] = field(default=None, init=False)
unet_config: Optional[UNetConfig] = field(default=None, init=False)
vae_config: Optional[VAEConfig] = field(default=None, init=False)

def __post_init__(self):
self.clip_config = CLIPConfig(dtype=self.dtype)
self.unet_config = UNetConfig(dtype=self.dtype)
self.vae_config = VAEConfig(dtype=self.dtype)

# equivalent to LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
def get_alphas_cumprod(beta_start=0.00085, beta_end=0.0120, n_training_steps=1000):
Expand Down
Loading

0 comments on commit 9602126

Please sign in to comment.