
    Uh                        d dl mZmZmZmZ d dlZd dlmZ ddlm	Z	 ddl
mZmZmZ ddlmZ ddlmZ ddlmZmZ dd	lmZ dd
lmZ ddlmZmZmZmZmZ ddlm Z  ddl!m"Z"m#Z# ddl$m%Z% ddl&m'Z'm(Z(m)Z) ddl*m+Z+m,Z,m-Z- ddl.m/Z/m0Z0m1Z1 ddl2m3Z3m4Z4  e)jj                  e6      Z7 G d de      Z8 G d dejr                        Z: G d dejr                        Z; G d de+      Z< G d de,      Z= G d de/      Z> G d  d!e      Z?e' G d" d#e#             Z@ G d$ d%e@      ZA G d& d'e0      ZB G d( d)e3      ZC e'd*+       G d, d-e@e             ZDg d.ZEy)/    )CallableOptionalTupleUnionN   )ACT2FN)CacheDynamicCacheEncoderDecoderCache)PretrainedConfig)GenerationMixin)_prepare_4d_attention_mask#_prepare_4d_attention_mask_for_sdpa)FlashAttentionKwargs)GradientCheckpointingLayer)BaseModelOutputBaseModelOutputWithPast)BaseModelOutputWithPastAndCrossAttentionsSeq2SeqLMOutputSeq2SeqModelOutput)rope_config_validation)ALL_ATTENTION_FUNCTIONSPreTrainedModel)Unpack)auto_docstringcan_return_tuplelogging   )GlmAttentionGlmRotaryEmbeddingapply_rotary_pos_emb)LlamaDecoderLayer
LlamaModeleager_attention_forward)WhisperModelshift_tokens_rightc                   j     e Zd ZdZdZdgZddddZ	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 d	 fd	Z xZS )
MoonshineConfiga"  
    This is the configuration class to store the configuration of a [`MoonshineModel`]. It is used to instantiate a Moonshine
    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
    defaults will yield a similar configuration to that of the Moonshine
    [UsefulSensors/moonshine-tiny](https://huggingface.co/UsefulSensors/moonshine-tiny).

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.

    Args:
        vocab_size (`int`, *optional*, defaults to 32768):
            Vocabulary size of the Moonshine model. Defines the number of different tokens that can be represented by the
            `inputs_ids` passed when calling [`MoonshineModel`].
        hidden_size (`int`, *optional*, defaults to 288):
            Dimension of the hidden representations.
        intermediate_size (`int`, *optional*, defaults to 1152):
            Dimension of the MLP representations.
        encoder_num_hidden_layers (`int`, *optional*, defaults to 6):
            Number of hidden layers in the Transformer encoder.
        decoder_num_hidden_layers (`int`, *optional*, defaults to 6):
            Number of hidden layers in the Transformer decoder.
        encoder_num_attention_heads (`int`, *optional*, defaults to 8):
            Number of attention heads for each attention layer in the Transformer encoder.
        decoder_num_attention_heads (`int`, *optional*, defaults to 8):
            Number of attention heads for each attention layer in the Transformer decoder.
        encoder_num_key_value_heads (`int`, *optional*):
            This is the number of key_value heads that should be used to implement Grouped Query Attention. If
            `encoder_num_key_value_heads=encoder_num_attention_heads`, the model will use Multi Head Attention (MHA), if
            `encoder_num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
            by meanpooling all the original heads within that group. For more details checkout [this
            paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
            `num_attention_heads`.
        decoder_num_key_value_heads (`int`, *optional*):
            This is the number of key_value heads that should be used to implement Grouped Query Attention. If
            `decoder_num_key_value_heads=decoder_num_attention_heads`, the model will use Multi Head Attention (MHA), if
            `decoder_num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
            by meanpooling all the original heads within that group. For more details checkout [this
            paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
            `decoder_num_attention_heads`.
        pad_head_dim_to_multiple_of (`int`, *optional*):
            Pad head dimension in encoder and decoder to the next multiple of this value. Necessary for using certain
            optimized attention implementations.
        encoder_hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
            The non-linear activation function (function or string) in the encoder.
        decoder_hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
            The non-linear activation function (function or string) in the decoder.
        max_position_embeddings (`int`, *optional*, defaults to 512):
            The maximum sequence length that this model might ever be used with.
        initializer_range (`float`, *optional*, defaults to 0.02):
            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
        decoder_start_token_id (`int`, *optional*, defaults to 1):
            Corresponds to the "<|startoftranscript|>" token, which is automatically used when no `decoder_input_ids`
            are provided to the `generate` function. It is used to guide the model`s generation process depending on
            the task.
        use_cache (`bool`, *optional*, defaults to `True`):
            Whether or not the model should return the last key/values attentions (not used by all models).
        rope_theta (`float`, *optional*, defaults to 10000.0):
            The base period of the RoPE embeddings.
        rope_scaling (`Dict`, *optional*):
            Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
            and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
            accordingly.
            Expected contents:
                `rope_type` (`str`):
                    The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
                    'llama3'], with 'default' being the original RoPE implementation.
                `factor` (`float`, *optional*):
                    Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
                    most scaling types, a `factor` of x will enable the model to handle sequences of length x *
                    original maximum pre-trained length.
                `original_max_position_embeddings` (`int`, *optional*):
                    Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
                    pretraining.
                `attention_factor` (`float`, *optional*):
                    Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
                    computation. If unspecified, it defaults to value recommended by the implementation, using the
                    `factor` field to infer the suggested value.
                `beta_fast` (`float`, *optional*):
                    Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
                    ramp function. If unspecified, it defaults to 32.
                `beta_slow` (`float`, *optional*):
                    Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
                    ramp function. If unspecified, it defaults to 1.
                `short_factor` (`List[float]`, *optional*):
                    Only used with 'longrope'. The scaling factor to be applied to short contexts (<
                    `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
                    size divided by the number of attention heads divided by 2
                `long_factor` (`List[float]`, *optional*):
                    Only used with 'longrope'. The scaling factor to be applied to long contexts (<
                    `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
                    size divided by the number of attention heads divided by 2
                `low_freq_factor` (`float`, *optional*):
                    Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
                `high_freq_factor` (`float`, *optional*):
                    Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
        partial_rotary_factor (`float`, *optional*, defaults to 0.9):
            Percentage of the query and keys which will have rotary embedding.
        is_encoder_decoder (`bool`, *optional*, defaults to `True`):
            Whether the model is used as an encoder/decoder or not.
        attention_bias (`bool`, *optional*, defaults to `False`):
            Whether to use a bias in the query, key, value and output projection layers during self-attention.
        attention_dropout (`float`, *optional*, defaults to 0.0):
            The dropout ratio for the attention probabilities.
        bos_token_id (`int`, *optional*, defaults to 1):
            Denotes beginning of sequences token id.
        eos_token_id (`int`, *optional*, defaults to 2):
            Denotes end of sequences token id.

    Example:

    ```python
    >>> from transformers import MoonshineModel, MoonshineConfig

    >>> # Initializing a Moonshine style configuration
    >>> configuration = MoonshineConfig().from_pretrained("UsefulSensors/moonshine-tiny")

    >>> # Initializing a model from the configuration
    >>> model = MoonshineModel(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```	moonshinepast_key_valuesencoder_num_key_value_headsencoder_num_attention_headsencoder_num_hidden_layers)num_key_value_headsnum_attention_headsnum_hidden_layersc                    || _         || _        || _        || _        || _        || _        || _        ||}|| _        |	|}	|	| _        |
| _	        || _
        || _        || _        || _        || _        || _        || _        || _        || _        || _        || _        || _        t-        |        t/        | `  d||||d| y )N)bos_token_ideos_token_idis_encoder_decoderdecoder_start_token_id )
vocab_sizehidden_sizeintermediate_sizer-   decoder_num_hidden_layersr,   decoder_num_attention_headsr+   decoder_num_key_value_headspad_head_dim_to_multiple_ofencoder_hidden_actdecoder_hidden_actmax_position_embeddingsinitializer_ranger5   	use_cache
rope_thetarope_scalingpartial_rotary_factorr4   attention_biasattention_dropoutr   super__init__)selfr7   r8   r9   r-   r:   r,   r;   r+   r<   r=   r>   r?   r@   rA   r5   rB   rC   rD   rE   r4   rF   rG   r2   r3   kwargs	__class__s                             /var/www/catia.catastroantioquia-mas.com/valormas/lib/python3.12/site-packages/transformers/models/moonshine/modular_moonshine.pyrI   zMoonshineConfig.__init__   s    8 %&!2)B&)B&+F(+F(&.*E'+F(&.*E'+F(+F("4"4'>$!2&<#"$(%:""4,!2 	t$ 	
%%1#9		

 	
    )i   i   i     rO      rP   NNNgelusilui   g{Gz?   Tg     @Ng?TF        rS   r   )	__name__
__module____qualname____doc__
model_typekeys_to_ignore_at_inferenceattribute_maprI   __classcell__rL   s   @rM   r(   r(   .   s    {z J#4"5<<8M "#"#$%$%$($($(!! # !3D
 D
rN   r(   c                   V     e Zd Z fdZdej
                  dej
                  fdZ xZS )MoonshineEncoderMLPc                    t         |           || _        t        |   | _        t        j                  |j                  |j                        | _	        t        j                  |j                  |j                        | _
        y NrH   rI   configr   activation_fnnnLinearr8   r9   fc1fc2rJ   rc   
hidden_actrL   s      rM   rI   zMoonshineEncoderMLP.__init__   s^    #J/99V//1I1IJ99V55v7I7IJrN   hidden_statesreturnc                 l    | j                  |      }| j                  |      }| j                  |      }|S ra   )rg   rd   rh   )rJ   rk   s     rM   forwardzMoonshineEncoderMLP.forward  s4    /**=9/rN   rU   rV   rW   rI   torchTensorrn   r\   r]   s   @rM   r_   r_      s$    KU\\ ell rN   r_   c                   V     e Zd Z fdZdej
                  dej
                  fdZ xZS )MoonshineDecoderMLPc                    t         |           || _        t        |   | _        t        j                  |j                  |j                  dz        | _	        t        j                  |j                  |j                        | _
        y )Nr   rb   ri   s      rM   rI   zMoonshineDecoderMLP.__init__  sc    #J/99V//1I1IA1MN99V55v7I7IJrN   rk   rl   c                     | j                  |      }|j                  dd      \  }}| j                  |      |z  }| j                  |      }|S )Nr   )dim)rg   chunkrd   rh   )rJ   rk   gates      rM   rn   zMoonshineDecoderMLP.forward  sS    /+11!1<t**40=@/rN   ro   r]   s   @rM   rs   rs   
  s$    KU\\ ell rN   rs   c                   h    e Zd Zdededededef
 fdZ	 	 	 	 	 ddej                  de	e
ej                  ej                  f      d	e	ej                     d
e	e   de	ej                     de	ej                     dee   de
ej                  e	ej                     e	e
ej                        f   fdZ xZS )MoonshineAttentionrc   	layer_idx	is_causalr/   r.   c                 n   |j                  ||d       t        | 	  ||       || _        t	        |d|j
                  |j                  z        | _        | j                  j                  C| j                  j                  }|| j                  |z   dz
  |z  z  }|| j                  z
  | _
        y d| _
        y )N)r/   r.   head_dimrS   r   )updaterH   rI   r}   getattrr8   r/   r   rc   r=   head_dim_padding)	rJ   rc   r|   r}   r/   r.   target_multipletarget_head_dimrL   s	           rM   rI   zMoonshineAttention.__init__  s     	.AZmno+"
F4F4F&JdJd4de ;;22>"kkEEO-$--/2QTU2UZi1ijO$3dmm$CD!$%D!rN   rk   position_embeddingsattention_maskpast_key_valuecache_positionkey_value_statesrK   rl   c                    |j                   d d \  }}	| j                  |      j                  ||	| j                  j                  | j
                        j                  dd      }
|d u}|Y|j                  j                  | j                        }|r&d|j                  | j                  <   |j                  }n|j                  }||n|}|r7|r5r3|j                  | j                     }|j                  | j                     }n| j                  |      j                  |d| j                  j                  | j
                        j                  dd      }| j                  |      j                  |d| j                  j                  | j
                        j                  dd      }|r%|#|j!                  ||| j                  d|i      \  }}|s?|\  }}t#        |
|||      \  }
}|'|||d}|j!                  ||| j                  |      \  }}t$        }| j                  j&                  dk7  r^| j                  j&                  dk(  r(|j                  d	d
      rt(        j+                  d       nt,        | j                  j&                     }| j.                  r	||	dkD  rdnd
}| j0                  dkD  rt2        j4                  j6                  j9                  |
d| j0                  f      }
t2        j4                  j6                  j9                  |d| j0                  f      }t2        j4                  j6                  j9                  |d| j0                  f      } || |
|||f| j:                  sdn| j<                  | j>                  |d|\  }}| j0                  dkD  r|dd | j0                   f   }|jA                  ||	d      jC                         }| jE                  |      }||fS )Nrv   rS   r   Tr   )sincosr   eagersdpaoutput_attentionsFz`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.r   rT   )dropoutscalingr}   .)#shapeq_projviewrc   r.   r   	transpose
is_updatedgetr|   cross_attention_cacheself_attention_cache	key_cachevalue_cachek_projv_projr   r!   r$   _attn_implementationloggerwarning_oncer   r}   r   rp   re   
functionalpadtrainingrG   r   reshape
contiguouso_proj)rJ   rk   r   r   r   r   r   rK   bszq_lenquery_statesis_cross_attentionr   current_states
key_statesvalue_statesr   r   cache_kwargsattention_interfacer}   attn_outputattn_weightss                          rM   rn   zMoonshineAttention.forward0  s    #(("-
U KK&++C8W8WY]YfYfgqqrsuvw 	 .T9%'2266t~~FJ!<@))$..9!/!E!E!/!D!D .>-I)}.Z'11$..AJ)55dnnEL N+c2t{{>>N1a  N+c2t{{>>N1a 
 "n&@+9+@+@dnn?OQ_>`,(
L "*HC';L*VY[^'_$L*)'*3.Y+9+@+@dnnl,(
L )@;;++w6{{//69fjjI\^c>d##L
 '>dkk>^>^&_# NN~/E%RS)DY^	  1$ 88..22<!TEZEZA[\L,,00aAVAV=WXJ 88..22<!TEZEZA[\L$7
%
  $}}C$2H2HLL
%
 
%
!\   1$%c+Cd.C.C-C+C&CDK!))#ub9DDFkk+.L((rN   )NNNNN)rU   rV   rW   r(   intboolrI   rp   rq   r   r   r	   
LongTensorr   r   rn   r\   r]   s   @rM   r{   r{     s   && & 	&
 !& !&0 LP15*.5937[)||[) &eELL%,,,F&GH[) !.	[)
 ![) !!1!12[) #5<<0[) -.[) 
u||Xell3XeELL>Q5RR	S[)rN   r{   c                       e Zd Zy)MoonshineRotaryEmbeddingN)rU   rV   rW   r6   rN   rM   r   r     s    rN   r   c                   (     e Zd Zdedef fdZ xZS )MoonshineEncoderLayerrc   r|   c                 F   t         |   ||       t        ||d|j                  |j                        | _        t        ||j                        | _        t        j                  |j                  d      | _        t        j                  |j                  d      | _        y )NFrc   r|   r}   r/   r.   bias)rH   rI   r{   r,   r+   	self_attnr_   r>   mlpre   	LayerNormr8   input_layernormpost_attention_layernormrJ   rc   r|   rL   s      rM   rI   zMoonshineEncoderLayer.__init__  s    ++ & B B & B B
 'vv/H/HI!||F,>,>UK(*V5G5Ge(T%rN   )rU   rV   rW   r(   r   rI   r\   r]   s   @rM   r   r     s    U U3 U UrN   r   c                        e Zd Zddedee   f fdZ	 	 	 	 	 	 	 	 	 	 	 ddej                  deej                     deej                     deej                     deej                     d	eej                     d
ee
   dee   dee   deej                     deeej                  ej                  f      deeej                  ej                  f      deej                  eeej                  ej                  f      f   fdZ xZS )MoonshineDecoderLayerrc   r|   c                    t         |           |j                  | _        t        ||d|j                  |j
                        | _        t        ||d|j                  |j
                        | _        t        ||j                        | _
        t        j                  |j                  d      | _        t        j                  |j                  d      | _        t        j                  |j                  d      | _        y )NTr   Fr   )rH   rI   r8   r{   r;   r<   r   encoder_attnrs   r?   r   re   r   r   r   final_layernormr   s      rM   rI   zMoonshineDecoderLayer.__init__  s    !--+ & B B & B B
 / & B B & B B
 'vv/H/HI!||F,>,>UK(*V5G5Ge(T%!||F,>,>UKrN   rk   r   encoder_hidden_statesencoder_attention_maskposition_idsencoder_position_idsr   r   rB   r   r   encoder_position_embeddingsrl   c                 H   |}| j                  |      } | j                  d||||||	|
|d|\  }}||z   }d }|2|}| j                  |      }| j                  ||||||	      \  }}||z   }|}| j	                  |      }| j                  |      }||z   }|f}|r|||fz  }|S )N)rk   r   r   r   r   rB   r   r   )rk   r   r   r   r   rB   r6   )r   r   r   r   r   r   )rJ   rk   r   r   r   r   r   r   r   rB   r   r   r   rK   residualself_attn_weightscross_attn_weightsoutputss                     rM   rn   zMoonshineDecoderLayer.forward  s     !,,]; ,:4>> 
,
')%)/) 3
,
 
,
(( !=0 " ,$H 99-HM040A0A+!65-"3# 1B 1-M- %}4M !,,];/ =0 ")+=>>GrN   ra   )NNNNNNFFNNN)rU   rV   rW   r(   r   r   rI   rp   rq   r   r	   r   r   FloatTensorrn   r\   r]   s   @rM   r   r     sj   L L8C= L6 268<9=37;?*.,1$)59KOSW<||< !.<  (5	<
 !) 6< u//0< 'u'7'78< !< $D>< D>< !!1!12< &eELL%,,,F&GH< &.eELL%,,4N.O%P< 
u  (51B1BEDUDU1U+V"WW	X<rN   r   c                   Z    e Zd ZeZdZdZdZddgZdZ	dZ
dZdZd Zdej                  fdZy	)
MoonshinePreTrainedModelmodelinput_valuesTr   r   c                 8   | j                   j                  }t        |t        j                  t        j
                  f      rY|j                  j                  j                  d|       |j                  %|j                  j                  j                          y y t        |t        j                  t        j                  f      rW|j                  j                  j                  d       |j                  %|j                  j                  j                          y y t        |t        j                        rf|j                  j                  j                  d|       |j                  2|j                  j                  |j                     j                          y y y )NrT   )meanstdg      ?)rc   rA   
isinstancere   rf   Conv1dweightdatanormal_r   zero_	GroupNormr   fill_	Embeddingpadding_idx)rJ   moduler   s      rM   _init_weightsz&MoonshinePreTrainedModel._init_weights  s    kk++fryy"))45MM&&CS&9{{&  &&( 'r|| <=MM$$S){{&  &&( '-MM&&CS&9!!-""6#5#56<<> . .rN   input_lengthsc                 ~    t        |dz
  dz  dz         }t        |dz
  dz  dz         }t        |dz
  dz  dz         }|S )zH
        Computes the output length of the convolutional layers
           @   rS      r   r   )r   )rJ   r   output_conv1_lengthoutput_conv2_lengthoutput_conv3_lengths        rM    _get_feat_extract_output_lengthsz9MoonshinePreTrainedModel._get_feat_extract_output_lengths  sZ     "=3#6""<q"@A!#6#:a"?!"CD!#6#:a"?!"CD""rN   N)rU   rV   rW   r(   config_classbase_model_prefixmain_input_namesupports_gradient_checkpointing_no_split_modules_supports_flash_attn_2_supports_sdpa_supports_cache_class_supports_static_cacher   rp   r   r   r6   rN   rM   r   r     sR    "L$O&*#02IJ!N !?#e>N>N #rN   r   c                        e Zd ZdZdZdef fdZdej                  fdZ	dej                  fdZ
e	 	 	 	 ddeej                     d	eej                     d
ee   dee   dee   defd       Z xZS )MoonshineEncoderz
    Transformer encoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MoonshineEncoderLayer`]

    Args:
        config: MoonshineConfig
    r   rc   c           	      b   t         |   |       || _        |j                  }t	        j
                  d|ddd      | _        t	        j
                  |d|z  dd	      | _        t	        j
                  d|z  |dd	      | _        t	        j                  d|d
      | _
        t        |      | _        t	        j                  t        |j                        D cg c]  }t!        ||       c}      | _        t	        j$                  |d      | _        d| _        | j+                          y c c}w )NrS   r   r   F)kernel_sizestrider   r   r   r   )r   r   gh㈵>)
num_groupsnum_channelseps)rc   r   )rH   rI   rc   r8   re   r   conv1conv2conv3r   	groupnormr   
rotary_emb
ModuleListranger-   r   layersr   
layer_normgradient_checkpointing	post_init)rJ   rc   	embed_dimidxrL   s       rM   rI   zMoonshineEncoder.__init__+  s     &&	YYq)ReT
YYy!i-QqQ
YYq9}iQqQ
PTU2&Amm;@AaAa;bcC"63/c
 ,,yu=&+# ds   D,rl   c                     | j                   S ra   r   rJ   s    rM   get_input_embeddingsz%MoonshineEncoder.get_input_embeddings?  s    zzrN   valuec                     || _         y ra   r  )rJ   r  s     rM   set_input_embeddingsz%MoonshineEncoder.set_input_embeddingsB  s	    
rN   r   r   output_hidden_statesflash_attn_kwargsc           	         ||n| j                   j                  }||n| j                   j                  }|t        d      |j	                  d      }t
        j                  j                  | j                  |            }| j                  |      }t
        j                  j                  | j                  |            }t
        j                  j                  | j                  |            }|j                  ddd      }|| j                  |j                  d         }d}|ddd|f   dd|f   }| j                   j                   d	k(  r|d
k(  j#                         r|nd}nH| j                   j                   dk(  r|st%        ||j&                        }nt)        ||j&                        }t+        j,                  d|j                  d   |j.                        j	                  d      }	| j1                  ||	      }
|rdnd}|rdnd}| j2                  D ])  }|r||fz  } ||f||	||
d|}|d   }|s!||d   fz  }+ | j5                  |      }|r||fz  }t7        |||      S )a  
        Args:
            input_values (`torch.FloatTensor` of shape `(batch_size, audio_length)`):
                Float values of the raw speech waveform. Raw speech waveform can be
                obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a
                `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
                `input_values`, the [`AutoFeatureExtractor`] should be used for padding
                and conversion into a tensor of type `torch.FloatTensor`.
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding indices in `input_values`. Mask values selected in `[0, 1]`:
                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.
                [What are attention masks?](../glossary#attention-mask)
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
                tensors for more detail.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
                more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        NzYou must specify input_values.rS   r   r   rv     .flash_attention_2rT   r   devicer6   )r   r   r   r   last_hidden_staterk   
attentions)rc   r   r  
ValueError	unsqueezere   r   tanhr   r  rQ   r  r  permuter   r   r   anyr   dtyper   rp   aranger  r  r  r  r   )rJ   r   r   r   r  r  rk   mask_lendownsample_strider   r   all_hidden_statesall_self_attnsencoder_layerlayer_outputss                  rM   rn   zMoonshineEncoder.forwardE  s   > 2C1N-TXT_T_TqTq$8$D $++JjJj 	 =>> $--a0**4::l+CD}5**4::m+DE**4::m+DE%--aA6 %<<^=Q=QRT=UVH *+C1D3D1D,DEc9H9nUN{{//3FF4Bc4I3N3N3PVZ 11V;DU!D^UbUhUh!i "<NML_L_!`||A}':':1'=mFZFZ[eefgh #oom\J #7BD0d![[ 	6M#!m%55!)-)"3$7 $M *!,M =#3"55!	6$ 6  -!11&++%
 	
rN   )NNNN)rU   rV   rW   rX   r   r(   rI   re   Moduler  r  r   r   rp   r   rq   r   r   r   r   rn   r\   r]   s   @rM   r   r   !  s     %O (bii "))   5915,0/3c
u001c
 !.c
 $D>	c

 'tnc
 $$89c
 
!c
 c
rN   r   c                   Z    e Zd ZdZdef fdZ	 	 	 	 	 	 	 	 	 	 	 ddeej                     deej                     deej                     dee
   deej                     dee   d	ee   d
ee   deej                     deej                     deej                     dee   deeef   fdZ xZS )MoonshineDecoder	input_idsrc   c           	         t         |   |       t        j                  |j                  d      | _        t        j                  t        |j                        D cg c]  }t        ||       c}      | _
        y c c}w NFr   )rH   rI   re   r   r8   normr  r  r:   r   r  )rJ   rc   r  rL   s      rM   rI   zMoonshineDecoder.__init__  s\     LL!3!3%@	mm;@AaAa;bcC"63/c
cs   A=r   r   r*   inputs_embedsrB   r   r  r   r   r   r  rl   c                 $   ||n| j                   j                  }||n| j                   j                  }||n| j                   j                  }|du |duz  rt	        d      | j
                  r%| j                  r|rt        j                  d       d}|| j                  |      }|r"| t               }t               }t        ||      }|	F||j                         nd}t        j                  |||j                  d   z   |j                         }	||	j#                  d      }| j%                  |||	||      }|}| j'                  ||      }|rdnd}|rdnd}|r|
dnd}||
j                  d	   }d
}|ddd|f   dd|f   }| j                   j(                  dk(  r|dk(  j+                         r|nd}nd| j                   j(                  dk(  r'|s%t-        ||j.                  |j                  d	         }n$t1        ||j.                  |j                  d	         }| j2                  D ]:  }|r||fz  } ||f|||
|||||	|d	|}|d   }|s&||d   fz  }|
2||d   fz  }< | j5                  |      }|r||fz  }t7        ||r|nd|||      S )a  
        encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
            of the decoder.
        encoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding indices in `encoder_hidden_states`. Mask values selected in `[0, 1]`:
            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.
            [What are attention masks?](../glossary#attention-mask)
        Nz:You must specify exactly one of input_ids or inputs_embedszX`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.Fr   rS   r  r6   r  .r  rT   r   )	r   r   r   r   r   r   rB   r   r   r   )r  r*   rk   r  cross_attentions)rc   r   r  rB   r  r	  r   r   r   embed_tokensr
   r   get_seq_lengthrp   r$  r   r  r  _update_causal_maskr  r   r"  r   r#  r   r  r1  r   )rJ   r.  r   r   r*   r2  rB   r   r  r   r   r   r  r   r   past_seen_tokenscausal_maskrk   r   r'  r(  all_cross_attentionsr%  r&  decoder_layerr*  s                             rM   rn   zMoonshineDecoder.forward  sF   2 2C1N-TXT_T_TqTq$8$D $++JjJj 	 "+!6IDKK<Q<Q	-t";<YZZ&&4==Yj I  --i8M0#/> $0N!12FH]^O!CRC^==?de"\\ "2]5H5H5K"KTaThThN )33A6L..M>?L]
 & #oom\J #7BD0d&7<Q<]rdh "-,2226H *%;CATCTAT<T%UVY[d\d[dVd%e"{{//3FFDZ^aDaCfCfCh)?nr& 11V;DU)L*M,?,?ATATUWAX*&
 *D*M,?,?ATATUWAX*& "[[ 	@M#!m%55!)*'=&;)."3#-$7 $M *!,M =#3"55(4(]1-=,??(1	@4 		-0  -!118+/8Od+%1
 	
rN   )NNNNNNNNNNN)rU   rV   rW   r   r(   rI   r   rp   r   rq   r	   r   r   r   r   r   r   r   rn   r\   r]   s   @rM   r-  r-    s4   !O
 
 151537+/59$(,0/359=A9=A
E,,-A
 !.A
 u//0	A

 "%A
   1 12A
 D>A
 $D>A
 'tnA
 !!1!12A
  ((9(9:A
 !) 6A
 $$89A
 
u--	.A
rN   r-  c                      e Zd Zee	 	 	 	 	 	 	 	 	 	 	 	 ddeej                     deej                     deej                     deej                     dee	e	ej                           dee
ee	ej                     f      dee	ej                        d	ee	ej                        d
ee   dee   dee   deej                     defd              Zy)MoonshineModelNr   r   decoder_input_idsdecoder_attention_maskencoder_outputsr*   decoder_inputs_embedsdecoder_position_idsrB   r   r  r   rl   c                 n   |
|
n| j                   j                  }
||n| j                   j                  }|	|	n| j                   j                  }	|| j	                  |||
|      }nGt        |t              s7t        |d   t        |      dkD  r|d   ndt        |      dkD  r|d   nd      }| j                  ||||j                  ||||	|
||      }t        |j                  |j                  |j                  |j                  |j                  |j                  |j                  |j                        S )	a\  
        input_values (`torch.FloatTensor` of shape `(batch_size, audio_length)`):
            Float values of the raw speech waveform. Raw speech waveform can be
            obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a
            `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
            `input_values`, the [`AutoFeatureExtractor`] should be used for padding
            and conversion into a tensor of type `torch.FloatTensor`.
        decoder_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
            it.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        decoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
            `past_key_values`).

            If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
            and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
            information on the default strategy.

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors than the
            model's internal embedding lookup matrix.
        decoder_position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.n_positions - 1]`.

            [What are position IDs?](../glossary#position-ids)

        Example:

        ```python
        >>> import torch
        >>> from transformers import AutoFeatureExtractor, MoonshineModel
        >>> from datasets import load_dataset

        >>> model = MoonshineModel.from_pretrained("UsefulSensors/moonshine-tiny")
        >>> feature_extractor = AutoFeatureExtractor.from_pretrained("UsefulSensors/moonshine-tiny")
        >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
        >>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
        >>> input_values = inputs.input_values
        >>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
        >>> last_hidden_state = model(input_values, decoder_input_ids=decoder_input_ids).last_hidden_state
        >>> list(last_hidden_state.shape)
        [1, 2, 288]
        ```
        N)r   r   r  r   rS   r   r  )r.  r   r   r   r*   r2  r   rB   r   r  r   )r  r*   decoder_hidden_statesdecoder_attentionsr5  encoder_last_hidden_stater   encoder_attentions)rc   r   r  rB   encoderr   r   lendecoderr  r   r*   rk   r  r5  )rJ   r   r   r?  r@  rA  r*   rB  rC  rB   r   r  r   decoder_outputss                 rM   rn   zMoonshineModel.forward;  s[   ` 2C1N-TXT_T_TqTq$8$D $++JjJj 	 "+!6IDKK<Q<Q	"/3||-"3%9	 0< 0O O_=-"1!"4474H14Loa0RV14_1E1I?1-tO FJ\\'1#1"1"C"C+/-/!5) FR F
 "-??+;;"1"?"?.99,==&5&G&G"1"?"?.99	
 		
rN   )NNNNNNNNNNNN)rU   rV   rW   r   r   r   rp   r   r   r   r   r   r   r   rn   r6   rN   rM   r>  r>  :  sd    59598<=AEIZ^DHBF$(,0/359{
u001{
 !!1!12{
 $E$4$45	{

 !))9)9 :{
 "%e.?.?(@"AB{
 "%(;U5CTCT=U(U"VW{
  (e.?.?(@A{
 'uU-=-='>?{
 D>{
 $D>{
 'tn{
 !!1!12{
 
{
  {
rN   r>  zj
    The Moonshine Model with a language modeling head. Can be used for automatic speech recognition.
    )custom_introc                       e Zd ZdgZdef fdZd Zd Zd Zd Z	de
j                  fd	Zee	 	 	 	 	 	 	 	 	 	 	 	 	 dd
eej"                     deej$                     deej$                     deej$                     deeeej"                           deeeeej"                     f      deeej"                        deeej$                        dee   dee   dee   deej$                     deej$                     defd              Z xZS )!MoonshineForConditionalGenerationzproj_out.weightrc   c                     t         |   |       t        |      | _        t	        j
                  |j                  |j                  d      | _        | j                          y r0  )
rH   rI   r>  r   re   rf   r8   r7   proj_outr
  )rJ   rc   rL   s     rM   rI   z*MoonshineForConditionalGeneration.__init__  sH     #F+
		&"4"4f6G6GeT 	rN   c                 6    | j                   j                         S ra   )r   get_encoderr  s    rM   rS  z-MoonshineForConditionalGeneration.get_encoder      zz%%''rN   c                 6    | j                   j                         S ra   )r   get_decoderr  s    rM   rV  z-MoonshineForConditionalGeneration.get_decoder  rT  rN   c                     | j                   S ra   rQ  r  s    rM   get_output_embeddingsz7MoonshineForConditionalGeneration.get_output_embeddings  s    }}rN   c                     || _         y ra   rX  )rJ   new_embeddingss     rM   set_output_embeddingsz7MoonshineForConditionalGeneration.set_output_embeddings  s	    &rN   rl   c                 6    | j                   j                         S ra   )r   r  r  s    rM   r  z6MoonshineForConditionalGeneration.get_input_embeddings  s    zz..00rN   r   r   r?  r@  rA  r*   rB  rC  rB   r   r  r   labelsc                    |9|7|5t        || j                  j                  | j                  j                        }| j	                  |||||||||	|
||      }| j                  |j                        }d}|(| j                  ||| j                  j                        }t        |||j                  |j                  |j                  |j                  |j                  |j                  |j                   	      S )a  
        input_values (`torch.FloatTensor` of shape `(batch_size, audio_length)`):
            Float values of the raw speech waveform. Raw speech waveform can be
            obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a
            `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
            `input_values`, the [`AutoFeatureExtractor`] should be used for padding
            and conversion into a tensor of type `torch.FloatTensor`.
        decoder_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
            it.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        decoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
            `past_key_values`).

            If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
            and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
            information on the default strategy.

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors than the
            model's internal embedding lookup matrix.
        decoder_position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.n_positions - 1]`.

            [What are position IDs?](../glossary#position-ids)
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]`
            or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is
            only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Example:

        ```python
        >>> import torch
        >>> from transformers import AutoProcessor, MoonshineForConditionalGeneration
        >>> from datasets import load_dataset

        >>> processor = AutoProcessor.from_pretrained("UsefulSensors/moonshine-tiny")
        >>> model = MoonshineForConditionalGeneration.from_pretrained("UsefulSensors/moonshine-tiny")

        >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

        >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
        >>> input_values = inputs.input_values

        >>> generated_ids = model.generate(input_values, max_new_tokens=100)

        >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        >>> transcription
        'Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
        ```N)r   r?  rA  r@  r*   rB  rC  rB   r   r  r   )logitsr^  r7   )	lossr`  r*   rE  rF  r5  rG  r   rH  )r&   rc   pad_token_idr5   r   rQ  r  loss_functionr7   r   r*   rE  rF  r5  rG  r   rH  )rJ   r   r   r?  r@  rA  r*   rB  rC  rB   r   r  r   r^  r   r`  ra  s                    rM   rn   z)MoonshineForConditionalGeneration.forward  s
   r  (-B-J$6DKK44dkk6X6X%! '+jj)/+#9+"7!5/!5) '1 '
 w889%%VFt{{OeOe%fD#33")"?"?&99$55&-&G&G")"?"?&99

 
	
rN   )NNNNNNNNNNNNN)rU   rV   rW   _tied_weights_keysr(   rI   rS  rV  rY  r\  re   r+  r  r   r   r   rp   r   r   r   r   r   r   r   rn   r\   r]   s   @rM   rO  rO    s    ,, (('1bii 1  59598<=AEIZ^DHBF$(,0/359-1{
u001{
 !!1!12{
 $E$4$45	{

 !))9)9 :{
 "%e.?.?(@"AB{
 "%(;U5CTCT=U(U"VW{
  (e.?.?(@A{
 'uU-=-='>?{
 D>{
 $D>{
 'tn{
 !!1!12{
 ))*{
 
{
  {
rN   rO  )r(   r>  r   rO  )Ftypingr   r   r   r   rp   torch.nnre   activationsr   cache_utilsr	   r
   r   configuration_utilsr   
generationr   modeling_attn_mask_utilsr   r   modeling_flash_attention_utilsr   modeling_layersr   modeling_outputsr   r   r   r   r   modeling_rope_utilsr   modeling_utilsr   r   processing_utilsr   utilsr   r   r   glm.modeling_glmr   r    r!   llama.modeling_llamar"   r#   r$   whisper.modeling_whisperr%   r&   
get_loggerrU   r   r(   r+  r_   rs   r{   r   r   r   r   r   r-  r>  rO  __all__r6   rN   rM   <module>rx     sQ   4 3   ! C C 3 ) g B 9  : F & > > U U Y Y G 
		H	%J
& J
Z")) "))  q) q)h	1 	U- U"U6 Up "# "# "#JH
/ H
VK
z K
\~
\ ~
B 
W
(@/ W

W
trN   