#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
#           This file was automatically generated from src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py.
#               Do NOT edit this file manually as any edits will be overwritten by the generation of
#             the file from the modular. If any change should be done, please apply the change to the
#                          modular_qwen2_5_omni.py file directly. One of our CI enforces this.
#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from huggingface_hub.dataclasses import strict

from ...configuration_utils import PreTrainedConfig
from ...modeling_rope_utils import RopeParameters
from ...utils import auto_docstring, logging


logger = logging.get_logger(__name__)


@auto_docstring(checkpoint="Qwen/Qwen2.5-Omni-7B")
@strict
class Qwen2_5OmniVisionEncoderConfig(PreTrainedConfig):
    r"""
    window_size (`int`, *optional*, defaults to 11):
        Size of windows.
    out_hidden_size (`int`, *optional*, defaults to 3584):
        The output hidden size of the vision model.
    fullatt_block_indexes (`int`, *optional*, defaults to `[7, 15, 23, 31]`):
        Indices of layers with full attention

    Example:

    ```python
    >>> from transformers import Qwen2_5OmniVisionEncoderConfig, Qwen2_5OmniVisionEncoder

    >>> # Initializing a Qwen2_5OmniVisionEncoderConfig
    >>> configuration = Qwen2_5OmniVisionEncoderConfig()

    >>> # Initializing a Qwen2_5OmniVisionEncoder (with random weights)
    >>> model = Qwen2_5OmniVisionEncoder(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```"""

    model_type = "qwen2_5_omni_vision_encoder"
    base_config_key = "vision_config"

    depth: int = 32
    hidden_size: int = 3584
    hidden_act: str = "silu"
    intermediate_size: int = 3420
    num_heads: int = 16
    in_channels: int = 3
    patch_size: int | list[int] | tuple[int, int] = 14
    spatial_merge_size: int = 2
    temporal_patch_size: int | list[int] | tuple[int, int] = 2
    window_size: int = 112
    out_hidden_size: int = 3584
    fullatt_block_indexes: list[int] | tuple[int, ...] = (7, 15, 23, 31)
    initializer_range: float = 0.02


@auto_docstring(checkpoint="Qwen/Qwen2.5-Omni-7B")
@strict
class Qwen2_5OmniAudioEncoderConfig(PreTrainedConfig):
    r"""
    max_source_positions (`int`, *optional*, defaults to 1500):
        The maximum sequence length of log-mel filter-bank features that this model might ever be used with.
    n_window (`int`, *optional*, defaults to 100):
        The chunk for conv and flash attn in AudioEncoder.
    output_dim (`int`, *optional*, defaults to 3584):
        The output dimension of AudioEncoder.

    Example:

    ```python
    >>> from transformers import Qwen2_5OmniAudioEncoderConfig, Qwen2_5OmniAudioEncoder

    >>> # Initializing a Qwen2_5OmniAudioEncoderConfig
    >>> configuration = Qwen2_5OmniAudioEncoderConfig()

    >>> # Initializing a Qwen2_5OmniAudioEncoder (with random weights)
    >>> model = Qwen2_5OmniAudioEncoder(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```"""

    model_type = "qwen2_5_omni_audio_encoder"
    attribute_map = {"num_hidden_layers": "encoder_layers"}

    num_mel_bins: int = 128
    encoder_layers: int = 32
    encoder_attention_heads: int = 20
    encoder_ffn_dim: int = 5120
    d_model: int = 1280
    dropout: float | int = 0.0
    attention_dropout: float | int = 0.0
    activation_function: str = "gelu"
    activation_dropout: float | int = 0.0
    scale_embedding: bool = False
    initializer_range: float = 0.02
    max_source_positions: int = 1500

    n_window: int = 100
    output_dim: int = 3584


@auto_docstring(checkpoint="Qwen/Qwen2.5-Omni-7B")
@strict
class Qwen2_5OmniTextConfig(PreTrainedConfig):
    r"""
    Example:

    ```python
    >>> from transformers import Qwen2_5OmniThinkerForConditionalGeneration, Qwen2_5OmniThinkerConfig, Qwen2_5OmniAudioEncoderConfig, Qwen2_5OmniVisionEncoderConfig

    >>> # Initializing a Qwen2_5OmniAudioEncoder config
    >>> audio_config = Qwen2_5OmniAudioEncoderConfig()

    >>> # Initializing a Qwen2_5OmniVisionEncoder config
    >>> vision_config = Qwen2_5OmniVisionEncoderConfig()

    >>> # Initializing a Qwen2.5OmniThinker configuration
    >>> configuration = Qwen2_5OmniThinkerConfig(audio_config, vision_config)

    >>> # Initializing a model from the Qwen-Omni style configuration
    >>> model = Qwen2_5OmniThinkerForConditionalGeneration(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```"""

    model_type = "qwen2_5_omni_text"
    keys_to_ignore_at_inference = ["past_key_values"]
    default_theta = 1000000.0

    # Default tensor parallel plan for base model `Qwen25OmniText`
    base_model_tp_plan = {
        "layers.*.self_attn.q_proj": "colwise",
        "layers.*.self_attn.k_proj": "colwise",
        "layers.*.self_attn.v_proj": "colwise",
        "layers.*.self_attn.o_proj": "rowwise",
        "layers.*.mlp.gate_proj": "colwise",
        "layers.*.mlp.up_proj": "colwise",
        "layers.*.mlp.down_proj": "rowwise",
    }
    base_model_pp_plan = {
        "embed_tokens": (["input_ids"], ["inputs_embeds"]),
        "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
        "norm": (["hidden_states"], ["hidden_states"]),
    }
    ignore_keys_at_rope_validation = {"mrope_section"}

    vocab_size: int = 152064
    hidden_size: int = 3584
    intermediate_size: int = 18944
    num_hidden_layers: int = 28
    num_attention_heads: int = 28
    num_key_value_heads: int | None = 4
    hidden_act: str = "silu"
    max_position_embeddings: int = 32768
    initializer_range: float = 0.02
    rms_norm_eps: float = 1e-6
    use_cache: bool = True
    rope_parameters: RopeParameters | dict | None = None
    use_sliding_window: bool = False
    sliding_window: int | None = 32768
    max_window_layers: int = 28
    layer_types: list[str] | None = None
    attention_dropout: float | int = 0.0
    pad_token_id: int | None = None
    bos_token_id: int | None = None
    eos_token_id: int | list[int] | None = None
    tie_word_embeddings: bool = True

    def __post_init__(self, **kwargs):
        self.sliding_window = self.sliding_window if self.use_sliding_window else None
        if self.num_key_value_heads is None:
            self.num_key_value_heads = self.num_attention_heads

        if self.layer_types is None:
            self.layer_types = [
                "sliding_attention"
                if self.sliding_window is not None and i >= self.max_window_layers
                else "full_attention"
                for i in range(self.num_hidden_layers)
            ]

        super().__post_init__(**kwargs)


@auto_docstring(checkpoint="Qwen/Qwen2.5-Omni-7B")
@strict
class Qwen2_5OmniThinkerConfig(PreTrainedConfig):
    r"""
    position_id_per_seconds (`int`, *optional*, defaults to 25):
        The increment of position id per second.
    seconds_per_chunk (`int`, *optional*, defaults to 2):
        The duration in seconds of the chunk of audio and video data.
    audio_start_token_id (`int`, *optional*, defaults to 151647):
        The audio start token index to encode the audio prompt.
    audio_end_token_id (`int`, *optional*, defaults to 151648):
        The audio end token index to encode the audio prompt.
    user_token_id (`int, *optional*, defaults to 872):
        The user token index to encode the user token.

    Example:

    ```python
    >>> from transformers import Qwen2_5OmniThinkerForConditionalGeneration, Qwen2_5OmniThinkerConfig, Qwen2_5OmniAudioEncoderConfig, Qwen2_5OmniVisionEncoderConfig

    >>> # Initializing a Qwen2_5OmniAudioEncoder config
    >>> audio_config = Qwen2_5OmniAudioEncoderConfig()

    >>> # Initializing a Qwen2_5OmniVisionEncoder config
    >>> vision_config = Qwen2_5OmniVisionEncoderConfig()

    >>> # Initializing a Qwen2_5OmniTextConfig config
    >>> text_config = Qwen2_5OmniTextConfig()

    >>> # Initializing a Qwen2.5OmniThinker configuration
    >>> configuration = Qwen2_5OmniThinkerConfig(audio_config, vision_config, text_config)

    >>> # Initializing a model from the Qwen-Omni style configuration
    >>> model = Qwen2_5OmniThinkerForConditionalGeneration(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```"""

    model_type = "qwen2_5_omni_thinker"
    attribute_map = {
        "image_token_id": "image_token_index",
        "video_token_id": "video_token_index",
        "audio_token_id": "audio_token_index",
    }
    sub_configs = {
        "audio_config": Qwen2_5OmniAudioEncoderConfig,
        "vision_config": Qwen2_5OmniVisionEncoderConfig,
        "text_config": Qwen2_5OmniTextConfig,
    }

    audio_config: dict | PreTrainedConfig | None = None
    vision_config: dict | PreTrainedConfig | None = None
    text_config: dict | PreTrainedConfig | None = None
    audio_token_index: int = 151646
    image_token_index: int = 151655
    video_token_index: int = 151656
    position_id_per_seconds: int = 25
    seconds_per_chunk: int = 2
    audio_start_token_id: int = 151647
    audio_end_token_id: int = 151648
    user_token_id: int = 872
    initializer_range: float = 0.02
    tie_word_embeddings: bool = False

    def __post_init__(self, **kwargs):
        if isinstance(self.vision_config, dict):
            self.vision_config = Qwen2_5OmniVisionEncoderConfig(**self.vision_config)
        elif self.vision_config is None:
            self.vision_config = Qwen2_5OmniVisionEncoderConfig()

        if isinstance(self.audio_config, dict):
            self.audio_config = Qwen2_5OmniAudioEncoderConfig(**self.audio_config)
        elif self.audio_config is None:
            self.audio_config = Qwen2_5OmniAudioEncoderConfig()

        if isinstance(self.text_config, dict):
            self.text_config = Qwen2_5OmniTextConfig(**self.text_config)
        elif self.text_config is None:
            self.text_config = Qwen2_5OmniTextConfig()

        super().__post_init__(**kwargs)


@auto_docstring(checkpoint="Qwen/Qwen2.5-Omni-7B")
@strict
class Qwen2_5OmniTalkerConfig(PreTrainedConfig):
    r"""
    tts_text_start_token_id (`int`, *optional*, defaults to 151860):
        The tts text start token index to encode the start of tts text.
    tts_text_end_token_id (`int`, *optional*, defaults to 151861):
        The tts text end token index to encode the end of tts text.
    tts_text_pad_token_id (`int`, *optional*, defaults to 151859):
        The tts text pad token index to encode the pad of tts text.
    tts_codec_start_token_id (`int`, *optional*, defaults to 8293):
        The tts codec start token index to encode the start of tts codec.
    tts_codec_end_token_id (`int`, *optional*, defaults to 8294):
        The tts codec end token index to encode the end of tts codec.
    tts_codec_pad_token_id (`int`, *optional*, defaults to 8292):
        The tts codec pad token index to encode the pad of tts codec.
    tts_codec_mask_token_id (`int`, *optional*, defaults to 8296):
        The tts codec mask token index to encode the mask of tts codec.
    position_id_per_seconds (`int`, *optional*, defaults to 25):
        The increment of position id per second.
    seconds_per_chunk (`int`, *optional*, defaults to 2):
        The duration in seconds of the chunk of audio and video data.
    audio_start_token_id (`int`, *optional*, defaults to 151647):
        The audio start token index to encode the audio prompt.
    audio_end_token_id (`int`, *optional*, defaults to 151648):
        The audio end token index to encode the audio prompt.

    Example:

    ```python
    >>> from transformers import Qwen2_5OmniTalkerForConditionalGeneration, Qwen2_5OmniThinkerConfig, Qwen2_5OmniAudioEncoderConfig, Qwen2_5OmniVisionEncoderConfig

    >>> # Initializing a Qwen2_5OmniAudioEncoder config
    >>> audio_config = Qwen2_5OmniAudioEncoderConfig()

    >>> # Initializing a Qwen2 config
    >>> text_config = Qwen2Config()

    >>> # Initializing a Qwen2_5Omni configuration
    >>> configuration = Qwen2_5OmniThinkerConfig(audio_config, text_config)

    >>> # Initializing a model from the qwen2-audio style configuration
    >>> model = Qwen2_5OmniTalkerForConditionalGeneration(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```"""

    model_type = "qwen2_5_omni_talker"
    default_theta = 1000000.0
    attribute_map = {
        "image_token_id": "image_token_index",
        "video_token_id": "video_token_index",
        "audio_token_id": "audio_token_index",
    }
    ignore_keys_at_rope_validation = {"mrope_section"}

    audio_token_index: int = 151646
    image_token_index: int = 151655
    video_token_index: int = 151656
    vocab_size: int = 8448
    tts_text_start_token_id: int = 151860
    tts_text_end_token_id: int = 151861
    tts_text_pad_token_id: int = 151859
    tts_codec_start_token_id: int = 8293
    tts_codec_end_token_id: int = 8294
    tts_codec_pad_token_id: int = 8292
    tts_codec_mask_token_id: int = 8296
    vision_start_token_id: int = 151652
    vision_end_token_id: int = 151653
    embedding_size: int = 3584
    hidden_size: int = 3584
    intermediate_size: int = 18944
    num_hidden_layers: int = 28
    num_attention_heads: int = 28
    num_key_value_heads: int = 4
    hidden_act: str = "silu"
    max_position_embeddings: int = 32768
    rms_norm_eps: float = 1e-06
    head_dim: int = 128
    use_cache: bool = True
    tie_word_embeddings: bool = False
    use_sliding_window: bool = False
    sliding_window: int | None = 32768
    max_window_layers: int = 28
    attention_dropout: float | int = 0.0
    rope_parameters: RopeParameters | dict | None = None
    position_id_per_seconds: int = 25
    seconds_per_chunk: int = 2
    audio_start_token_id: int = 151647
    audio_end_token_id: int = 151648
    initializer_range: float = 0.02
    spatial_merge_size: int = 2
    layer_types: list[str] | None = None
    pad_token_id: int | None = None

    def __post_init__(self, **kwargs):
        self.sliding_window = self.sliding_window if self.use_sliding_window else None

        if self.num_key_value_heads is None:
            self.num_key_value_heads = self.num_attention_heads

        if self.layer_types is None:
            self.layer_types = [
                "sliding_attention"
                if self.sliding_window is not None and i >= self.max_window_layers
                else "full_attention"
                for i in range(self.num_hidden_layers)
            ]

        super().__post_init__(**kwargs)


@auto_docstring(checkpoint="Qwen/Qwen2.5-Omni-7B")
@strict
class Qwen2_5OmniDiTConfig(PreTrainedConfig):
    r"""
    ff_mult (`int`, *optional*, defaults to 2):
        The multiplier for the feedforward layer in each transformer block.
    emb_dim (`int`, *optional*, defaults to 512):
        The dimension of the embedding layer.
    block_size (`int`, *optional*, defaults to 64):
        Number of tokens (frames) in each processing block.
    look_ahead_layers (`list[int]`, *optional*, defaults to `[10]`):
        Number of transformer layers that are permitted to attend to future blocks
    look_backward_layers (`list[int]`, *optional*, defaults to `[0, 20]`):
        Number of transformer layers that attend to past blocks beyond the current block boundary
    repeats (`int`, *optional*, defaults to 2):
        The number of times the codec embeddings are repeated.
    num_embeds (`int`, *optional*, defaults to 8193):
        The number of unique embeddings in the codec.
    mel_dim (`int`, *optional*, defaults to 80):
        The dimension of the mel-spectrogram.
    enc_emb_dim (`int`, *optional*, defaults to 192):
        The dimension of the pre-trained speaker embedding.
    enc_dim (`int`, *optional*, defaults to 128):
        The dimension of the encoder output.
    enc_channels (`list[int]`, *optional*, defaults to `[256, 256, 256, 256, 768]`):
        A list of output channels for each TDNN/SERes2Net layer in the encoder.
    enc_kernel_sizes (`list[int]`, *optional*, defaults to `[5, 3, 3, 3, 1]`):
        A list of kernel sizes for each layer in the encoder.
    enc_dilations (`list[int]`, *optional*, defaults to `[1, 2, 3, 4, 1]`):
        A list of dilations for each layer in the encoder.
    enc_attention_channels (`int`, *optional*, defaults to 64):
        The number of attention channels in the SqueezeExcitationBlock.
    enc_res2net_scale (`int`, *optional*, defaults to 2):
        The scale of the Res2Net block in the encoder.
    enc_se_channels (`int`, *optional*, defaults to 64):
        The number of output channels after squeeze in the SqueezeExcitationBlock.
    """

    model_type = "qwen2_5_omni_dit"

    hidden_size: int = 1024
    num_hidden_layers: int = 22
    num_attention_heads: int = 16
    ff_mult: int = 2
    emb_dim: int = 512
    head_dim: int = 64
    rope_parameters: RopeParameters | dict | None = None
    max_position_embeddings: int = 32768
    block_size: int = 24
    look_ahead_layers: list[int] | tuple[int, ...] = (10,)
    look_backward_layers: list[int] | tuple[int, ...] = (0, 20)
    repeats: int = 2
    num_embeds: int = 8193
    mel_dim: int = 80
    dropout: float | int = 0.1
    enc_emb_dim: int = 192
    enc_dim: int = 128
    enc_channels: list[int] | tuple[int, ...] = (256, 256, 256, 256, 768)
    enc_kernel_sizes: list[int] | tuple[int, ...] = (5, 3, 3, 3, 1)
    enc_dilations: list[int] | tuple[int, ...] = (1, 2, 3, 4, 1)
    enc_attention_channels: int = 64
    enc_res2net_scale: int = 2
    enc_se_channels: int = 64


