
    UhP~                        d Z ddlmZmZmZ ddlZddlZddlmZ ddl	m
Z
mZmZ ddlmZ ddlmZmZmZ dd	lmZ dd
lmZmZmZ ddlmZmZ ddlmZ  ej<                  e      Z d Z!d Z"d"dZ# G d dejH                        Z%d Z& G d dejH                        Z'e G d de             Z(e G d de(             Z) ed       G d de(e             Z* ed       G d d e(             Z+g d!Z,y)#zPyTorch CTRL model.    )OptionalTupleUnionN)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )GenerationMixin)BaseModelOutputWithPastCausalLMOutputWithPastSequenceClassifierOutput)PreTrainedModel)Conv1D find_pruneable_heads_and_indicesprune_linear_layer)auto_docstringlogging   )
CTRLConfigc                 P    dt        j                  dd|dz  z  |z        z  }| |z  S )Nr   i'     )torchpow)posid_model_sizeangle_ratess       x/var/www/catia.catastroantioquia-mas.com/valormas/lib/python3.12/site-packages/transformers/models/ctrl/modeling_ctrl.py
angle_defnr    '   s/    eiiQ!V'DEEK    c                    t        t        j                  | t        j                        j	                  |      j                  d      t        j                  |t        j                        j	                  |      j                  d      |      }t        j                  |d d dd df         }t        j                  |d d dd df         }t        j                  ||gd      }|S )Ndtyper   r   r   dim)	r    r   arangeint64to	unsqueezesincoscat)positionr   r$   
angle_radssinescosinespos_encodings          r   positional_encodingr4   ,   s    XU[[144U;EEaH\588?II!LJ IIjADqD)*Eii
1add7+,G99eW-26Lr!   c           	         t        j                  | |j                  dddd            }|j                  d   }|t	        j
                  |      z  }|6|j                  d      |j                  d      }
}	|||
|	z
  |
d |
f   dz  z  }|||z   }t        j                  |d      }|||z  }t        j                  ||      }||fS )	Nr   r   r
   r   r%   g     r&   )r   matmulpermuteshapenpsqrtsizesoftmax)qkvmaskattention_mask	head_mask	matmul_qkdkscaled_attention_logitsndnsattention_weightsoutputs                r   scaled_dot_product_attentionrK   ;   s    Q		!Q1 56I	
B'"''"+5(--b13J3O3OPR3SB4R"crc(9#:T#AA!"9N"J&=2F -	9\\+Q/F$$$r!   c                   <     e Zd Z fdZd Zd Z	 	 	 	 	 ddZ xZS )MultiHeadAttentionc                 n   t         |           || _        || _        t	        || j                  z        | _        t        j                  ||      | _        t        j                  ||      | _	        t        j                  ||      | _
        t        j                  ||      | _        t               | _        y N)super__init__	num_headsr   intdepthr   LinearWqWkWvdensesetpruned_heads)selfr   rR   	__class__s      r   rQ   zMultiHeadAttention.__init__V   s    "(67
))L,7))L,7))L,7YY|\:
Er!   c                    | j                   | j                  z  }t        |      dk(  ry t        || j                  || j                        \  }}t        | j                  |      | _        t        | j                  |      | _        t        | j                  |      | _        t        | j                  |d      | _	        | j                  t        |      z
  | _        || j                  z  | _         | j                  j                  |      | _        y )Nr   r   r&   )r   rR   lenr   r[   r   rV   rW   rX   rY   union)r\   headsattention_head_sizeindexs       r   prune_headszMultiHeadAttention.prune_headsd   s    "//4>>Au:?7t~~Obdhduduvu %TWWe4$TWWe4$TWWe4'

EqA
 #e*4/$..@ --33E:r!   c                 x    |j                  |d| j                  | j                        }|j                  g d      S )Nr%   r   r   r   r
   )reshaperR   rT   r8   )r\   x
