Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax jitted functions cloudpickled work but include some error messages #537

Open
bionicles opened this issue Jun 28, 2024 · 4 comments
Open

Comments

@bionicles
Copy link

problems

cloudpickle works for jax.jit functions but a visual inspection of the cloudpickle contents shows there's a lurking error message

challenges

not sure if this belongs in cloudpickle or jax

is this my bad? I was hopeful we could just use the string jaxpr in utf8, it's more human readable, but I don't know how to regenerate a PjitFunction from a jaxpr

opportunities

a fix could reduce the size of cloudpickled jax.jit functions

def test_jax_cloudpickle():
    def jnp_func(x):
        return jax.numpy.sin(jax.numpy.cos(x))

    jitted1 = jax.jit(jnp_func)
    del jnp_func  # this to ensure jitted2 can't cheat by recompiling jnp_func within a session
    assert "jnp_func" not in locals(), "failed to remove jnp_func"
    jitted1_buf = cloudpickle.dumps(jitted1)
    rprint(jitted1_buf)
    jitted2 = cloudpickle.loads(jitted1_buf)
    assert jitted1(0.3) == jitted2(0.3), "weird"
    assert b"TRACEBACK" not in jitted1_buf, "error message in cloudpickle of jax.jit"


test_jax_cloudpickle()

Could JAX_TRACEBACK_FILTERING= be greppable ?

image

thank you for making cloudpickle

@bionicles bionicles changed the title jax jitted functions cloudpickled work but include tracebacks in the buffer jax jitted functions cloudpickled work but include some error messages Jun 28, 2024
@bionicles
Copy link
Author

bionicles commented Jun 28, 2024

update: the methods change and a hash equality check fails, despite similar functionality before and after.

anybody know a good way to get this equality check passing?

        elif isinstance(value1, PjitFunction):
            if not isinstance(value2, PjitFunction):
                decision = False
            else:
                v1_hash = value1.__hash__
                v2_hash = value2.__hash__
                rprint(f"{type(value1)=}", dir(value1))
                rprint(f"{type(value2)=}", dir(value2))
                if v1_hash != v2_hash:
                    decision = False

image

@ogrisel
Copy link
Contributor

ogrisel commented Aug 5, 2024

I don't understand the intents of the snippet above: __hash__ is a method, not a precomputed attribute. Hence I would have expected v1_hash = value1.__hash__() or even v1_hash = hash(value1).

@ogrisel
Copy link
Contributor

ogrisel commented Aug 5, 2024

Cloud you please provide a full reproducing snippet for the code used in #537 (comment)?

@bionicles
Copy link
Author

bionicles commented Aug 5, 2024

ah that code's out of date, you're right that looks like i screwed up, good eye

just trying to do a round trip serialize/deserialize and wanted to be able to test the function is the same afterward, but the hashes change here

import cloudpickle
import jax.numpy as jnp
import jax


def test_cloudpickle_jax_jit():
    def f(x):
        return jnp.sin(jnp.cos(x))

    def g(x):
        return jnp.sin(jnp.cos(x)) + 1.0

    print(f"{f=} {type(f)=}")
    print(f"{g=} {type(g)=}")

    fjit = jax.jit(f)
    gjit = jax.jit(g)

    print(f"\n{fjit=}\n{type(fjit)=}")
    print(f"\n{gjit=}\n{type(gjit)=}")

    fjit_hash = hash(fjit)
    gjit_hash = hash(gjit)
    print(f"\n{fjit_hash=}")
    print(f"{gjit_hash=}")

    fjit_dump = cloudpickle.dumps(fjit)
    gjit_dump = cloudpickle.dumps(gjit)

    fjit_load = cloudpickle.loads(fjit_dump)
    gjit_load = cloudpickle.loads(gjit_dump)
    print(f"\n{fjit_load=}\n{type(fjit_load)=}")
    print(f"\n{gjit_load=}\n{type(gjit_load)=}")

    fjit_load_hash = hash(fjit_load)
    gjit_load_hash = hash(gjit_load)
    print(f"\n{fjit_load_hash=}")
    print(f"{gjit_load_hash=}")

    assert fjit_load_hash == fjit_hash
    assert gjit_load_hash == gjit_hash

looks like it loses this reference to the locals function, which is understandable, but also makes testing hard because we can't verify the function is the same afterward

of <function test_cloudpickle_jax_jit.<locals>.f at 0x7f0676f98fe0>
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants