From 8fcb24944da9bde75c81aa9b9cea9e6ad465c8a9 Mon Sep 17 00:00:00 2001 From: enzymezoo-code Date: Fri, 22 Apr 2022 22:52:09 -0500 Subject: [PATCH 1/2] Added ViT-L/14@336px --- clip_jax/clip.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/clip_jax/clip.py b/clip_jax/clip.py index 62a02d066..6d6398845 100644 --- a/clip_jax/clip.py +++ b/clip_jax/clip.py @@ -26,6 +26,7 @@ "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", + "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", } @@ -177,7 +178,10 @@ def text_jax(text): rng_key = jax.random.PRNGKey(42) transformed = hk.transform(clip_jax) - jax_params = transformed.init(rng=rng_key, image=jnp.zeros((1, 3, 224, 224)), text=jnp.zeros((1, 77), dtype=jnp.int16)) + if name == "ViT-L/14@336px": + jax_params = transformed.init(rng=rng_key, image=jnp.zeros((1, 3, 336, 336)), text=jnp.zeros((1, 77), dtype=jnp.int16)) + else: + jax_params = transformed.init(rng=rng_key, image=jnp.zeros((1, 3, 224, 224)), text=jnp.zeros((1, 77), dtype=jnp.int16)) jax_params = convert_params(state_dict, jax_params) image_fn = hk.without_apply_rng(hk.transform(vit_jax)).apply From b0c6691a0a268933d0d88c06fcc39a8807e7ddd2 Mon Sep 17 00:00:00 2001 From: enzymezoo-code Date: Sat, 23 Apr 2022 10:13:12 -0500 Subject: [PATCH 2/2] Added tests for ViT-L/14@336px --- tests/test_consistency.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_consistency.py b/tests/test_consistency.py index ad9957422..95fe1f98c 100644 --- a/tests/test_consistency.py +++ b/tests/test_consistency.py @@ -29,3 +29,4 @@ def test_model(model_name): test_model("ViT-B/32") test_model("ViT-B/16") test_model("ViT-L/14") +test_model("ViT-L/14@336px")