batch_sizes      r   split_into_headsz#MultiHeadAttention.split_into_headsu   s-    IIj"dnndjjAyy&&r!   c
                 x   |j                   d   }
| j                  |      }| j                  |      }| j                  |      }| j	                  ||
      }| j	                  ||
      }| j	                  ||
      }|<|d   |d   }}t        j                  ||fd      }t        j                  ||fd      }|du rt        j                  ||f      }nd}t        ||||||      }|d   j                  g d      }|d   }|j                  |
d| j                        }| j                  |      }||f}|	r||fz   }|S )	Nr   r   r6   r&   TrO   rf   r%   )r9   rV   rW   rX   rj   r   r.   stackrK   r8   rg   r   rY   )r\   r@   r?   r>   rA   
layer_pastrB   rC   	use_cacheoutput_attentionsri   past_key
past_valuepresentrJ   scaled_attentionattnoriginal_size_attentionoutputss                      r   forwardzMultiHeadAttention.forwardy   sF    WWQZ
GGAJGGAJGGAJ!!!Z0!!!Z0!!!Z0!#-a=*Q-jH		8Q-R0A		:q/r2Akk1a&)GG-aAt^YW!!9,,\:ay"2":"::r4K\K\"]347#'Gr!   NNNFF)__name__
__module____qualname__rQ   rd   rj   rw   __classcell__r]   s   @r   rM   rM   U   s(    ";"' (r!   rM   c                     t        j                  t        j                  | |      t        j                         t        j                  ||             S rO   )r   
SequentialrU   ReLU)r   dffs     r   point_wise_feed_forward_networkr      s2    ==<5rwwy"))CQ]B^__r!   c                   *     e Zd Zd fd	Z	 ddZ xZS )EncoderLayerc                 >   t         |           t        ||      | _        t	        ||      | _        t        j                  |d      | _        t        j                  |d      | _	        t        j                  |      | _        t        j                  |      | _        y )Ngư>eps)rP   rQ   rM   multi_head_attentionr   ffnr   	LayerNorm
layernorm1
layernorm2Dropoutdropout1dropout2)r\   r   rR   r   rater]   s        r   rQ   zEncoderLayer.__init__   so    $6|Y$O!2<E,,|>,,|>

4(

4(r!   c                    | j                  |      }| j                  |||||||||	      }	|	d   }
| j                  |
      }
||
z   }| j                  |      }| j	                  |      }| j                  |      }||z   }|f|	dd  z   }|S )Nrm   rB   rC   rn   ro   r   r   )r   r   r   r   r   r   )r\   rh   rA   rm   rB   rC   rn   ro   normedattn_outputsattn_outputout1out2
ffn_outputrv   s                  r   rw   zEncoderLayer.forward   s     #00!)/ 1 

 #1ommK0;t$XXd^
]]:.
j 'L,,r!   )g?rx   )ry   rz   r{   rQ   rw   r|   r}   s   @r   r   r      s    
) qvr!   r   c                       e Zd ZeZdZd Zy)CTRLPreTrainedModeltransformerc                    t        |t        j                  t        f      rm|j                  j
                  j                  d| j                  j                         |j                  %|j                  j
                  j                          yyt        |t        j                        rz|j                  j
                  j                  d| j                  j                         |j                  2|j                  j
                  |j                     j                          yyt        |t        j                        rJ|j                  j
                  j                          |j                  j
                  j                  d       yy)zInitialize the weights.g        )meanstdN      ?)