@auto_docstring(checkpoint="Qwen/Qwen2.5-Omni-7B")
@strict
class Qwen2_5OmniBigVGANConfig(PreTrainedConfig):
    r"""
    mel_dim (`int`, *optional*, defaults to 80):
        The dimension of the mel-spectrogram.
    upsample_initial_channel (`int`, *optional*, defaults to 1536):
        The number of channels in the initial upsampling layer.
    resblock_kernel_sizes (`list[int]`, *optional*, defaults to `[3, 7, 11]`):
        A list of kernel sizes for each residual block.
    resblock_dilation_sizes (`list[list[int]]`, *optional*, defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`):
        A list of dilation sizes for each residual block.
    upsample_rates (`list[int]`, *optional*, defaults to `[5, 3, 2, 2, 2, 2]`):
        A list of upsampling rates for each upsampling layer.
    upsample_kernel_sizes (`list[int]`, *optional*, defaults to `[11, 7, 4, 4, 4, 4]`):
        A list of kernel sizes for each upsampling layer.
    """

    model_type = "qwen2_5_omni_bigvgan"

    mel_dim: int = 80
    upsample_initial_channel: int = 1536
    resblock_kernel_sizes: list[int] | tuple[int, ...] = (3, 7, 11)
    resblock_dilation_sizes: list | tuple = ((1, 3, 5), (1, 3, 5), (1, 3, 5))
    upsample_rates: list[int] | tuple[int, ...] = (5, 3, 2, 2, 2, 2)
    upsample_kernel_sizes: list[int] | tuple[int, ...] = (11, 7, 4, 4, 4, 4)


@auto_docstring(checkpoint="Qwen/Qwen2.5-Omni-7B")
@strict
class Qwen2_5OmniToken2WavConfig(PreTrainedConfig):
    r"""
    dit_config ([`DiT_Args`], *optional*):
        Configuration class for the Diffusion Transformer (DiT) module responsible for generating mel-spectrograms.
    bigvgan_config ([`BigVGAN_Args`], *optional*):
        Configuration class for the BigVGAN module responsible for converting mel-spectrograms to waveforms.

    Example:

    ```python
    >>> from transformers import Qwen2_5OmniToken2WavModel, DiT_Args, BigVGAN_Args

    >>> # Initialize DiT configuration
    >>> dit_config = DiT_Args(
    ...     dim=1024,
    ...     depth=22,
    ...     heads=16,
    ...     ff_mult=2
    ... )

    >>> # Initialize BigVGAN configuration
    >>> bigvgan_config = BigVGAN_Args(
    ...     mel_dim=80,
    ...     upsample_rates=[5,3,2,2,2,2]
    ... )

    >>> # Initialize main configuration
    >>> config = Qwen2_5OmniToken2WavConfig(dit_config, bigvgan_config)

    >>> # Initialize model with config
    >>> model = Qwen2_5OmniToken2Wav(config)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```
    """

    model_type = "qwen2_5_omni_token2wav"
    sub_configs = {
        "dit_config": Qwen2_5OmniDiTConfig,
        "bigvgan_config": Qwen2_5OmniBigVGANConfig,
    }

    dit_config: dict | PreTrainedConfig | None = None
    bigvgan_config: dict | PreTrainedConfig | None = None

    def __post_init__(self, **kwargs):
        if self.dit_config is None:
            self.dit_config = Qwen2_5OmniDiTConfig()
        elif isinstance(self.dit_config, dict):
            self.dit_config = Qwen2_5OmniDiTConfig(**self.dit_config)

        if self.bigvgan_config is None:
            self.bigvgan_config = Qwen2_5OmniBigVGANConfig()
        elif isinstance(self.bigvgan_config, dict):
            self.bigvgan_config = Qwen2_5OmniBigVGANConfig(**self.bigvgan_config)

        super().__post_init__(**kwargs)


