
    Uh                        d Z ddlZddlZddlmZ ddlmZmZmZm	Z	m
Z
 ddlZddlZddlZddlmZ ddlmZmZmZ ddlmZ dd	lmZ dd
lmZ ddlmZ ddlmZmZmZ ddl m!Z!m"Z" ddl#m$Z$m%Z%m&Z&m'Z'm(Z( ddl)m*Z*m+Z+m,Z,  e'jZ                  e.      Z/d Z0	 dKdejb                  de2de2de2de2dejb                  fdZ3dLdZ4d Z5d Z6e G d de$             Z7e G d de$             Z8e G d  d!e$             Z9 G d" d#ejt                        Z; G d$ d%ejt                        Z<	 dMd&ejt                  d'ejb                  d(ejb                  d)ejb                  d*eejb                     d+e2d,e2fd-Z= G d. d/ejt                        Z> G d0 d1ejt                        Z? G d2 d3e      Z@e% G d4 d5e"             ZA G d6 d7ejt                        ZB G d8 d9ejt                        ZC e%d:;       G d< d=eA             ZD G d> d?ejt                        ZE G d@ dAejt                        ZF e%dB;       G dC dDeA             ZGe% G dE dFeA             ZH e%dG;       G dH dIeA             ZIg dJZJy)NzPyTorch Siglip model.    N)	dataclass)AnyCallableOptionalTupleUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss)_calculate_fan_in_and_fan_out   )ACT2FN)_prepare_4d_attention_mask)GradientCheckpointingLayer)BaseModelOutputBaseModelOutputWithPoolingImageClassifierOutput)ALL_ATTENTION_FUNCTIONSPreTrainedModel)ModelOutputauto_docstringcan_return_tuplelogging	torch_int   )SiglipConfigSiglipTextConfigSiglipVisionConfigc                    d }||d|z  z
  k  s||d|z  z   kD  rt        j                  dd        |||z
  |z        } |||z
  |z        }| j                  d|z  dz
  d|z  dz
         | j                          | j	                  |t        j                  d      z         | j                  |       | j                  ||       y )Nc                 d    dt        j                  | t        j                  d      z        z   dz  S )N      ?       @)matherfsqrt)xs    |/var/www/catia.catastroantioquia-mas.com/valormas/lib/python3.12/site-packages/transformers/models/siglip/modeling_siglip.pynorm_cdfz _trunc_normal_.<locals>.norm_cdf,   s(    dhhq499S>122c99       zjmean is more than 2 std from [a, b] in nn.init.trunc_normal_. The distribution of values may be incorrect.)
stacklevelr   r#   )minmax)	warningswarnuniform_erfinv_mul_r$   r&   add_clamp_)tensormeanstdabr)   lus           r(   _trunc_normal_r=   )   s    : 	q1s7{q1s7{ 2;	
 	!d(c!"A!d(c!"A OOAEAIq1uqy) NN KKdiin$%
KK MMaQMr*   r6   r7   r8   r9   r:   returnc                     t        j                         5  t        | dd||       | j                  |      j	                  |       ddd       y# 1 sw Y   yxY w)an  Fills the input Tensor with values drawn from a truncated
    normal distribution. The values are effectively drawn from the
    normal distribution :math:`\mathcal{N}(	ext{mean}, 	ext{std}^2)`
    with values outside :math:`[a, b]` redrawn until they are within
    the bounds. The method used for generating the random values works
    best when :math:`a \leq 	ext{mean} \leq b`.

    NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
    bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
    and the result is subsequently scaled and shifted by the mean and std args.

    Args:
        tensor: an n-dimensional `torch.Tensor`
        mean: the mean of the normal distribution
        std: the standard deviation of the normal distribution
        a: the minimum cutoff value
        b: the maximum cutoff value
    r   r"   N)torchno_gradr=   r3   r4   )r6   r7   r8   r9   r:   s        r(   trunc_normal_tf_rB   M   sI    * 
 $vq#q!,Cd#$ $ $s   0AAc                 ,   t        |       \  }}|dk(  r|}n|dk(  r|}n|dk(  r||z   dz  }|z  }|dk(  r$t        | t        j                  |      dz         y |dk(  rCt	        j
                         5  | j                  t        j                  |             d d d        y |d	k(  rIt        j                  d
|z        }t	        j
                         5  | j                  | |       d d d        y t        d|       # 1 sw Y   y xY w# 1 sw Y   y xY w)Nfan_infan_outfan_avgr+   truncated_normalg۶%?r8   normaluniformr   zinvalid distribution )	r   rB   r$   r&   r@   rA   normal_r1   
