#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
#           This file was automatically generated from src/transformers/models/minicpmv4_6/modular_minicpmv4_6.py.
#               Do NOT edit this file manually as any edits will be overwritten by the generation of
#             the file from the modular. If any change should be done, please apply the change to the
#                          modular_minicpmv4_6.py file directly. One of our CI enforces this.
#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# Copyright 2026 OpenBMB and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from collections.abc import Callable
from typing import Any

import torch
import torch.nn.functional as F
from torch import nn

from ... import initialization as init
from ...activations import ACT2FN, gelu_pytorch_tanh
from ...generation import GenerationMixin
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
    BaseModelOutput,
    BaseModelOutputWithPast,
    BaseModelOutputWithPooling,
    CausalLMOutputWithPast,
)
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring
from ...utils.generic import can_return_tuple, is_flash_attention_requested, merge_with_config_defaults
from ...utils.import_utils import torch_compilable_check
from ...utils.output_capturing import capture_outputs
from ..auto import AutoModel
from .configuration_minicpmv4_6 import MiniCPMV4_6Config, MiniCPMV4_6VisionConfig


class MiniCPMV4_6VisionEmbeddings(nn.Module):
    """
    This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable
    resolution.

    The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://huggingface.co/papers/2307.06304)
    which allows treating images in their native aspect ratio and without the need to resize them to the same
    fixed size. In particular, we start from the original pre-trained SigLIP model
    (which uses images of fixed-size square images) and adapt it by training on images of variable resolutions.
    """

    def __init__(self, config: MiniCPMV4_6VisionConfig):
        super().__init__()
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size

        self.patch_embedding = nn.Conv2d(
            in_channels=config.num_channels,
            out_channels=self.embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            padding="valid",
        )

        self.num_patches_per_side = self.image_size // self.patch_size
        self.num_patches = self.num_patches_per_side**2
        self.num_positions = self.num_patches
        self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)

    def forward(
        self,
        pixel_values: torch.FloatTensor,
        target_sizes: torch.IntTensor | None = None,
    ) -> torch.Tensor:
        patch_embeds = self.patch_embedding(pixel_values)
        embeddings = patch_embeds.flatten(2).transpose(1, 2)

        boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side)

        position_embeddings = []
        for target_size in target_sizes:
            nb_patches_h = target_size[0]
            nb_patches_w = target_size[1]

            fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
            fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)

            bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
            bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)

            pos_ids = (
                (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w)
                .flatten()
                .to(self.position_embedding.weight.device)
            )

            position_embeddings.append(self.position_embedding(pos_ids))

        position_embeddings = torch.concat(position_embeddings, dim=0).unsqueeze(0)
        embeddings = embeddings + position_embeddings
        return embeddings


class MiniCPMV4_6VisionMLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.activation_fn = ACT2FN[config.hidden_act]
        self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
        self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states = self.fc2(hidden_states)
        return hidden_states


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def eager_attention_forward(
    module: nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: torch.Tensor | None,
    scaling: float,
    dropout: float = 0.0,
    **kwargs,
):
    key_states = repeat_kv(key, module.num_key_value_groups)
    value_states = repeat_kv(value, module.num_key_value_groups)

    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
    if attention_mask is not None:
        attn_weights = attn_weights + attention_mask

    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
    attn_output = torch.matmul(attn_weights, value_states)
    attn_output = attn_output.transpose(1, 2).contiguous()

    return attn_output, attn_weights


