Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Stable Diffusion demo #100

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft

Conversation

akhilg-nv
Copy link
Collaborator

Cleanup for Tripy inference, torch inference pipeline, and accuracy test are still WIP. README is currently minimal as well.

tripy/examples/diffusion/model.py Outdated Show resolved Hide resolved
tripy/examples/diffusion/model.py Outdated Show resolved Hide resolved
tripy/examples/diffusion/model.py Outdated Show resolved Hide resolved
tripy/examples/diffusion/model.py Outdated Show resolved Hide resolved
tripy/examples/diffusion/model.py Outdated Show resolved Hide resolved
@akhilg-nv akhilg-nv marked this pull request as draft August 14, 2024 17:16
- Paper: https://arxiv.org/abs/1706.03762v7
"""

if is_causal: # this path is not called in demoDiffusion
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we drop this if it's not called?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have left it for now in case we want to add it as part of the API somewhere

return tp.cast(tp.softmax((qk + attn_mask) if attn_mask is not None else qk, -1), query.dtype) @ value


def sequential(input: tp.Tensor, ll: List[Callable[[tp.Tensor], tp.Tensor]]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's worth making this part of the API, similar to torch.nn.Sequential.

@pranavm-nvidia pranavm-nvidia added the tripy Pull request for the tripy project label Aug 14, 2024
@akhilg-nv akhilg-nv force-pushed the dev-akhilg-demo-diffusion branch 3 times, most recently from 0d3c2cd to c7c81bd Compare August 29, 2024 22:02
Signed-off-by: Akhil Goel <[email protected]>
Signed-off-by: Akhil Goel <[email protected]>
Root cause: Index for denoising timesteps were reversed while
refactoring.

Signed-off-by: Akhil Goel <[email protected]>
Signed-off-by: Akhil Goel <[email protected]>
Signed-off-by: Akhil Goel <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
tripy Pull request for the tripy project
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants