
    Uh                    j   d dl Z d dlmZ d dlmZmZmZmZmZ d dl	Z	d dl
mc mZ d dlZ	d dl	mZ d dlmZmZmZ ddlmZ ddlmZ dd	lmZ dd
lmZmZmZmZmZ ddlm Z  ddl!m"Z"m#Z#m$Z$ ddl%m&Z& ddl'm(Z(m)Z)  e#       rd dl*m+Z+ d dl,m-Z- d dl.m/Z/ ne0Z- e$jb                  e2      Z3 G d de      Z4	 	 dPde	jj                  de	jj                  dee	jj                     dee	jj                     dee	jj                  e	jj                  e	jj                  e6ee	jj                     ee	jj                     f   f
dZ7de	jj                  de	jj                  de6de6de	jj                  f
dZ8 G d d e	jr                  jt                        Z;	 	 dPd!ee	jj                     d"ee6   fd#Z< G d$ d%e-      Z= G d& d'ej|                        Z? G d( d)ej|                        Z@ G d* d+e(      ZA	 dQd,d-d.e	jj                  de	jj                  d/e	jj                  dee	j                     d0ee6e6f   d1e6d2e6d3eeC   deee	jj                  e	jj                  f   ee	jj                     f   fd4ZDe	j                  fd,d-d.e	jj                  d5e=d!e	jj                  d"e6d0ee6e6f   d1e6d2e6d6e	j                  dee	jj                     fd7ZGd,d-d.e	jj                  de	jj                  d/e	jj                  dee	j                     d0ee6e6f   d1e6d2e6dee	jj                     fd8ZHeGeDeHd9ZI G d: d-ej|                        ZJ G d; d<ej|                        ZKe" G d= d>e              ZLe" G d? d@eL             ZM G dA dBej|                        ZN e"dCD       G dE dFeL             ZO e"dGD       G dH dIeL             ZP e"dJD       G dK dLeL             ZQe" G dM dNeL             ZRg dOZSy)R    N)nullcontext)DictLiteralOptionalTupleUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)PretrainedConfig)_prepare_4d_attention_mask)BaseModelOutputMaskedLMOutputQuestionAnsweringModelOutputSequenceClassifierOutputTokenClassifierOutput)PreTrainedModel)auto_docstringis_flash_attn_2_availablelogging)is_triton_available   )GemmaRotaryEmbeddingapply_rotary_pos_emb) flash_attn_varlen_qkvpacked_func)RotaryEmbedding)apply_rotaryc                        e Zd ZdZdZdgZ	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 dded   f fdZ fdZ xZ	S )	ModernBertConfiga  
    This is the configuration class to store the configuration of a [`ModernBertModel`]. It is used to instantiate an ModernBert
    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
    defaults will yield a similar configuration to that of the ModernBERT-base.
    e.g. [answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base)

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.

    Args:
        vocab_size (`int`, *optional*, defaults to 50368):
            Vocabulary size of the ModernBert model. Defines the number of different tokens that can be represented by the
            `inputs_ids` passed when calling [`ModernBertModel`]
        hidden_size (`int`, *optional*, defaults to 768):
            Dimension of the hidden representations.
        intermediate_size (`int`, *optional*, defaults to 1152):
            Dimension of the MLP representations.
        num_hidden_layers (`int`, *optional*, defaults to 22):
            Number of hidden layers in the Transformer decoder.
        num_attention_heads (`int`, *optional*, defaults to 12):
            Number of attention heads for each attention layer in the Transformer decoder.
        hidden_activation (`str` or `function`, *optional*, defaults to `"gelu"`):
            The non-linear activation function (function or string) in the decoder. Will default to `"gelu"`
            if not specified.
        max_position_embeddings (`int`, *optional*, defaults to 8192):
            The maximum sequence length that this model might ever be used with.
        initializer_range (`float`, *optional*, defaults to 0.02):
            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
        initializer_cutoff_factor (`float`, *optional*, defaults to 2.0):
            The cutoff factor for the truncated_normal_initializer for initializing all weight matrices.
        norm_eps (`float`, *optional*, defaults to 1e-05):
            The epsilon used by the rms normalization layers.
        norm_bias (`bool`, *optional*, defaults to `False`):
            Whether to use bias in the normalization layers.
        pad_token_id (`int`, *optional*, defaults to 50283):
            Padding token id.
        eos_token_id (`int`, *optional*, defaults to 50282):
            End of stream token id.
        bos_token_id (`int`, *optional*, defaults to 50281):
            Beginning of stream token id.
        cls_token_id (`int`, *optional*, defaults to 50281):
            Classification token id.
        sep_token_id (`int`, *optional*, defaults to 50282):
            Separation token id.
        global_rope_theta (`float`, *optional*, defaults to 160000.0):
            The base period of the global RoPE embeddings.
        attention_bias (`bool`, *optional*, defaults to `False`):
            Whether to use a bias in the query, key, value and output projection layers during self-attention.
        attention_dropout (`float`, *optional*, defaults to 0.0):
            The dropout ratio for the attention probabilities.
        global_attn_every_n_layers (`int`, *optional*, defaults to 3):
            The number of layers between global attention layers.
        local_attention (`int`, *optional*, defaults to 128):
            The window size for local attention.
        local_rope_theta (`float`, *optional*, defaults to 10000.0):
            The base period of the local RoPE embeddings.
        embedding_dropout (`float`, *optional*, defaults to 0.0):
            The dropout ratio for the embeddings.
        mlp_bias (`bool`, *optional*, defaults to `False`):
            Whether to use bias in the MLP layers.
        mlp_dropout (`float`, *optional*, defaults to 0.0):
            The dropout ratio for the MLP layers.
        decoder_bias (`bool`, *optional*, defaults to `True`):
            Whether to use bias in the decoder layers.
        classifier_pooling (`str`, *optional*, defaults to `"cls"`):
            The pooling method for the classifier. Should be either `"cls"` or `"mean"`. In local attention layers, the
            CLS token doesn't attend to all tokens on long sequences.
        classifier_dropout (`float`, *optional*, defaults to 0.0):
            The dropout ratio for the classifier.
        classifier_bias (`bool`, *optional*, defaults to `False`):
            Whether to use bias in the classifier.
        classifier_activation (`str`, *optional*, defaults to `"gelu"`):
            The activation function for the classifier.
        deterministic_flash_attn (`bool`, *optional*, defaults to `False`):
            Whether to use deterministic flash attention. If `False`, inference will be faster but not deterministic.
        sparse_prediction (`bool`, *optional*, defaults to `False`):
            Whether to use sparse prediction for the masked language model instead of returning the full dense logits.
        sparse_pred_ignore_index (`int`, *optional*, defaults to -100):
            The index to ignore for the sparse prediction.
        reference_compile (`bool`, *optional*):
            Whether to compile the layers of the model which were compiled during pretraining. If `None`, then parts of
            the model will be compiled if 1) `triton` is installed, 2) the model is not on MPS, 3) the model is not
            shared between devices, and 4) the model is not resized after initialization. If `True`, then the model may
            be faster in some scenarios.
        repad_logits_with_grad (`bool`, *optional*, defaults to `False`):
            When True, ModernBertForMaskedLM keeps track of the logits' gradient when repadding for output. This only
            applies when using Flash Attention 2 with passed labels. Otherwise output logits always have a gradient.

    Examples:

    ```python
    >>> from transformers import ModernBertModel, ModernBertConfig

    >>> # Initializing a ModernBert style configuration
    >>> configuration = ModernBertConfig()

    >>> # Initializing a model from the modernbert-base style configuration
    >>> model = ModernBertModel(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```
modernbertpast_key_valuesclassifier_poolingclsmeanc$           	      "   t        %|   d|||||d|$ || _        || _        || _        || _        || _        || _        || _        |	| _	        |
| _
        || _        || _        || _        || _        || _        || _        || _        || _        || _        || _        || _        || _        || _        || _        || _        || _        || _        | | _        |!| _        |"| _        |#| _        | j.                  dvrtA        d| j.                   d      y )N)pad_token_idbos_token_ideos_token_idcls_token_idsep_token_idr&   zQInvalid value for `classifier_pooling`, should be either "cls" or "mean", but is . )!super__init__
vocab_sizemax_position_embeddingshidden_sizeintermediate_sizenum_hidden_layersnum_attention_headsinitializer_rangeinitializer_cutoff_factornorm_eps	norm_biasglobal_rope_thetaattention_biasattention_dropouthidden_activationglobal_attn_every_n_layerslocal_attentionlocal_rope_thetaembedding_dropoutmlp_biasmlp_dropoutdecoder_biasr%   classifier_dropoutclassifier_biasclassifier_activationdeterministic_flash_attnsparse_predictionsparse_pred_ignore_indexreference_compilerepad_logits_with_grad
ValueError)&selfr3   r5   r6   r7   r8   r@   r4   r9   r:   r;   r<   r*   r,   r+   r-   r.   r=   r>   r?   rA   rB   rC   rD   rE   rF   rG   r%   rH   rI   rJ   rK   rL   rM   rN   rO   kwargs	__class__s&                                        /var/www/catia.catastroantioquia-mas.com/valormas/lib/python3.12/site-packages/transformers/models/modernbert/modular_modernbert.pyr2   zModernBertConfig.__init__   sT   N 	 	
%%%%%	
 	
 %'>$&!2!2#6 !2)B& "!2,!2!2*D'. 0!2 &("4"4.%:"(@%!2(@%!2&<#""/9cdhd{d{c||}~  :    c                 H    t         |          }|j                  dd        |S )NrN   )r1   to_dictpop)rQ   outputrS   s     rT   rW   zModernBertConfig.to_dict   s#    "

&-rU   )#i  i   i        gelui    g{Gz?       @gh㈵>Fik  j  i  r_   r^   g     AF        r           @r`   Fr`   Tr'   r`   Fr\   FFiNF)
__name__
__module____qualname____doc__
model_typekeys_to_ignore_at_inferencer   r2   rW   __classcell__rS   s   @rT   r"   r"   5   s    eN J#4"5   $"%"#$ 5:$!&!%$IQ8 $M29Qf rU   r"   inputsattention_maskposition_idslabelsreturnc                    |j                  dt        j                        }t        j                  |j	                         d      j	                         }t        |j                         j                               }t        j                  j                  j                  t        j                  |dt        j                        d      }| j                         dk(  r| j	                         |   }n*| j                  ^}	}
}|	|
z  } | j                  |g| |   }||j	                         |   nd}||j	                         |   nd}||||||fS )	a  
    Remove padding from input sequences.

    Args:
        inputs: (batch, seqlen, ...) or (batch, seqlen)
        attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
        position_ids: (batch, seqlen), int, position ids
        labels: (batch, seqlen), int, labels

    Returns:
        unpadded_inputs: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask.
        indices: (total_nnz)
        cu_seqlens: (batch + 1), the cumulative sequence lengths
        max_seqlen_in_batch: int
        unpadded_position_ids: (total_nnz) or None
        unpadded_labels: (total_nnz) or None
    dimdtypeF)as_tupler   )   r   r   N)sumtorchint32nonzeroflattenintmaxitemr	   
functionalpadcumsumrs   shapeview)rk   rl   rm   rn   seqlens_in_batchindicesmax_seqlen_in_batch
cu_seqlensunpadded_inputsbatchseqlenrestr   unpadded_position_idsunpadded_labelss                  rT   _unpad_modernbert_inputr      s,   . &))b)DmmN224uEMMOG.22499;<$$((6FAUZU`U`)acijJzz|q ..*73%||v%&++e3d3G<?K?WL0027;]a393Efnn&w/4OGZ1DF[]lllrU   r   r   r   c                 l   | j                         dk(  rHt        j                  ||z  | j                  | j                        }| ||<   |j                  ||      }|S | j                  ^}}t        j                  ||z  g|| j                  | j                  d}| ||<    |j
                  ||g| }|S )aQ  
    Add padding to sequences.

    Args:
        inputs: (total_nnz, ...) or (total_nnz,), where total_nnz = number of tokens selected in attention_mask.
        indices: (total_nnz)
        batch: int, batch size
        seqlen: int, max sequence length

    Returns:
        padded_inputs: (batch, seqlen, ...) or (batch, seqlen)
    rv   )rt   device)rs   rx   zerosrt   r   r   r   )rk   r   r   r   rY   padded_inputs_r   s           rT   _pad_modernbert_outputr   "  s    $ zz|qUV^6<<V wE62  <<DUV^]d]&,,v}}] w#E69D9rU   c                   \    e Zd Ze	 	 ddeej                     dee   fd       Zed        Z	y)ApplyRotaryEmbUnpadNr   
max_seqlenc           
          |j                         }|j                  \  }}}}	|d d d df   j                  |d|	      }
t        |
||d||dd       | j	                  |||       || _        |S )Nr   rq   r   FT)seqlen_offsetsr   r   interleavedinplace)
contiguousr   r   r    save_for_backwardr   )ctxqkvcossinr   r   	total_nnz_three_nheadsheaddimqks              rT   forwardzApplyRotaryEmbUnpad.forwardB  s     nn.1ii+	67G BQBZ__YG4!!		
 	c3
3#
rU   c                     | j                   \  }}}|j                         }|j                  \  }}}}|d d d df   j                  |d|      }	t	        |	||d|| j
                  ddd	       |d d d d d d fS )Nr   rq   r   FT)r   r   r   r   r   	conjugate)saved_tensorsr   r   r   r    r   )
r   dor   r   r   r   r   r   r   dqks
             rT   backwardzApplyRotaryEmbUnpad.backwarda  s    "00S*]]_.0hh+	67G BQBinnYG4!~~
	
 4tT455rU   NN)
rc   rd   re   staticmethodr   rx   Tensorr|   r   r   r0   rU   rT   r   r   A  sQ     .2$(
 U\\* SM < 6 6rU   r   r   r   c                 4    t         j                  | ||||      S )a  
    Arguments:
        qkv: (total_nnz, 3, nheads, headdim) - input tensor for packed QKV.
        cos, sin: (seqlen_rotary, rotary_dim / 2)
        interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
            of 1st half and 2nd half (GPT-NeoX style).
        inplace: if True, apply rotary embedding in-place.
        seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
            Most commonly used in inference when we have KV cache.
        cu_seqlens: (batch + 1,) or None
        max_seqlen: int
    Return:
        out: (total_nnz, dim)
    rotary_dim must be <= headdim
    Apply rotary embedding to the first rotary_dim of x.
    )r   apply)r   r   r   r   r   s        rT   apply_rotary_unpaddedr   x  s    . $$S#sJ
KKrU   c                   "    e Zd ZdZ	 	 	 	 ddededee   deej                     deej                     f
 fdZ
	 ddej                  d	ej                  dee   d
eej                  eej                  ej                  f   f   fdZd
efdZ xZS )!ModernBertUnpaddedRotaryEmbeddingzP
    The rotary position embeddings applied directly to unpadded sequences.
    rs   baser   r   rt   c                 v    t         |   ||d|d       || _        |||| j                  |||       yyyy)a  
        max_seqlen: if max_seqlen, device, and dtype are provided, we precompute the cos_sin_cache
            up to max_seqlen. If the max_seqlen, device, or dtype during training/inference differ,
            the cos_sin_cache will be recomputed during the forward pass.
        TF)rs   r   pos_idx_in_fp32r   r   Nr   rt   )r1   r2   r   _update_cos_sin_cache)rQ   rs   r   r   r   rt   rS   s         rT   r2   z*ModernBertUnpaddedRotaryEmbedding.__init__  sV     	StT&^cd$!f&8U=N&&z&&N >O&8!rU   r   r   ro   c                     |(| j                  ||j                  |j                         t        || j                  | j
                  ||      }|S )z
        Apply rotary embedding *inplace* to qkv.
        qkv: (total_nnz, 3, nheads, headdim)
        cu_seqlens: (batch + 1,) cumulative sequence lengths
        max_seqlen: int max seq length in the batch
        r   r   r   )r   r   rt   r   _cos_cached_sin_cached)rQ   r   r   r   s       rT   r   z)ModernBertUnpaddedRotaryEmbedding.forward  sS     !&&z#**CII&V#!!
 
rU   c                 T    d| j                    d| j                   d| j                   S )Nzdim=z, base=z, scale_base=)rs   r   
scale_baserQ   s    rT   
extra_reprz,ModernBertUnpaddedRotaryEmbedding.extra_repr  s(    dhhZwtyykt>OPPrU   )rb   NNNN)rc   rd   re   rf   r|   floatr   rx   r   rt   r2   r   r   r   r   strr   ri   rj   s   @rT   r   r     s     $()-'+OO O SM	O
 &O $O. %)	\\ LL SM	
 
u||U5<<#=>>	?2QC QrU   r   c                        e Zd ZdZdef fdZ ej                  d      dej                  dej                  fd       Z
	 ddeej                     d	eej                     dej                  fd
Z xZS )ModernBertEmbeddingszV
    Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
    configc                 d   t         |           || _        t        j                  |j
                  |j                  |j                        | _        t        j                  |j                  |j                  |j                        | _        t        j                  |j                        | _        y )N)padding_idxepsbias)r1   r2   r   r	   	Embeddingr3   r5   r*   tok_embeddings	LayerNormr;   r<   normDropoutrD   droprQ   r   rS   s     rT   r2   zModernBertEmbeddings.__init__  sw     ll6+<+<f>P>P^d^q^qrLL!3!3vO_O_`	JJv778	rU   Tdynamic	input_idsro   c                 `    | j                  | j                  | j                  |                  S r   )r   r   r   )rQ   r   s     rT   compiled_embeddingsz(ModernBertEmbeddings.compiled_embeddings  s%    yy4#6#6y#ABCCrU   inputs_embedsc                     |"| j                  | j                  |            }|S | j                  j                  r| j	                  |      n.| j                  | j                  | j                  |                  }|S r   )r   r   r   rN   r   r   )rQ   r   r   hidden_statess       rT   r   zModernBertEmbeddings.forward  su     $ IIdii&>?M  ;;00 ((3YYtyy)<)<Y)GHI 
 rU   r   )rc   rd   re   rf   r"   r2   rx   compile
LongTensorr   r   r   r   ri   rj   s   @rT   r   r     s    9/ 9 U]]4 DU-=-= D%,, D !D ei!%"2"23KSTYT`T`Ka	rU   r   c                   `     e Zd ZdZdef fdZdej                  dej                  fdZ xZ	S )ModernBertMLPa6  Applies the GLU at the end of each ModernBERT layer.

    Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate`
    and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality.
    r   c                    t         |           || _        t        j                  |j
                  t        |j                        dz  |j                        | _	        t        |j                     | _        t        j                  |j                        | _        t        j                  |j                  |j
                  |j                        | _        y )Nr   r   )r1   r2   r   r	   Linearr5   r|   r6   rE   Wir   r@   actr   rF   r   Wor   s     rT   r2   zModernBertMLP.__init__  s    ))F..F4L4L0MPQ0QX^XgXgh&223JJv112	))F44f6H6Hv_rU   r   ro   c                     | j                  |      j                  dd      \  }}| j                  | j                  | j	                  |      |z              S )Nr   rq   rs   )r   chunkr   r   r   )rQ   r   inputgates       rT   r   zModernBertMLP.forward  sI    ggm,221"2=twwtyy%4!7899rU   )
rc   rd   re   rf   r"   r2   rx   r   r   ri   rj   s   @rT   r   r     s2    `/ `:U\\ :ell :rU   r   c            
       L     e Zd Zddedededeej                     f fdZ	 xZ
S )ModernBertRotaryEmbeddingr   rs   r   r   c                 d    t         |   | ||       | j                  d |||      \  }| _        y )N)r   r   )rs   r   )r1   r2   rope_init_fnattention_scaling)rQ   r   rs   r   r   inv_freqrS   s         rT   r2   z"ModernBertRotaryEmbedding.__init__  s9    fV<+/+<+<T6sY]+<+^($(rU   r   )rc   rd   re   r"   r|   r   r   rx   r   r2   ri   rj   s   @rT   r   r     s=    _/ _c _ _PXY^YeYePf _ _rU   r   moduleModernBertAttentionr   sliding_window_maskrB   bsrs   output_attentionsc	                    | j                  ||      \  }
}|j                  dd      j                  d      \  }}}t        |||
|      \  }}| j                  dz  }t        j                  ||j                  dd            |z  }|dk7  r|}||z   }t        j                  j                  |dt
        j                  	      j                  |j                        }t        j                  j                  || j                  | j                  
      }t        j                  ||      }|j                  dd      j!                         }|j#                  |d|      }|r||fS |fS )Nrm   r   rv   r   r         ࿩rq   rq   rq   rr   )ptraining)
rotary_emb	transposeunbindr   head_dimrx   matmulr	   r   softmaxfloat32tort   dropoutr?   r   r   r   )r   r   rl   r   rm   rB   r   rs   r   _kwargsr   r   querykeyvaluescaleattn_weightsattn_outputs                     rT   eager_attention_forwardr    sK      < @HCa+22q29E3%eS#s;JE3OOT!E<<s}}Q':;eCL(",.0L ==((2U]](SVVW\WbWbcL==((9Q9Q\b\k\k(lL,,|U3K''1-88:K""2r3/K\**>rU   r   target_dtypec	                     ||||      }|j                   t        j                  t        j                  fv}
|
rb|j                   }|j	                  |      }t        |||| j                  r| j                  nd| j                  |      }|j	                  |      }n3t        |||| j                  r| j                  nd| j                  |      }|j                  ||      fS )Nr   r`   )r   r   	dropout_pdeterministicwindow_size)
rt   rx   float16bfloat16r  r   r   r?   rK   r   )r   r   r   r   r   rB   r   rs   r  r  convert_dtype
orig_dtypeattns                rT   flash_attention_forwardr  $  s     SZJ
GCIIemmU^^%DDM YY
ff\"/!!28//f..s 99'
 wwz"/!!28//f..s 99'
 IIb#  rU   c                 v   | j                  ||      \  }	}
|j                  dd      j                  d      \  }}}t        |||	|
      \  }}|dk7  r|}t	        j
                  |||| j                  r| j                  nd|      j                  dd      j                         }|j                  |d	|      }|fS )
Nr   r   rv   r   r   r   r`   )r  	attn_maskrq   )
r   r   r   r   Fscaled_dot_product_attentionr   r?   r   r   )r   r   rl   r   rm   rB   r   rs   r  r   r   r  r  r	  r  s                  rT   sdpa_attention_forwardr  O  s       < @HCa+22q29E3%eS#s;JE3(", 	
&&28//f..s$	
 
1a	  ""2r3/K>rU   )flash_attention_2eagersdpac                   z     e Zd ZdZd	dedee   f fdZ	 d
dej                  dee
   dej                  fdZ xZS )r   a  Performs multi-headed self attention on a batch of unpadded sequences.

    If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput.
    If Flash Attention 2 is not installed, the implementation will use PyTorch's SDPA kernel,
    which requires padding and unpadding inputs, adding some overhead.

    See `forward` method for additional details.
    r   layer_idc                    t         |           || _        || _        |j                  |j
                  z  dk7  r&t        d|j                   d|j
                   d      |j                  | _        |j                  | _        |j
                  | _	        |j                  |j
                  z  | _
        | j                  | j                  z  | _        t        j                  |j                  d| j                  z  |j                        | _        ||j                   z  dk7  r$|j"                  dz  |j"                  dz  f| _        nd| _        |j$                  }|j&                  }| j"                  dk7  r$|j(                  |j(                  }|j"                  }|j*                  d	k(  rt-        | j                  ||
      | _        nt1        || j                  |      | _        t        j                  |j                  |j                  |j                        | _        |j                  dkD  rt        j4                  |j                        nt        j6                         | _        t;               | _        y )Nr   zThe hidden size (z6) is not a multiple of the number of attention heads ()r   r   r   r   r  )rs   r   r   )r   rs   r   r`   )r1   r2   r   r"  r5   r8   rP   r?   rK   	num_headsr   all_head_sizer	   r   r>   WqkvrA   rB   r=   r4   rC   _attn_implementationr   r   r   r   r   Identityout_dropsetpruned_heads)rQ   r   r"  
rope_thetar4   rS   s        rT   r2   zModernBertAttention.__init__  s      : ::a?#F$6$6#77mnt  oI  oI  nJ  JK  L  "(!9!9(.(G(G%33**f.H.HH!]]T^^;IIf00!d6H6H2HvOdOde	f7771<$*$:$:a$?AWAW[\A\#]D #+D --
"("@"@8+&&2#44
&,&<&<#&&*==?MM.EJDO 8v4==_ijDO))F..0B0BI^I^_@F@X@X[^@^

6#;#;<dfdododqErU   r   r   ro   c           
         | j                  |      }|j                  d   }| j                  j                  dk(  r)|j	                  dd| j
                  | j                        }n)|j	                  |dd| j
                  | j                        }t        | j                  j                     | f|| j                  | j                  || j                  |d|}|d   }| j                  | j                  |            }|f|dd  z   S )Nr   r  rq   r   )r   r   rB   r   rs   r   rv   )r'  r   r   r(  r   r%  r   MODERNBERT_ATTENTION_FUNCTIONr   rB   r&  r*  r   )rQ   r   r   rR   r   r   attn_outputss          rT   r   zModernBertAttention.forward  s     ii&  #;;++/BB((2q$..$--@C((2r1dnndmmDC4T[[5U5UV	
 00""/	
 	
 %Qdggm&<=,qr"222rU   r   F)rc   rd   re   rf   r"   r   r|   r2   rx   r   boolr   ri   rj   s   @rT   r   r   y  sS    &"/ &"8C= &"V -23||3 $D>3
 
3rU   c                   f    e Zd Zddedee   f fdZ ej                  d      dej                  dej                  fd       Z
	 	 	 	 	 	 ddej                  d	eej                     d
eej                     deej                     deej                     dee   dee   dej                  fdZ xZS )ModernBertEncoderLayerr   r"  c                    t         |           || _        |dk(  rt        j                         | _        n;t        j                  |j                  |j                  |j                        | _        t        ||      | _        t        j                  |j                  |j                  |j                        | _        t        |      | _        y )Nr   r   )r   r"  )r1   r2   r   r	   r)  	attn_normr   r5   r;   r<   r   r  mlp_normr   mlprQ   r   r"  rS   s      rT   r2   zModernBertEncoderLayer.__init__  s    q=[[]DN\\&*<*<&//X^XhXhiDN'vI	V%7%7V__SYScScd (rU   Tr   r   ro   c                 B    | j                  | j                  |            S r   )r8  r7  rQ   r   s     rT   compiled_mlpz#ModernBertEncoderLayer.compiled_mlp  s    xxm455rU   rl   r   rm   r   r   r   c           	      
   | j                  | j                  |      ||||||      }||d   z   }| j                  j                  r| j	                  |      n| j                  | j                  |            }	||	z   }|f|dd  z   S )Nrl   r   rm   r   r   r   r   rv   )r  r6  r   rN   r<  r8  r7  )
rQ   r   rl   r   rm   r   r   r   r0  
mlp_outputs
             rT   r   zModernBertEncoderLayer.forward  s     yyNN=)) 3%!!/ ! 
 &Q7 {{,, m,$--67 	
 &
2,qr"222rU   r   )NNNNNF)rc   rd   re   r"   r   r|   r2   rx   r   r   r<  r   r2  r   ri   rj   s   @rT   r4  r4    s    	)/ 	)8C= 	) U]]4 6%,, 65<< 6 !6 266:37-1$(,13||3 !.3 &ell3	3
 u//03 U\\*3 SM3 $D>3 
3rU   r4  c                        e Zd ZeZdZdZddgZdZdZ	dZ
dej                  fdZe	 	 	 	 dded	eej$                     d
eeeeeef   f      def fd       Zd Z fdZ xZS )ModernBertPreTrainedModelmodelTr   r4  Fr   c                    | j                   j                  ddt        j                  dt        ffd}| j                   j
                  | j                   j
                  t        j                  d| j                   j                  z        z  | j                   j
                  | j                   j                  dz  d}t        |t              r ||j                  |d          y t        |t              r- ||j                  |d	           ||j                  |d
          y t        |t               r- ||j"                  |d	           ||j                  |d
          y t        |t$              r ||j&                  |d
          y t        |t(              r ||j*                  |d
          y t        |t,        t.        t0        f      r ||j2                  |d          y t        |t        j4                        rW|j6                  j8                  j;                  d       |j<                  %|j<                  j8                  j?                          y y y )Nr   r   stdc                    t         j                  j                  | j                  d| |z  |z         t	        | t         j
                        r7| j                  *t         j                  j                  | j                         y y y )Nr`   )r(   rD  ab)r	   inittrunc_normal_weight
isinstancer   r   zeros_)r   rD  cutoff_factors     rT   init_weightz<ModernBertPreTrainedModel._init_weights.<locals>.init_weight  sq    GG!! .3&#% "  &")),;;*GGNN6;;/ + -rU   r]   r   )inout	embedding	final_outrQ  rO  rP  rR  g      ?) r   r:   r	   Moduler   r9   mathsqrtr7   r5   rK  r   r   r   r   r   r   r'  ModernBertPredictionHeaddenseModernBertForMaskedLMdecoder#ModernBertForSequenceClassification ModernBertForTokenClassificationModernBertForQuestionAnswering
classifierr   rJ  datafill_r   zero_)rQ   r   rN  stdsrM  s       @rT   _init_weightsz'ModernBertPreTrainedModel._init_weights  s   == M	0		 	0 	0 ++//;;00499S4;;C`C`=`3aa6600$6	
 f23--tK/@A.		4:.		4;/ 34T$Z0		4;/ 89d5k2 56U402RTrs
 ))4+<=-MM$$S){{&  &&( ' .rU   use_flash_attention_2torch_dtype
device_mapcheck_device_mapc                     |j                   ,d|_         	 | j                  |t        j                  |d|      S t        |   ||t        j                  ||      S # t        t
        f$ r
 d |_         Y :w xY w)Nr  F)rd  re  hard_check_onlyrf  )rc  rd  re  rf  )_attn_implementation_internal_check_and_enable_flash_attn_2rx   r  rP   ImportErrorr1   _autoset_attn_implementation)r'   r   rc  rd  re  rf  rS   s         rT   rl  z6ModernBertPreTrainedModel._autoset_attn_implementation0  s     //73FF0	<99 %)$)%5 :   w3"7!- 4 
 	
 , <7;4<s   #A A54A5c                    | j                   j                  du ry t        | d      rTt        | j                        dkD  r<| j                   j                  rt
        j                  d       d| j                   _        | j                  j                  dk(  r<| j                   j                  rt
        j                  d       d| j                   _        | j                  j                  dk(  r<| j                   j                  rt
        j                  d       d| j                   _        | j                   j                  t               | j                   _        y y )	NFhf_device_maprv   zqIf `accelerate` split the model across devices, `torch.compile` will not work. Falling back to non-compiled mode.mpsz|Compiling the model with `torch.compile` and using a `torch.mps` device is not supported. Falling back to non-compiled mode.cpuz|Compiling the model with `torch.compile` and using a `torch.cpu` device is not supported. Falling back to non-compiled mode.)
r   rN   hasattrlenrn  loggerwarning_oncer   typer   r   s    rT   _maybe_set_compilez,ModernBertPreTrainedModel._maybe_set_compileQ  s   ;;((E14)c$2D2D.E.I{{,,##9 -2DKK);;u${{,,##9 -2DKK);;u${{,,##9 -2DKK);;((0,?,ADKK) 1rU   c                     t        |   |i |}| j                  j                  dv r<| j                  j                  rt        j                  d       d| j                  _        |S )N>   NTzcResizing token embeddings with `torch.compile` is not supported. Falling back to non-compiled mode.F)r1   resize_token_embeddingsr   rN   rs  rt  )rQ   argsrR   model_embedsrS   s       rT   rx  z1ModernBertPreTrainedModel.resize_token_embeddingsp  s[    w6GG;;((L8{{,,##y -2DKK)rU   )FNNT)rc   rd   re   r"   config_classbase_model_prefixsupports_gradient_checkpointing_no_split_modules_supports_flash_attn_2_supports_sdpa_supports_flex_attnr	   rS  rb  classmethodr2  r   rx   rt   r   r   r   r|   rl  rv  rx  ri   rj   s   @rT   rA  rA    s    #L&*#/1IJ!N-)BII -)^  ',-1;?!%
  $
 ekk*	

 U3S#X#678
 
 
@B>
 
rU   rA  c            !           e Zd Zdef fdZd Zd Ze	 	 	 	 	 	 	 	 	 	 	 	 	 ddee	j                     dee	j                     dee	j                     dee	j                     d	ee	j                     d
ee	j                     dee	j                     dee   dee   dee   dee   dee   dee   deee	j                  df   ef   fd       Zde	j                  dede	j                  fdZ xZS )ModernBertModelr   c           	         t         |   |       || _        t        |      | _        t        j                  t        |j                        D cg c]  }t        ||       c}      | _
        t        j                  |j                  |j                  |j                        | _        d| _        | j#                          y c c}w )Nr   F)r1   r2   r   r   
embeddingsr	   
ModuleListranger7   r4  layersr   r5   r;   r<   
final_normgradient_checkpointing	post_initr9  s      rT   r2   zModernBertModel.__init__  s     .v6mmFKFLdLdFef(#FH5f
 ,,v'9'9vU[UeUef&+#	 gs   C c                 .    | j                   j                  S r   r  r   r   s    rT   get_input_embeddingsz$ModernBertModel.get_input_embeddings  s    ---rU   c                 &    || j                   _        y r   r  )rQ   r	  s     rT   set_input_embeddingsz$ModernBertModel.set_input_embeddings  s    ).&rU   r   rl   r   rm   r   r   r   r   
batch_sizeseq_lenr   output_hidden_statesreturn_dictro   .c                 j  	
 ||n| j                   j                  }||n| j                   j                  }||n| j                   j                  }|du |duz  rt	        d      |rdnd}|rdnd}| j                          || j                  ||       	)
'||j                  dd \  	
n|j                  dd \  	
||j                  n|j                  }|(t        j                  	
f|t        j                        }d}| j                   j                  dk(  rM||d}|0t        j                         5  t        ||	      ^}}}}ddd       nQt        ||	      ^}}}}n>|&t        j                  
|
      j!                  d      }| j#                  ||      \  }}| j%                  ||      }| j&                  D ]t  }|r||fz   }| j(                  r/| j*                  r#| j-                  |j.                  |||||||      }n ||||||||      }|d   }|s]t1        |      dkD  sl||d   fz   }v |r||fz   }| j3                  |      }|r't5        |	
      }|t7        	
fd|D              }|st7        d |||fD              S t9        |||      S # 1 sw Y   xY w)  
        sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
            perform global attention, while the rest perform local attention. This mask is used to avoid attending to
            far-away tokens in the local attention layers when not using Flash Attention.
        indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
            Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
        cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
            Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
        max_seqlen (`int`, *optional*):
            Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
        batch_size (`int`, *optional*):
            Batch size of the input sequences. Used to pad the output tensors.
        seq_len (`int`, *optional*):
            Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
        Nz:You must specify exactly one of input_ids or inputs_embedsr0   r   r   Fr  T)rk   rl   )r   r   )r   )r   r   r>  rv   rk   r   r   r   c              3   <   K   | ]  }t        |         yw)r  N)r   ).0hsr  r   r  s     rT   	<genexpr>z*ModernBertModel.forward.<locals>.<genexpr>
  s(      * +"gZ`ghh*s   c              3   &   K   | ]	  }||  y wr   r0   )r  vs     rT   r  z*ModernBertModel.forward.<locals>.<genexpr>  s     mq_`_lms   )last_hidden_stater   
attentions)r   r   r  use_return_dictrP   rv  %warn_if_padding_and_no_attention_maskr   r   rx   onesr2  r(  no_gradr   arange	unsqueeze_update_attention_maskr  r  r  r   _gradient_checkpointing_func__call__rr  r  r   tupler   )rQ   r   rl   r   rm   r   r   r   r   r  r  r   r  r  all_hidden_statesall_self_attentionsr   repadr   r   encoder_layerlayer_outputss         `  ``           rT   r   zModernBertModel.forward  sH   B 2C1N-TXT_T_TqTq$8$D $++JjJj 	 &1%<k$++B]B]-t";<YZZ"6BD$5b4! 66y.Q'/(&3&9&9"1&=#
G&/oobq&9#
G%.%:!!@T@T!"ZZW(=fTYT^T^_N;;++/BB:#5*:L ( I`#,^JF	7J
Q 
 Ja,^JFM7J
Q #$||GFCMMaP262M2M2C 3N 3/N/ )=Y![[ 	PM#$58H$H!**t}} $ A A!**!"' %	! !.!#1(;!-))&7! *!,M S%7!%;&9]1=M<O&O#7	P:   1]4D D62$gZPWM !,$) */* %!
 m]4EGZ$[mmm++*
 	
A s   >J((J2c                     |r| j                   j                  dk(  r't        j                  d       d| j                   _        nF| j                   j                  dk7  r-t        j                  d| j                   j                   d       t	        || j
                        }t        j                  |j                  d         j                  d      }t        j                  ||j                  z
        }|| j                   j                  dz  k  j                  d      j                  d      j                  |j                        }|j                  |j!                         t        j"                  | j
                        j$                        }||fS )Nr   zOutputting attentions is only supported with the 'eager' attention implementation, not with "sdpa". Falling back to `attn_implementation="eager"`.r  zZOutputting attentions is only supported with the eager attention implementation, not with zT. Consider setting `attn_implementation="eager"`. Setting `output_attentions=False`.r   r   )r   r(  rs  rt  r   rt   rx   r  r   r  absTrB   r  r   masked_filllogical_notfinfomin)rQ   rl   r   global_attention_maskrowsdistancewindow_maskr   s           rT   r  z&ModernBertModel._update_attention_mask  sS   {{//69##V 4;011W<##  $ @ @A B:: !;>4:: V ||177:;EEaH99TDFF]+ 4499DDQGQQRSTWWXfXmXmn 	 4??@W@W@Y[`[f[fgkgqgq[r[v[vw$&999rU   NNNNNNNNNNNNN)rc   rd   re   r"   r2   r  r  r   r   rx   r   r   r|   r2  r   r   r   r   r  ri   rj   s   @rT   r  r  }  s   	/ 	./  15156:3704*.-1$($(!%,0/3&*D
E,,-D
 !.D
 &ell3	D

 u//0D
  -D
 %,,'D
 U\\*D
 SMD
 SMD
 #D
 $D>D
 'tnD
 d^D
 
uU\\3&'8	9D
 D
L:U\\ :VZ :_d_k_k :rU   r  c                   \     e Zd Zdef fdZdej                  dej                  fdZ xZS )rV  r   c                 J   t         |           || _        t        j                  |j
                  |j
                  |j                        | _        t        |j                     | _
        t        j                  |j
                  |j                  |j                        | _        y )Nr   )r1   r2   r   r	   r   r5   rI   rW  r   rJ   r   r   r;   r<   r   r   s     rT   r2   z!ModernBertPredictionHead.__init__8  sq    YYv1163E3EvG]G]^
&667LL!3!3vO_O_`	rU   r   ro   c                 `    | j                  | j                  | j                  |                  S r   )r   r   rW  r;  s     rT   r   z ModernBertPredictionHead.forward?  s#    yy$**]";<==rU   )	rc   rd   re   r"   r2   rx   r   r   ri   rj   s   @rT   rV  rV  7  s-    a/ a>U\\ >ell >rU   rV  zd
    The ModernBert Model with a decoder head on top that is used for masked language modeling.
    )custom_introc            "       8    e Zd ZdgZdef fdZd Zdej                  fdZ	 e
j                  d      d	e
j                  d
e
j                  fd       Ze	 	 	 	 	 	 	 	 	 	 	 	 	 	 ddee
j                      dee
j                     dee
j                     dee
j                     dee
j                     dee
j                     dee
j                     dee
j                     dee   dee   dee   dee   dee   dee   d
eee
j                     ef   fd       Z xZS )rX  zdecoder.weightr   c                 t   t         |   |       || _        t        |      | _        t        |      | _        t        j                  |j                  |j                  |j                        | _        | j                  j                  | _        | j                  j                  | _        | j                          y )Nr   )r1   r2   r   r  rB  rV  headr	   r   r5   r3   rG   rY  rL   rM   r  r   s     rT   r2   zModernBertForMaskedLM.__init__K  s     $V,
,V4	yy!3!3V5F5FVM`M`a!%!>!>(,(L(L% 	rU   c                     | j                   S r   rY  r   s    rT   get_output_embeddingsz+ModernBertForMaskedLM.get_output_embeddingsX  s    ||rU   new_embeddingsc                     || _         y r   r  )rQ   r  s     rT   set_output_embeddingsz+ModernBertForMaskedLM.set_output_embeddings[  s	    %rU   Tr   rY   ro   c                 B    | j                  | j                  |            S r   )rY  r  )rQ   rY   s     rT   compiled_headz#ModernBertForMaskedLM.compiled_head^  s    ||DIIf-..rU   r   rl   r   rm   r   rn   r   r   r   r  r  r   r  r  c                 H   ||n| j                   j                  }| j                          | j                   j                  dk(  r|||	|
)|'||j                  dd \  }
}n|j                  dd \  }
}||j
                  n|j
                  }|(t        j                  |
|f|t        j                        }|4t        j                         5  t        ||||      \  }}}}	}}ddd       nt        ||||      \  }}}}	}}| j                  ||||||||	|
||||      }|d   }| j                  rK|I|j                  d      }|j                  |j                  d   d      }|| j                  k7  }||   }||   }| j                   j                  r| j!                  |      n| j#                  | j%                  |            }d}|(| j'                  ||| j                   j(                  	      }| j                   j                  dk(  rN| j                   j*                  s|
t-               nt        j                         5  t/        |||
|
      }ddd       |s|f}||f|z   S |S t1        |||j2                  |j4                        S # 1 sw Y   xY w# 1 sw Y   HxY w)r  Nr  r   r   )rk   rl   rm   rn   r   rl   r   rm   r   r   r   r   r  r  r   r  r  r   rq   )r3   r  losslogitsr   r  )r   r  rv  r(  r   r   rx   r  r2  r  r   rB  rL   r   rM   rN   r  rY  r  loss_functionr3   rO   r   r   r   r   r  )rQ   r   rl   r   rm   r   rn   r   r   r   r  r  r   r  r  rR   r   outputsr  mask_tokensr  r  rY   s                          rT   r   zModernBertForMaskedLM.forwardb  s   F &1%<k$++B]B]!;;++/BB:#5*:L%'/$0.;.A.A"1.E+
G.7oobq.A+
G-6-B))H\H\!)%*ZZW0Ef\a\f\f%gN ( [r#,^Zfou\X	7J
LRX 
 \s,^Zfou\XM7J
LRX **) 3%'!!!/!5#  
 $AJ!!f&8[[_F 1 6 6v||A K !D$A$AAK 1+ >K(F {{,, 01dii(9:; 	 %%ffAWAW%XD;;++/BB"&++"D"D\a\i\i\k r/vwV`ipqr YF)-)9TGf$EvE!//))	
 	
m ^r rs   JJJJ!NNNNNNNNNNNNNN)rc   rd   re   _tied_weights_keysr"   r2   r  r	   r   r  rx   r   r   r  r   r   r   r|   r2  r   r   r   r   ri   rj   s   @rT   rX  rX  C  s    ++/ &BII & U]]4 /ELL /U\\ / !/  15156:/304)-*.-1$($(!%,0/3&*m
E,,-m
 !.m
 &ell3	m

 u||,m
  -m
 &m
 %,,'m
 U\\*m
 SMm
 SMm
 #m
 $D>m
 'tnm
 d^m
" 
uU\\"N2	3#m
 m
rU   rX  z`
    The ModernBert Model with a sequence classification head on top that performs pooling.
    c            "           e Zd Zdef fdZe	 	 	 	 	 	 	 	 	 	 	 	 	 	 ddeej                     deej                     deej                     deej                     deej                     deej                     d	eej                     d
eej                     dee
   dee
   dee
   dee   dee   dee   deeej                     ef   fd       Z xZS )rZ  r   c                 n   t         |   |       |j                  | _        || _        t	        |      | _        t        |      | _        t        j                  j                  |j                        | _        t        j                  |j                  |j                        | _        | j!                          y r   )r1   r2   
num_labelsr   r  rB  rV  r  rx   r	   r   rH   r   r   r5   r]  r  r   s     rT   r2   z,ModernBertForSequenceClassification.__init__  s      ++$V,
,V4	HH$$V%>%>?	))F$6$68I8IJ 	rU   r   rl   r   rm   r   rn   r   r   r   r  r  r   r  r  ro   c                 f   ||n| j                   j                  }| j                          | j                  ||||||||	|
||||      }|d   }| j                   j                  dk(  r
|dddf   }nQ| j                   j                  dk(  r8||j                  d      z  j                  d      |j                  dd	
      z  }| j                  |      }| j                  |      }| j                  |      }d}|| j                   j                  | j                  dk(  rd| j                   _
        nl| j                  dkD  rL|j                  t        j                  k(  s|j                  t        j                  k(  rd| j                   _
        nd| j                   _
        | j                   j                  dk(  rIt!               }| j                  dk(  r& ||j#                         |j#                               }n |||      }n| j                   j                  dk(  r=t%               } ||j'                  d| j                        |j'                  d            }n,| j                   j                  dk(  rt)               } |||      }|s|f}||f|z   S |S t+        |||j,                  |j.                        S )aB  
        sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
            perform global attention, while the rest perform local attention. This mask is used to avoid attending to
            far-away tokens in the local attention layers when not using Flash Attention.
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
            Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
        cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
            Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
        max_seqlen (`int`, *optional*):
            Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
        batch_size (`int`, *optional*):
            Batch size of the input sequences. Used to pad the output tensors.
        seq_len (`int`, *optional*):
            Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
        Nr  r   r'   r(   rq   rv   r   T)rs   keepdim
regressionsingle_label_classificationmulti_label_classificationr  )r   r  rv  rB  r%   r  rw   r  r   r]  problem_typer  rt   rx   longr|   r   squeezer   r   r
   r   r   r  )rQ   r   rl   r   rm   r   rn   r   r   r   r  r  r   r  r  rR   r  r  pooled_outputr  r  loss_fctrY   s                          rT   r   z+ModernBertForSequenceClassification.forward  s   N &1%<k$++B]B]!**) 3%'!!!/!5#  
 $AJ;;))U2 1!Q$ 7[[++v5!2^5M5Mb5Q!Q V V[\ V ]`n`r`rt as a ! 		"34		-0/{{''/??a'/;DKK,__q(fllejj.HFLL\a\e\eLe/LDKK,/KDKK,{{''<7"9??a'#FNN$4fnn6FGD#FF3D))-JJ+-B @&++b/R))-II,./YF)-)9TGf$EvE'!//))	
 	
rU   r  )rc   rd   re   r"   r2   r   r   rx   r   r   r|   r2  r   r   r   r   ri   rj   s   @rT   rZ  rZ    sk   /   15156:/304)-*.-1$($(!%,0/3&*e
E,,-e
 !.e
 &ell3	e

 u||,e
  -e
 &e
 %,,'e
 U\\*e
 SMe
 SMe
 #e
 $D>e
 'tne
 d^e
" 
uU\\"$<<	=#e
 e
rU   rZ  zv
    The ModernBert Model with a token classification head on top, e.g. for Named Entity Recognition (NER) tasks.
    c            "           e Zd Zdef fdZe	 	 	 	 	 	 	 	 	 	 	 	 	 	 ddeej                     deej                     deej                     deej                     deej                     deej                     d	eej                     d
eej                     dee
   dee
   dee
   dee   dee   dee   deeej                     ef   fd       Z xZS )r[  r   c                 `   t         |   |       |j                  | _        t        |      | _        t        |      | _        t        j                  j                  |j                        | _        t        j                  |j                  |j                        | _        | j                          y r   r1   r2   r  r  rB  rV  r  rx   r	   r   rH   r   r   r5   r]  r  r   s     rT   r2   z)ModernBertForTokenClassification.__init__U  s{      ++$V,
,V4	HH$$V%>%>?	))F$6$68I8IJ 	rU   r   rl   r   rm   r   rn   r   r   r   r  r  r   r  r  ro   c                    ||n| j                   j                  }| j                          | j                  ||||||||	|
||||      }|d   }| j	                  |      }| j                  |      }| j                  |      }d}|<t               } ||j                  d| j                        |j                  d            }|s|f|dd z   }||f|z   S |S t        |||j                  |j                        S )a  
        sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
            perform global attention, while the rest perform local attention. This mask is used to avoid attending to
            far-away tokens in the local attention layers when not using Flash Attention.
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
        indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
            Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
        cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
            Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
        max_seqlen (`int`, *optional*):
            Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
        batch_size (`int`, *optional*):
            Batch size of the input sequences. Used to pad the output tensors.
        seq_len (`int`, *optional*):
            Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
        Nr  r   rq   rv   r  )r   r  rv  rB  r  r   r]  r   r   r  r   r   r  )rQ   r   rl   r   rm   r   rn   r   r   r   r  r  r   r  r  r  r  r  r  r  rY   s                        rT   r   z(ModernBertForTokenClassification.forwarda  s#   H &1%<k$++B]B]!**) 3%'!!!/!5#  
 $AJ II&78 II&78!23')HFKKDOO<fkk"oNDY,F)-)9TGf$EvE$!//))	
 	
rU   r  )rc   rd   re   r"   r2   r   r   rx   r   r   r|   r2  r   r   r   r   ri   rj   s   @rT   r[  r[  O  sk   
/ 
  15156:/304)-*.-1$($(!%,0/3&*I
E,,-I
 !.I
 &ell3	I

 u||,I
  -I
 &I
 %,,'I
 U\\*I
 SMI
 SMI
 #I
 $D>I
 'tnI
 d^I
  
uU\\"$99	:!I
 I
rU   r[  c            "           e Zd Zdef fdZe	 	 	 	 	 	 	 	 	 	 	 	 	 ddeej                     deej                     deej                     deej                     deej                     deej                     d	eej                     d
eej                     dee	   dee	   dee	   dee
   dee
   dee
   deeej                     ef   fd       Z xZS )r\  r   c                 `   t         |   |       |j                  | _        t        |      | _        t        |      | _        t        j                  j                  |j                        | _        t        j                  |j                  |j                        | _        | j                          y r   r  r   s     rT   r2   z'ModernBertForQuestionAnswering.__init__  sy      ++$V,
,V4	HH$$V%>%>?	))F$6$68I8IJrU   r   rl   r   rm   start_positionsend_positionsr   r   r   r  r  r   r  r  ro   c                 T   ||n| j                   j                  }| j                          | j                  |||||||	|
||||      }|d   }| j	                  |      }| j                  |      }| j                  |      }|j                  dd      \  }}|j                  d      j                         }|j                  d      j                         }d}|| | j                  ||||fi |}|s||f|dd z   }||f|z   S |S t        ||||j                  |j                        S )r  N)rl   r   rm   r   r   r   r  r  r   r  r  r   rv   rq   r   )r  start_logits
end_logitsr   r  )r   r  rv  rB  r  r   r]  splitr  r   r  r   r   r  )rQ   r   rl   r   rm   r  r  r   r   r   r  r  r   r  r  rR   r  r  r  r  r  r  rY   s                          rT   r   z&ModernBertForQuestionAnswering.forward  sh   F &1%<k$++B]B]!**) 3%!!!/!5#  
 $AJ II&78 II&78!23#)<<r<#: j#++B/::<''+668
&=+D%4%%lJQ^ibhiD"J/'!"+=F)-)9TGf$EvE+%!!//))
 	
rU   r  )rc   rd   re   r"   r2   r   r   rx   r   r|   r2  r   r   r   r   ri   rj   s   @rT   r\  r\    sf   	/ 	  266:/32604*.-1$($(!%,0/3&*K
ELL)K
 !.K
 &ell3	K

 u||,K
 "%,,/K
  -K
 %,,'K
 U\\*K
 SMK
 SMK
 #K
 $D>K
 'tnK
 d^K
" 
uU\\"$@@	A#K
 K
rU   r\  )r"   r  rA  rX  rZ  r[  r\  r   r1  )TrT  
contextlibr   typingr   r   r   r   r   rx   torch.nn.functionalr	   r   r  torch.utils.checkpointtorch.nnr
   r   r   activationsr   configuration_utilsr   modeling_attn_mask_utilsr   modeling_outputsr   r   r   r   r   modeling_utilsr   utilsr   r   r   utils.import_utilsr   gemma.modeling_gemmar   r   flash_attn.flash_attn_interfacer   flash_attn.layers.rotaryr   flash_attn.ops.triton.rotaryr    object
get_loggerrc   rs  r"   r   r|   r   r   autogradFunctionr   r   r   rS  r   r   r   r   r2  r  r  rt   r  r  r/  r   r4  rA  r  rV  rX  rZ  r[  r\  __all__r0   rU   rT   <module>r     s     " 8 8      A A ! 3 B  . G G 5 M P89O 
		H	%A' AN ,0%)	&mLL&mLL&m 5<<(&m U\\"	&m
 5<<u||S(5<<:PRZ[`[g[gRhhi&mRLL\\  	
 \\>46%..11 46v *. $L &	L
 L42Q 2Qj299 <:BII :(_ 4 _ )."!"	" LL" 	"
 5++," 38_" 	" 
"  ~" 5u||+,eELL.AAB"\ !&(!!(!	(! 2(! 	(!
 (! 38_(! 	(! 
(! ++(! 5<<(!V ! 	  LL  	 
 5++,  38_  	  
  5<< H 1$"! M3")) M3`+3RYY +3\ B B BJ v:/ v: v:r	>ryy 	> 
H
5 H

H
V 
t
*C t

t
n 
W
'@ W

W
t X
%> X
 X
vrU   