#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
#           This file was automatically generated from src/transformers/models/sam3_lite_text/modular_sam3_lite_text.py.
#               Do NOT edit this file manually as any edits will be overwritten by the generation of
#             the file from the modular. If any change should be done, please apply the change to the
#                          modular_sam3_lite_text.py file directly. One of our CI enforces this.
#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# Copyright 2026 the HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from collections.abc import Callable
from dataclasses import dataclass

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

from ... import initialization as init
from ...activations import ACT2FN
from ...masking_utils import create_bidirectional_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPooling, ModelOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...pytorch_utils import compile_compatible_method_lru_cache
from ...utils import auto_docstring, can_return_tuple, is_torchvision_available, logging
from ...utils.generic import TransformersKwargs, is_flash_attention_requested, merge_with_config_defaults
from ...utils.output_capturing import capture_outputs
from ..auto import AutoModel
from .configuration_sam3_lite_text import (
    Sam3LiteTextConfig,
    Sam3LiteTextDETRDecoderConfig,
    Sam3LiteTextDETREncoderConfig,
    Sam3LiteTextGeometryEncoderConfig,
    Sam3LiteTextMaskDecoderConfig,
    Sam3LiteTextTextConfig,
)


if is_torchvision_available():
    import torchvision


logger = logging.get_logger(__name__)


@dataclass
class Sam3LiteTextTextEncoderOutput(BaseModelOutputWithPooling):
    r"""
    last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
        Full sequence of hidden states from the text encoder.
    pooler_output (`torch.FloatTensor` of shape `(batch_size, projection_dim)`):
        EOT-pooled output projected to `projection_dim` via the internal CLIP-style projection.
    hidden_states (`tuple(torch.FloatTensor)`, *optional*):
        Tuple of hidden states at each layer, returned when `output_hidden_states=True`.
    attentions (`tuple(torch.FloatTensor)`, *optional*):
        Tuple of attention weights at each transformer layer, returned when `output_attentions=True`.
    """


class Sam3LiteTextTextPositionEmbedding(nn.Module):
    """Learnable positional embedding with bilinear interpolation for variable sequence lengths."""

    def __init__(self, max_position_embeddings: int, hidden_size: int):
        super().__init__()
        self.position_embedding = nn.Parameter(torch.empty(1, 1, max_position_embeddings, hidden_size))

    def forward(self, seq_len: int) -> torch.Tensor:
        position_embedding = self.position_embedding
        if seq_len != position_embedding.shape[2]:
            position_embedding = F.interpolate(
                position_embedding,
                size=(seq_len, position_embedding.shape[-1]),
                mode="bilinear",
            )
        return position_embedding.reshape(1, seq_len, -1)


class Sam3LiteTextMobileOneBlock(nn.Module):
    """Depthwise conv branch with batch norm on the skip path and after the conv (MobileOne-style)."""

    def __init__(self, hidden_size: int, kernel_size: int = 3):
        super().__init__()
        self.batchnorm_skip = nn.BatchNorm2d(hidden_size)
        self.conv = nn.Conv2d(
            hidden_size,
            hidden_size,
            kernel_size=(1, kernel_size),
            stride=1,
            padding=(0, kernel_size // 2),
            groups=hidden_size,
            bias=False,
        )
        self.batchnorm_conv = nn.BatchNorm2d(hidden_size)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.batchnorm_conv(self.conv(hidden_states))
        hidden_states = hidden_states + self.batchnorm_skip(residual)
        return hidden_states


class Sam3LiteTextConvMLP(nn.Module):
    """Pointwise MLP using 1×1 convolutions, compatible with 4-D (B, C, H, W) feature maps."""

    def __init__(self, config: Sam3LiteTextTextConfig):
        super().__init__()
        self.activation_fn = ACT2FN[config.hidden_act]
        self.fc1 = nn.Conv2d(config.hidden_size, config.intermediate_size, kernel_size=1)
        self.fc2 = nn.Conv2d(config.intermediate_size, config.hidden_size, kernel_size=1)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states = self.fc2(hidden_states)
        return hidden_states


class Sam3LiteTextConvolutionalFeedForward(nn.Module):
    """Convolutional feed-forward network: depthwise conv + two pointwise projections."""

    def __init__(self, config: Sam3LiteTextTextConfig):
        super().__init__()
        self.depthwise_conv = nn.Conv2d(
            config.hidden_size,
            config.hidden_size,
            kernel_size=(1, config.repmixer_kernel_size),
            padding=(0, config.repmixer_kernel_size // 2),
            groups=config.hidden_size,
            bias=False,
        )
        self.depthwise_batchnorm = nn.BatchNorm2d(config.hidden_size)
        self.mlp = Sam3LiteTextConvMLP(config)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.depthwise_batchnorm(self.depthwise_conv(hidden_states))
        return self.mlp(hidden_states)


class Sam3LiteTextLayerScaledResidual(nn.Module):
    """Common layer-scale residual pattern shared by the RepMixer and feed-forward branches."""

    def __init__(self, hidden_size: int, layer_scale_init_value: float):
        super().__init__()
        self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones((hidden_size, 1, 1)), requires_grad=True)

    def layer_scale_residual(self, hidden_states: torch.Tensor, update: torch.Tensor) -> torch.Tensor:
        return hidden_states + self.layer_scale * update


class Sam3LiteTextRepMixer(Sam3LiteTextLayerScaledResidual):
    """Re-parameterisable depthwise-conv token mixer operating on 1D sequence data."""

    def __init__(self, config: Sam3LiteTextTextConfig):
        super().__init__(config.hidden_size, config.layer_scale_init_value)
        self.reference_batchnorm = nn.BatchNorm2d(config.hidden_size)
        self.mixer = Sam3LiteTextMobileOneBlock(config.hidden_size, kernel_size=config.repmixer_kernel_size)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        return self.layer_scale_residual(
            hidden_states, self.mixer(hidden_states) - self.reference_batchnorm(hidden_states)
        )


class Sam3LiteTextRepMixerBlock(Sam3LiteTextLayerScaledResidual):
    """Token-mixing RepMixer plus a convolutional feed-forward path, each with layer scale."""

    def __init__(self, config: Sam3LiteTextTextConfig):
        super().__init__(config.hidden_size, config.layer_scale_init_value)
        self.token_mixer = Sam3LiteTextRepMixer(config)
        self.conv_feed_forward = Sam3LiteTextConvolutionalFeedForward(config)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> torch.Tensor:
        hidden_states = hidden_states.transpose(1, 2).unsqueeze(2)
        hidden_states = self.token_mixer(hidden_states)
        hidden_states = self.layer_scale_residual(hidden_states, self.conv_feed_forward(hidden_states))
        return hidden_states.squeeze(2).transpose(1, 2)


def eager_attention_forward(
    module: nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: torch.Tensor | None,
    scaling: float,
    dropout: float = 0.0,
    **kwargs,
):
    attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
    if attention_mask is not None:
        attn_weights = attn_weights + attention_mask

    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

    attn_output = torch.matmul(attn_weights, value)
    attn_output = attn_output.transpose(1, 2).contiguous()

    return attn_output, attn_weights


class Sam3LiteTextTextAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
                f" {self.num_heads})."
            )
        self.scale = self.head_dim**-0.5
        self.dropout = config.attention_dropout
        self.is_causal = False

        self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor | None = None,
        **kwargs,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        """Input shape: Batch x Time x Channel"""

        input_shape = hidden_states.shape[:-1]

        hidden_shape = (*input_shape, -1, self.head_dim)
        queries = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        keys = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        values = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
            self.config._attn_implementation, eager_attention_forward
        )

        attn_output, attn_weights = attention_interface(
            self,
            queries,
            keys,
            values,
            attention_mask,
            is_causal=self.is_causal,
            scaling=self.scale,
            dropout=0.0 if not self.training else self.dropout,
        )

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.out_proj(attn_output)

        return attn_output, attn_weights


class Sam3LiteTextTextMLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.activation_fn = ACT2FN[config.hidden_act]
        self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
        self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states = self.fc2(hidden_states)
        return hidden_states


class Sam3LiteTextTextEncoderLayer(GradientCheckpointingLayer):
    def __init__(self, config: Sam3LiteTextTextConfig):
        super().__init__()
        self.embed_dim = config.hidden_size
        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
        self.self_attn = Sam3LiteTextTextAttention(config)
        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
        self.mlp = Sam3LiteTextTextMLP(config)

    @auto_docstring
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor,
        **kwargs: Unpack[TransformersKwargs],
    ) -> torch.FloatTensor:
        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
        hidden_states, _ = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            **kwargs,
        )
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states


class Sam3LiteTextTextEmbeddings(nn.Module):
    """Token embedding + interpolatable positional embedding for the text encoder."""

    def __init__(self, config: Sam3LiteTextTextConfig):
        super().__init__()
        self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
        self.position_embedding = Sam3LiteTextTextPositionEmbedding(config.max_position_embeddings, config.hidden_size)

    def forward(self, input_ids: torch.LongTensor) -> torch.Tensor:
        hidden_states = self.token_embedding(input_ids)
        hidden_states = hidden_states + self.position_embedding(input_ids.shape[1]).to(hidden_states.dtype)
        return hidden_states


@auto_docstring
class Sam3LiteTextPreTrainedModel(PreTrainedModel):
    config_class = Sam3LiteTextConfig
    base_model_prefix = "model"
    main_input_name = "pixel_values"
    input_modalities = ["image", "text"]
    _supports_sdpa = True
    _supports_flash_attn = True
    _supports_flex_attn = True
    _supports_attention_backend = True
    supports_gradient_checkpointing = True

    @torch.no_grad()
    def _init_weights(self, module):
        super()._init_weights(module)
        if isinstance(module, Sam3LiteTextTextPositionEmbedding):
            init.normal_(module.position_embedding, std=module.position_embedding.shape[-1] ** -0.5)
        elif isinstance(module, Sam3LiteTextTextModel):
            init.normal_(module.projection.weight, std=module.config.hidden_size**-0.5)