isinstancer   rU   r   weightdatanormal_configinitializer_rangebiaszero_	Embeddingpadding_idxr   fill_)r\   modules     r   _init_weightsz!CTRLPreTrainedModel._init_weights   s   fryy&12 MM&&CT[[5R5R&S{{&  &&( '-MM&&CT[[5R5R&S!!-""6#5#56<<> .-KK""$MM$$S) .r!   N)ry   rz   r{   r   config_classbase_model_prefixr    r!   r   r   r      s    L%*r!   r   c                       e Zd Z fdZd Zd Zd Ze	 	 	 	 	 	 	 	 	 	 	 ddee	j                     deeee	j                           dee	j                     dee	j                     d	ee	j                     d
ee	j                     dee	j                     dee   dee   dee   dee   deee	j                     ef   fd       Z xZS )	CTRLModelc                    t         |   |       |j                  | _        |j                  | _        t        |j                  | j                  t        j                        | _
        t        j                  |j                  |j                        | _        t        j                  |j                         | _        t        j$                  t'        |j                        D cg c]8  }t)        |j                  |j*                  |j,                  |j.                        : c}      | _        t        j2                  |j                  |j4                        | _        | j9                          y c c}w )Nr   )rP   rQ   n_embdr   n_layer
num_layersr4   n_positionsr   floatr3   r   r   
vocab_sizewr   
embd_pdropdropout
ModuleListranger   n_headr   resid_pdrophr   layer_norm_epsilon	layernorm	post_init)r\   r   _r]   s      r   rQ   zCTRLModel.__init__   s     "MM ../0B0BDDUDUW\WbWbcf//?zz&"3"34afgmguguavw\]\&--

FDVDVWw
 fmm9R9RS 	 xs    =E*c                     | j                   S rO   r   r\   s    r   get_input_embeddingszCTRLModel.get_input_embeddings   s    vvr!   c                     || _         y rO   r   r\   new_embeddingss     r   set_input_embeddingszCTRLModel.set_input_embeddings   s	    r!   c                     |j                         D ]-  \  }}| j                  |   j                  j                  |       / y)zv
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
        N)itemsr   r   rd   )r\   heads_to_prunelayerra   s       r   _prune_headszCTRLModel._prune_heads  s>     +002 	BLE5FF5M..::5A	Br!   	input_idspast_key_valuesrB   token_type_idsposition_idsrC   inputs_embedsrn   ro   output_hidden_statesreturn_dictreturnc           
         |	|	n| j                   j                  }	||n| j                   j                  }|
|
n| j                   j                  }
||n| j                   j                  }||t        d      |G| j                  ||       |j                         }|j                  d|d         }|j                  d   }n0|#|j                         dd }|j                  d   }nt        d      ||j                  n|j                  }|%d}t        dgt        | j                        z        }n|d   d   j                  d      }|>t        j                  ||d   |z   t        j                   |      }|j#                  d      }||dk  rt        d      |j                  |d      }|j#                  d	      j#                  d
      }|j%                  | j&                        }d|z
  t        j(                  | j&                        j*                  z  }| j-                  || j                   j.                        }|I|j                  d|d         }| j1                  |      }|t3        j4                  | j6                        z  }nd}|| j1                  |      }|d   }t        j8                  t        j:                  ||z   ||z         d	      j%                  |      }|t3        j4                  | j6                        z  }| j<                  j%                  |      | _        | j<                  |ddf   }||z   |z   }| j?                  |      }|rdnd}|
rdnd}|	rdnd}tA        tC        | j                  |            D ]@  \  }\  }}|
r||fz   } |||||||   ||	      }|dd
 \  }}|du r||fz   }|	s8||d
   fz  }B | jE                  |      }|
