Skip to content

ViT (Vision Transformer)

The ViT (Vision Transformer) is a transformer-based neural network architecture for image classification. It divides an image into fixed-size patches, linearly embeds each patch, adds position embeddings, and processes the resulting sequence of vectors through a standard transformer encoder.

The ViT model was introduced in the paper "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" and has shown strong performance on image classification benchmarks.

jimm.models.vit.VisionTransformer

Bases: Module

Vision Transformer (ViT) model for image classification.

This implements the Vision Transformer as described in the paper "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale"

Source code in src/jimm/models/vit.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
class VisionTransformer(nnx.Module):
    """Vision Transformer (ViT) model for image classification.

    This implements the Vision Transformer as described in the paper
    "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale"
    """

    def __init__(
        self,
        num_classes: int = 1000,
        in_channels: int = 3,
        img_size: int = 224,
        patch_size: int = 16,
        num_layers: int = 12,
        num_heads: int = 12,
        mlp_dim: int = 3072,
        hidden_size: int = 768,
        dropout_rate: float = 0.1,
        use_quick_gelu: bool = False,
        do_classification: bool = True,
        dtype: DTypeLike = jnp.float32,
        param_dtype: DTypeLike = jnp.float32,
        rngs: nnx.Rngs = nnx.Rngs(0),
        mesh: Mesh | None = None,
    ) -> None:
        """Initialize a Vision Transformer.

        Args:
            num_classes (int): Number of output classes. Defaults to 1000.
            in_channels (int): Number of input channels. Defaults to 3.
            img_size (int): Size of the input image (assumed square). Defaults to 224.
            patch_size (int): Size of each patch (assumed square). Defaults to 16.
            num_layers (int): Number of transformer layers. Defaults to 12.
            num_heads (int): Number of attention heads. Defaults to 12.
            mlp_dim (int): Size of the MLP dimension. Defaults to 3072.
            hidden_size (int): Size of the hidden dimension. Defaults to 768.
            dropout_rate (float): Dropout rate. Defaults to 0.1.
            use_quick_gelu (bool): Whether to use quickgelu instead of gelu. Defaults to False.
            do_classification (bool): Whether to include the final classification head. Defaults to True.
            dtype (DTypeLike): Data type for computations. Defaults to jnp.float32.
            param_dtype (DTypeLike): Data type for parameters. Defaults to jnp.float32.
            rngs (nnx.Rngs): Random number generator keys. Defaults to nnx.Rngs(0).
            mesh (Mesh|None): Optional JAX device mesh for parameter sharding. Defaults to None.
        """
        self.do_classification = do_classification
        self.encoder = VisionTransformerBase(
            img_size=img_size,
            patch_size=patch_size,
            in_channels=in_channels,
            hidden_size=hidden_size,
            num_layers=num_layers,
            num_heads=num_heads,
            mlp_dim=mlp_dim,
            dropout_rate=dropout_rate,
            use_quick_gelu=use_quick_gelu,
            use_pre_norm=False,
            use_patch_bias=True,
            layernorm_epsilon=1e-12,
            rngs=rngs,
            dtype=dtype,
            param_dtype=param_dtype,
            mesh=mesh,
        )

        if self.do_classification:
            self.classifier = nnx.Linear(
                hidden_size,
                num_classes,
                dtype=dtype,
                param_dtype=param_dtype,
                rngs=rngs,
                kernel_init=sharded_init(nnx.initializers.xavier_uniform(), P(None, "model"), mesh),
                bias_init=sharded_init(nnx.initializers.zeros_init(), P("model"), mesh),
            )

    def __call__(self, x: Float[Array, "batch height width channels"]) -> Float[Array, "batch num_classes"]:
        """Forward pass of the Vision Transformer.

        Args:
            x (Float[Array, "batch height width channels"]): Input tensor with shape [batch, height, width, channels]

        Returns:
            Float[Array, "batch num_classes"]: Output logits with shape [batch, num_classes]
        """
        x = self.encoder(x)
        if self.do_classification:
            return self.classifier(x)
        return x

    @classmethod
    def from_pretrained(cls, model_name_or_path: str, use_pytorch: bool = False, mesh: Mesh | None = None, dtype: DTypeLike = jnp.float32) -> "VisionTransformer":
        """Load a pretrained Vision Transformer from a local path or HuggingFace Hub.

        Args:
            model_name_or_path (str): Path to local weights or HuggingFace model ID.
            use_pytorch (bool): Whether to load from PyTorch weights. Defaults to False.
            mesh (Mesh|None): Optional device mesh for parameter sharding. Defaults to None.
            dtype (DTypeLike): Data type for computations. Defaults to jnp.float32.

        Returns:
            VisionTransformer: Initialized Vision Transformer with pretrained weights
        """
        params_fstate, config_dict = load_params_and_config(model_name_or_path, use_pytorch)

        config: dict[str, Any] | None = config_dict

        hidden_size_val: int
        num_classes_val: int
        num_layers_val: int
        num_heads_val: int
        mlp_dim_val: int
        patch_size_val: int
        img_size_val: int
        use_quick_gelu_val: bool = False

        if config:
            hidden_size_val = config["hidden_size"]
            num_classes_val = len(config["id2label"]) if "id2label" in config else config.get("num_labels", 1000)
            num_layers_val = config["num_hidden_layers"]
            num_heads_val = config["num_attention_heads"]
            mlp_dim_val = config["intermediate_size"]
            patch_size_val = config["patch_size"]
            img_size_val = config["image_size"]
            if "hidden_act" in config and config["hidden_act"] == "quick_gelu":
                use_quick_gelu_val = True
            elif "hidden_act" in config and config["hidden_act"] != "gelu":
                print(f"Warning: Unexpected hidden_act '{config['hidden_act']}' in config, defaulting to standard GELU.")

        elif not use_pytorch and (os.path.exists(model_name_or_path) and os.path.isfile(model_name_or_path)):
            hidden_size_val = params_fstate["vit.embeddings.cls_token"].shape[-1]
            num_classes_val = params_fstate["classifier.bias"].shape[0]

            max_layer_idx = -1
            for k in params_fstate:
                if k.startswith("vit.encoder.layer."):
                    max_layer_idx = max(max_layer_idx, int(k.split(".")[3]))
            num_layers_val = max_layer_idx + 1

            mlp_dim_val = params_fstate["vit.encoder.layer.0.intermediate.dense.weight"].shape[0]

            assumed_head_dim = 64
            num_heads_val = hidden_size_val // assumed_head_dim

            patch_kernel_shape = params_fstate["vit.embeddings.patch_embeddings.projection.weight"].shape
            patch_size_val = patch_kernel_shape[2]

            num_patches_from_embeddings = params_fstate["vit.embeddings.position_embeddings"].shape[1] - 1
            img_size_dim = int(jnp.sqrt(num_patches_from_embeddings))
            img_size_val = img_size_dim * patch_size_val
        else:
            raise ValueError(f"Could not load or infer configuration for {model_name_or_path}")

        if not all(v is not None for v in [hidden_size_val, num_classes_val, num_layers_val, num_heads_val, mlp_dim_val, patch_size_val, img_size_val]):
            raise ValueError(f"One or more configuration parameters could not be determined for {model_name_or_path}")

        model = cls(
            num_classes=num_classes_val,
            img_size=img_size_val,
            patch_size=patch_size_val,
            num_layers=num_layers_val,
            num_heads=num_heads_val,
            mlp_dim=mlp_dim_val,
            hidden_size=hidden_size_val,
            use_quick_gelu=use_quick_gelu_val,
            mesh=mesh,
            dtype=dtype,
            param_dtype=dtype,
        )

        flax_model_params_fstate = dict(nnx.to_flat_state(nnx.state(model, nnx.Param)))

        def hf_param_name(name: str) -> str:
            return "weight" if name in ["kernel", "scale"] else name

        hidden_size_per_head = hidden_size_val // num_heads_val

        mapping_list = [
            (("encoder", "cls_token"), ("vit", "embeddings", "cls_token")),
            (("encoder", "position_embeddings"), ("vit", "embeddings", "position_embeddings")),
            (("encoder", "patch_embeddings", "kernel"), ("vit", "embeddings", "patch_embeddings", "projection", "weight")),
            (("encoder", "patch_embeddings", "bias"), ("vit", "embeddings", "patch_embeddings", "projection", "bias")),
            (("classifier", "kernel"), ("classifier", "weight")),
            (("classifier", "bias"), ("classifier", "bias")),
            (("encoder", "ln_post", "scale"), ("vit", "layernorm", "weight")),
            (("encoder", "ln_post", "bias"), ("vit", "layernorm", "bias")),
        ]

        for i in range(num_layers_val):
            flax_base = ("encoder", "transformer", "blocks", "layers", i)
            hf_base = ("vit", "encoder", "layer", str(i))
            mapping_list.extend(
                [(flax_base + ("attn", y_type, p_name), hf_base + ("attention", "attention", y_type, hf_param_name(p_name))) for p_name in ["kernel", "bias"] for y_type in ["key", "value", "query"]]
            )
            mapping_list.extend([(flax_base + ("attn", "out", p_name), hf_base + ("attention", "output", "dense", hf_param_name(p_name))) for p_name in ["kernel", "bias"]])
            mapping_list.extend(
                [
                    (flax_base + ("mlp", "layers", y1_idx, p_name), hf_base + (y2_name, "dense", hf_param_name(p_name)))
                    for p_name in ["kernel", "bias"]
                    for y1_idx, y2_name in [(0, "intermediate"), (3, "output")]
                ]
            )
            mapping_list.extend(
                [
                    (flax_base + (norm_flax, p_name), hf_base + (norm_hf, hf_param_name(p_name)))
                    for p_name in ["scale", "bias"]
                    for norm_flax, norm_hf in [("norm1", "layernorm_before"), ("norm2", "layernorm_after")]
                ]
            )
        params_name_mapping = dict(mapping_list)
        nonvisited = set(flax_model_params_fstate.keys())
        used_hf_keys: Set[str] = set()

        for flax_dst_key_tuple, hf_src_key_tuple in params_name_mapping.items():
            assert flax_dst_key_tuple in flax_model_params_fstate, flax_dst_key_tuple
            hf_src_key_as_string = ".".join(hf_src_key_tuple)
            used_hf_keys.add(hf_src_key_as_string)
            assert hf_src_key_as_string in params_fstate, f"HF key '{hf_src_key_as_string}' (from Flax key {flax_dst_key_tuple}) not found in loaded weights."
            nonvisited.remove(flax_dst_key_tuple)
            src_value: Array = params_fstate[hf_src_key_as_string]

            dst_value_obj = flax_model_params_fstate[flax_dst_key_tuple]
            original_param_sharding = dst_value_obj.value.sharding

            if flax_dst_key_tuple == ("encoder", "patch_embeddings", "kernel"):
                src_value = jnp.transpose(src_value, (2, 3, 1, 0))
            elif hf_src_key_tuple[-1] == "weight" and hf_src_key_tuple[-2] in ("key", "value", "query"):
                src_value = jnp.transpose(src_value, (1, 0))
                src_value = src_value.reshape((hidden_size_val, num_heads_val, hidden_size_per_head))
            elif hf_src_key_tuple[-1] == "bias" and hf_src_key_tuple[-2] in ("key", "value", "query"):
                src_value = src_value.reshape((num_heads_val, hidden_size_per_head))
            elif hf_src_key_tuple[-4:] == ("attention", "output", "dense", "weight"):
                src_value = jnp.transpose(src_value, (1, 0))
                src_value = src_value.reshape((num_heads_val, hidden_size_per_head, hidden_size_val))
            elif hf_src_key_tuple[-1] == "weight" and src_value.ndim == 2:
                src_value = jnp.transpose(src_value, (1, 0))

            assert src_value.shape == dst_value_obj.value.shape, f"Shape mismatch for {flax_dst_key_tuple} (Flax) vs {hf_src_key_as_string} (HF): {dst_value_obj.value.shape} != {src_value.shape}"

            sharded_new_value: Array = jax.device_put(src_value, original_param_sharding)
            dst_value_obj.value = sharded_new_value

            assert jnp.allclose(dst_value_obj.value.mean(), src_value.mean()), (dst_value_obj.value.mean(), src_value.mean())

        assert len(nonvisited) == 0, f"Some Flax model parameters were not visited: {nonvisited}"

        leftover_hf_keys = set(params_fstate.keys()) - used_hf_keys
        known_unused_hf_buffer_keys = {
            "text_model.embeddings.position_ids",
            "vision_model.embeddings.position_ids",
        }
        unexpected_leftover_hf_keys = leftover_hf_keys - known_unused_hf_buffer_keys

        assert len(unexpected_leftover_hf_keys) == 0, f"Some unexpected HuggingFace checkpoint parameters were not used: {sorted(list(unexpected_leftover_hf_keys))}"
        nnx.update(model, nnx.from_flat_state(flax_model_params_fstate))

        del flax_model_params_fstate
        del params_fstate
        return model