@auto_docstring(
    custom_intro="""
    MobileCLIP MCT text encoder used in EfficientSAM3 LiteText.

    When `config.use_repmixer_blocks` is `True`, the first and last layers are
    `Sam3LiteTextRepMixerBlock` modules; the rest are standard `Sam3LiteTextTextEncoderLayer` layers.
"""
)
class Sam3LiteTextTextModel(Sam3LiteTextPreTrainedModel):
    config_class = Sam3LiteTextTextConfig
    config: Sam3LiteTextTextConfig
    _can_record_outputs = {
        "hidden_states": Sam3LiteTextTextEncoderLayer,
        "attentions": Sam3LiteTextTextAttention,
    }

    def __init__(self, config: Sam3LiteTextTextConfig):
        super().__init__(config)
        self.embeddings = Sam3LiteTextTextEmbeddings(config)
        repmixer_positions = {0, config.num_hidden_layers - 1} if config.use_repmixer_blocks else set()
        self.layers = nn.ModuleList(
            [
                Sam3LiteTextRepMixerBlock(config) if i in repmixer_positions else Sam3LiteTextTextEncoderLayer(config)
                for i in range(config.num_hidden_layers)
            ]
        )
        self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
        self.post_init()

    @merge_with_config_defaults
    @capture_outputs
    @auto_docstring
    def forward(
        self,
        input_ids: torch.LongTensor | None = None,
        attention_mask: torch.Tensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> tuple | Sam3LiteTextTextEncoderOutput:
        hidden_states = self.embeddings(input_ids)
        attention_mask = create_bidirectional_mask(self.config, hidden_states, attention_mask)

        for layer in self.layers:
            hidden_states = layer(hidden_states, attention_mask=attention_mask, **kwargs)

        hidden_states = self.final_layer_norm(hidden_states)

        pooled = hidden_states[
            torch.arange(hidden_states.shape[0], device=hidden_states.device), input_ids.argmax(dim=-1)
        ]
        pooled = self.projection(pooled)
        return Sam3LiteTextTextEncoderOutput(
            last_hidden_state=hidden_states,
            pooler_output=pooled,
        )


@auto_docstring
@dataclass
class Sam3LiteTextVisionEncoderOutput(BaseModelOutputWithPooling):
    r"""
    fpn_hidden_states (`tuple[torch.FloatTensor]`):
        Tuple of multi-level FPN feature maps.
    fpn_position_encoding (`tuple[torch.FloatTensor]`):
        Tuple of position encodings for each FPN level.
    """

    fpn_hidden_states: tuple[torch.FloatTensor, ...] = None
    fpn_position_encoding: tuple[torch.FloatTensor, ...] = None


@dataclass
@auto_docstring
class Sam3LiteTextGeometryEncoderOutput(ModelOutput):
    r"""
    last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_prompts, hidden_size)`):
        Encoded geometry prompt features (boxes).
    attention_mask (`torch.BoolTensor` of shape `(batch_size, num_prompts)`, *optional*):
        Attention mask for geometry prompts where True indicates valid positions and False indicates padding.
    """

    last_hidden_state: torch.FloatTensor = None
    attention_mask: torch.BoolTensor | None = None


@dataclass
@auto_docstring
class Sam3LiteTextDETREncoderOutput(ModelOutput):
    r"""
    last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
        Encoded vision features (flattened from multi-level features).
    pos_embeds_flattened (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
        Flattened position embeddings for the vision features.
    text_features (`torch.FloatTensor` of shape `(batch_size, text_seq_len, hidden_size)`, *optional*):
        Text features (may be pooled after encoder processing).
    spatial_shapes (`torch.LongTensor` of shape `(num_levels, 2)`, *optional*):
        Spatial shapes (height, width) for each feature pyramid level.
    hidden_states (`tuple[torch.FloatTensor]`, *optional*):
        Tuple of hidden states from all encoder layers.
    attentions (`tuple[torch.FloatTensor]`, *optional*):
        Tuple of attention weights from all encoder layers.
    """

    last_hidden_state: torch.FloatTensor = None
    pos_embeds_flattened: torch.FloatTensor | None = None
    text_features: torch.FloatTensor | None = None
    spatial_shapes: torch.LongTensor | None = None
    hidden_states: tuple[torch.FloatTensor] | None = None
    attentions: tuple[torch.FloatTensor] | None = None


@dataclass
@auto_docstring
class Sam3LiteTextDETRDecoderOutput(ModelOutput):
    r"""
    intermediate_hidden_states (`torch.FloatTensor` of shape `(num_layers, batch_size, num_queries, hidden_size)`):
        Decoder hidden states from all layers.
    reference_boxes (`torch.FloatTensor` of shape `(num_layers, batch_size, num_queries, 4)`):
        Predicted reference boxes from all decoder layers in (cx, cy, w, h) format.
    presence_logits (`torch.FloatTensor` of shape `(num_layers, batch_size, 1)`):
        Presence logits from all decoder layers indicating object presence confidence.
    hidden_states (`tuple[torch.FloatTensor]`, *optional*):
        Tuple of hidden states from all decoder layers.
    attentions (`tuple[torch.FloatTensor]`, *optional*):
        Tuple of attention weights from all decoder layers (self-attention and cross-attention).
    """

    intermediate_hidden_states: torch.FloatTensor = None
    reference_boxes: torch.FloatTensor = None
    presence_logits: torch.FloatTensor = None
    hidden_states: tuple[torch.FloatTensor] | None = None
    attentions: tuple[torch.FloatTensor] | None = None


@dataclass
@auto_docstring
class Sam3LiteTextMaskDecoderOutput(ModelOutput):
    r"""
    pred_masks (`torch.FloatTensor` of shape `(batch_size, num_queries, height, width)`):
        Predicted segmentation masks for each query.
    semantic_seg (`torch.FloatTensor` of shape `(batch_size, 1, height, width)`, *optional*):
        Semantic segmentation output.
    attentions (`tuple[torch.FloatTensor]`, *optional*):
        Tuple of attention weights from mask decoder cross-attention layers.
    """

    pred_masks: torch.FloatTensor = None
    semantic_seg: torch.FloatTensor | None = None
    attentions: tuple[torch.FloatTensor] | None = None


@dataclass
@auto_docstring
class Sam3LiteTextImageSegmentationOutput(ModelOutput):
    r"""
    pred_masks (`torch.FloatTensor` of shape `(batch_size, num_queries, height, width)`):
        Predicted segmentation masks for each query.
    pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
        Predicted bounding boxes in (x1, y1, x2, y2) format.
    pred_logits (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
        Classification confidence scores for each query, computed via dot product between
        decoder query features and text features.
    presence_logits (`torch.FloatTensor` of shape `(batch_size, 1)`, *optional*):
        Presence logits from the DETR decoder presence token (last layer only). These indicate whether objects
        are present in the scene. Can be used to compute final scores by multiplying with pred_logits:
        `final_scores = pred_logits.sigmoid() * presence_logits.sigmoid()`.
    semantic_seg (`torch.FloatTensor` of shape `(batch_size, 1, height, width)`, *optional*):
        Semantic segmentation output.
    decoder_hidden_states (`tuple[torch.FloatTensor]`, *optional*):
        Tuple of hidden states from all DETR decoder layers. Each tensor has shape `(batch_size, num_queries, hidden_size)`.
    decoder_reference_boxes (`torch.FloatTensor` of shape `(num_layers, batch_size, num_queries, 4)`, *optional*):
        Reference boxes from all DETR decoder layers.
    encoder_hidden_states (`tuple[torch.FloatTensor]`, *optional*):
        Tuple of hidden states from all DETR encoder layers.
    vision_hidden_states (`tuple[torch.FloatTensor]`, *optional*):
        Tuple of hidden states from all vision encoder (ViT) layers.
    vision_attentions (`tuple[torch.FloatTensor]`, *optional*):
        Attention weights from vision encoder (ViT) layers.
    detr_encoder_attentions (`tuple[torch.FloatTensor]`, *optional*):
        Attention weights from DETR encoder layers.
    detr_decoder_attentions (`tuple[torch.FloatTensor]`, *optional*):
        Attention weights from DETR decoder layers (self-attention and cross-attention).
    mask_decoder_attentions (`tuple[torch.FloatTensor]`, *optional*):
        Attention weights from mask decoder layers.
    """

    pred_masks: torch.FloatTensor = None
    pred_boxes: torch.FloatTensor = None
    pred_logits: torch.FloatTensor | None = None
    presence_logits: torch.FloatTensor | None = None
    semantic_seg: torch.FloatTensor | None = None
    decoder_hidden_states: tuple[torch.FloatTensor] | None = None
    decoder_reference_boxes: torch.FloatTensor | None = None
    encoder_hidden_states: tuple[torch.FloatTensor] | None = None
    vision_hidden_states: tuple[torch.FloatTensor] | None = None
    vision_attentions: tuple[torch.FloatTensor] | None = None
    detr_encoder_attentions: tuple[torch.FloatTensor] | None = None
    detr_decoder_attentions: tuple[torch.FloatTensor] | None = None
    mask_decoder_attentions: tuple[torch.FloatTensor] | None = None


class Sam3LiteTextMLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.activation_fn = ACT2FN[config.hidden_act]
        self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
        self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
        self.dropout = nn.Dropout(config.hidden_dropout)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.fc1(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states = self.fc2(hidden_states)
        return hidden_states


class Sam3LiteTextAttention(nn.Module):
    """
    Multi-head attention.
    Handles standard [batch_size, seq_len, hidden_size] tensors.
    """

    def __init__(self, config):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.num_attention_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // config.num_attention_heads
        self.scaling = self.head_dim**-0.5
        self.is_causal = False

        self.q_proj = nn.Linear(self.hidden_size, self.hidden_size)
        self.k_proj = nn.Linear(self.hidden_size, self.hidden_size)
        self.v_proj = nn.Linear(self.hidden_size, self.hidden_size)
        self.o_proj = nn.Linear(self.hidden_size, self.hidden_size)

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        attention_mask: torch.Tensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            query: [batch_size, query_len, hidden_size]
            key: [batch_size, key_len, hidden_size]
            value: [batch_size, value_len, hidden_size]
            attention_mask: [batch_size, num_heads, query_len, key_len] or broadcastable

        Returns:
            Tuple of (output, attention_weights)
                output: [batch_size, query_len, hidden_size]
                attention_weights: [batch_size, num_heads, query_len, key_len]
        """
        batch_size = query.shape[0]
        query_len = query.shape[1]
        key_len = key.shape[1]

        query = self.q_proj(query).view(batch_size, query_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
        key = self.k_proj(key).view(batch_size, key_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
        value = self.v_proj(value).view(batch_size, key_len, self.num_attention_heads, self.head_dim).transpose(1, 2)

        attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
            self.config._attn_implementation, eager_attention_forward
        )

        if (
            is_flash_attention_requested(self.config)
            and attention_mask is not None
            and attention_mask.dtype != torch.bool
        ):
            # Relative position bias tensors are represented as float masks and are incompatible with Flash Attention
            # Fallback to SDPA for this call only so the rest of the model can still benefit from FA
            attention_interface = ALL_ATTENTION_FUNCTIONS["sdpa"]
            logger.warning_once(
                "Sam3LiteTextAttention: falling back to SDPA for relative-position cross-attention because "
                "Flash Attention does not support additive bias masks."
            )

        attn_output, attn_weights = attention_interface(
            self,
            query,
            key,
            value,
            attention_mask=attention_mask,
            dropout=0.0,
            scaling=self.scaling,
            is_causal=self.is_causal,
            **kwargs,
        )

        attn_output = attn_output.reshape(batch_size, query_len, self.num_attention_heads * self.head_dim).contiguous()
        attn_output = self.o_proj(attn_output)

        return attn_output, attn_weights


class Sam3LiteTextSinePositionEmbedding(nn.Module):
    """
    This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
    need paper, generalized to work on images.
    """

    def __init__(
        self, num_pos_feats: int = 64, temperature: int = 10000, normalize: bool = False, scale: float | None = None
    ):
        super().__init__()
        if scale is not None and normalize is False:
            raise ValueError("normalize should be True if scale is passed")
        self.num_pos_feats = num_pos_feats
        self.temperature = temperature
        self.normalize = normalize
        self.scale = 2 * math.pi if scale is None else scale

    def encode_1d_positions(self, x: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Encode 1D coordinate pairs using sine/cosine positional embeddings.

        Args:
            x: 1D tensor of x coordinates (flattened)
            y: 1D tensor of y coordinates (flattened)

        Returns:
            Tuple of (pos_x, pos_y) positional embeddings
        """
        x_embed = x * self.scale
        y_embed = y * self.scale

        dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=x.device).to(x.dtype)
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)

        pos_x = x_embed[:, None] / dim_t
        pos_y = y_embed[:, None] / dim_t
        pos_x = torch.stack((pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2).flatten(1)
        pos_y = torch.stack((pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2).flatten(1)
        return pos_x, pos_y

    def encode_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
        """
        Encode 4D box coordinates (x, y, w, h) for decoder conditioning using sine/cosine embeddings.

        Args:
            boxes: Box coordinates [batch_size, num_queries, 4] in (x, y, w, h) format

        Returns:
            Position embeddings [batch_size, num_queries, num_pos_feats*4]
        """
        assert boxes.size(-1) == 4, f"Expected 4D box coordinates (x, y, w, h), got shape {boxes.shape}"
        dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=boxes.device).to(boxes.dtype)
        dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats)

        x_embed = boxes[:, :, 0] * self.scale
        y_embed = boxes[:, :, 1] * self.scale
        w_embed = boxes[:, :, 2] * self.scale
        h_embed = boxes[:, :, 3] * self.scale

        pos_x = x_embed[:, :, None] / dim_t
        pos_y = y_embed[:, :, None] / dim_t
        pos_w = w_embed[:, :, None] / dim_t
        pos_h = h_embed[:, :, None] / dim_t

        pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
        pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
        pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)
        pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)

        pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)

        return pos

    @compile_compatible_method_lru_cache(maxsize=4)
    def forward(
        self,
        shape: torch.Size,
        device: torch.device | str,
        dtype: torch.dtype,
        mask: Tensor | None = None,
    ) -> Tensor:
        if mask is None:
            mask = torch.zeros((shape[0], shape[2], shape[3]), device=device, dtype=torch.bool)
        not_mask = (~mask).to(dtype)
        y_embed = not_mask.cumsum(1)
        x_embed = not_mask.cumsum(2)
        if self.normalize:
            eps = 1e-6
            y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
            x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale

        dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=device).to(dtype)
        dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats)

        pos_x = x_embed[:, :, :, None] / dim_t
        pos_y = y_embed[:, :, :, None] / dim_t
        pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
        return pos