r||fz   }|st        d ||||fD              S tG        ||||      S )aE  
        input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
            `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0].shape[-2]`
            (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.

            If `past_key_values` is used, only input IDs that do not have their past calculated should be passed as
            `input_ids`.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
            [`PreTrainedTokenizer.encode`] for details.

            [What are input IDs?](../glossary#input-ids)

        Example:

        ```python
        >>> from transformers import AutoTokenizer, CTRLModel
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("Salesforce/ctrl")
        >>> model = CTRLModel.from_pretrained("Salesforce/ctrl")

        >>> # CTRL was trained with control codes as the first token
        >>> inputs = tokenizer("Opinion My dog is cute", return_tensors="pt")
        >>> assert inputs["input_ids"][0, 0].item() in tokenizer.control_codes.values()

        >>> outputs = model(**inputs)

        >>> last_hidden_states = outputs.last_hidden_state
        >>> list(last_hidden_states.shape)
        [1, 5, 1280]
        ```NzDYou cannot specify both input_ids and inputs_embeds at the same timer%   r   z5You have to specify either input_ids or inputs_embedsr6   )r$   devicez$batch_size has to be defined and > 0r   r   r#   r   r   r   Tc              3   &   K   | ]	  }||  y wrO   r   ).0r@   s     r   	<genexpr>z$CTRLModel.forward.<locals>.<genexpr>  s     rqdedqrs   )last_hidden_stater   hidden_states
attentions)$r   ro   rn   r   use_return_dict
ValueError%warn_if_padding_and_no_attention_maskr<   viewr9   r   tupler_   r   r   r(   longr+   r*   r$   finfominget_head_maskr   r   r:   r;   r   triuonesr3   r   	enumeratezipr   r   )r\   r   r   rB   r   r   rC   r   rn   ro   r   r   kwargsinput_shaperi   r   past_lengthtoken_type_embedsseq_lenrA   
pos_embedsr   presentsall_hidden_statesall_attentionsr   r   rm   rv   rr   s                                 r   rw   zCTRLModel.forward	  s9   ` 2C1N-TXT_T_TqTq!*!6IDKK<Q<Q	$8$D $++JjJj 	 &1%<k$++B]B] ]%>cdd"66y.Q#..*K!r;r?;I"+J&',,.s3K&,,Q/JTUU%.%:!!@T@T"K#TFS[$89O)!,Q/44R8K <<[_{5RZ_ZdZdmstL'11!4L %Q !GHH+00R@N ,55a8BB1EN ,..TZZ.@N!N2ekk$**6M6Q6QQN &&y$++2E2EF	%+00[_EN $~ 6):):!;; !  FF9-Mb/zz%**W{%:Gk<QRTUVYYZ`a!2!233 !--008&&|Q7
%
25FF]3"2"6BD0d"+C,H"I 	0A:#$58H$H!%-#A,#"3G &-Ra["M7D #wj0 71:-/#	0& }5 1]4D Dr]H>OQ_$`rrr&+$+%	
 	
r!   )NNNNNNNNNNN)ry   rz   r{   rQ   r   r   r   r   r   r   
LongTensorr   FloatTensorboolr   Tensorr   rw   r|   r}   s   @r   r   r      sN   & B  15EI6:59371559$(,0/3&*^
E,,-^
 "%e.?.?(@"AB^
 !!2!23	^

 !!1!12^
 u//0^
 E--.^
   1 12^
 D>^
 $D>^
 'tn^
 d^^
 
uU\\"$;;	<^
 ^
r!   r   z
    The CTRL Model transformer with a language modeling head on top (linear layer with weights tied to the input
    embeddings).
    )custom_introc                   "    e Zd ZdgZ fdZd Zd Ze	 	 	 	 	 	 	 	 	 	 	 	 ddee	j                     deeee	j                           dee	j                     dee	j                     d	ee	j                     d
ee	j                     dee	j                     dee	j                     dee   dee   dee   dee   deee	j                     ef   fd       ZddZedeee	j                        de	j                  deee	j                        fd       Z xZS )CTRLLMHeadModelzlm_head.weightc                     t         |   |       t        |      | _        t	        j
                  |j                  |j                  d      | _        | j                          y )NTr   )
rP   rQ   r   r   r   rU   r   r   lm_headr   r\   r   r]   s     r   rQ   zCTRLLMHeadModel.__init__  sG     $V,yy0A0AM 	r!   c                     | j                   S rO   r  r   s    r   get_output_embeddingsz%CTRLLMHeadModel.get_output_embeddings  s    ||r!   c                     || _         y rO   r  r   s     r   set_output_embeddingsz%CTRLLMHeadModel.set_output_embeddings  s	    %r!   r   r   rB   r   r   rC   r   labelsrn   ro   r   r   r   c                 |   ||n| j                   j                  }| j                  ||||||||	|
||      }|d   }| j                  |      }d}|* | j                  ||fd| j                   j
                  i|}|s|f|dd z   }||f|z   S |S t        |||j                  |j                  |j                        S )a
  
        input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
            `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0].shape[-2]`
            (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.

            If `past_key_values` is used, only input IDs that do not have their past calculated should be passed as
            `input_ids`.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
            [`PreTrainedTokenizer.encode`] for details.

            [What are input IDs?](../glossary#input-ids)
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`

        Example:

        ```python
        >>> import torch
        >>> from transformers import AutoTokenizer, CTRLLMHeadModel

        >>> tokenizer = AutoTokenizer.from_pretrained("Salesforce/ctrl")
        >>> model = CTRLLMHeadModel.from_pretrained("Salesforce/ctrl")

        >>> # CTRL was trained with control codes as the first token
        >>> inputs = tokenizer("Wikipedia The llama is", return_tensors="pt")
        >>> assert inputs["input_ids"][0, 0].item() in tokenizer.control_codes.values()

        >>> sequence_ids = model.generate(inputs["input_ids"])
        >>> sequences = tokenizer.batch_decode(sequence_ids)
        >>> sequences
        ['Wikipedia The llama is a member of the family Bovidae. It is native to the Andes of Peru,']

        >>> outputs = model(**inputs, labels=inputs["input_ids"])
        >>> round(outputs.loss.item(), 2)
        9.21

        >>> list(outputs.logits.shape)
        [1, 5, 246534]
        ```N
r   rB   r   r   rC   r   rn   ro   r   r   r   r   r   )losslogitsr   r   r   )
r   r   r   r  loss_functionr   r   r   r   r   )r\   r   r   rB   r   r   rC   r   r  rn   ro   r   r   r   transformer_outputsr   	lm_logitsr  rJ   s                      r   rw   zCTRLLMHeadModel.forward  s   v &1%<k$++B]B]"..+))%'/!5# / 
 ,A.LL/	%4%%  ;;11 	D \$7$;;F)-)9TGf$EvE%/??-;;*55
 	
r!   c                     |G|d   d   j                   d   }|j                   d   |kD  r|}n|j                   d   dz
  }|d d |d f   }|||dS )Nr   r   r   )r   r   rn   )r9   )r\   r   r   rn   r   r   remove_prefix_lengths          r   prepare_inputs_for_generationz-CTRLLMHeadModel.prepare_inputs_for_generation&  st     &)!,Q/55a8K q!K/'2$ (1q'9A'=$!!%9%:":;I&?Ybccr!   beam_idxc                 ,    t        fd| D              S )a  
        This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
        [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
        beam_idx at every generation step.
        c              3   F   K   | ]  }t        fd |D                yw)c              3   t   K   | ]/  }|j                  d j                  |j                               1 yw)r   N)index_selectr*   r   )r   
past_stater  s     r   r   z;CTRLLMHeadModel._reorder_cache.<locals>.<genexpr>.<genexpr>B  s.     jQ[*))!X[[9J9J-KLjs   58Nr   )r   rm   r  s     r   r   z1CTRLLMHeadModel._reorder_cache.<locals>.<genexpr>A  s%      
 j_ijj
s   !r  )r   r  s    `r   _reorder_cachezCTRLLMHeadModel._reorder_cache8  s      
-
 
 	
r!   NNNNNNNNNNNNNN)ry   rz   r{   _tied_weights_keysrQ   r  r  r   r   r   r   r   r   r   r   r   r   rw   r  staticmethodr  r|   r}   s   @r   r   r     s    ++&  15EI6:59371559-1$(,0/3&*a
E,,-a
 "%e.?.?(@"ABa
 !!2!23	a

 !!1!12a
 u//0a
 E--.a
   1 12a
 ))*a
 D>a
 $D>a
 'tna
 d^a
 