__call__(x)

Forward pass of the Vision Transformer.

Parameters:

Name Type Description Default
x Float[Array, 'batch height width channels']

Input tensor with shape [batch, height, width, channels]

required

Returns:

Type Description
Float[Array, 'batch num_classes']

Float[Array, "batch num_classes"]: Output logits with shape [batch, num_classes]

Source code in src/jimm/models/vit.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
def __call__(self, x: Float[Array, "batch height width channels"]) -> Float[Array, "batch num_classes"]:
    """Forward pass of the Vision Transformer.

    Args:
        x (Float[Array, "batch height width channels"]): Input tensor with shape [batch, height, width, channels]

    Returns:
        Float[Array, "batch num_classes"]: Output logits with shape [batch, num_classes]
    """
    x = self.encoder(x)
    if self.do_classification:
        return self.classifier(x)
    return x

__init__(num_classes=1000, in_channels=3, img_size=224, patch_size=16, num_layers=12, num_heads=12, mlp_dim=3072, hidden_size=768, dropout_rate=0.1, use_quick_gelu=False, do_classification=True, dtype=jnp.float32, param_dtype=jnp.float32, rngs=nnx.Rngs(0), mesh=None)

Initialize a Vision Transformer.

Parameters:

Name Type Description Default
num_classes int

Number of output classes. Defaults to 1000.

1000
in_channels int

Number of input channels. Defaults to 3.

3
img_size int

Size of the input image (assumed square). Defaults to 224.

224
patch_size int

Size of each patch (assumed square). Defaults to 16.

16
num_layers int

Number of transformer layers. Defaults to 12.

12
num_heads int

Number of attention heads. Defaults to 12.

12
mlp_dim int

Size of the MLP dimension. Defaults to 3072.

3072
hidden_size int

Size of the hidden dimension. Defaults to 768.

768
dropout_rate float

Dropout rate. Defaults to 0.1.

0.1
use_quick_gelu bool

Whether to use quickgelu instead of gelu. Defaults to False.

False
do_classification bool

Whether to include the final classification head. Defaults to True.

True
dtype DTypeLike

Data type for computations. Defaults to jnp.float32.

float32
param_dtype DTypeLike

Data type for parameters. Defaults to jnp.float32.

float32
rngs Rngs

Random number generator keys. Defaults to nnx.Rngs(0).

Rngs(0)
mesh Mesh | None

Optional JAX device mesh for parameter sharding. Defaults to None.

None
Source code in src/jimm/models/vit.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def __init__(
    self,
    num_classes: int = 1000,
    in_channels: int = 3,
    img_size: int = 224,
    patch_size: int = 16,
    num_layers: int = 12,
    num_heads: int = 12,
    mlp_dim: int = 3072,
    hidden_size: int = 768,
    dropout_rate: float = 0.1,
    use_quick_gelu: bool = False,
    do_classification: bool = True,
    dtype: DTypeLike = jnp.float32,
    param_dtype: DTypeLike = jnp.float32,
    rngs: nnx.Rngs = nnx.Rngs(0),
    mesh: Mesh | None = None,
) -> None:
    """Initialize a Vision Transformer.

    Args:
        num_classes (int): Number of output classes. Defaults to 1000.
        in_channels (int): Number of input channels. Defaults to 3.
        img_size (int): Size of the input image (assumed square). Defaults to 224.
        patch_size (int): Size of each patch (assumed square). Defaults to 16.
        num_layers (int): Number of transformer layers. Defaults to 12.
        num_heads (int): Number of attention heads. Defaults to 12.
        mlp_dim (int): Size of the MLP dimension. Defaults to 3072.
        hidden_size (int): Size of the hidden dimension. Defaults to 768.
        dropout_rate (float): Dropout rate. Defaults to 0.1.
        use_quick_gelu (bool): Whether to use quickgelu instead of gelu. Defaults to False.
        do_classification (bool): Whether to include the final classification head. Defaults to True.
        dtype (DTypeLike): Data type for computations. Defaults to jnp.float32.
        param_dtype (DTypeLike): Data type for parameters. Defaults to jnp.float32.
        rngs (nnx.Rngs): Random number generator keys. Defaults to nnx.Rngs(0).
        mesh (Mesh|None): Optional JAX device mesh for parameter sharding. Defaults to None.
    """
    self.do_classification = do_classification
    self.encoder = VisionTransformerBase(
        img_size=img_size,
        patch_size=patch_size,
        in_channels=in_channels,
        hidden_size=hidden_size,
        num_layers=num_layers,
        num_heads=num_heads,
        mlp_dim=mlp_dim,
        dropout_rate=dropout_rate,
        use_quick_gelu=use_quick_gelu,
        use_pre_norm=False,
        use_patch_bias=True,
        layernorm_epsilon=1e-12,
        rngs=rngs,
        dtype=dtype,
        param_dtype=param_dtype,
        mesh=mesh,
    )

    if self.do_classification:
        self.classifier = nnx.Linear(
            hidden_size,
            num_classes,
            dtype=dtype,
            param_dtype=param_dtype,
            rngs=rngs,
            kernel_init=sharded_init(nnx.initializers.xavier_uniform(), P(None, "model"), mesh),
            bias_init=sharded_init(nnx.initializers.zeros_init(), P("model"), mesh),
        )

from_pretrained(model_name_or_path, use_pytorch=False, mesh=None, dtype=jnp.float32) classmethod

Load a pretrained Vision Transformer from a local path or HuggingFace Hub.

Parameters:

Name Type Description Default
model_name_or_path str

Path to local weights or HuggingFace model ID.

required
use_pytorch bool

Whether to load from PyTorch weights. Defaults to False.

False
mesh Mesh | None

Optional device mesh for parameter sharding. Defaults to None.

None
dtype DTypeLike

Data type for computations. Defaults to jnp.float32.

float32

Returns:

Name Type Description
VisionTransformer VisionTransformer

Initialized Vision Transformer with pretrained weights

Source code in src/jimm/models/vit.py
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
@classmethod
def from_pretrained(cls, model_name_or_path: str, use_pytorch: bool = False, mesh: Mesh | None = None, dtype: DTypeLike = jnp.float32) -> "VisionTransformer":
    """Load a pretrained Vision Transformer from a local path or HuggingFace Hub.

    Args:
        model_name_or_path (str): Path to local weights or HuggingFace model ID.
        use_pytorch (bool): Whether to load from PyTorch weights. Defaults to False.
        mesh (Mesh|None): Optional device mesh for parameter sharding. Defaults to None.
        dtype (DTypeLike): Data type for computations. Defaults to jnp.float32.

    Returns:
        VisionTransformer: Initialized Vision Transformer with pretrained weights
    """
    params_fstate, config_dict = load_params_and_config(model_name_or_path, use_pytorch)

    config: dict[str, Any] | None = config_dict

    hidden_size_val: int
    num_classes_val: int
    num_layers_val: int
    num_heads_val: int
    mlp_dim_val: int
    patch_size_val: int
    img_size_val: int
    use_quick_gelu_val: bool = False

    if config:
        hidden_size_val = config["hidden_size"]
        num_classes_val = len(config["id2label"]) if "id2label" in config else config.get("num_labels", 1000)
        num_layers_val = config["num_hidden_layers"]
        num_heads_val = config["num_attention_heads"]
        mlp_dim_val = config["intermediate_size"]
        patch_size_val = config["patch_size"]
        img_size_val = config["image_size"]
        if "hidden_act" in config and config["hidden_act"] == "quick_gelu":
            use_quick_gelu_val = True
        elif "hidden_act" in config and config["hidden_act"] != "gelu":
            print(f"Warning: Unexpected hidden_act '{config['hidden_act']}' in config, defaulting to standard GELU.")

    elif not use_pytorch and (os.path.exists(model_name_or_path) and os.path.isfile(model_name_or_path)):
        hidden_size_val = params_fstate["vit.embeddings.cls_token"].shape[-1]
        num_classes_val = params_fstate["classifier.bias"].shape[0]

        max_layer_idx = -1
        for k in params_fstate:
            if k.startswith("vit.encoder.layer."):
                max_layer_idx = max(max_layer_idx, int(k.split(".")[3]))
        num_layers_val = max_layer_idx + 1

        mlp_dim_val = params_fstate["vit.encoder.layer.0.intermediate.dense.weight"].shape[0]

        assumed_head_dim = 64
        num_heads_val = hidden_size_val // assumed_head_dim

        patch_kernel_shape = params_fstate["vit.embeddings.patch_embeddings.projection.weight"].shape
        patch_size_val = patch_kernel_shape[2]

        num_patches_from_embeddings = params_fstate["vit.embeddings.position_embeddings"].shape[1] - 1
        img_size_dim = int(jnp.sqrt(num_patches_from_embeddings))
        img_size_val = img_size_dim * patch_size_val
    else:
        raise ValueError(f"Could not load or infer configuration for {model_name_or_path}")

    if not all(v is not None for v in [hidden_size_val, num_classes_val, num_layers_val, num_heads_val, mlp_dim_val, patch_size_val, img_size_val]):
        raise ValueError(f"One or more configuration parameters could not be determined for {model_name_or_path}")

    model = cls(
        num_classes=num_classes_val,
        img_size=img_size_val,
        patch_size=patch_size_val,
        num_layers=num_layers_val,
        num_heads=num_heads_val,
        mlp_dim=mlp_dim_val,
        hidden_size=hidden_size_val,
        use_quick_gelu=use_quick_gelu_val,
        mesh=mesh,
        dtype=dtype,
        param_dtype=dtype,
    )

    flax_model_params_fstate = dict(nnx.to_flat_state(nnx.state(model, nnx.Param)))

    def hf_param_name(name: str) -> str:
        return "weight" if name in ["kernel", "scale"] else name

    hidden_size_per_head = hidden_size_val // num_heads_val

    mapping_list = [
        (("encoder", "cls_token"), ("vit", "embeddings", "cls_token")),
        (("encoder", "position_embeddings"), ("vit", "embeddings", "position_embeddings")),
        (("encoder", "patch_embeddings", "kernel"), ("vit", "embeddings", "patch_embeddings", "projection", "weight")),
        (("encoder", "patch_embeddings", "bias"), ("vit", "embeddings", "patch_embeddings", "projection", "bias")),
        (("classifier", "kernel"), ("classifier", "weight")),
        (("classifier", "bias"), ("classifier", "bias")),
        (("encoder", "ln_post", "scale"), ("vit", "layernorm", "weight")),
        (("encoder", "ln_post", "bias"), ("vit", "layernorm", "bias")),
    ]

    for i in range(num_layers_val):
        flax_base = ("encoder", "transformer", "blocks", "layers", i)
        hf_base = ("vit", "encoder", "layer", str(i))
        mapping_list.extend(
            [(flax_base + ("attn", y_type, p_name), hf_base + ("attention", "attention", y_type, hf_param_name(p_name))) for p_name in ["kernel", "bias"] for y_type in ["key", "value", "query"]]
        )
        mapping_list.extend([(flax_base + ("attn", "out", p_name), hf_base + ("attention", "output", "dense", hf_param_name(p_name))) for p_name in ["kernel", "bias"]])
        mapping_list.extend(
            [
                (flax_base + ("mlp", "layers", y1_idx, p_name), hf_base + (y2_name, "dense", hf_param_name(p_name)))
                for p_name in ["kernel", "bias"]
                for y1_idx, y2_name in [(0, "intermediate"), (3, "output")]
            ]
        )
        mapping_list.extend(
            [
                (flax_base + (norm_flax, p_name), hf_base + (norm_hf, hf_param_name(p_name)))
                for p_name in ["scale", "bias"]
                for norm_flax, norm_hf in [("norm1", "layernorm_before"), ("norm2", "layernorm_after")]
            ]
        )
    params_name_mapping = dict(mapping_list)
    nonvisited = set(flax_model_params_fstate.keys())
    used_hf_keys: Set[str] = set()

    for flax_dst_key_tuple, hf_src_key_tuple in params_name_mapping.items():
        assert flax_dst_key_tuple in flax_model_params_fstate, flax_dst_key_tuple
        hf_src_key_as_string = ".".join(hf_src_key_tuple)
        used_hf_keys.add(hf_src_key_as_string)
        assert hf_src_key_as_string in params_fstate, f"HF key '{hf_src_key_as_string}' (from Flax key {flax_dst_key_tuple}) not found in loaded weights."
        nonvisited.remove(flax_dst_key_tuple)
        src_value: Array = params_fstate[hf_src_key_as_string]

        dst_value_obj = flax_model_params_fstate[flax_dst_key_tuple]
        original_param_sharding = dst_value_obj.value.sharding

        if flax_dst_key_tuple == ("encoder", "patch_embeddings", "kernel"):
            src_value = jnp.transpose(src_value, (2, 3, 1, 0))
        elif hf_src_key_tuple[-1] == "weight" and hf_src_key_tuple[-2] in ("key", "value", "query"):
            src_value = jnp.transpose(src_value, (1, 0))
            src_value = src_value.reshape((hidden_size_val, num_heads_val, hidden_size_per_head))
        elif hf_src_key_tuple[-1] == "bias" and hf_src_key_tuple[-2] in ("key", "value", "query"):
            src_value = src_value.reshape((num_heads_val, hidden_size_per_head))
        elif hf_src_key_tuple[-4:] == ("attention", "output", "dense", "weight"):
            src_value = jnp.transpose(src_value, (1, 0))
            src_value = src_value.reshape((num_heads_val, hidden_size_per_head, hidden_size_val))
        elif hf_src_key_tuple[-1] == "weight" and src_value.ndim == 2:
            src_value = jnp.transpose(src_value, (1, 0))

        assert src_value.shape == dst_value_obj.value.shape, f"Shape mismatch for {flax_dst_key_tuple} (Flax) vs {hf_src_key_as_string} (HF): {dst_value_obj.value.shape} != {src_value.shape}"

        sharded_new_value: Array = jax.device_put(src_value, original_param_sharding)
        dst_value_obj.value = sharded_new_value

        assert jnp.allclose(dst_value_obj.value.mean(), src_value.mean()), (dst_value_obj.value.mean(), src_value.mean())

    assert len(nonvisited) == 0, f"Some Flax model parameters were not visited: {nonvisited}"

    leftover_hf_keys = set(params_fstate.keys()) - used_hf_keys
    known_unused_hf_buffer_keys = {
        "text_model.embeddings.position_ids",
        "vision_model.embeddings.position_ids",
    }
    unexpected_leftover_hf_keys = leftover_hf_keys - known_unused_hf_buffer_keys

    assert len(unexpected_leftover_hf_keys) == 0, f"Some unexpected HuggingFace checkpoint parameters were not used: {sorted(list(unexpected_leftover_hf_keys))}"
    nnx.update(model, nnx.from_flat_state(flax_model_params_fstate))

    del flax_model_params_fstate
    del params_fstate
    return model