class MiniCPMV4_6VisionAttention(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.dim // self.num_heads
        self.num_key_value_groups = 1  # needed for eager attention
        self.scaling = self.head_dim**-0.5
        self.config = config
        self.attention_dropout = config.attention_dropout
        self.is_causal = False
        if self.head_dim * self.num_heads != self.dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.dim} and `num_heads`:"
                f" {self.num_heads})."
            )

        self.k_proj = nn.Linear(self.dim, self.dim)
        self.v_proj = nn.Linear(self.dim, self.dim)
        self.q_proj = nn.Linear(self.dim, self.dim)
        self.out_proj = nn.Linear(self.dim, self.dim)

    def forward(
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor,
        max_seqlen: int,
        attention_mask: torch.Tensor | None = None,
        **kwargs,
    ) -> torch.Tensor:
        """Input shape: Batch x Time x Channel"""
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)

        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
            self.config._attn_implementation, eager_attention_forward
        )

        if is_flash_attention_requested(self.config):
            # Flash Attention: Use cu_seqlens for variable length attention
            attn_output, _ = attention_interface(
                self,
                query_states,
                key_states,
                value_states,
                attention_mask=None,
                scaling=self.scaling,
                dropout=0.0 if not self.training else self.attention_dropout,
                cu_seq_lens_q=cu_seqlens,
                cu_seq_lens_k=cu_seqlens,
                max_length_q=max_seqlen,
                max_length_k=max_seqlen,
                is_causal=False,
                **kwargs,
            )
        else:
            # Other implementations: Process each chunk separately
            lengths = cu_seqlens[1:] - cu_seqlens[:-1]
            splits = [
                torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
            ]

            attn_outputs = [
                attention_interface(
                    self,
                    q,
                    k,
                    v,
                    attention_mask=None,
                    scaling=self.scaling,
                    dropout=0.0 if not self.training else self.attention_dropout,
                    is_causal=False,
                    **kwargs,
                )[0]
                for q, k, v in zip(*splits)
            ]
            attn_output = torch.cat(attn_outputs, dim=1)

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.out_proj(attn_output)

        return attn_output, None


class MiniCPMV4_6VisionEncoderLayer(GradientCheckpointingLayer):
    def __init__(self, config: MiniCPMV4_6VisionConfig):
        super().__init__()
        self.embed_dim = config.hidden_size
        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
        self.self_attn = MiniCPMV4_6VisionAttention(config)
        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
        self.mlp = MiniCPMV4_6VisionMLP(config)

    @auto_docstring
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor,
        **kwargs: Unpack[TransformersKwargs],
    ) -> torch.FloatTensor:
        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
        hidden_states, _ = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            **kwargs,
        )
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states