uU\\"$::	;a
 a
Fd$ 
uU\\23
?D||
	uU\\"	#
 
r!   r   a  
    The CTRL Model transformer with a sequence classification head on top (linear layer).
    [`CTRLForSequenceClassification`] uses the last token in order to do the classification, as other causal models
    (e.g. GPT-2) do. Since it does classification on the last token, it requires to know the position of the last
    token. If a `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in
    each row. If no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot
    guess the padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last
    value in each row of the batch).
    c                       e Zd Z fdZe	 	 	 	 	 	 	 	 	 	 	 	 ddeej                     deeeej                           deej                     deej                     deej                     deej                     deej                     d	eej                     d
ee
   dee
   dee
   dee
   deeej                     ef   fd       Z xZS )CTRLForSequenceClassificationc                     t         |   |       |j                  | _        t        |      | _        t        j                  |j                  | j                  d      | _        | j                          y )NFr   )
rP   rQ   
num_labelsr   r   r   rU   r   
classifierr   r  s     r   rQ   z&CTRLForSequenceClassification.__init__S  sR      ++$V,))FMM4??O 	r!   r   r   rB   r   r   rC   r   r  rn   ro   r   r   r   c                    ||n| j                   j                  }| j                  ||||||||	|
||      }|d   }| j                  |      }||j                  dd \  }}n|j                  dd \  }}| j                   j
                  |dk7  rt        d      | j                   j
                  d}n||| j                   j
                  k7  j                  |j                  t        j                        }t        j                  |j                  d   |j                  t        j                        }||z  j                  d      }n.d}t        j                  | j                  j                    d	       |t        j                  ||j                  
      |f   }d}|| j                   j"                  | j$                  dk(  rd| j                   _        nl| j$                  dkD  rL|j&                  t        j(                  k(  s|j&                  t        j*                  k(  rd| j                   _        nd| j                   _        | j                   j"                  dk(  rIt-               }| j$                  dk(  r& ||j/                         |j/                               }n |||      }n| j                   j"                  dk(  r=t1               } ||j3                  d| j$                        |j3                  d            }n,| j                   j"                  dk(  rt5               } |||      }|s|f|dd z   }||f|z   S |S t7        |||j8                  |j:                        S )a2  
        input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
            `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0].shape[-2]`
            (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.

            If `past_key_values` is used, only input IDs that do not have their past calculated should be passed as
            `input_ids`.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
            [`PreTrainedTokenizer.encode`] for details.

            [What are input IDs?](../glossary#input-ids)
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).

        Example of single-label classification:

        ```python
        >>> import torch
        >>> from transformers import AutoTokenizer, CTRLForSequenceClassification

        >>> tokenizer = AutoTokenizer.from_pretrained("Salesforce/ctrl")
        >>> model = CTRLForSequenceClassification.from_pretrained("Salesforce/ctrl")

        >>> # CTRL was trained with control codes as the first token
        >>> inputs = tokenizer("Opinion My dog is cute", return_tensors="pt")
        >>> assert inputs["input_ids"][0, 0].item() in tokenizer.control_codes.values()

        >>> with torch.no_grad():
        ...     logits = model(**inputs).logits

        >>> predicted_class_id = logits.argmax().item()
        >>> model.config.id2label[predicted_class_id]
        'LABEL_0'
        ```

        ```python
        >>> import torch

        >>> torch.manual_seed(42)  # doctest: +IGNORE_RESULT
        >>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
        >>> num_labels = len(model.config.id2label)
        >>> model = CTRLForSequenceClassification.from_pretrained("Salesforce/ctrl", num_labels=num_labels)

        >>> labels = torch.tensor(1)
        >>> loss = model(**inputs, labels=labels).loss
        >>> round(loss.item(), 2)
        0.93
        ```

        Example of multi-label classification:

        ```python
        >>> import torch
        >>> from transformers import AutoTokenizer, CTRLForSequenceClassification

        >>> tokenizer = AutoTokenizer.from_pretrained("Salesforce/ctrl")
        >>> model = CTRLForSequenceClassification.from_pretrained(
        ...     "Salesforce/ctrl", problem_type="multi_label_classification"
        ... )

        >>> # CTRL was trained with control codes as the first token
        >>> inputs = tokenizer("Opinion My dog is cute", return_tensors="pt")
        >>> assert inputs["input_ids"][0, 0].item() in tokenizer.control_codes.values()

        >>> with torch.no_grad():
        ...     logits = model(**inputs).logits

        >>> predicted_class_id = logits.argmax().item()
        >>> model.config.id2label[predicted_class_id]
        'LABEL_0'
        ```

        ```python
        >>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
        >>> num_labels = len(model.config.id2label)
        >>> model = CTRLForSequenceClassification.from_pretrained("Salesforce/ctrl", num_labels=num_labels)

        >>> num_labels = len(model.config.id2label)
        >>> labels = torch.nn.functional.one_hot(torch.tensor([predicted_class_id]), num_classes=num_labels).to(
        ...     torch.float
        ... )
        >>> loss = model(**inputs, labels=labels).loss
        >>> loss.backward()  # doctest: +IGNORE_RESULT
        ```Nr
  r   r   r   z=Cannot handle batch sizes > 1 if no padding token is defined.r%   )r   r$   z will not detect padding tokens in `inputs_embeds`. Results may be unexpected if using padding tokens in conjunction with `inputs_embeds.`)r   
regressionsingle_label_classificationmulti_label_classification)r  r  r   r   )r   r   r   r#  r9   pad_token_idr   r*   r   r   int32r(   argmaxloggerwarning_oncer]   ry   problem_typer"  r$   r   rS   r	   squeezer   r   r   r   r   r   )r\   r   r   rB   r   r   rC   r   r  rn   ro   r   r   r  r   r  ri   sequence_lengthlast_non_pad_tokennon_pad_masktoken_indicespooled_logitsr  loss_fctrJ   s                            r   rw   z%CTRLForSequenceClassification.forward\  s   P &1%<k$++B]B]"..+))%'/!5# / 
 ,A./ *3//"1*='J*7*=*=bq*A'J;;##+
a\]];;##+!#"%)A)AAEEfmmUZU`U`aL!LL)<V]]Z_ZeZefM"/,">!F!Fr!J!#>>**+ ,Z Z
 u||Jv}}MOaab{{''/??a'/;DKK,__q(fllejj.HFLL\a\e\eLe/LDKK,/KDKK,{{''<7"9??a'#M$9$9$;V^^=MND#M6:D))-JJ+- 2 22t GUWY))-II,.v6#%(;AB(??F)-)9TGf$EvE' -;;*55	
 	
r!   r  )ry   rz   r{   rQ   r   r   r   r   r   r   r   r   r   r   rw   r|   r}   s   @r   r   r   G  sW     15EI6:59371559-1$(,0/3&*p
E,,-p
 "%e.?.?(@"ABp
 !!2!23	p

 !!1!12p
 u//0p
 E--.p
   1 12p
 ))*p
 D>p
 $D>p
 'tnp
 d^p
 
uU\\"$<<	=p
 p
r!   r   )r   r   r   r   r  )-__doc__typingr   r   r   numpyr:   r   r   torch.nnr   r   r	   
generationr   modeling_outputsr   r   r   modeling_utilsr   pytorch_utilsr   r   r   utilsr   r   configuration_ctrlr   
get_loggerry   r+  r    r4   rK   ModulerM   r   r   r   r   r   r   __all__r   r!   r   <module>rB     s!     ) )    A A ) i i - Y Y + 
		H	%
%4L L^`&299 &R */ * ** @
# @
 @
F S
)? S
S
l 
{
$7 {

{
| cr!   