ValueError)	r6   scalemodedistributionrD   rE   denomvariancebounds	            r(   variance_scaling_rS   g   s   3F;OFGx				'!Q&u}H))TYYx%8;N%NO		!]]_ 	4NNtyy2N3	4 	4		"		!h,']]_ 	+OOUFE*	+ 	+ 0?@@	4 	4	+ 	+s   3&C>D
>D
Dc                      t        | dd       y )NrD   rG   rN   rO   rS   r6   s    r(   lecun_normal_rX      s    f8:LMr*   c                      t        | dd       y )NrD   rI   rU   rV   rW   s    r(   default_flax_embed_initrZ      s    f8(Cr*   c                       e Zd ZU dZdZeej                     ed<   dZ	eej                     ed<   dZ
eeej                  df      ed<   dZeeej                  df      ed<   y)SiglipVisionModelOutputa  
    Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.

    Args:
        image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
            The image embeddings obtained by applying the projection layer to the pooler_output.
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    Nimage_embedslast_hidden_state.hidden_states
attentions)__name__
__module____qualname____doc__r]   r   r@   FloatTensor__annotations__r^   r_   r   r`    r*   r(   r\   r\      sr    * 15L(5,,-459x 1 129=AM8E%"3"3S"89:A:>Ju00#567>r*   r\   c                       e Zd ZU dZdZeej                     ed<   dZ	eej                     ed<   dZ
eeej                  df      ed<   dZeeej                  df      ed<   y)SiglipTextModelOutputa  
    Base class for text model's outputs that also contains a pooling of the last hidden states.

    Args:
        text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
            The text embeddings obtained by applying the projection layer to the pooler_output.
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    Ntext_embedsr^   .r_   r`   )ra   rb   rc   rd   rj   r   r@   re   rf   r^   r_   r   r`   rg   r*   r(   ri   ri      sr    * 04K%++,359x 1 129=AM8E%"3"3S"89:A:>Ju00#567>r*   ri   c                      e Zd ZU dZdZeej                     ed<   dZ	eej                     ed<   dZ
eej                     ed<   dZeej                     ed<   dZeej                     ed<   dZeed<   dZeed	<   d
ee   fdZy)SiglipOutputa  
    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
            Contrastive loss for image-text similarity.
        logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
            The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
            similarity scores.
        logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
            The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
            similarity scores.
        text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
            The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`].
        image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
            The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`].
        text_model_output (`BaseModelOutputWithPooling`):
            The output of the [`SiglipTextModel`].
        vision_model_output (`BaseModelOutputWithPooling`):
            The output of the [`SiglipVisionModel`].
    Nlosslogits_per_imagelogits_per_textrj   r]   text_model_outputvision_model_outputr>   c                 H     t         fd j                         D              S )Nc              3   d   K   | ]'  }|d vr|   nt        |      j                          ) yw))rp   rq   N)getattrto_tuple).0kselfs     r(   	<genexpr>z(SiglipOutput.to_tuple.<locals>.<genexpr>   s=      
  LLDGRYZ^`aRbRkRkRmm
s   -0)tuplekeysrx   s   `r(   ru   zSiglipOutput.to_tuple   s#     
YY[
 
 	
r*   )ra   rb   rc   rd   rm   r   r@   re   rf   rn   ro   rj   r]   rp   r   rq   r   r   ru   rg   r*   r(   rl   rl      s    ( )-D(5$$
%,48hu001837OXe//07/3K%++,304L(5,,-448186:3:
%* 
r*   rl   c                        e Zd Zdef fdZdej                  dededej                  fdZd
dej                  dej                  fd	Z
 xZS )SiglipVisionEmbeddingsconfigc                 f   t         |           || _        |j                  | _        |j
                  | _        |j                  | _        t        j                  |j                  | j                  | j                  | j                  d      | _
        | j
                  | j                  z  dz  | _        | j                  | _        t        j                  | j                  | j                        | _        | j                  dt!        j"                  | j                        j%                  d      d       y )Nvalid)in_channelsout_channelskernel_sizestridepaddingr+   position_idsr   F
persistent)super__init__r   hidden_size	embed_dim
image_size
patch_sizer	   Conv2dnum_channelspatch_embeddingnum_patchesnum_positions	Embeddingposition_embeddingregister_bufferr@   arangeexpandrx   r   	__class__s     r(   r   zSiglipVisionEmbeddings.__init__   s    ++ ++ ++!yy++?? 
 !OOt>1D!--"$,,t/A/A4>>"R^U\\$:L:L-M-T-TU\-]jopr*   
embeddingsheightwidthr>   c                    |j                   d   }| j                  j                  j                   d   }t        j                  j                         s%||k(  r ||k(  r| j                  | j                        S | j                  j                  j                  d      }|j                   d   }|| j                  z  }|| j                  z  }	t        |dz        }
|j                  d|
|
|      }|j                  dddd      }t        j                  j                  |||	fdd	      }|j                  dddd      j                  dd|      }|S )
a  
        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
        images. This method is also adapted to support torch.jit tracing and no class embeddings.

        Adapted from:
        - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
        - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
        r   r   r   g      ?r   r+   bicubicF)sizerN   align_corners)shaper   weightr@   jit
is_tracingr   	unsqueezer   r   reshapepermuter	   
functionalinterpolateview)rx   r   r   r   r   r   patch_pos_embeddim
new_height	new_widthsqrt_num_positionss              r(   interpolate_pos_encodingz/SiglipVisionEmbeddings.interpolate_pos_encoding   sE    !&&q)//66<<Q? yy##%+*F6UZ?**4+<+<==1188BB1Er"t.
T__,	&}c'9:)11!5GI[]`a)11!Q1=--33i(	 4 
 *11!Q1=BB1b#Nr*   pixel_valuesc                 \   |j                   \  }}}}| j                  j                  j                  }| j                  |j	                  |            }|j                  d      j                  dd      }|r|| j                  |||      z   }|S || j                  | j                        z   }|S )N)dtyper+   r   )
r   r   r   r   toflatten	transposer   r   r   )	rx   r   r   _r   r   target_dtypepatch_embedsr   s	            r(   forwardzSiglipVisionEmbeddings.forward%  s    *001fe++2288++LOO,O,OP!))!,66q!<
##d&C&CJPVX]&^^J  $d&=&=d>O>O&PPJr*   F)ra   rb   rc   r   r   r@   Tensorintr   re   r   __classcell__r   s   @r(   r~   r~      s`    q1 q($5<< $ $UX $]b]i]i $L
E$5$5 
Z_ZfZf 
r*   r~   c            	            e Zd Zdef fdZ	 	 	 ddeej                     deej                     deej                     dej                  fdZ
 xZS )	SiglipTextEmbeddingsr   c                 N   t         |           |j                  }t        j                  |j
                  |      | _        t        j                  |j                  |      | _        | j                  dt        j                  |j                        j                  d      d       y )Nr   r   Fr   )r   r   r   r	   r   
vocab_sizetoken_embeddingmax_position_embeddingsr   r   r@   r   r   rx   r   r   r   s      r(   r   zSiglipTextEmbeddings.__init__4  s    &&	!||F,=,=yI"$,,v/M/My"Y 	ELL)G)GHOOPWXej 	 	
r*   	input_idsr   inputs_embedsr>   c                 8   ||j                   d   n|j                   d   }| j                  j                  j                   d   }||kD  rt        d| d|       || j                  d d d |f   }|| j                  |      }| j                  |      }||z   }|S )Nr   r   zRSequence length must be less than max_position_embeddings (got `sequence length`: z and max_position_embeddings: )r   r   r   rL   r   r   )rx   r   r   r   
seq_lengthmax_position_embeddingposition_embeddingsr   s           r(   r   zSiglipTextEmbeddings.forward@  s     -6,AY__R(}GZGZ[]G^
!%!8!8!?!?!E!Ea!H..d,<=S<TV 
 ,,Q^<L  00;M"55lC"%88
r*   NNN)ra   rb   rc   r   r   r   r@   
LongTensorre   r   r   r   r   s   @r(   r   r   3  sk    

/ 

 153759	E,,- u//0   1 12	
 
r*   r   modulequerykeyvalueattention_maskscalingdropoutc                    t        j                  ||j                  dd            |z  }|||z   }t        j                  j                  |dt         j                        j                  |j                        }t        j                  j                  ||| j                        }t        j                  ||      }	|	j                  dd      j                         }	|	|fS )Nr   r   )r   r   )ptrainingr   r+   )r@   matmulr   r	   r   softmaxfloat32r   r   r   r   
contiguous)
r   r   r   r   r   r   r   kwargsattn_weightsattn_outputs
             r(   eager_attention_forwardr   [  s     <<s}}R'<=GL!#n4==((2U]](SVVW\WbWbcL==((6??([L,,|U3K''1-88:K$$r*   c                        e Zd ZdZdeeef   f fdZ	 	 d	dej                  de
ej                     de
e   deej                  e
ej                     f   fdZ xZS )
SiglipAttentionz=Multi-headed attention from 'Attention Is All You Need' paperr   c                    t         |           || _        |j                  | _        |j
                  | _        | j                  | j                  z  | _        | j                  | j                  z  | j                  k7  r&t        d| j                   d| j                   d      | j                  dz  | _	        |j                  | _        d| _        t        j                  | j                  | j                        | _        t        j                  | j                  | j                        | _        t        j                  | j                  | j                        | _        t        j                  | j                  | j                        | _        y )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).      F)r   r   r   r   r   num_attention_heads	num_headshead_dimrL   rM   attention_dropoutr   	is_causalr	   Lineark_projv_projq_projout_projr   s     r(   r   zSiglipAttention.__init__u  s   ++33$..8==4>>)T^^;MdnnM] ^NN#2'  ]]D(
//ii?ii?ii?		$..$..Ar*   r_   r   output_attentionsr>   c           
         |j                   \  }}}| j                  |      }| j                  |      }| j                  |      }	|j	                  ||| j
                  | j                        j                  dd      }|j	                  ||| j
                  | j                        j                  dd      }|	j	                  ||| j
                  | j                        j                  dd      }	t        }
| j                  j                  dk7  rN| j                  j                  dk(  r|rt        j                  d       nt        | j                  j                     }
 |
| |||	|| j                  | j                  | j                   sdn| j"                        \  }}|j%                  |||      j'                         }| j)                  |      }|sd}||fS )	z#Input shape: Batch x Time x Channelr   r+   eagersdpaz`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.        )r   r   r   N)r   r   r   r   r   r   r   r   r   r   _attn_implementationloggerwarning_oncer   r   rM   r   r   r   r   r   )rx   r_   r   r   
batch_sizer   r   queriesr{   valuesattention_interfacer   r   s                r(   r   zSiglipAttention.forward  s    -:,?,?)
J	++m,{{=)]+,,z:t~~t}}U__`acdeyyZOYYZ[]^_ZT^^T]]S]]^_abc(?;;++w6{{//69>O##L
 '>dkk>^>^&_#$7nnJJ#}}C$,,	%
!\ "))*j)LWWYmmK0 LL((r*   NF)ra   rb   rc   rd   r   r   r   r   r@   r   r   boolr   r   r   r   s   @r(   r   r   r  s{    GBu%79I%IJ B. 26,1	-)||-) !.-) $D>	-)
 
u||Xell33	4-)r*   r   c                   V     e Zd Z fdZdej
                  dej
                  fdZ xZS )	SiglipMLPc                    t         |           || _        t        |j                     | _        t        j                  |j                  |j                        | _
        t        j                  |j                  |j                        | _        y N)r   r   r   r   
hidden_actactivation_fnr	   r   r   intermediate_sizefc1fc2r   s     r(   r   zSiglipMLP.__init__  sd    #F$5$5699V//1I1IJ99V55v7I7IJr*   r_   r>   c                 l    | j                  |      }| j                  |      }| j                  |      }|S r  )r	  r  r
  )rx   r_   s     r(   r   zSiglipMLP.forward  s4    /**=9/r*   )ra   rb   rc   r   r@   r   r   r   r   s   @r(   r  r    s$    KU\\ ell r*   r  c            
            e Zd Zdeeef   f fdZ	 ddej                  dej                  de	e
   deej                     fdZ xZS )	SiglipEncoderLayerr   c                 D   t         |           |j                  | _        t	        j
                  | j                  |j                        | _        t        |      | _	        t	        j
                  | j                  |j                        | _
        t        |      | _        y )Neps)r   r   r   r   r	   	LayerNormlayer_norm_epslayer_norm1r   	self_attnlayer_norm2r  mlpr   s     r(   r   zSiglipEncoderLayer.__init__  sm    ++<<F<Q<QR(0<<F<Q<QRV$r*   r_   r   r   r>   c                     |}| j                  |      }| j                  |||      \  }}||z   }|}| j                  |      }| j                  |      }||z   }|f}|r||fz  }|S )a=  
        Args:
            hidden_states (`torch.FloatTensor`):
                Input to the layer of shape `(batch, seq_len, embed_dim)`.
            attention_mask (`torch.FloatTensor`):
                Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
            output_attentions (`bool`, *optional*, defaults to `False`):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        )r_   r   r   )r  r  r  r  )rx   r_   r   r   residualr   outputss          r(   r   zSiglipEncoderLayer.forward  s      !((7&*nn')/ '5 '
#|
 !=0 ((7/ =0 "&Gr*   r   )ra   rb   rc   r   r   r   r   r@   r   r   r  r   re   r   r   r   s   @r(   r  r    sb    %u%79I%IJ % -2	$||$ $ $D>	$
 
u  	!$r*   r  c                   .    e Zd ZeZdZdZg dZdZdZ	d Z
y)SiglipPreTrainedModelsiglipT)r   r  r~   r  #SiglipMultiheadAttentionPoolingHeadc                    t        |t              rt        | j                  t              r | j                  j                  j
                  n| j                  j
                  }t        j                  j                  |j                  j                  dt        j                  |      z         yt        |t        j                        rt        |j                         yt        |t              rt        j                  j!                  |j"                  j                         t        j                  j!                  |j$                  j                         t        j                  j!                  |j&                  j                         t        j                  j!                  |j(                  j                         t        j                  j+                  |j"                  j,                         t        j                  j+                  |j$                  j,                         t        j                  j+                  |j&                  j,                         t        j                  j+                  |j(                  j,                         yt        |t.              rt        j                  j!                  |j0                  j                         t        j                  j!                  |j2                  j                         t        j                  j                  |j0                  j,                  d       t        j                  j                  |j2                  j,                  d       yt        |t4              rt        j                  j!                  |j6                  j8                         t        j                  j!                  |j:                  j<                  j8                         t        j                  j+                  |j:                  j>                  j8                         yt        |t@              rrtC        jD                  tC        jF                  d            }|jH                  j8                  jK                  |       |jL                  j8                  jO                          yt        |tP              rnt        j                  j                  |jR                  j                  | j                  j                  j
                  dz  | j                  jT                  z         yt        |t        jV                  t        jX                  f      rLt[        |j                         |j,                  *t        j                  j+                  |j,                         yyt        |t        j\                        rJ|j,                  j8                  jO                          |j                  j8                  jK                  d       yy)zInitialize the weightsr   rH   gư>r"   r   N)/
isinstancer~   r   r   vision_configr   r	   initrK   r   r   npr&   r   rZ   r   xavier_uniform_r   r   r   r   zeros_biasr  r	  r
  r  probedata	attentionin_proj_weightin_proj_biasSiglipModelr@   logr6   logit_scalefill_
logit_biaszero_SiglipForImageClassification
classifierinitializer_factorr   r   rX   r  )rx   r   r   logit_scale_inits       r(   _init_weightsz#SiglipPreTrainedModel._init_weights	  s   f45 dkk<8 ))55[[,, 
 GGOOF55<<!bggenBTOU-#FMM20GG##FMM$8$89GG##FMM$8$89GG##FMM$8$89GG##FOO$:$:;GGNN6==--.GGNN6==--.GGNN6==--.GGNN6??//0	*GG##FJJ$5$56GG##FJJ$5$56GGOOFJJOOO6GGOOFJJOOO6 CDGG##FLL$5$56GG##F$4$4$C$C$H$HIGGNN6++88==>,$yyc):;##))*:;""((* <=GGOO!!((KK--994?$++B`B``   BII 67&--({{&v{{+ '-KK""$MM$$S) .r*   N)ra   rb   rc   r   config_classbase_model_prefixsupports_gradient_checkpointing_no_split_modules_supports_flash_attn_2_supports_sdpar5  rg   r*   r(   r  r    s-    L &*# "N,*r*   r  c            
       x     e Zd ZdZdef fdZe	 	 	 d	deej                     dee
   dee
   defd       Z xZS )
SiglipEncoderz
    Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
    [`SiglipEncoderLayer`].

    Args:
        config: SiglipConfig
    r   c                     t         |           || _        t        j                  t        |j                        D cg c]  }t        |       c}      | _        d| _	        y c c}w r   )
r   r   r   r	   
ModuleListrangenum_hidden_layersr  layersgradient_checkpointing)rx   r   r   r   s      r(   r   zSiglipEncoder.__init__B  sO    mmvOgOgIh$iA%7%?$ij&+# %js   A#r   r   output_hidden_statesr>   c                    ||n| j                   j                  }||n| j                   j                  }|rdnd}|rdnd}|}| j                  D ]&  }|r||fz   } ||||      }	|	d   }|s||	d   fz   }( |r||fz   }t	        |||      S )ad  
        Args:
            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
                This is useful if you want more control over how to convert `input_ids` indices into associated vectors
                than the model's internal embedding lookup matrix.
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        Nrg   )r   r   r   )r^   r_   r`   )r   r   rD  rB  r   )
rx   r   r   r   rD  encoder_statesall_attentionsr_   encoder_layerlayer_outputss
             r(   r   zSiglipEncoder.forwardI  s    < 2C1N-TXT_T_TqTq$8$D $++JjJj 	  40d%![[ 	FM#!/=2B!B)"3M *!,M !/=3C2E!E	F  +}.>>N+(%
 	
r*   r   )ra   rb   rc   rd   r   r   r   r   r@   r   r  r   r   r   r   s   @r(   r=  r=  9  sl    ,| ,  26,0/3<
 !.<
 $D>	<

 'tn<
 
<
 <
r*   r=  c                        e Zd Zdef fdZee	 	 	 	 	 d
deej                     deej                     deej                     dee
   dee
   defd	              Z xZS )SiglipTextTransformerr   c                 F   t         |           || _        |j                  }t	        |      | _        t        |      | _        t        j                  ||j                        | _        t        j                  ||j                        | _        |j                  dk(  | _        y )Nr  flash_attention_2)r   r   r   r   r   r   r=  encoderr	   r  r  final_layer_normr   projection_sizeheadr   _use_flash_attention_2r   s      r(   r   zSiglipTextTransformer.__init__  s}    &&	.v6$V, "YF<Q<Q RIIi)?)?@	&,&A&AEX&X#r*   r   r   r   r   rD  r>   c                     ||n| j                   j                  }||n| j                   j                  }|t        d      |j	                         }|j                  d|d         }| j                  ||      }|"| j                  st        ||j                        }| j                  ||||      }|j                  }	| j                  |	      }	|	d d dd d f   }
| j                  |
      }
t        |	|
|j                  |j                         S )NzYou have to specify input_idsr   )r   r   )r   r   r   rD  r^   pooler_outputr_   r`   )r   r   rD  rL   r   r   r   rR  r   r   rN  r^   rO  rQ  r   r_   r`   )rx   r   r   r   r   rD  input_shaper_   encoder_outputsr^   pooled_outputs              r(   r   zSiglipTextTransformer.forward  s&    2C1N-TXT_T_TqTq$8$D $++JjJj 	 <==nn&NN2{27	),W %d.I.I7H[H[\N+/<<')/!5	 ,8 ,
 ,== 112CD *!R(3		-0)/')77&11	
 	
r*   NNNNN)ra   rb   rc   r   r   r   r   r   r@   r   r  r   r   r   r   s   @r(   rK  rK    s    	Y/ 	Y  -115/3,0/3.
ELL).
 !..
 u||,	.

 $D>.
 'tn.
 
$.
  .
r*   rK  zK
    The text model from SigLIP without any head or projection on top.
    )custom_introc                        e Zd ZeZdef fdZdej                  fdZd Z	e
e	 	 	 	 	 ddeej                     deej                     deej                     d	ee   d
ee   defd              Z xZS )SiglipTextModelr   c                 d    t         |   |       t        |      | _        | j	                          y r  )r   r   rK  
text_model	post_initr   s     r(   r   zSiglipTextModel.__init__  s&     /7r*   r>   c                 B    | j                   j                  j                  S r  r^  r   r   r|   s    r(   get_input_embeddingsz$SiglipTextModel.get_input_embeddings  s    ))999r*   c                 :    || j                   j                  _        y r  ra  )rx   r   s     r(   set_input_embeddingsz$SiglipTextModel.set_input_embeddings  s    5:""2r*   r   r   r   r   rD  c                 .    | j                  |||||      S )a  
        Examples:

        ```python
        >>> from transformers import AutoTokenizer, SiglipTextModel

        >>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224")
        >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")

        >>> # important: make sure to set padding="max_length" as that's how the model was trained
        >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> last_hidden_state = outputs.last_hidden_state
        >>> pooled_output = outputs.pooler_output  # pooled (EOS token) states
        ```r   r   r   r   rD  )r^  )rx   r   r   r   r   rD  s         r(   r   zSiglipTextModel.forward  s)    6 )%/!5  
 	
r*   rY  )ra   rb   rc   r   r6  r   r	   Modulerb  rd  r   r   r   r@   r   r  r   r   r   r   s   @r(   r\  r\    s     $L/ :bii :;  -115/3,0/3
ELL)
 !.
 u||,	

 $D>
 'tn
 
$
  
r*   r\  c                   j     e Zd Zdef fdZee	 	 	 ddee   dee   dee   de	fd              Z
 xZS )	SiglipVisionTransformerr   c                 L   t         |           || _        |j                  }t	        |      | _        t        |      | _        t        j                  ||j                        | _        t        |d      sdn|j                  | _        | j                  rt        |      | _        y y )Nr  vision_use_headT)r   r   r   r   r~   r   r=  rN  r	   r  r  post_layernormhasattrrk  use_headr  rQ  r   s      r(   r   z SiglipVisionTransformer.__init__  s    &&	08$V, ll9&:O:OP$+F4E$FFLbLb==;FCDI r*   r   rD  r   r>   c                 v   ||n| j                   j                  }||n| j                   j                  }| j                  ||      }| j	                  |||      }|j
                  }| j                  |      }| j                  r| j                  |      nd }t        |||j                  |j                        S )N)r   )r   r   rD  rT  )r   r   rD  r   rN  r^   rl  rn  rQ  r   r_   r`   )	rx   r   r   rD  r   r_   rW  r^   rU  s	            r(   r   zSiglipVisionTransformer.forward  s     2C1N-TXT_T_TqTq$8$D $++JjJj 	 Ogh+/<<'/!5 ,8 ,
 ,== //0AB8<		"344)/')77&11	
 	
r*   NNF)ra   rb   rc   r   r   r   r   r   r  r   r   r   r   s   @r(   ri  ri     sm    
D1 
D  -1/338
 $D>
 'tn	

 #+4.
 
$
  
r*   ri  c                   .     e Zd ZdZdef fdZd Z xZS )r  zMultihead Attention Pooling.r   c                    t         |           t        j                  t	        j
                  dd|j                              | _        t        j                  j                  |j                  |j                  d      | _
        t        j                  |j                  |j                        | _        t        |      | _        y )Nr   T)batch_firstr  )r   r   r	   	Parameterr@   randnr   r&  MultiheadAttentionr   r(  r  r  	layernormr  r  r   s     r(   r   z,SiglipMultiheadAttentionPoolingHead.__init__3  s    \\%++aF4F4F"GH
44V5G5GIcIcqu4vf&8&8f>S>STV$r*   c                     |j                   d   }| j                  j                  |dd      }| j                  |||      d   }|}| j	                  |      }|| j                  |      z   }|d d df   S )Nr   r   )r   r&  repeatr(  rw  r  )rx   hidden_stater   r&  r  s        r(   r   z+SiglipMultiheadAttentionPoolingHead.forward;  sv    !''*


!!*a3~~e\<HK~~l3$((<"88AqD!!r*   )ra   rb   rc   rd   r   r   r   r   r   s   @r(   r  r  0  s    &%1 %
"r*   r  zM
    The vision model from SigLIP without any head or projection on top.
    c                        e Zd ZeZdZdef fdZdej                  fdZ	e
e	 	 	 d
dee   dee   dedefd	              Z xZS )SiglipVisionModelr   r   c                 d    t         |   |       t        |      | _        | j	                          y r  )r   r   ri  vision_modelr_  r   s     r(   r   zSiglipVisionModel.__init__Q  s)     3F; 	r*   r>   c                 B    | j                   j                  j                  S r  )r~  r   r   r|   s    r(   rb  z&SiglipVisionModel.get_input_embeddingsY  s      ++;;;r*   r   rD  r   c                 ,    | j                  ||||      S )a  
        Examples:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, SiglipVisionModel

        >>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224")
        >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(images=image, return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> last_hidden_state = outputs.last_hidden_state
        >>> pooled_output = outputs.pooler_output  # pooled features
        ```r   r   rD  r   )r~  )rx   r   r   rD  r   s        r(   r   zSiglipVisionModel.forward\  s(    <   %/!5%=	 ! 
 	
r*   rp  )ra   rb   rc   r   r6  main_input_namer   r	   rg  rb  r   r   r   r  r   r   r   r   s   @r(   r|  r|  H  s     &L$O1 <bii <  -1/3).!
 $D>!
 'tn	!

 #'!
 
$!
  !
r*   r|  c                       e Zd ZeZdef fdZe	 	 	 	 	 ddeej                     deej                     deej                     dee
   dee
   dej                  fd	       Ze	 	 	 	 dd
eej                     dee
   dee
   de
dej                  f
d       Zee	 	 	 	 	 	 	 	 ddeej                     d
eej                     deej                     deej                     dee
   dee
   dee
   de
defd              Z xZS )r+  r   c                    t         |   |       t        |j                  t              s"t        dt        |j                         d      t        |j                  t              s"t        dt        |j                         d      |j                  }|j                  }t        j                  |      }t        j                  |      }|j                  | _        |j                  | _        t        j                  t!        j"                  d            | _        t        j                  t!        j"                  d            | _        | j)                          y )NzMconfig.text_config is expected to be of type SiglipTextConfig but is of type .zQconfig.vision_config is expected to be of type SiglipVisionConfig but is of type r   )r   r   r  text_configr   	TypeErrortyper   r   r\  _from_configr|  r^  r~  r	   rt  r@   ru  r-  r/  r_  )rx   r   r  r   r^  r~  r   s         r(   r   zSiglipModel.__init__  s    &,,.>?++,-Q0 
 &..0BC--./q2 
 ((,, %11+>
(55mD %//(55<<A7,,u{{1~6 	r*   r   r   r   r   rD  r>   c                     ||n| j                   j                  }||n| j                   j                  }| j                  |||||      }|j                  }|S )aJ  
        Returns:
            text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
            applying the projection layer to the pooled output of [`SiglipTextModel`].

        Examples:

        ```python
        >>> from transformers import AutoTokenizer, AutoModel
        >>> import torch

        >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
        >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")

        >>> # important: make sure to set padding="max_length" as that's how the model was trained
        >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
        >>> with torch.no_grad():
        ...     text_features = model.get_text_features(**inputs)
        ```rf  )r   r   rD  r^  rU  )rx   r   r   r   r   rD  text_outputsrX  s           r(   get_text_featureszSiglipModel.get_text_features  sr    : 2C1N-TXT_T_TqTq$8$D $++JjJj 	 48??)%/!5 4C 4
 %22r*   r   r   c                     ||n| j                   j                  }||n| j                   j                  }| j                  ||||      }|j                  }|S )a  
        Returns:
            image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
            applying the projection layer to the pooled output of [`SiglipVisionModel`].

        Examples:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, AutoModel
        >>> import torch

        >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
        >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(images=image, return_tensors="pt")

        >>> with torch.no_grad():
        ...     image_features = model.get_image_features(**inputs)
        ```r  )r   r   rD  r~  rU  )rx   r   r   rD  r   vision_outputsrX  s          r(   get_image_featureszSiglipModel.get_image_features  sr    B 2C1N-TXT_T_TqTq$8$D $++JjJj 	 6:5F5F%/!5%=	 6G 6
 '44r*   return_lossc	           	         ||n| j                   j                  }||n| j                   j                  }| j                  ||||      }	| j	                  |||||      }
|	j
                  }|