class MiniCPMV4_6VisionEncoder(nn.Module):
    """Transformer encoder consisting of `config.num_hidden_layers` [`MiniCPMV4_6VisionEncoderLayer`] layers."""

    def __init__(self, config: MiniCPMV4_6VisionConfig):
        super().__init__()
        self.config = config
        self.layers = nn.ModuleList([MiniCPMV4_6VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
        self.gradient_checkpointing = False

    # Ignore copy
    @auto_docstring
    def forward(
        self,
        inputs_embeds,
        attention_mask: torch.Tensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> BaseModelOutput:
        hidden_states = inputs_embeds
        for encoder_layer in self.layers:
            hidden_states = encoder_layer(
                hidden_states,
                attention_mask,
                **kwargs,
            )

        return BaseModelOutput(last_hidden_state=hidden_states)


class MiniCPMV4_6ViTWindowAttentionMerger(nn.Module):
    def __init__(self, config: MiniCPMV4_6VisionConfig):
        super().__init__()
        self.window_kernel_size = tuple(config.window_kernel_size)
        self.embed_dim = config.hidden_size

        self.self_attn = MiniCPMV4_6VisionAttention(config)

        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)

        self.pre_norm = nn.LayerNorm(config.window_hidden_size, eps=config.layer_norm_eps)
        self.linear_1 = nn.Linear(config.window_hidden_size, config.window_intermediate_size, bias=True)
        self.act = gelu_pytorch_tanh
        self.linear_2 = nn.Linear(config.window_intermediate_size, self.embed_dim, bias=True)

    def _init_weights(self):
        """Block-diagonal normal init: preserves the structural prior that each
        2x2 window patch is processed independently at initialization.

        Uses ``init.*`` helpers so the ``_is_hf_initialized`` guard is
        respected.  The ``linear_1`` block-diagonal init writes to *slices*
        which do not inherit the flag, so we guard the entire block manually.
        """
        for proj in (self.self_attn.q_proj, self.self_attn.k_proj, self.self_attn.v_proj, self.self_attn.out_proj):
            init.normal_(proj.weight)
            init.zeros_(proj.bias)

        for ln in (self.layer_norm1, self.pre_norm):
            init.ones_(ln.weight)
            init.zeros_(ln.bias)

        hidden_size = self.embed_dim
        intermediate_size = self.linear_1.weight.shape[0] // 4
        if not getattr(self.linear_1.weight, "_is_hf_initialized", False):
            self.linear_1.weight.data.zero_()
            for i in range(4):
                self.linear_1.weight.data[
                    i * intermediate_size : (i + 1) * intermediate_size,
                    i * hidden_size : (i + 1) * hidden_size,
                ].normal_()
            self.linear_1.weight._is_hf_initialized = True
        init.normal_(self.linear_1.bias, std=1e-6)

        init.normal_(self.linear_2.weight, std=0.25)
        init.normal_(self.linear_2.bias, std=1e-6)

    def get_window_index(self, target_sizes):
        window_h, window_w = self.window_kernel_size
        max_seqlens = window_h * window_w

        window_index_list = []
        cu_seqlens = [0]
        token_offset = 0

        for height, width in target_sizes:
            # Cast 0-d device tensors to Python ints so that the whole function
            # stays CPU-side integer arithmetic. `torch.arange` without `device=`
            # always returns on CPU; mixing with a device-bound `token_offset`
            # raises in strict PyTorch versions (2.10+).
            height, width = int(height), int(width)
            if height % window_h != 0 or width % window_w != 0:
                raise ValueError(
                    f"height={height}, width={width} must be divisible by window size ({window_h}, {window_w})"
                )
            index = torch.arange(height * width).reshape(height, width)
            num_windows_h = height // window_h
            num_windows_w = width // window_w
            num_windows = num_windows_h * num_windows_w

            index = index.reshape(num_windows_h, window_h, num_windows_w, window_w)
            index = index.permute(0, 2, 1, 3).reshape(num_windows, window_h * window_w)

            window_index_list.append(index.reshape(-1) + token_offset)

            cu_this = torch.arange(1, num_windows + 1) * (window_h * window_w) + cu_seqlens[-1]
            cu_seqlens.extend(cu_this.tolist())

            token_offset += height * width

        window_index = torch.cat(window_index_list)
        cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32)

        return window_index, cu_seqlens, max_seqlens

    def forward(
        self,
        hidden_states: torch.Tensor,
        target_sizes: torch.IntTensor,
        cu_seqlens: torch.Tensor | None = None,
    ):
        residual = hidden_states
        hidden_states = self.layer_norm1(hidden_states)
        device = hidden_states.device

        window_index, window_cu_seqlens, window_max_seqlens = self.get_window_index(target_sizes)
        window_index = window_index.to(device)

        hidden_states = hidden_states[:, window_index, :]
        hidden_states, _ = self.self_attn(
            hidden_states=hidden_states,
            cu_seqlens=window_cu_seqlens.to(device),
            max_seqlen=window_max_seqlens,
        )
        hidden_states = hidden_states[:, torch.argsort(window_index), :]
        hidden_states = residual + hidden_states

        batch_size, _ = target_sizes.shape
        window_h, window_w = self.window_kernel_size
        all_pixel_values = []
        for batch_idx in range(batch_size):
            height, width = target_sizes[batch_idx]
            patch = hidden_states[0, cu_seqlens[batch_idx] : cu_seqlens[batch_idx + 1], :].squeeze(0)

            embed_dim = patch.shape[-1]
            merged_h, merged_w = height // window_h, width // window_w
            patch_5d = patch.view(merged_h, window_h, merged_w, window_w, embed_dim).permute(0, 2, 1, 3, 4)
            hidden_state = patch_5d.reshape(merged_h * merged_w, window_h * window_w * embed_dim)
            residual = patch_5d.reshape(merged_h * merged_w, window_h * window_w, embed_dim).mean(dim=1)

            hidden_state = self.pre_norm(hidden_state)
            hidden_state = self.linear_1(hidden_state)
            hidden_state = self.act(hidden_state)
            hidden_state = self.linear_2(hidden_state)

            all_pixel_values.append(hidden_state + residual)

        new_hidden_states = torch.concat(all_pixel_values, dim=0).unsqueeze(0)
        return new_hidden_states


class MiniCPMV4_6VisionPreTrainedModel(PreTrainedModel):
    config_class = MiniCPMV4_6VisionConfig
    main_input_name = "pixel_values"
    _input_embed_layer = "patch_embedding"
    supports_gradient_checkpointing = True
    _supports_sdpa = True
    _supports_flash_attn = True

    _can_record_outputs = {
        "hidden_states": MiniCPMV4_6VisionEncoderLayer,
        "attentions": MiniCPMV4_6VisionAttention,
    }


class MiniCPMV4_6VisionModel(MiniCPMV4_6VisionPreTrainedModel):
    def __init__(self, config: MiniCPMV4_6VisionConfig):
        super().__init__(config)
        embed_dim = config.hidden_size

        self.embeddings = MiniCPMV4_6VisionEmbeddings(config)
        self.encoder = MiniCPMV4_6VisionEncoder(config)
        self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
        self.vit_merger = MiniCPMV4_6ViTWindowAttentionMerger(config)
        self.post_init()

    def get_downsampled_inputs(
        self, target_sizes: torch.Tensor, max_seqlens: int, device: torch.device, **kwargs
    ) -> tuple[dict[str, Any], torch.Tensor, torch.Tensor]:
        # NOTE: intentionally not checking for shapes as this is expensive to call `.any()`
        target_sizes = target_sizes // 2
        max_seqlens = max_seqlens // 4

        cu_seqlens = F.pad(
            torch.cumsum(target_sizes[:, 0] * target_sizes[:, 1], dim=0, dtype=torch.int32).to(device), (1, 0)
        )

        downsampled_kwargs = {
            "attention_mask": None,
            "cu_seqlens": cu_seqlens,
            "max_seqlen": max_seqlens,
            **kwargs,
        }
        return downsampled_kwargs, target_sizes, cu_seqlens

    @merge_with_config_defaults
    @capture_outputs
    @auto_docstring
    def forward(
        self,
        pixel_values,
        target_sizes: torch.IntTensor | None = None,
        use_vit_merger: bool = True,
        **kwargs: Unpack[TransformersKwargs],
    ) -> BaseModelOutputWithPooling:
        r"""
        target_sizes (`torch.IntTensor` of shape `(batch_size, 2)`, *optional*):
            Patch grid sizes `(h, w)` for computing position embeddings.
        use_vit_merger (`bool`, *optional*, defaults to `True`):
            Whether to apply the ViT window-attention merger after the encoder.
        """

        hidden_states = self.embeddings(pixel_values, target_sizes=target_sizes)

        cu_seqlens = F.pad(
            torch.cumsum(target_sizes[:, 0] * target_sizes[:, 1], dim=0, dtype=torch.int32).to(hidden_states.device),
            (1, 0),
        )
        max_seqlens = torch.max(cu_seqlens[1:] - cu_seqlens[:-1])

        attn_kwargs = {
            "attention_mask": None,
            "cu_seqlens": cu_seqlens,
            "max_seqlen": max_seqlens,
            **kwargs,
        }

        insert_layer_id = self.config.insert_layer_id if use_vit_merger else -1
        if use_vit_merger and insert_layer_id >= 0:
            for layer_index, encoder_layer in enumerate(self.encoder.layers):
                hidden_states = encoder_layer(hidden_states, **attn_kwargs)
                if layer_index == insert_layer_id:
                    hidden_states = self.vit_merger(hidden_states, target_sizes, cu_seqlens)

                    # NOTE: Downsampled hidden states, and therefore other kwargs should also!
                    attn_kwargs, target_sizes, cu_seqlens = self.get_downsampled_inputs(
                        target_sizes=target_sizes, max_seqlens=max_seqlens, device=hidden_states.device, **kwargs
                    )
        else:
            encoder_outputs = self.encoder(inputs_embeds=hidden_states, **attn_kwargs)
            hidden_states = encoder_outputs.last_hidden_state

        last_hidden_state = self.post_layernorm(hidden_states)

        return BaseModelOutputWithPooling(last_hidden_state=last_hidden_state)


class MiniCPMV4_6DownsampleMLP(nn.Module):
    def __init__(self, hidden_size: int, llm_embed_dim: int):
        super().__init__()
        # factor 4 = two successive 2×2 spatial merges (ViT insert merger + downsample MLP)
        merged_hidden_size = hidden_size * 4

        self.pre_norm = nn.LayerNorm(merged_hidden_size, eps=1e-6)
        self.linear_1 = nn.Linear(merged_hidden_size, merged_hidden_size, bias=True)
        self.act = nn.GELU()
        self.linear_2 = nn.Linear(merged_hidden_size, llm_embed_dim, bias=True)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.pre_norm(hidden_states).view(-1, self.linear_1.in_features)
        hidden_states = self.linear_1(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        return hidden_states


class MiniCPMV4_6Merger(nn.Module):
    def __init__(self, config: MiniCPMV4_6Config):
        super().__init__()

        self.merge_kernel_size = tuple(config.merge_kernel_size)
        self.merger_times = config.merger_times
        hidden_size = config.vision_config.hidden_size
        llm_embed_dim = config.text_config.hidden_size
        # Downsample `self.merger_times - 1` times and finally apply projection into LLM space
        mlps = [MiniCPMV4_6DownsampleMLP(hidden_size, hidden_size) for _ in range(self.merger_times - 1)]
        mlps.append(MiniCPMV4_6DownsampleMLP(hidden_size, llm_embed_dim))
        self.mlp = nn.ModuleList(mlps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        target_sizes: torch.IntTensor,
    ) -> list[torch.Tensor]:
        merge_h, merge_w = self.merge_kernel_size

        start = 0
        processed_features = []
        for batch_idx in range(len(target_sizes)):
            height, width = target_sizes[batch_idx]
            num_patches = height * width

            embed_dim = hidden_states.shape[-1]
            merged_h, merged_w = height // merge_h, width // merge_w
            hidden_state = (
                hidden_states[0, start : start + num_patches, :]
                .view(merged_h, merge_h, merged_w, merge_w, embed_dim)
                .permute(0, 2, 1, 3, 4)
                .reshape(merged_h * merged_w, merge_h * merge_w * embed_dim)
            )
            hidden_state = self.mlp[0](hidden_state)

            for i in range(1, self.merger_times):
                if height % merge_h != 0 or width % merge_w != 0:
                    raise ValueError(
                        f"Patch grid ({height}, {width}) must be divisible by merge kernel size "
                        f"{self.merge_kernel_size} at merge round {i}"
                    )
                height = height // merge_h
                width = width // merge_w

                inner_dim = hidden_state.shape[-1]
                merged_h, merged_w = height // merge_h, width // merge_w
                hidden_state = (
                    hidden_state.view(merged_h, merge_h, merged_w, merge_w, inner_dim)
                    .permute(0, 2, 1, 3, 4)
                    .reshape(merged_h * merged_w, merge_h * merge_w * inner_dim)
                )
                hidden_state = self.mlp[i](hidden_state)

            start += num_patches
            processed_features.append(hidden_state)

        return processed_features


class MiniCPMV4_6PreTrainedModel(PreTrainedModel):
    config_class = MiniCPMV4_6Config
    base_model_prefix = "model"
    input_modalities = ("image", "video", "text")
    supports_gradient_checkpointing = True
    _supports_flash_attn = True
    _supports_sdpa = True
    _no_split_modules = [
        "MiniCPMV4_6VisionEmbeddings",
        "MiniCPMV4_6VisionEncoderLayer",
        "MiniCPMV4_6ViTWindowAttentionMerger",
    ]


@auto_docstring(
    custom_intro="""
    The MiniCPMV4_6 model which consists of a vision backbone and a language model, without a language modeling head.
    """
)
class MiniCPMV4_6Model(MiniCPMV4_6PreTrainedModel):
    def __init__(self, config: MiniCPMV4_6Config):
        super().__init__(config)

        self.vision_tower = MiniCPMV4_6VisionModel._from_config(config.vision_config)
        self.language_model = AutoModel.from_config(config.text_config)
        self.merger = MiniCPMV4_6Merger(config)
        self.post_init()

    def get_input_embeddings(self):
        return self.language_model.get_input_embeddings()

    def set_input_embeddings(self, value):
        self.language_model.set_input_embeddings(value)

    @can_return_tuple
    @auto_docstring(custom_intro="Extract image features: vision encoder, insert merger, then MLP merger.")
    def get_image_features(
        self,
        pixel_values: torch.FloatTensor,
        target_sizes: torch.IntTensor,
        downsample_mode: str | None = None,
    ) -> BaseModelOutputWithPooling:
        r"""
        target_sizes (`torch.IntTensor` of shape `(num_images, 2)`):
            Height and width (in patches) of each image.
        downsample_mode (`str`, *optional*):
            When set to `"4x"` the intermediate `vit_merger` is skipped so that each image keeps
            `4×` more visual tokens. Default `"16x"` mode applies the full merge pipeline.
        """
        downsample_mode = downsample_mode if downsample_mode else self.config.downsample_mode
        use_vit_merger = downsample_mode != "4x"
        pixel_values = pixel_values.to(dtype=self.vision_tower.dtype)

        vision_output = self.vision_tower(
            pixel_values,
            target_sizes=target_sizes,
            use_vit_merger=use_vit_merger,
        )

        if use_vit_merger:
            target_sizes = target_sizes // 2
        vision_output.pooler_output = self.merger(vision_output.last_hidden_state, target_sizes)
        return vision_output

    def get_placeholder_mask(
        self,
        input_ids: torch.LongTensor,
        inputs_embeds: torch.FloatTensor,
        features: torch.FloatTensor,
        token_id: int,
    ):
        """
        Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
        equal to the length of multimodal features. If the lengths are different, an error is raised.
        """
        if input_ids is None:
            special_mask = inputs_embeds == self.get_input_embeddings()(
                torch.tensor(token_id, dtype=torch.long, device=inputs_embeds.device)
            )
            special_mask = special_mask.all(-1)
        else:
            special_mask = input_ids == token_id

        n_tokens = special_mask.sum()
        special_mask = special_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
        n_features = features.shape[0]
        torch_compilable_check(
            inputs_embeds[special_mask].numel() == features.numel(),
            f"Multimodal features and tokens do not match, tokens: {n_tokens}, features: {n_features}",
        )
        return special_mask

    @can_return_tuple
    @auto_docstring
    def forward(
        self,
        input_ids: torch.LongTensor | None = None,
        pixel_values: torch.FloatTensor | None = None,
        target_sizes: torch.IntTensor | None = None,
        pixel_values_videos: torch.FloatTensor | None = None,
        target_sizes_videos: torch.IntTensor | None = None,
        attention_mask: torch.Tensor | None = None,
        position_ids: torch.LongTensor | None = None,
        past_key_values: list[torch.FloatTensor] | None = None,
        inputs_embeds: torch.FloatTensor | None = None,
        use_cache: bool | None = None,
        downsample_mode: str | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> tuple | BaseModelOutputWithPast:
        r"""
        pixel_values (`torch.FloatTensor`, *optional*):
            Pixel value patches for images, NaViT-packed.
        target_sizes (`torch.IntTensor`, *optional*):
            Height and width (in patches) for each image.
        pixel_values_videos (`torch.FloatTensor`, *optional*):
            Pixel value patches for video frames, NaViT-packed.
        target_sizes_videos (`torch.IntTensor`, *optional*):
            Height and width (in patches) for each video frame.
        downsample_mode (`str`, *optional*):
            `"4x"` keeps 4x more visual tokens; default `"16x"` applies full merge.
        """
        if inputs_embeds is None:
            inputs_embeds = self.get_input_embeddings()(input_ids)

        if pixel_values is not None and self.config.image_token_id is not None:
            # Pixels are always `1` in first dim due to NaViT packing, and we don't
            # want to waste compute processing the same image `num_beams` times. Hack until
            # @raushan adds support for encoding images once same waay as in enc-dec models
            num_beams = pixel_values.shape[0]
            vision_output = self.get_image_features(pixel_values[:1], target_sizes, downsample_mode=downsample_mode)
            image_features = (
                torch.cat(vision_output.pooler_output, dim=0)
                .to(device=inputs_embeds.device, dtype=inputs_embeds.dtype)
                .repeat(num_beams, 1)
            )
            mask = self.get_placeholder_mask(input_ids, inputs_embeds, image_features, self.config.image_token_id)
            inputs_embeds = inputs_embeds.masked_scatter(mask, image_features)

        if pixel_values_videos is not None and self.config.video_token_id is not None:
            num_beams = pixel_values_videos.shape[0]
            vision_output = self.get_video_features(
                pixel_values_videos[:1], target_sizes_videos, downsample_mode=downsample_mode
            )
            video_features = (
                torch.cat(vision_output.pooler_output, dim=0)
                .to(device=inputs_embeds.device, dtype=inputs_embeds.dtype)
                .repeat(num_beams, 1)
            )
            mask = self.get_placeholder_mask(input_ids, inputs_embeds, video_features, self.config.video_token_id)
            inputs_embeds = inputs_embeds.masked_scatter(mask, video_features)

        output = self.language_model(
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            **kwargs,
        )
        return output

    @can_return_tuple
    @auto_docstring(
        custom_intro="Extract video features: repack frames into NaViT format, then vision encoder + merger."
    )
    def get_video_features(
        self,
        pixel_values_videos: torch.FloatTensor,
        target_sizes_videos: torch.IntTensor,
        downsample_mode: str | None = None,
    ) -> BaseModelOutputWithPooling:
        r"""
        pixel_values_videos (`torch.FloatTensor` of shape `(1, channels, patch_size, seq_len)`):
            NaViT-packed pixel patches for all video frames. The video processor concatenates
            every frame's patches along the last dimension into a single sequence with dim-0 = 1,
            identical to the image packing format.
        target_sizes_videos (`torch.IntTensor` of shape `(num_patches, 2)`):
            Height and width (in patches) of each visual unit.
        downsample_mode (`str`, *optional*):
            When set to `"4x"` the intermediate `vit_merger` is skipped so that each frame keeps
            `4×` more visual tokens. Default `"16x"` mode applies the full merge pipeline.
        """
        num_frames = pixel_values_videos.shape[0]
        pixel_values = pixel_values_videos.permute(1, 2, 0, 3).reshape(
            1, pixel_values_videos.shape[1], pixel_values_videos.shape[2], -1
        )
        target_sizes = target_sizes_videos.repeat(num_frames, 1)
        return self.get_image_features(pixel_values, target_sizes, downsample_mode=downsample_mode)


class MiniCPMV4_6ForConditionalGeneration(MiniCPMV4_6PreTrainedModel, GenerationMixin):
    _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}

    def __init__(self, config: MiniCPMV4_6Config):
        super().__init__(config)
        self.model = MiniCPMV4_6Model(config)
        self.vocab_size = config.text_config.vocab_size
        self.lm_head = nn.Linear(config.text_config.hidden_size, self.vocab_size, bias=False)
        self.post_init()

    @can_return_tuple
    @auto_docstring
    def forward(
        self,
        input_ids: torch.LongTensor | None = None,
        pixel_values: torch.FloatTensor | None = None,
        target_sizes: torch.IntTensor | None = None,
        pixel_values_videos: torch.FloatTensor | None = None,
        target_sizes_videos: torch.IntTensor | None = None,
        attention_mask: torch.Tensor | None = None,
        position_ids: torch.LongTensor | None = None,
        past_key_values: list[torch.FloatTensor] | None = None,
        inputs_embeds: torch.FloatTensor | None = None,
        labels: torch.LongTensor | None = None,
        use_cache: bool | None = None,
        downsample_mode: str | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> tuple | CausalLMOutputWithPast:
        r"""
        pixel_values (`torch.FloatTensor`, *optional*):
            Pixel value patches for images, NaViT-packed.
        target_sizes (`torch.IntTensor`, *optional*):
            Height and width (in patches) for each image.
        pixel_values_videos (`torch.FloatTensor`, *optional*):
            Pixel value patches for video frames, NaViT-packed.
        target_sizes_videos (`torch.IntTensor`, *optional*):
            Height and width (in patches) for each video frame.
        downsample_mode (`str`, *optional*):
            `"4x"` keeps 4x more visual tokens; default `"16x"` applies full merge.
        """
        outputs = self.model(
            input_ids=input_ids,
            pixel_values=pixel_values,
            target_sizes=target_sizes,
            pixel_values_videos=pixel_values_videos,
            target_sizes_videos=target_sizes_videos,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            downsample_mode=downsample_mode,
            **kwargs,
        )

        hidden_states = outputs[0]
        logits = self.lm_head(hidden_states)

        loss = None
        if labels is not None:
            loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    @auto_docstring(custom_intro="Extract image features: vision encoder, insert merger, then MLP merger.")
    def get_image_features(self, *args, **kwargs) -> BaseModelOutputWithPooling:
        return self.model.get_image_features(*args, **kwargs)

    @auto_docstring(
        custom_intro="Extract video features: repack frames into NaViT format, then vision encoder + merger."
    )
    def get_video_features(self, *args, **kwargs) -> BaseModelOutputWithPooling:
        return self.model.get_video_features(*args, **kwargs)

    def prepare_inputs_for_generation(
        self,
        input_ids,
        past_key_values=None,
        attention_mask=None,
        inputs_embeds=None,
        pixel_values=None,
        target_sizes=None,
        pixel_values_videos=None,
        target_sizes_videos=None,
        downsample_mode=None,
        is_first_iteration=False,
        **kwargs,
    ):
        model_inputs = super().prepare_inputs_for_generation(
            input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            is_first_iteration=is_first_iteration,
            downsample_mode=downsample_mode,
            **kwargs,
        )
        if is_first_iteration or not kwargs.get("use_cache", True):
            model_inputs["pixel_values"] = pixel_values
            model_inputs["target_sizes"] = target_sizes
            model_inputs["pixel_values_videos"] = pixel_values_videos
            model_inputs["target_sizes_videos"] = target_sizes_videos
        return model_inputs

    def _expand_inputs_for_generation(
        self,
        expand_size: int = 1,
        is_encoder_decoder: bool = False,
        input_ids: torch.LongTensor | None = None,
        **model_kwargs,
    ) -> tuple[torch.LongTensor, dict[str, Any]]:
        # NaViT packs all images/frames into a single sequence with dim-0 = 1.
        # We let parent repeat_interleave pixel_values / pixel_values_videos
        # along dim-0 ([1,C,P,L] -> [num_beams,C,P,L]) so forward() can
        # infer num_beams from shape[0], then encode only [:1].
        #
        # target_sizes ([K,2]) must be popped because:
        #  - forward encodes pixel_values[:1] (original K images), so
        #    target_sizes must stay [K,2] to match cu_seqlens computation.
        #  - expanded [K*num_beams, 2] would define phantom segments with
        #    no corresponding pixel data, crashing the vision encoder.
        ts_keys = ("target_sizes", "target_sizes_videos")
        saved = {k: model_kwargs.pop(k) for k in ts_keys if model_kwargs.get(k) is not None}

        input_ids, model_kwargs = super()._expand_inputs_for_generation(
            expand_size=expand_size,
            is_encoder_decoder=is_encoder_decoder,
            input_ids=input_ids,
            **model_kwargs,
        )

        model_kwargs.update(saved)
        return input_ids, model_kwargs


__all__ = ["MiniCPMV4_6PreTrainedModel", "MiniCPMV4_6Model", "MiniCPMV4_6ForConditionalGeneration"]