@auto_docstring(checkpoint="Qwen/Qwen2.5-Omni-7B")
@strict
class Qwen2_5OmniConfig(PreTrainedConfig):
    r"""
    thinker_config (`dict`, *optional*):
        Configuration of the underlying thinker sub-model.
    talker_config (`dict`, *optional*):
        Configuration of the underlying talker sub-model.
    token2wav_config (`dict`, *optional*):
        Configuration of the underlying codec sub-model.
    enable_audio_output (`bool`, *optional*, defaults to `True`):
        Whether enable audio output and load talker and token2wav module.

    Example:

    ```python
    >>> from transformers import (
    ...     Qwen2_5OmniThinkerConfig,
    ...     Qwen2_5OmniTalkerConfig,
    ...     Qwen2_5OmniToken2WavConfig,
    ...     Qwen2_5OmniForConditionalGeneration,
    ...     Qwen2_5OmniConfig,
    ... )

    >>> # Initializing sub-modules configurations.
    >>> thinker_config = Qwen2_5OmniThinkerConfig()
    >>> talker_config = Qwen2_5OmniTalkerConfig()
    >>> token2wav_config = Qwen2_5OmniToken2WavConfig()


    >>> # Initializing a module style configuration
    >>> configuration = Qwen2_5OmniConfig(
    ...     thinker_config, talker_config, token2wav_config
    ... )

    >>> # Initializing a model (with random weights)
    >>> model = Qwen2_5OmniForConditionalGeneration(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```
    """

    model_type = "qwen2_5_omni"
    sub_configs = {
        "thinker_config": Qwen2_5OmniThinkerConfig,
        "talker_config": Qwen2_5OmniTalkerConfig,
        "token2wav_config": Qwen2_5OmniToken2WavConfig,
    }

    thinker_config: dict | PreTrainedConfig | None = None
    talker_config: dict | PreTrainedConfig | None = None
    token2wav_config: dict | PreTrainedConfig | None = None
    enable_audio_output: bool = True

    def __post_init__(self, **kwargs):
        if self.thinker_config is None:
            self.thinker_config = Qwen2_5OmniThinkerConfig()
            logger.info("thinker_config is None. Initializing thinker model with default values")
        elif isinstance(self.thinker_config, dict):
            self.thinker_config = Qwen2_5OmniThinkerConfig(**self.thinker_config)

        if self.talker_config is None:
            self.talker_config = Qwen2_5OmniTalkerConfig()
            logger.info("talker_config is None. Initializing talker model with default values")
        elif isinstance(self.talker_config, dict):
            self.talker_config = Qwen2_5OmniTalkerConfig(**self.talker_config)

        if self.token2wav_config is None:
            self.token2wav_config = Qwen2_5OmniToken2WavConfig()
            logger.info("token2wav_config is None. Initializing token2wav model with default values")
        elif isinstance(self.token2wav_config, dict):
            self.token2wav_config = Qwen2_5OmniToken2WavConfig(**self.token2wav_config)

        super().__post_init__(**kwargs)

    def get_text_config(self, *args, **kwargs):
        """
        Returns the config that is meant to be used with text IO. On most models, it is the original config instance
        itself. On specific composite models, it is under a set of valid names.

        Args:
            decoder (`Optional[bool]`, *optional*, defaults to `False`):
                If set to `True`, then only search for decoder config names.
        """
        # Overridden for deeply nested config like Qwen2-Omni. We don't have any omni model
        # except for Qwen yet. This has to be generalized if more deeply nested configs are
        # added. NOTE: currently method used only by vLLM
        return self.thinker_config.get_text_config(*args, **kwargs)


__all__ = [
    "Qwen2_5OmniConfig",
    "Qwen2_5OmniThinkerConfig",
    "Qwen2_5OmniTalkerConfig",
    "Qwen2_5OmniToken2WavConfig",
    "Qwen2_5OmniAudioEncoderConfig",
    "Qwen2_5OmniBigVGANConfig",
    "Qwen2_5OmniDiTConfig",
    "Qwen2_5OmniTextConfig",
    "Qwen2_5OmniVisionEncoderConfig",
]