class Sam3LiteTextGeometryEncoderLayer(nn.Module):
    def __init__(self, config: Sam3LiteTextGeometryEncoderConfig):
        super().__init__()
        self.layer_norm1 = nn.LayerNorm(config.hidden_size)
        self.self_attn = Sam3LiteTextAttention(config)
        self.dropout = nn.Dropout(config.dropout)

        self.cross_attn = Sam3LiteTextAttention(config)
        self.layer_norm2 = nn.LayerNorm(config.hidden_size)

        self.mlp = Sam3LiteTextMLP(config)
        self.layer_norm3 = nn.LayerNorm(config.hidden_size)

    def forward(
        self,
        prompt_feats: Tensor,
        vision_feats: Tensor,
        vision_pos_encoding: Tensor,
        prompt_mask: Tensor,
        **kwargs: Unpack[TransformersKwargs],
    ):
        residual = prompt_feats
        hidden_states = self.layer_norm1(prompt_feats)
        hidden_states, _ = self.self_attn(
            query=hidden_states, key=hidden_states, value=hidden_states, attention_mask=prompt_mask, **kwargs
        )
        hidden_states = self.dropout(hidden_states) + residual
        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        key = vision_feats + vision_pos_encoding
        hidden_states, _ = self.cross_attn(query=hidden_states, key=key, value=vision_feats, **kwargs)
        hidden_states = self.dropout(hidden_states) + residual
        residual = hidden_states
        hidden_states = self.layer_norm3(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = self.dropout(hidden_states) + residual

        return hidden_states


def concat_padded_sequences(seq1, mask1, seq2, mask2, return_index: bool = False):
    """
    Concatenates two right-padded sequences, such that the resulting sequence
    is contiguous and also right-padded.

    Tensors are batch-first, masks are batch-first with True=valid, False=padding.

    Args:
        seq1: A tensor of shape (batch_size, seq1_length, hidden_size).
        mask1: A tensor of shape (batch_size, seq1_length) with True=valid, False=padding.
        seq2: A tensor of shape (batch_size, seq2_length, hidden_size).
        mask2: A tensor of shape (batch_size, seq2_length) with True=valid, False=padding.
        return_index: If True, also returns the index of the ids of the element of seq2
            in the concatenated sequence. This can be used to retrieve the elements of seq2.

    Returns:
        A tuple (concatenated_sequence, concatenated_mask) if return_index is False,
        otherwise (concatenated_sequence, concatenated_mask, index).
        The concatenated_mask uses True=valid, False=padding convention.
    """
    batch_size, seq1_length, hidden_size = seq1.shape
    batch_size2, seq2_length, hidden_size2 = seq2.shape

    assert batch_size == batch_size2 == mask1.size(0) == mask2.size(0)
    assert hidden_size == hidden_size2
    assert seq1_length == mask1.size(1)
    assert seq2_length == mask2.size(1)

    actual_seq1_lengths = mask1.sum(dim=-1)
    actual_seq2_lengths = mask2.sum(dim=-1)

    final_lengths = actual_seq1_lengths + actual_seq2_lengths
    max_length = seq1_length + seq2_length

    concatenated_mask = (
        torch.arange(max_length, device=seq2.device)[None].repeat(batch_size, 1) < final_lengths[:, None]
    )

    concatenated_sequence = torch.zeros((batch_size, max_length, hidden_size), device=seq2.device, dtype=seq2.dtype)
    concatenated_sequence[:, :seq1_length, :] = seq1

    # Shift seq2 elements to start at the end of valid seq1
    index = torch.arange(seq2_length, device=seq2.device)[None].repeat(batch_size, 1)
    index = index + actual_seq1_lengths[:, None]

    # Scatter seq2 into the right positions
    concatenated_sequence = concatenated_sequence.scatter(1, index[:, :, None].expand(-1, -1, hidden_size), seq2)

    if return_index:
        return concatenated_sequence, concatenated_mask, index

    return concatenated_sequence, concatenated_mask


def box_cxcywh_to_xyxy(x):
    """Convert boxes from (cx, cy, w, h) format to (x1, y1, x2, y2) format."""
    x_c, y_c, w, h = x.unbind(-1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=-1)


class Sam3LiteTextGeometryEncoder(nn.Module):
    """
    Encoder for geometric prompts (boxes).

    Boxes are encoded using three approaches:
     - Direct projection: linear projection from coordinate space to hidden_size
     - Pooling: pool features from the backbone at the specified location (ROI align for boxes)
     - Position encoding: use position encoding of the box center

    These encodings are combined additively and further processed with transformer layers.
    """

    def __init__(self, config: Sam3LiteTextGeometryEncoderConfig):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.roi_size = config.roi_size

        self.position_encoding = Sam3LiteTextSinePositionEmbedding(
            num_pos_feats=config.hidden_size // 2, normalize=True
        )
        self.label_embed = nn.Embedding(2, self.hidden_size)
        self.cls_embed = nn.Embedding(1, self.hidden_size)

        # Box encoding layers
        self.boxes_direct_project = nn.Linear(4, self.hidden_size)
        self.boxes_pool_project = nn.Conv2d(self.hidden_size, self.hidden_size, self.roi_size)
        self.boxes_pos_enc_project = nn.Linear(self.hidden_size + 2, self.hidden_size)

        # Image feature normalization
        self.vision_layer_norm = nn.LayerNorm(self.hidden_size)

        # Prompt projection and normalization
        self.final_proj = nn.Linear(self.hidden_size, self.hidden_size)
        self.prompt_layer_norm = nn.LayerNorm(self.hidden_size)

        # Transformer layers
        self.layers = nn.ModuleList([Sam3LiteTextGeometryEncoderLayer(config) for _ in range(config.num_layers)])
        self.output_layer_norm = nn.LayerNorm(self.hidden_size)

    def _encode_box_coordinates(
        self, center_x: torch.Tensor, center_y: torch.Tensor, width: torch.Tensor, height: torch.Tensor
    ) -> torch.Tensor:
        """
        Encode box coordinates by combining position-encoded centers with raw width/height.

        Args:
            center_x: 1D tensor of box center x coordinates
            center_y: 1D tensor of box center y coordinates
            width: 1D tensor of box widths
            height: 1D tensor of box heights

        Returns:
            Encoded box coordinates [N, embedding_dim]
        """
        pos_x, pos_y = self.position_encoding.encode_1d_positions(center_x, center_y)
        pos = torch.cat((pos_y, pos_x, height[:, None], width[:, None]), dim=1)
        return pos

    def _encode_boxes(self, boxes, boxes_mask, boxes_labels, vision_features):
        """Encode box prompts. Mask convention: True=valid, False=padding."""
        batch_size, num_boxes = boxes.shape[:2]
        height, width = vision_features.shape[-2:]
        boxes_embed = self.boxes_direct_project(boxes)

        # Pool features using ROI align
        # Convert boxes from CxCyWH to xyxy format and denormalize
        boxes_xyxy = box_cxcywh_to_xyxy(boxes)
        scale = torch.tensor([width, height, width, height], dtype=boxes_xyxy.dtype, device=boxes_xyxy.device)
        scale = scale.view(1, 1, 4)
        boxes_xyxy = boxes_xyxy * scale
        # ROI align expects list of boxes per batch element,
        # convert from bfloat16 to float16 as roi_align only supports float16 and float32
        dtype = torch.float16 if vision_features.dtype == torch.bfloat16 else vision_features.dtype
        sampled_features = torchvision.ops.roi_align(
            vision_features.to(dtype), boxes_xyxy.to(dtype).unbind(0), self.roi_size
        ).to(vision_features.dtype)

        pooled_projection = self.boxes_pool_project(sampled_features)
        pooled_projection = pooled_projection.view(batch_size, num_boxes, self.hidden_size)
        boxes_embed = boxes_embed + pooled_projection

        # Add position encoding
        center_x, center_y, box_width, box_height = boxes.unbind(-1)
        pos_enc = self._encode_box_coordinates(
            center_x.flatten(), center_y.flatten(), box_width.flatten(), box_height.flatten()
        )
        pos_enc = pos_enc.view(batch_size, num_boxes, pos_enc.shape[-1])
        pos_projection = self.boxes_pos_enc_project(pos_enc)
        boxes_embed = boxes_embed + pos_projection

        # Add label embeddings (positive/negative)
        label_embed = self.label_embed(boxes_labels.long())
        return label_embed + boxes_embed, boxes_mask

    def forward(
        self,
        box_embeddings: torch.Tensor,
        box_mask: torch.Tensor,
        box_labels: torch.Tensor,
        img_feats: tuple[torch.Tensor, ...],
        img_pos_embeds: tuple[torch.Tensor, ...] | None = None,
    ):
        """
        Forward pass for encoding geometric prompts.

        Args:
            box_embeddings: Box coordinates in CxCyWH format [batch_size, num_boxes, 4]
            box_mask: Attention mask for boxes [batch_size, num_boxes]
            box_labels: Labels for boxes (positive/negative) [batch_size, num_boxes]
            img_feats: Image features from vision encoder
            img_pos_embeds: Optional position embeddings for image features

        Returns:
            Sam3LiteTextGeometryEncoderOutput containing encoded geometry features and attention mask.
        """
        batch_size = box_embeddings.shape[0]

        # Prepare vision features for cross-attention: flatten spatial dimensions
        vision_feats = img_feats[-1]  # [B, C, H, W]
        vision_pos_embeds = img_pos_embeds[-1] if img_pos_embeds is not None else torch.zeros_like(vision_feats)
        vision_feats_flat = vision_feats.flatten(2).transpose(1, 2)  # [B, H*W, C]
        vision_pos_embeds_flat = vision_pos_embeds.flatten(2).transpose(1, 2)  # [B, H*W, C]

        # Normalize image features for pooling operations
        img_feats_last = img_feats[-1]  # [B, C, H, W]
        img_feats_last = img_feats_last.permute(0, 2, 3, 1)  # [B, H, W, C]
        normalized_img_feats = self.vision_layer_norm(img_feats_last)
        normalized_img_feats = normalized_img_feats.permute(0, 3, 1, 2)  # [B, C, H, W]

        prompt_embeds, prompt_mask = self._encode_boxes(box_embeddings, box_mask, box_labels, normalized_img_feats)

        # Add CLS token (always valid)
        cls_embed = self.cls_embed.weight.view(1, self.hidden_size).unsqueeze(0).expand(batch_size, -1, -1)
        cls_mask = torch.ones(batch_size, 1, dtype=prompt_mask.dtype, device=prompt_mask.device)
        prompt_embeds, prompt_mask = concat_padded_sequences(prompt_embeds, prompt_mask, cls_embed, cls_mask)

        prompt_embeds = self.prompt_layer_norm(self.final_proj(prompt_embeds))

        # Create bidirectional attention mask for transformer layers
        prompt_attention_mask = None
        if prompt_mask is not None:
            prompt_attention_mask = create_bidirectional_mask(
                config=self.config,
                inputs_embeds=prompt_embeds,
                attention_mask=prompt_mask,
            )

        # Apply transformer layers with cross-attention to vision features
        for layer in self.layers:
            prompt_embeds = layer(
                prompt_feats=prompt_embeds,
                vision_feats=vision_feats_flat,
                vision_pos_encoding=vision_pos_embeds_flat,
                prompt_mask=prompt_attention_mask,
            )

        # Final output normalization
        prompt_embeds = self.output_layer_norm(prompt_embeds)

        return Sam3LiteTextGeometryEncoderOutput(
            last_hidden_state=prompt_embeds,
            attention_mask=prompt_mask,
        )


class Sam3LiteTextDetrEncoderLayer(nn.Module):
    """DETR encoder layer with self-attention and cross-attention."""

    def __init__(self, config: Sam3LiteTextDETREncoderConfig):
        super().__init__()
        self.config = config
        self.layer_norm1 = nn.LayerNorm(config.hidden_size)
        self.self_attn = Sam3LiteTextAttention(config)
        self.dropout = nn.Dropout(config.dropout)

        self.cross_attn = Sam3LiteTextAttention(config)
        self.layer_norm2 = nn.LayerNorm(config.hidden_size)

        self.mlp = Sam3LiteTextMLP(config)
        self.layer_norm3 = nn.LayerNorm(config.hidden_size)

    def forward(
        self,
        vision_feats: Tensor,
        prompt_feats: Tensor,
        vision_pos_encoding: Tensor,
        prompt_cross_attn_mask: Tensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ):
        """
        Forward pass for DETR encoder layer.

        Args:
            vision_feats: Vision features [batch_size, vision_len, hidden_size] (main hidden states)
            prompt_feats: Text prompt features [batch_size, text_len, hidden_size]
            vision_pos_encoding: Position encoding for vision [batch_size, vision_len, hidden_size]
            prompt_cross_attn_mask: Cross-attention mask for prompt features

        Returns:
            Updated vision features [batch_size, vision_len, hidden_size]
        """
        # Self-attention on vision features with position encoding
        residual = vision_feats
        hidden_states = self.layer_norm1(vision_feats)
        hidden_states_with_pos = hidden_states + vision_pos_encoding
        hidden_states, _ = self.self_attn(
            query=hidden_states_with_pos,
            key=hidden_states_with_pos,
            value=hidden_states,
            **kwargs,
        )
        hidden_states = self.dropout(hidden_states) + residual

        # Cross-attention: vision queries attend to text/prompt features
        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)

        hidden_states, _ = self.cross_attn(
            query=hidden_states,
            key=prompt_feats,
            value=prompt_feats,
            attention_mask=prompt_cross_attn_mask,
            **kwargs,
        )
        hidden_states = self.dropout(hidden_states) + residual

        # MLP
        residual = hidden_states
        hidden_states = self.layer_norm3(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = self.dropout(hidden_states) + residual

        return hidden_states


class Sam3LiteTextDetrEncoder(Sam3LiteTextPreTrainedModel):
    """
    DETR-style encoder that processes multi-level vision features with text fusion.

    This encoder processes vision features from multiple levels (e.g., FPN features at different
    resolutions) and fuses them with text prompts through a stack of transformer encoder layers.
    """

    _can_record_outputs = {
        "hidden_states": Sam3LiteTextDetrEncoderLayer,
        "attentions": Sam3LiteTextAttention,
    }

    def __init__(self, config: Sam3LiteTextDETREncoderConfig):
        super().__init__(config)
        self.config = config
        self.hidden_size = config.hidden_size

        self.layers = nn.ModuleList([Sam3LiteTextDetrEncoderLayer(config) for _ in range(config.num_layers)])

        self.post_init()

    def _prepare_multilevel_features(
        self,
        vision_features: list[torch.Tensor],
        vision_pos_embeds: list[torch.Tensor],
    ):
        """
        Prepare multi-level vision features by flattening spatial dimensions and adding level embeddings.

        Args:
            vision_features: List of vision features at different levels [batch_size, channels, height, width]
            vision_pos_embeds: List of position embeddings for each level [batch_size, channels, height, width]

        Returns:
            Tuple containing flattened features, position embeddings, and spatial metadata
        """
        features_flattened = []
        pos_embeds_flattened = []
        spatial_shapes = []

        for features, pos_embed in zip(vision_features, vision_pos_embeds):
            height, width = features.shape[-2:]
            spatial_shapes.append((height, width))

            # Flatten spatial dimensions: [batch_size, channels, height, width] -> [batch_size, height*width, channels]
            features = features.flatten(2).transpose(1, 2)
            pos_embed = pos_embed.flatten(2).transpose(1, 2)

            features_flattened.append(features)
            pos_embeds_flattened.append(pos_embed)

        # Concatenate all levels into single sequence
        features_flattened = torch.cat(features_flattened, dim=1)
        pos_embeds_flattened = torch.cat(pos_embeds_flattened, dim=1)

        spatial_shapes = torch.tensor(spatial_shapes, dtype=torch.long, device=features_flattened.device)

        return (
            features_flattened,
            pos_embeds_flattened,
            spatial_shapes,
        )

    @merge_with_config_defaults
    @capture_outputs
    def forward(
        self,
        vision_features: list[torch.Tensor],
        text_features: torch.Tensor,
        vision_pos_embeds: list[torch.Tensor] | None = None,
        text_mask: torch.Tensor | None = None,
        spatial_sizes: list[tuple[int, int]] | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> tuple | Sam3LiteTextDETREncoderOutput:
        """
        Forward pass for the DETR encoder.

        Args:
            vision_features: List of vision features at different levels
            text_features: Text prompt features [batch_size, seq_len, hidden_size]
            vision_pos_embeds: Optional list of position embeddings for each level
            text_mask: Optional text padding mask [batch_size, seq_len]
            spatial_sizes: Optional list of (height, width) tuples for reshaping

        Returns:
            Sam3LiteTextDETREncoderOutput containing encoded features and metadata.
        """
        batch_size = vision_features[0].shape[0] if vision_features[0].dim() == 4 else vision_features[0].shape[1]

        # TODO: See if we can remove that reshaping and just use the features as is.
        if spatial_sizes is not None:
            for i, (height, width) in enumerate(spatial_sizes):
                # Reshape from [height*width, batch_size, channels] to [batch_size, channels, height, width]
                vision_features[i] = vision_features[i].reshape(height, width, batch_size, -1).permute(2, 3, 0, 1)
                vision_pos_embeds[i] = vision_pos_embeds[i].reshape(height, width, batch_size, -1).permute(2, 3, 0, 1)

        # Flatten multi-level features for encoder processing
        (
            features_flattened,
            pos_embeds_flattened,
            spatial_shapes,
        ) = self._prepare_multilevel_features(vision_features, vision_pos_embeds)

        prompt_cross_attn_mask = None
        if text_mask is not None:
            prompt_cross_attn_mask = create_bidirectional_mask(
                config=self.config,
                inputs_embeds=features_flattened,
                attention_mask=text_mask,
                encoder_hidden_states=text_features,
            )

        hidden_states = features_flattened
        for layer in self.layers:
            hidden_states = layer(
                hidden_states,
                prompt_feats=text_features,
                vision_pos_encoding=pos_embeds_flattened,
                prompt_cross_attn_mask=prompt_cross_attn_mask,
                **kwargs,
            )
        return Sam3LiteTextDETREncoderOutput(
            last_hidden_state=hidden_states,
            pos_embeds_flattened=pos_embeds_flattened,
            text_features=text_features,
            spatial_shapes=spatial_shapes,
        )


class Sam3LiteTextDecoderMLP(nn.Module):
    """Simple 2 or 3-layer MLP for decoder components."""

    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int = 2):
        super().__init__()
        if num_layers == 2:
            self.layer1 = nn.Linear(input_dim, hidden_dim)
            self.layer2 = nn.Linear(hidden_dim, output_dim)
            self.layer3 = None
        elif num_layers == 3:
            self.layer1 = nn.Linear(input_dim, hidden_dim)
            self.layer2 = nn.Linear(hidden_dim, hidden_dim)
            self.layer3 = nn.Linear(hidden_dim, output_dim)
        else:
            raise ValueError(f"Only 2 or 3 layers supported, got {num_layers}")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.relu(self.layer1(x))
        if self.layer3 is not None:
            x = F.relu(self.layer2(x))
            x = self.layer3(x)
        else:
            x = self.layer2(x)
        return x


class Sam3LiteTextDetrDecoderLayer(nn.Module):
    """DETR decoder layer with self-attention, text cross-attention, and vision cross-attention."""

    def __init__(self, config: Sam3LiteTextDETRDecoderConfig):
        super().__init__()
        self.config = config
        self.self_attn = Sam3LiteTextAttention(config)
        self.self_attn_dropout = nn.Dropout(config.dropout)
        self.self_attn_layer_norm = nn.LayerNorm(config.hidden_size)

        self.text_cross_attn = Sam3LiteTextAttention(config)
        self.text_cross_attn_dropout = nn.Dropout(config.dropout)
        self.text_cross_attn_layer_norm = nn.LayerNorm(config.hidden_size)

        self.vision_cross_attn = Sam3LiteTextAttention(config)
        self.vision_cross_attn_dropout = nn.Dropout(config.dropout)
        self.vision_cross_attn_layer_norm = nn.LayerNorm(config.hidden_size)

        self.mlp = Sam3LiteTextMLP(config)
        self.mlp_layer_norm = nn.LayerNorm(config.hidden_size)
        self.mlp_dropout = nn.Dropout(config.dropout)

    def forward(
        self,
        hidden_states: torch.Tensor,
        query_pos: torch.Tensor,
        text_features: torch.Tensor,
        vision_features: torch.Tensor,
        vision_pos_encoding: torch.Tensor,
        text_cross_attn_mask: torch.Tensor | None = None,
        vision_cross_attn_mask: torch.Tensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> torch.Tensor:
        """
        Forward pass for decoder layer.

        Args:
            hidden_states: Query features [batch_size, num_queries + 1, hidden_size] (includes presence token at position 0)
            query_pos: Query position embeddings [batch_size, num_queries, hidden_size]
            text_features: Text features [batch_size, seq_len, hidden_size]
            vision_features: Vision features [batch_size, height*width, hidden_size]
            vision_pos_encoding: Vision position encoding [batch_size, height*width, hidden_size]
            text_cross_attn_mask: Text cross-attention mask
            vision_cross_attn_mask: Vision cross-attention mask, already expanded for presence token

        Returns:
            Updated hidden states (including presence token at position 0)
        """
        # Prepend zeros to query_pos for presence token
        query_pos = F.pad(query_pos, (0, 0, 1, 0), mode="constant", value=0)

        # Self-attention with query position encoding
        residual = hidden_states
        query_with_pos = hidden_states + query_pos
        attn_output, _ = self.self_attn(
            query=query_with_pos,
            key=query_with_pos,
            value=hidden_states,
            attention_mask=None,
            **kwargs,
        )
        hidden_states = residual + self.self_attn_dropout(attn_output)
        hidden_states = self.self_attn_layer_norm(hidden_states)

        # Text cross-attention: queries attend to text features
        residual = hidden_states
        query_with_pos = hidden_states + query_pos

        attn_output, _ = self.text_cross_attn(
            query=query_with_pos,
            key=text_features,
            value=text_features,
            attention_mask=text_cross_attn_mask,
            **kwargs,
        )
        hidden_states = residual + self.text_cross_attn_dropout(attn_output)
        hidden_states = self.text_cross_attn_layer_norm(hidden_states)

        # Vision cross-attention: queries attend to vision features (with RPB)
        residual = hidden_states
        query_with_pos = hidden_states + query_pos
        key_with_pos = vision_features + vision_pos_encoding
        attn_output, _ = self.vision_cross_attn(
            query=query_with_pos,
            key=key_with_pos,
            value=vision_features,
            attention_mask=vision_cross_attn_mask,
            **kwargs,
        )
        hidden_states = residual + self.vision_cross_attn_dropout(attn_output)
        hidden_states = self.vision_cross_attn_layer_norm(hidden_states)

        # MLP
        residual = hidden_states
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + self.mlp_dropout(hidden_states)
        hidden_states = self.mlp_layer_norm(hidden_states)

        return hidden_states


def inverse_sigmoid(x: torch.Tensor, eps: float = 1e-3) -> torch.Tensor:
    """The inverse function for sigmoid activation function."""
    x = x.clamp(min=0, max=1)
    x1 = x.clamp(min=eps)
    x2 = (1 - x).clamp(min=eps)
    return torch.log(x1 / x2)


class Sam3LiteTextDetrDecoder(Sam3LiteTextPreTrainedModel):
    """
    DETR-style decoder with box refinement and presence token.

    Simplified version that assumes:
    - Box refinement is always enabled
    - Intermediate outputs are always returned
    - BoxRPB (relative position bias) with log-scale encoding
    - Presence token is used
    """

    _can_record_outputs = {
        "hidden_states": Sam3LiteTextDetrDecoderLayer,
        "attentions": Sam3LiteTextAttention,
    }

    def __init__(
        self,
        config: Sam3LiteTextDETRDecoderConfig,
    ):
        super().__init__(config)
        self.config = config
        self.hidden_size = config.hidden_size

        self.layers = nn.ModuleList([Sam3LiteTextDetrDecoderLayer(config) for _ in range(config.num_layers)])

        self.output_layer_norm = nn.LayerNorm(config.hidden_size)

        self.box_head = Sam3LiteTextDecoderMLP(config.hidden_size, config.hidden_size, 4, 3)

        self.query_embed = nn.Embedding(config.num_queries, config.hidden_size)
        self.reference_points = nn.Embedding(config.num_queries, 4)

        self.presence_token = nn.Embedding(1, config.hidden_size)
        self.presence_head = Sam3LiteTextDecoderMLP(config.hidden_size, config.hidden_size, 1, 3)
        self.presence_layer_norm = nn.LayerNorm(config.hidden_size)
        self.clamp_presence_logit_max_val = 10.0

        self.ref_point_head = Sam3LiteTextDecoderMLP(2 * config.hidden_size, config.hidden_size, config.hidden_size, 2)

        self.box_rpb_embed_x = Sam3LiteTextDecoderMLP(2, config.hidden_size, config.num_attention_heads, 2)
        self.box_rpb_embed_y = Sam3LiteTextDecoderMLP(2, config.hidden_size, config.num_attention_heads, 2)

        self.position_encoding = Sam3LiteTextSinePositionEmbedding(
            num_pos_feats=config.hidden_size // 2, normalize=False
        )

        self.post_init()

    @compile_compatible_method_lru_cache(maxsize=1)
    def _get_coords(
        self, height: torch.Tensor, width: torch.Tensor, dtype: torch.dtype, device: torch.device
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Generate normalized coordinate grids."""
        coords_h = torch.arange(0, height, device=device, dtype=dtype) / height
        coords_w = torch.arange(0, width, device=device, dtype=dtype) / width
        return coords_h, coords_w

    def _get_rpb_matrix(
        self, reference_boxes: torch.Tensor, spatial_shape: tuple[torch.Tensor, torch.Tensor]
    ) -> torch.Tensor:
        """
        Compute box relative position bias (RPB) matrix using log-scale encoding.
        RPB helps the decoder attend to relevant spatial locations based on predicted box positions.

        Args:
            reference_boxes: Reference boxes [batch_size, num_queries, 4] in sigmoid space
            spatial_shape: (height, width) of the vision features as tensors

        Returns:
            RPB matrix [batch_size, num_heads, num_queries, height*width]
        """
        height, width = spatial_shape
        boxes_xyxy = box_cxcywh_to_xyxy(reference_boxes)
        batch_size, num_queries, _ = boxes_xyxy.shape

        # Generate coordinate grids
        coords_h, coords_w = self._get_coords(
            height, width, dtype=reference_boxes.dtype, device=reference_boxes.device
        )

        # Compute deltas between coordinates and box boundaries
        deltas_y = coords_h.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 1:4:2]
        deltas_y = deltas_y.view(batch_size, num_queries, -1, 2)
        deltas_x = coords_w.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 0:3:2]
        deltas_x = deltas_x.view(batch_size, num_queries, -1, 2)

        # Apply log-scale encoding
        deltas_x_log = deltas_x * 8
        deltas_x_log = torch.sign(deltas_x_log) * torch.log2(torch.abs(deltas_x_log) + 1.0) / math.log2(8)
        deltas_y_log = deltas_y * 8
        deltas_y_log = torch.sign(deltas_y_log) * torch.log2(torch.abs(deltas_y_log) + 1.0) / math.log2(8)

        # Embed deltas
        deltas_x = self.box_rpb_embed_x(deltas_x_log)  # [batch_size, num_queries, width, num_heads]
        deltas_y = self.box_rpb_embed_y(deltas_y_log)  # [batch_size, num_queries, height, num_heads]

        # Combine into 2D bias matrix
        rpb_matrix = deltas_y.unsqueeze(3) + deltas_x.unsqueeze(
            2
        )  # [batch_size, num_queries, height, width, num_heads]
        rpb_matrix = rpb_matrix.flatten(2, 3)  # [batch_size, num_queries, height*width, num_heads]
        rpb_matrix = rpb_matrix.permute(0, 3, 1, 2).contiguous()  # [batch_size, num_heads, num_queries, height*width]
        return rpb_matrix

    @merge_with_config_defaults
    @capture_outputs
    def forward(
        self,
        vision_features: torch.Tensor,
        text_features: torch.Tensor,
        vision_pos_encoding: torch.Tensor,
        text_mask: torch.Tensor | None = None,
        spatial_shapes: torch.Tensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> tuple | Sam3LiteTextDETRDecoderOutput:
        """
        Forward pass for the DETR decoder.

        Args:
            vision_features: Vision features [batch_size, height*width, hidden_size]
            text_features: Text features [batch_size, seq_len, hidden_size]
            vision_pos_encoding: Vision position encoding [batch_size, height*width, hidden_size]
            text_mask: Text padding mask [batch_size, seq_len] where True=valid, False=padding
            spatial_shapes: Spatial shapes [num_levels, 2]

        Returns:
            Sam3LiteTextDETRDecoderOutput containing decoder outputs from all layers.
        """
        batch_size = vision_features.shape[0]

        query_embeds = self.query_embed.weight.unsqueeze(0).expand(batch_size, -1, -1)
        reference_boxes = self.reference_points.weight.unsqueeze(0).expand(batch_size, -1, -1)
        reference_boxes = reference_boxes.sigmoid()
        presence_token = self.presence_token.weight.unsqueeze(0).expand(batch_size, -1, -1)

        # Concatenate presence token with query embeddings
        hidden_states = torch.cat([presence_token, query_embeds], dim=1)

        text_cross_attn_mask = None
        if text_mask is not None:
            text_cross_attn_mask = create_bidirectional_mask(
                config=self.config,
                inputs_embeds=hidden_states,
                attention_mask=text_mask,
                encoder_hidden_states=text_features,
            )

        intermediate_outputs = []
        intermediate_boxes = [reference_boxes]
        intermediate_presence_logits = []

        for layer in self.layers:
            # Generate sine embeddings for conditional queries
            reference_points_input = reference_boxes.unsqueeze(2)
            query_sine_embed = self.position_encoding.encode_boxes(reference_points_input[:, :, 0, :])
            query_pos = self.ref_point_head(query_sine_embed)

            # Compute box relative position bias (RPB) attention mask
            vision_cross_attn_mask = None
            if spatial_shapes is not None and spatial_shapes.shape[0] == 1:
                spatial_shape = (spatial_shapes[0, 0], spatial_shapes[0, 1])
                rpb_matrix = self._get_rpb_matrix(reference_boxes, spatial_shape)
                # Prepend zeros row for presence token (it attends to all vision tokens equally)
                vision_cross_attn_mask = F.pad(rpb_matrix, (0, 0, 1, 0), mode="constant", value=0)

            hidden_states = layer(
                hidden_states,
                query_pos=query_pos,
                text_features=text_features,
                vision_features=vision_features,
                vision_pos_encoding=vision_pos_encoding,
                text_cross_attn_mask=text_cross_attn_mask,
                vision_cross_attn_mask=vision_cross_attn_mask,
                **kwargs,
            )

            # Extract query hidden states (without presence token) for box refinement
            query_hidden_states = hidden_states[:, 1:]

            # Box refinement: predict delta and update reference boxes
            reference_boxes_before_sigmoid = inverse_sigmoid(reference_boxes)
            delta_boxes = self.box_head(self.output_layer_norm(query_hidden_states))
            new_reference_boxes = (delta_boxes + reference_boxes_before_sigmoid).sigmoid()
            reference_boxes = new_reference_boxes.detach()

            intermediate_outputs.append(self.output_layer_norm(query_hidden_states))
            intermediate_boxes.append(new_reference_boxes)

            # Process presence token
            presence_hidden = hidden_states[:, :1]
            presence_logits = self.presence_head(self.presence_layer_norm(presence_hidden)).squeeze(-1)
            presence_logits = presence_logits.clamp(
                min=-self.clamp_presence_logit_max_val, max=self.clamp_presence_logit_max_val
            )
            intermediate_presence_logits.append(presence_logits)

        # Stack outputs from all layers
        intermediate_outputs = torch.stack(intermediate_outputs)
        intermediate_boxes = torch.stack(intermediate_boxes[:-1])
        intermediate_presence_logits = torch.stack(intermediate_presence_logits)

        return Sam3LiteTextDETRDecoderOutput(
            intermediate_hidden_states=intermediate_outputs,
            reference_boxes=intermediate_boxes,
            presence_logits=intermediate_presence_logits,
        )


class Sam3LiteTextDotProductScoring(nn.Module):
    """
    Computes classification scores by computing dot product between projected decoder queries and pooled text features.
    This is used to determine confidence/presence scores for each query.
    """

    def __init__(self, config: Sam3LiteTextConfig):
        super().__init__()
        self.config = config
        hidden_size = config.detr_decoder_config.hidden_size
        projection_dim = config.detr_decoder_config.hidden_size

        self.text_mlp = Sam3LiteTextDecoderMLP(
            input_dim=hidden_size,
            hidden_dim=config.detr_decoder_config.intermediate_size,
            output_dim=hidden_size,
            num_layers=2,
        )
        self.text_mlp_dropout = nn.Dropout(config.detr_decoder_config.dropout)
        self.text_mlp_out_norm = nn.LayerNorm(hidden_size)

        # Projections for text and query features
        self.text_proj = nn.Linear(hidden_size, projection_dim)
        self.query_proj = nn.Linear(hidden_size, projection_dim)

        # Scale factor for dot product
        self.scale = float(1.0 / np.sqrt(projection_dim))

        # Clamping to avoid numerical issues
        self.clamp_logits = True
        self.clamp_max_val = 12.0

    def _pool_text_features(self, text_features: torch.Tensor, text_mask: torch.Tensor | None) -> torch.Tensor:
        """
        Mean pool text features, accounting for padding.

        Args:
            text_features: [batch_size, seq_len, hidden_size]
            text_mask: [batch_size, seq_len] where True indicates valid tokens, False indicates padding

        Returns:
            pooled_text: [batch_size, hidden_size]
        """
        if text_mask is None:
            # No padding, simple mean
            return text_features.mean(dim=1)

        is_valid = text_mask.to(text_features.dtype).unsqueeze(-1)  # [batch_size, seq_len, 1]

        # Count valid tokens per batch
        num_valid = is_valid.sum(dim=1).clamp(min=1.0)  # [batch_size, 1]

        # Mean pool only over valid tokens
        pooled_text = (text_features * is_valid).sum(dim=1) / num_valid  # [batch_size, hidden_size]

        return pooled_text

    def forward(
        self,
        decoder_hidden_states: torch.Tensor,
        text_features: torch.Tensor,
        text_mask: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """
        Compute classification scores via dot product.

        Args:
            decoder_hidden_states: [num_layers, batch_size, num_queries, hidden_size]
            text_features: [batch_size, seq_len, hidden_size]
            text_mask: [batch_size, seq_len] where True=valid, False=padding

        Returns:
            scores: [num_layers, batch_size, num_queries, 1]
        """
        orig_text_features = text_features
        text_features = self.text_mlp(text_features)
        text_features = self.text_mlp_dropout(text_features)
        text_features = text_features + orig_text_features
        text_features = self.text_mlp_out_norm(text_features)

        pooled_text = self._pool_text_features(text_features, text_mask)

        proj_text = self.text_proj(pooled_text)
        proj_queries = self.query_proj(decoder_hidden_states)

        proj_text = proj_text.unsqueeze(-1)
        scores = torch.matmul(proj_queries, proj_text.unsqueeze(0))
        scores = scores * self.scale
        if self.clamp_logits:
            scores = scores.clamp(min=-self.clamp_max_val, max=self.clamp_max_val)

        return scores


class Sam3LiteTextMaskEmbedder(nn.Module):
    """
    MLP that embeds object queries for mask prediction.
    Similar to MaskFormer's mask embedder.
    """

    def __init__(self, config: Sam3LiteTextMaskDecoderConfig):
        super().__init__()
        self.config = config
        hidden_size = config.hidden_size

        self.layers = nn.ModuleList(
            [
                nn.Linear(hidden_size, hidden_size),
                nn.Linear(hidden_size, hidden_size),
                nn.Linear(hidden_size, hidden_size),
            ]
        )
        self.activation = nn.ReLU()

    def forward(self, queries: torch.Tensor) -> torch.Tensor:
        """
        Args:
            queries: Query embeddings [batch_size, num_queries, hidden_size]

        Returns:
            Mask embeddings [batch_size, num_queries, hidden_size]
        """
        hidden_states = queries
        for i, layer in enumerate(self.layers):
            hidden_states = layer(hidden_states)
            if i < len(self.layers) - 1:
                hidden_states = self.activation(hidden_states)
        return hidden_states


class Sam3LiteTextPixelDecoder(nn.Module):
    """
    Feature Pyramid Network (FPN) decoder that generates pixel-level features.
    Inspired by MaskFormer's pixel decoder.
    """

    def __init__(self, config: Sam3LiteTextMaskDecoderConfig):
        super().__init__()
        self.config = config
        hidden_size = config.hidden_size
        num_upsampling_stages = config.num_upsampling_stages

        # Create conv layers and norms for FPN
        self.conv_layers = nn.ModuleList(
            [
                nn.Conv2d(hidden_size, hidden_size, kernel_size=3, stride=1, padding=1)
                for _ in range(num_upsampling_stages)
            ]
        )
        self.norms = nn.ModuleList([nn.GroupNorm(8, hidden_size) for _ in range(num_upsampling_stages)])

        self.out_channels = hidden_size

    def forward(self, backbone_features: list[torch.Tensor]) -> torch.Tensor:
        """
        Args:
            backbone_features: List of backbone features [batch_size, hidden_size, H_i, W_i]
                              from low to high resolution (assumes already projected to hidden_size)

        Returns:
            Pixel embeddings [batch_size, hidden_size, H, W] at the finest resolution
        """
        # Start from the coarsest feature (last in list)
        prev_fpn = backbone_features[-1]
        # Iterate through features from coarse to fine (excluding the last which we started with)
        for layer_idx, backbone_feat in enumerate(reversed(backbone_features[:-1])):
            # Upsample previous FPN output to match current backbone feature size
            prev_fpn = F.interpolate(prev_fpn, size=backbone_feat.shape[-2:], mode="nearest")

            # Add skip connection
            prev_fpn = prev_fpn + backbone_feat

            # Apply conv and norm
            prev_fpn = self.conv_layers[layer_idx](prev_fpn)
            prev_fpn = self.norms[layer_idx](prev_fpn)
            prev_fpn = F.relu(prev_fpn)

        return prev_fpn


class Sam3LiteTextMaskDecoder(Sam3LiteTextPreTrainedModel):
    """
    Mask decoder that combines object queries with pixel-level features to predict instance masks.
    Also produces a semantic segmentation output and supports cross-attention to prompts.
    """

    _can_record_outputs = {
        "attentions": Sam3LiteTextAttention,
    }

    def __init__(self, config: Sam3LiteTextMaskDecoderConfig):
        super().__init__(config)
        self.config = config
        hidden_size = config.hidden_size

        # Pixel decoder (FPN)
        self.pixel_decoder = Sam3LiteTextPixelDecoder(config)

        # Mask embedder (MLP to transform queries)
        self.mask_embedder = Sam3LiteTextMaskEmbedder(config)

        # Projection from pixel decoder output to mask embedding space
        self.instance_projection = nn.Conv2d(self.pixel_decoder.out_channels, hidden_size, kernel_size=1)

        # Semantic segmentation head (always present in UniversalSegmentationHead)
        self.semantic_projection = nn.Conv2d(self.pixel_decoder.out_channels, 1, kernel_size=1)

        self.prompt_cross_attn = Sam3LiteTextAttention(config)
        self.prompt_cross_attn_norm = nn.LayerNorm(hidden_size)
        self.prompt_cross_attn_dropout = nn.Dropout(config.dropout)

        self.post_init()

    @merge_with_config_defaults
    @capture_outputs
    def forward(
        self,
        decoder_queries: torch.Tensor,
        backbone_features: list[torch.Tensor],
        encoder_hidden_states: torch.Tensor,
        prompt_features: torch.Tensor | None = None,
        prompt_mask: torch.Tensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> tuple | Sam3LiteTextMaskDecoderOutput:
        """
        Args:
            decoder_queries: Decoder output queries [batch_size, num_queries, hidden_size]
            backbone_features: List of backbone features to process through FPN
            encoder_hidden_states: Encoder outputs [batch_size, seq_len, hidden_size]
            prompt_features: Prompt features (text + geometry) for cross-attention [batch_size, prompt_len, hidden_size]
            prompt_mask: Padding mask [batch_size, prompt_len] where True=valid, False=padding

        Returns:
            Sam3LiteTextMaskDecoderOutput containing predicted masks and semantic segmentation.
        """
        if prompt_features is not None:
            # Cross-attention: encoder features attend to prompt features
            residual = encoder_hidden_states
            normed_hidden_states = self.prompt_cross_attn_norm(encoder_hidden_states)

            cross_attn_mask = None
            if prompt_mask is not None:
                cross_attn_mask = create_bidirectional_mask(
                    config=self.config,
                    inputs_embeds=normed_hidden_states,
                    encoder_hidden_states=prompt_features,
                    attention_mask=prompt_mask,
                )

            attn_output, _ = self.prompt_cross_attn(
                query=normed_hidden_states,
                key=prompt_features,
                value=prompt_features,
                attention_mask=cross_attn_mask,
                **kwargs,
            )
            encoder_hidden_states = residual + self.prompt_cross_attn_dropout(attn_output)

        # Process backbone features through FPN to get pixel embeddings
        pixel_embed = self._embed_pixels(
            backbone_features=backbone_features,
            encoder_hidden_states=encoder_hidden_states,
        )

        # Predict instance masks via dot product between query embeddings and pixel embeddings
        instance_embeds = self.instance_projection(pixel_embed)
        mask_embeddings = self.mask_embedder(decoder_queries)
        pred_masks = torch.einsum("bqc,bchw->bqhw", mask_embeddings, instance_embeds)

        # Generate semantic segmentation
        semantic_seg = self.semantic_projection(pixel_embed)

        return Sam3LiteTextMaskDecoderOutput(
            pred_masks=pred_masks,
            semantic_seg=semantic_seg,
        )

    def _embed_pixels(
        self,
        backbone_features: list[torch.Tensor],
        encoder_hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        """
        Embed pixels by combining backbone FPN features with encoder vision features.
        The encoder vision features replace the finest-resolution backbone feature.

        Args:
            backbone_features: List of backbone features [batch_size, C, H_i, W_i]
            encoder_hidden_states: Encoder outputs [batch_size, seq_len, hidden_size]

        Returns:
            Pixel embeddings [batch_size, hidden_size, H, W]
        """
        backbone_visual_feats = [feat.clone() for feat in backbone_features]

        # Extract vision features from encoder output and reshape to spatial format
        spatial_dim = backbone_features[-1].shape[-2] * backbone_features[-1].shape[-1]
        encoder_visual_embed = encoder_hidden_states[:, :spatial_dim, :]
        batch_size, _, hidden_size = encoder_visual_embed.shape
        height, width = backbone_features[-1].shape[-2:]
        encoder_visual_embed = encoder_visual_embed.transpose(1, 2).reshape(batch_size, hidden_size, height, width)

        # Replace finest backbone feature with encoder vision features
        backbone_visual_feats[-1] = encoder_visual_embed

        # Process through FPN decoder
        pixel_embed = self.pixel_decoder(backbone_visual_feats)

        return pixel_embed


class Sam3LiteTextModel(Sam3LiteTextPreTrainedModel):
    input_modalities = ["image", "text"]
    base_model_prefix = "detector_model"
    _keys_to_ignore_on_load_unexpected = [
        r"^tracker_model.",
        r"^tracker_neck.",
    ]
    # DETR components create float masks from features, so flash/flex attention cannot be dispatched safely.
    _supports_flash_attn = False
    _supports_flex_attn = False

    def __init__(self, config: Sam3LiteTextConfig):
        # loading from a sam3_lite_text_video config
        if hasattr(config, "detector_config") and config.detector_config is not None:
            detector_config = config.detector_config
            if isinstance(detector_config, dict):
                detector_config = Sam3LiteTextConfig(**detector_config)
            config = detector_config
        super().__init__(config)
        self.vision_encoder = AutoModel.from_config(config.vision_config)
        self.text_encoder = Sam3LiteTextTextModel(config.text_config)
        self.vocab_size = config.text_config.vocab_size

        # Project text features from text encoder hidden size to model hidden size
        # CLIP text encoder outputs 1024-dim features, but we need 256-dim for DETR
        self.text_projection = nn.Linear(config.text_config.hidden_size, config.detr_encoder_config.hidden_size)

        # Pass _attn_implementation to subconfigs BEFORE creating modules
        config.geometry_encoder_config._attn_implementation = config._attn_implementation
        config.detr_encoder_config._attn_implementation = config._attn_implementation
        config.detr_decoder_config._attn_implementation = config._attn_implementation
        config.mask_decoder_config._attn_implementation = config._attn_implementation

        self.geometry_encoder = Sam3LiteTextGeometryEncoder(config.geometry_encoder_config)
        self.detr_encoder = Sam3LiteTextDetrEncoder(config.detr_encoder_config)
        self.detr_decoder = Sam3LiteTextDetrDecoder(config.detr_decoder_config)
        self.mask_decoder = Sam3LiteTextMaskDecoder(config.mask_decoder_config)

        # Dot product scoring to compute classification scores
        self.dot_product_scoring = Sam3LiteTextDotProductScoring(config)

        self.post_init()

    @can_return_tuple
    @auto_docstring
    def get_text_features(
        self,
        input_ids: torch.LongTensor,
        attention_mask: torch.Tensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> tuple | BaseModelOutputWithPooling:
        r"""
        Example:

        ```python
        >>> from transformers import Sam3LiteTextModel, Sam3LiteTextProcessor
        >>> from PIL import Image
        >>> import httpx
        >>> from io import BytesIO

        >>> model = Sam3LiteTextModel.from_pretrained("facebook/sam3_lite_text")
        >>> processor = Sam3LiteTextProcessor.from_pretrained("facebook/sam3_lite_text")

        >>> # Pre-compute text embeddings
        >>> text_inputs = processor(text="cat", return_tensors="pt")
        >>> text_embeds = model.get_text_features(**text_inputs).pooler_output

        >>> # Reuse text embeddings for multiple images
        >>> url = "http://images.cocodataset.org/val2017/000000077595.jpg"
        >>> with httpx.stream("GET", url) as response:
        ...     image = Image.open(BytesIO(response.read()))
        >>> img_inputs = processor(images=image, return_tensors="pt")
        >>> outputs = model(pixel_values=img_inputs.pixel_values, text_embeds=text_embeds)
        ```
        """
        text_outputs = self.text_encoder(
            input_ids=input_ids, attention_mask=attention_mask, return_dict=True, **kwargs
        )
        last_hidden_state = text_outputs.last_hidden_state
        text_outputs.pooler_output = self.text_projection(last_hidden_state)

        return text_outputs

    @auto_docstring
    def get_vision_features(
        self,
        pixel_values: torch.FloatTensor,
        **kwargs: Unpack[TransformersKwargs],
    ) -> Sam3LiteTextVisionEncoderOutput:
        r"""
        Example:

        ```python
        >>> from transformers import Sam3LiteTextModel, Sam3LiteTextProcessor
        >>> from PIL import Image
        >>> import httpx
        >>> from io import BytesIO

        >>> model = Sam3LiteTextModel.from_pretrained("facebook/sam3_lite_text")
        >>> processor = Sam3LiteTextProcessor.from_pretrained("facebook/sam3_lite_text")

        >>> # Pre-compute vision embeddings
        >>> url = "http://images.cocodataset.org/val2017/000000077595.jpg"
        >>> with httpx.stream("GET", url) as response:
        ...     image = Image.open(BytesIO(response.read()))
        >>> img_inputs = processor(images=image, return_tensors="pt")
        >>> vision_embeds = model.get_vision_features(pixel_values=img_inputs.pixel_values)

        >>> # Reuse vision embeddings for multiple text prompts
        >>> text_inputs = processor(text="cat", return_tensors="pt")
        >>> outputs = model(vision_embeds=vision_embeds, input_ids=text_inputs.input_ids)
        ```
        """
        vision_outputs = self.vision_encoder(pixel_values, **kwargs)
        return vision_outputs

    @can_return_tuple
    @auto_docstring
    def forward(
        self,
        pixel_values: torch.FloatTensor | None = None,
        vision_embeds: Sam3LiteTextVisionEncoderOutput | None = None,
        input_ids: torch.LongTensor | None = None,
        attention_mask: torch.Tensor | None = None,
        text_embeds: torch.FloatTensor | None = None,
        input_boxes: torch.FloatTensor | None = None,
        input_boxes_labels: torch.LongTensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> Sam3LiteTextImageSegmentationOutput:
        r"""
        vision_embeds (`Sam3LiteTextVisionEncoderOutput`, *optional*):
            Pre-computed vision embeddings. Can be used to easily reuse vision embeddings. If provided, `pixel_values`
            should not be passed. Mutually exclusive with `pixel_values`.
        text_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Pre-computed text embeddings. Can be used to easily reuse text embeddings. If provided, `input_ids`
            should not be passed. Mutually exclusive with `input_ids`.
        input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`, *optional*):
            Normalized box coordinates in [0, 1] range, in (cx, cy, w, h) format.
        input_boxes_labels (`torch.LongTensor` of shape `(batch_size, num_boxes)`, *optional*):
            Labels for boxes: 1 (positive), 0 (negative).

        Example:

        ```python
        >>> from PIL import Image
        >>> import httpx
        >>> from io import BytesIO
        >>> from transformers import AutoModel, AutoProcessor

        >>> model = AutoModel.from_pretrained("facebook/sam3_lite_text")
        >>> processor = AutoProcessor.from_pretrained("facebook/sam3_lite_text")

        >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png"
        >>> with httpx.stream("GET", url) as response:
        ...     image = Image.open(BytesIO(response.read())).convert("RGB")
        >>> text = "car"
        >>> inputs = processor(images=image, text=text, return_tensors="pt")

        >>> # Get segmentation output
        >>> outputs = model(**inputs)
        >>> pred_masks = outputs.pred_masks
        >>> pred_boxes = outputs.pred_boxes
        ```
        """
        if (pixel_values is None) == (vision_embeds is None):
            raise ValueError("You must specify exactly one of pixel_values or vision_embeds")

        if (input_ids is None) == (text_embeds is None):
            raise ValueError("You must specify exactly one of input_ids or text_embeds")

        if pixel_values is not None:
            batch_size = pixel_values.shape[0]
            device = pixel_values.device
        else:
            batch_size = vision_embeds.fpn_hidden_states[0].shape[0]
            device = vision_embeds.fpn_hidden_states[0].device

        if vision_embeds is None:
            vision_outputs = self.vision_encoder(pixel_values, **kwargs)
        else:
            vision_outputs = vision_embeds

        fpn_hidden_states = vision_outputs.fpn_hidden_states[:-1]
        fpn_position_encoding = vision_outputs.fpn_position_encoding[:-1]

        if text_embeds is None:
            text_features = self.get_text_features(
                input_ids=input_ids, attention_mask=attention_mask, return_dict=True
            ).pooler_output
        else:
            text_features = text_embeds

        text_mask = attention_mask.bool() if attention_mask is not None else None
        has_geometry_prompts = input_boxes is not None and input_boxes.numel() > 0

        geometry_prompt_features = None
        geometry_prompt_mask = None

        if has_geometry_prompts:
            if input_boxes is not None and input_boxes.numel() > 0:
                box_embeddings = input_boxes  # [batch_size, num_boxes, 4]
                box_labels = (
                    input_boxes_labels
                    if input_boxes_labels is not None
                    else torch.ones_like(box_embeddings[..., 0], dtype=torch.long)
                )
                box_mask = (
                    (input_boxes_labels != -10)
                    if input_boxes_labels is not None
                    else torch.ones(batch_size, input_boxes.shape[1], dtype=torch.bool, device=device)
                )
                box_labels = torch.where(box_labels == -10, 0, box_labels)
            else:
                box_embeddings = torch.zeros(batch_size, 0, 4, dtype=text_features.dtype, device=device)
                box_labels = torch.zeros(batch_size, 0, dtype=torch.long, device=device)
                box_mask = torch.zeros(batch_size, 0, dtype=torch.bool, device=device)

            geometry_outputs = self.geometry_encoder(
                box_embeddings=box_embeddings,
                box_mask=box_mask,
                box_labels=box_labels,
                img_feats=fpn_hidden_states,
                img_pos_embeds=fpn_position_encoding,
            )

            geometry_prompt_features = geometry_outputs.last_hidden_state
            geometry_prompt_mask = geometry_outputs.attention_mask

        if geometry_prompt_features is not None:
            # Repeat text_features for all geometry prompts
            if text_features.shape[0] == 1 and geometry_prompt_features.shape[0] > 1:
                text_features = text_features.repeat(geometry_prompt_features.shape[0], 1, 1)
            combined_prompt_features = torch.cat([text_features, geometry_prompt_features], dim=1)
            if text_mask is not None and text_mask.shape[0] == 1 and geometry_prompt_mask.shape[0] > 1:
                text_mask = text_mask.repeat(geometry_prompt_mask.shape[0], 1)

            if text_mask is not None and geometry_prompt_mask is not None:
                combined_prompt_mask = torch.cat([text_mask, geometry_prompt_mask], dim=1)
            elif text_mask is not None:
                geo_valid_mask = torch.ones(
                    batch_size, geometry_prompt_features.shape[1], dtype=torch.bool, device=device
                )
                combined_prompt_mask = torch.cat([text_mask, geo_valid_mask], dim=1)
            elif geometry_prompt_mask is not None:
                text_valid_mask = torch.ones(batch_size, text_features.shape[1], dtype=torch.bool, device=device)
                combined_prompt_mask = torch.cat([text_valid_mask, geometry_prompt_mask], dim=1)
            else:
                combined_prompt_mask = None
        else:
            combined_prompt_features = text_features
            combined_prompt_mask = text_mask

        encoder_outputs = self.detr_encoder(
            vision_features=[fpn_hidden_states[-1]],
            text_features=combined_prompt_features,
            vision_pos_embeds=[fpn_position_encoding[-1]],
            text_mask=combined_prompt_mask,
            **kwargs,
        )

        decoder_outputs = self.detr_decoder(
            vision_features=encoder_outputs.last_hidden_state,
            text_features=encoder_outputs.text_features,
            vision_pos_encoding=encoder_outputs.pos_embeds_flattened,
            text_mask=combined_prompt_mask,
            spatial_shapes=encoder_outputs.spatial_shapes,
            **kwargs,
        )

        # Refine boxes from decoder
        all_box_offsets = self.detr_decoder.box_head(decoder_outputs.intermediate_hidden_states)
        reference_boxes_inv_sig = inverse_sigmoid(decoder_outputs.reference_boxes)
        all_pred_boxes_cxcywh = (reference_boxes_inv_sig + all_box_offsets).sigmoid()
        all_pred_boxes = box_cxcywh_to_xyxy(all_pred_boxes_cxcywh)

        all_pred_logits = self.dot_product_scoring(
            decoder_hidden_states=decoder_outputs.intermediate_hidden_states,
            text_features=encoder_outputs.text_features,
            text_mask=combined_prompt_mask,
        ).squeeze(-1)

        pred_logits = all_pred_logits[-1]
        pred_boxes = all_pred_boxes[-1]
        decoder_hidden_states = decoder_outputs.intermediate_hidden_states[-1]
        presence_logits = decoder_outputs.presence_logits[-1]

        mask_outputs = self.mask_decoder(
            decoder_queries=decoder_hidden_states,
            backbone_features=list(fpn_hidden_states),
            encoder_hidden_states=encoder_outputs.last_hidden_state,
            prompt_features=combined_prompt_features,
            prompt_mask=combined_prompt_mask,
            **kwargs,
        )

        return Sam3LiteTextImageSegmentationOutput(
            pred_masks=mask_outputs.pred_masks,
            pred_boxes=pred_boxes,
            pred_logits=pred_logits,
            presence_logits=presence_logits,
            semantic_seg=mask_outputs.semantic_seg,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_reference_boxes=decoder_outputs.reference_boxes,
            encoder_hidden_states=encoder_outputs.hidden_states,
            vision_hidden_states=vision_outputs.hidden_states,
            vision_attentions=vision_outputs.attentions,
            detr_encoder_attentions=encoder_outputs.attentions,
            detr_decoder_attentions=decoder_outputs.attentions,
            mask_decoder_attentions=mask_outputs.attentions,
        )


__all__ = ["Sam3LiteTextModel", "Sam3LiteTextPreTrainedModel", "Sam3LiteTextTextModel"]