j
                  }||j                  ddd      z  }||j                  ddd      z  }t        j                  ||j                         j                  |j                              }| j                  j                  |j                        | j                  j                  |j                        }}||j                         z  |z   }|j                         }d}|rt        j                  |j!                  d      |j                  	      }t        j"                  |       d|z  z   }t        j$                  j&                  j)                  ||z        }t        j*                  |d
       }|j-                         }t/        ||||||
|	      S )a  
        return_loss (`bool`, *optional*):
            Whether or not to return the contrastive loss.

        Examples:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, AutoModel
        >>> import torch

        >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
        >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"]
        >>> # important: we pass `padding=max_length` since the model was trained with this
        >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt")

        >>> with torch.no_grad():
        ...     outputs = model(**inputs)

        >>> logits_per_image = outputs.logits_per_image
        >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
        >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
        31.9% that image 0 is 'a photo of 2 cats'
        ```Nr  rf  r+   r   T)r   r   keepdimr   )devicer   )rm   rn   ro   rj   r]   rp   rq   )r   r   rD  r~  r^  rU  normr@   r   tr   r  r-  r/  expeyer   	ones_liker	   r   
logsigmoidsumr7   rl   )rx   r   r   r   r   r  r   rD  r   r  r  r]   rj   ro   r-  r/  rn   rm   r  m1_diag1logliknlls                         r(   r   zSiglipModel.forward  s   X 2C1N-TXT_T_TqTq$8$D $++JjJj 	 6:5F5F%/!5%=	 6G 6
 48??)%/!5 4C 4
 &33"00 $l&7&7!T&7&RR!K$4$4qb$$4$OO  ,,{LNN4D4G4GHZHZ4[\"&"2"2"5"5k6H6H"I4??K]K]^i^p^pKqZ)KOO,==
J*,,.))O003O<R<RSC881s7BHXX((33H4NOF99V,,C88:D-+#%* .
 	
r*   rY  )NNNF)NNNNNNNF)ra   rb   rc   r   r6  r   r   r   r@   r   r  re   r  r  r   r   rl   r   r   r   s   @r(   r+  r+    s   L| @  -115/3,0/3+ELL)+ !.+ u||,	+
 $D>+ 'tn+ 
		+ +Z  59,0/3)..u001. $D>. 'tn	.
 #'. 
		. .`  15481537&*,0/3).^
E,,-^
 u001^
 !.	^

 u//0^
 d^^
 $D>^
 'tn^
 #'^
 
^
  ^
r*   r+  z
    SigLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of
    the patch tokens) e.g. for ImageNet.
    c                        e Zd ZdZdeddf fdZee	 	 	 	 	 ddee	j                     dee	j                     dee   dee   d	edefd
              Z xZS )r1  r   r   r>   Nc                 ~   t         |   |       |j                  | _        t        j	                  |j
                        }|j                  | _        |j                  dkD  r4t        j                  |j
                  j                  |j                        nt        j                         | _        | j                          y )Nr   )r   r   
num_labelsr|  r  r   r~  r	   r   r   Identityr2  r_  )rx   r   r~  r   s      r(   r   z%SiglipForImageClassification.__init__q  s      ++ )55f6J6JK(55 OUN_N_bcNcBIIf**668I8IJikititiv 	
 	r*   labelsr   rD  r   c                    ||n| j                   j                  }||n| j                   j                  }| j                  ||||      }|j                  }t        j                  |d      }| j                  |      }d}	||j                  |j                        }| j                   j                  | j                  dk(  rd| j                   _
        nl| j                  dkD  rL|j                  t
        j                  k(  s|j                  t
        j                  k(  rd| j                   _
        nd| j                   _
        | j                   j                  dk(  rIt               }
| j                  dk(  r& |
|j!                         |j!                               }	n |
||      }	n| j                   j                  dk(  r=t#               }
 |
|j%                  d| j                        |j%                  d            }	n,| j                   j                  dk(  rt'               }
 |
||      }	t)        |	||j*                  |j,                  	      S )
a$  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).

        Examples:

        ```python
        >>> from transformers import AutoImageProcessor, SiglipForImageClassification
        >>> import torch
        >>> from PIL import Image
        >>> import requests

        >>> torch.manual_seed(3)  # doctest: +IGNORE_RESULT
        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> # note: we are loading a `SiglipModel` from the hub here,
        >>> # so the head will be randomly initialized, hence the predictions will be random if seed is not set above.
        >>> image_processor = AutoImageProcessor.from_pretrained("google/siglip-base-patch16-224")
        >>> model = SiglipForImageClassification.from_pretrained("google/siglip-base-patch16-224")

        >>> inputs = image_processor(images=image, return_tensors="pt")
        >>> outputs = model(**inputs)
        >>> logits = outputs.logits
        >>> # model predicts one of the two classes
        >>> predicted_class_idx = logits.argmax(-1).item()
        >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
        Predicted class: LABEL_1
        ```N)r   rD  r   r   r  
regressionsingle_label_classificationmulti_label_classificationr   )rm   logitsr_   r`   )r   r   rD  r~  r^   r@   r7   r2  r   r  problem_typer  r   longr   r   squeezer   r   r
   r   r_   r`   )rx   r   r  r   rD  r   r  sequence_outputr  rm   loss_fcts              r(   r   z$SiglipForImageClassification.forward  s   R 2C1N-TXT_T_TqTq$8$D $++JjJj 	 /3.?.?/!5%=	 /@ /
 "33  **_!<1YYv}}-F{{''/??a'/;DKK,__q(fllejj.HFLL\a\e\eLe/LDKK,/KDKK,{{''<7"9??a'#FNN$4fnn6FGD#FF3D))-JJ+-B @&++b/R))-II,./$!//))	
 	
r*   )NNNNF)ra   rb   rc   r  r   r   r   r   r   r@   r   r  r   r   r   r   s   @r(   r1  r1  h  s     %O|  $  04)-,0/3).X
u||,X
 &X
 $D>	X

 'tnX
 #'X
 
X
  X
r*   r1  )r+  r  r\  r|  r1  )r   r"   g       r#   )r"   rD   rI   )r   )Krd   r$   r/   dataclassesr   typingr   r   r   r   r   numpyr"  r@   torch.utils.checkpointr	   torch.nnr
   r   r   torch.nn.initr   activationsr   modeling_attn_mask_utilsr   modeling_layersr   modeling_outputsr   r   r   modeling_utilsr   r   utilsr   r   r   r   r   configuration_siglipr   r   r   
get_loggerra   r   r=   r   floatrB   rS   rX   rZ   r\   ri   rl   rg  r~   r   r   r   r  r  r  r=  rK  r\  ri  r  r|  r+  r1  __all__rg   r*   r(   <module>r     s      ! 8 8     A A 7 ! B 9 b b F V V T T 
		H	%! J \_$LL$ %$27$BG$SX$
\\$4A2ND ?k ? ?8 ?K ? ?8 !
; !
 !
HERYY ER%299 %^ %II%<<% 
% <<	%
 U\\*% % %.D)bii D)P		 -3 -` ;*O ;* ;*~M
BII M
`<
BII <
~ 
0
+ 0

0
f-
bii -
`"")) "0 
2
- 2

2
j b
' b
 b
J o
#8 o
o
dr*   