Self-Attention Mechanism and LoRA (Low-Rank Adaptation)


Self-Attention Mechanism and LoRA (Low-Rank Adaptation)


What is LoRA?
Before I even cover this... I want to discuss the problem of training Large Language Models (LLMs).
As you know, LLMs are extremely big. They can consider billions of parameters and have many possible uses.

 

LoRA is a parameter-efficient fine-tuning method. Instead of fine-tuning all the weights of a large model, LoRA:

  • Freezes the original model weights (keeps them unchanged).

  • Adds small trainable low-rank matrices to specific parts of the model (usually the attention layers).

  • During training, only these small matrices are updated.

 

How it works...

In a previous article, I briefly discussed attention layers, but you need to understand what an attention layer is in order to understand how LoRA works.

An attention layer is a core part of transformer models, such as GPT, BERT, LLaMA, etc.

It allows the model to focus on different parts of the input when generating output, like how we focus on certain words when reading a sentence.

Inside the attention layer, the model computes three key components:

  • Query (Q)

  • Key (K)

  • Value (V)

Each of these is generated by multiplying the input with weight matrices (e.g., WQ, WK, WV), which are large trainable parameters.

So, attention layers are dense linear layers that heavily influence how the model processes information. Important! Expensive to train.

Self-Attention Step-by-Step (Matrix Version)

1. Input Features (X)

This is your input; each column vector (x1, x2, x3 ) can be thought of as a feature vector or an embedding of a particular input element (e.g. a word in a sentence or a patch in an image).

We have 3 input elements, each with a 4-dimensional feature vector.

$$ \begin{array}{r@{\quad}c@{\quad}l} % Main array with 3 columns: left alignment (for X=), center alignment (for matrix), left alignment (for row labels) & \begin{array}{ccc} x_1 & x_2 & x_3 \\ \Uparrow & \Uparrow & \Uparrow \end{array} \\[0.5em] % Column headers with arrows, aligned above the matrix X = & \begin{bmatrix} % The numerical matrix with brackets 2 & 0 & 0 \\ 0 & 1 & 0 \\ 2 & 1 & 0 \\ 0 & 0 & 1 \end{bmatrix} \end{array} $$

If the image's matrix X is the input for an attention mechanism:

  • it provides the raw feature representations for 3 individual input elements (if columns are elements)
  • or
  • 4 input elements (if rows are elements).
  • Therefore, be careful, as you may need to transpose your matrix!

    How to transpose your matrix? Transpose of a matrix is a matrix that is obtained by swapping the rows and columns of the given matrix or vice versa.

    X = X T

    Here is your matrix X, a 3 × 4 matrix (4 rows and 3 columns) When this matrix X is transposed we will have a matrix of 4 x 3 (3 rows and 4 columns).

    $$ X = \begin{bmatrix} 2 & 0 & 0 \\ 0 & 1 & 0 \\ 2 & 1 & 0 \\ 0 & 0 & 1 \end{bmatrix} $$

    $$ X^T = \begin{bmatrix} 2 & 0 & 2 & 0 \\ 0 & 1 & 1 & 0 \\ 0 & 0 & 0 & 1 \end{bmatrix} $$

    These representations will be transformed into Queries (Q), Keys (K), and Values (V).

    2. Create Query, Key, Value Matrices (Linear Projections)

    Important: Each column in the matrix is treated as a separate element. We use the X matrix in its original form, without any modifications.

    At the very beginning of training an attention-based model (like a Transformer), the

    WQ, WK and WV matrices are initialised with random numerical values (e.g. using a Glorot/Xavier or Kaiming initialisation).

    Weight Matrices with Randomly aassigned values:

    $$ W^Q = \begin{bmatrix} 0.1 & 0.2 & 0.3 & 0.4 \\ 0.5 & 0.6 & 0.7 & 0.8 \\ 0.9 & 1.0 & 1.1 & 1.2 \\ 1.3 & 1.4 & 1.5 & 1.6 \end{bmatrix} $$ $$ W^K = \begin{bmatrix} 0.2 & 0.1 & 0.4 & 0.3 \\ 0.6 & 0.5 & 0.8 & 0.7 \\ 1.0 & 0.9 & 1.2 & 1.1 \\ 1.4 & 1.3 & 1.6 & 1.5 \end{bmatrix} $$ $$ W^V = \begin{bmatrix} 0.3 & 0.4 & 0.1 & 0.2 \\ 0.7 & 0.8 & 0.5 & 0.6 \\ 1.1 & 1.2 & 0.9 & 1.0 \\ 1.5 & 1.6 & 1.3 & 1.4 \end{bmatrix} $$
    • Each does a matrix multiplication with the input X.
    $$ Q = X \cdot W_Q, \quad K = X \cdot W_K, \quad V = X \cdot W_V $$

    This gives us:

    • $$ Q = [q_1, q_2, q_3, q_4] $$
    • $$ K = [k_1, k_2, k_3, k_4] $$
    • $$ V = [v_1, v_2, v_3, v_4] $$

    Since this involves several mathematical steps, I’ll walk through the process just for the Q values to show how a single query vector is computed. I recommend following along with pen and paper for clarity!

    $Q = X \cdot W_Q = \begin{bmatrix} 0.1 & 0.2 & 0.3 & 0.4 \\ 0.5 & 0.6 & 0.7 & 0.8 \\ 0.9 & 1.0 & 1.1 & 1.2 \\ 1.3 & 1.4 & 1.5 & 1.6 \end{bmatrix} \times \begin{bmatrix} 2 & 0 & 0 \\ 0 & 1 & 0 \\ 2 & 1 & 0 \\ 0 & 0 & 1 \end{bmatrix}$

    $Q = \begin{bmatrix} 0.8 & 0.5 & 0.4 \\ 2.4 & 1.3 & 0.8 \\ 4.0 & 2.1 & 1.2 \\ 5.6 & 2.9 & 1.6 \end{bmatrix}$

    $Q = \begin{bmatrix} Q_{11} & Q_{12} & Q_{13} \\ Q_{21} & Q_{22} & Q_{23} \\ Q_{31} & Q_{32} & Q_{33} \\ Q_{41} & Q_{42} & Q_{43} \\ \end{bmatrix} $

    • Element $Q_{11}$ :$(0.1 \cdot 2) + (0.2 \cdot 0) + (0.3 \cdot 2) + (0.4 \cdot 0) = 0.2 + 0 + 0.6 + 0 = 0.8$
    • Element $Q_{12}$ :$(0.1 \cdot 0) + (0.2 \cdot 1) + (0.3 \cdot 1) + (0.4 \cdot 0) = 0 + 0.2 + 0.3 + 0 = 0.5$
    • Element $Q_{13}$ : $(0.1 \cdot 0) + (0.2 \cdot 0) + (0.3 \cdot 0) + (0.4 \cdot 1) = 0 + 0 + 0 + 0.4 = 0.4$
    • Element $Q_{21}$ : $(0.5 \cdot 2) + (0.6 \cdot 0) + (0.7 \cdot 2) + (0.8 \cdot 0) = 1.0 + 0 + 1.4 + 0 = 2.4$
    • Element $Q_{22}$ : $(0.5 \cdot 0) + (0.6 \cdot 1) + (0.7 \cdot 1) + (0.8 \cdot 0) = 0 + 0.6 + 0.7 + 0 = 1.3$
    • Element $Q_{23}$ : $(0.5 \cdot 0) + (0.6 \cdot 0) + (0.7 \cdot 0) + (0.8 \cdot 1) = 0 + 0 + 0 + 0.8 = 0.8$
    • Element $Q_{31}$ : $(0.9 \cdot 2) + (1.0 \cdot 0) + (1.1 \cdot 2) + (1.2 \cdot 0) = 1.8 + 0 + 2.2 + 0 = 4.0$
    • Element $Q_{32}$ : $(0.9 \cdot 0) + (1.0 \cdot 1) + (1.1 \cdot 1) + (1.2 \cdot 0) = 0 + 1.0 + 1.1 + 0 = 2.1$
    • Element $Q_{33}$ : $(0.9 \cdot 0) + (1.0 \cdot 0) + (1.1 \cdot 0) + (1.2 \cdot 1) = 0 + 0 + 0 + 1.2 = 1.2$
    • Element $Q_{41}$ : $(1.3 \cdot 2) + (1.4 \cdot 0) + (1.5 \cdot 2) + (1.6 \cdot 0) = 2.6 + 0 + 3.0 + 0 = 5.6$
    • Element $Q_{42}$ : $(1.3 \cdot 0) + (1.4 \cdot 1) + (1.5 \cdot 1) + (1.6 \cdot 0) = 0 + 1.4 + 1.5 + 0 = 2.9$
    • Element $Q_{43}$ : $(1.3 \cdot 0) + (1.4 \cdot 0) + (1.5 \cdot 0) + (1.6 \cdot 1) = 0 + 0 + 0 + 1.6 = 1.6$

    $$ Q = [q_1, q_2, q_3, q_4] $$

    $$ q_1 = [0.8, 2.4, 4.0, 5.6] $$ $$ q_2 = [0.5, 1.3, 2.1, 2.9] $$ $$ q_2 = [0.4, 0.8, 1.2, 1.6] $$

    $K = \begin{bmatrix} 1.2 & 0.5 & 0.3 \\ 2.8 & 1.3 & 0.7 \\ 4.4 & 2.1 & 1.1 \\ 6.0 & 2.9 & 1.5 \end{bmatrix}$

    $V = \begin{bmatrix} 0.8 & 0.5 & 0.2 \\ 2.4 & 1.3 & 0.6 \\ 4.0 & 2.1 & 1.0 \\ 5.6 & 2.9 & 1.4 \end{bmatrix}$

    3. Attention Scores (Dot Product)

    For a specific Query vector (let's say Qi for the i-th element), an attention score is calculated by taking the dot product between Qi and all Key vectors (Kj) in the sequence. The dot product measures the similarity or compatibility between the query and each key. A higher dot product indicates a stronger relationship or relevance.

    We compute dot products between queries and keys to measure attention:

    score ( qi , kj ) = qi kjT

    All these scores go into a matrix:

    Attention Scores Matrix = Q KT

    $K^T = \begin{bmatrix} 1.2 & 2.8 & 4.4 & 6.0 \\ 0.5 & 1.3 & 2.1 & 2.9 \\ 0.3 & 0.7 & 1.1 & 1.5 \end{bmatrix}$

    $Attention$ $Scores$ $Matrix$ $ = Q ⋅ K^T =\begin{bmatrix} 1.33 & 3.17 & 5.01 & 6.85 \\ 3.77 & 8.97 & 14.17 & 19.37 \\ 6.21 & 14.77 & 23.33 & 31.89 \\ 8.65 & 20.57 & 32.49 & 44.41 \end{bmatrix}$

    4. Scale

    Divide each score by dk, where dk is the dimension of the key vectors.

    This matrix K has 4 rows (representing 4 tokens/elements in the sequence) and 3 columns.

    Each column represents a dimension of the key vectors. Therefore, the dimension of the key vectors, dk, is 3.

    \[ \frac{Q \cdot K^T}{\sqrt{d_k}} \]

    The expression is √dk where dk is 3.

    Therefore, √3 ≈ 1.732 ≈ 2.

    The Scaled Attention Scores Matrix (after dividing by √dk = 1.732) is:

    $Scaled$ $Attention$ $Scores$ $Matrix$ $ = \begin{bmatrix} 0.7679 & 1.8303 & 2.8926 & 3.9550 \\ 2.1767 & 5.1790 & 8.1813 & 11.1836 \\ 3.5855 & 8.5277 & 13.4700 & 18.4122 \\ 4.9942 & 11.8764 & 18.7587 & 25.6409 \end{bmatrix} $

    5. Softmax

    In the context of attention mechanisms (specifically "Scaled Dot-Product Attention"), after computing the raw scores \( Q \cdot K^T \) and scaling them by \( \sqrt{d_k} \), these values are still arbitrary real numbers.

    They could be negative, positive, very large, or very small.

    Applying softmax to these scaled scores does the following:

    • Converts to Probabilities/Weights: It turns the similarity scores into a set of weights where higher original scores result in higher weights, and these weights represent how much focus (or "attention") the model should put on each corresponding Value vector.
    • Normalisation: It ensures that for any given query, the sum of attention weights across all Key-Value pairs is exactly 1. This means the attention is distributed like a probability, clearly showing the relative importance of each part of the input sequence.

    Apply softmax across each row to normalise scores into probabilities:

    Attention Weights = softmax ( QKT dk )

    Each row now represents how much each token attends to the others.

    Applying Softmax to the First Row:

    The first row of the Scaled Attention Scores Matrix is:

    \[ \mathbf{z} = [0.7679, 1.8303, 2.8926, 3.9550] \]

    Step 1: Calculate the exponential of each element ($e^{z_i}$)

    • For $z_1 = 0.7679$: \( e^{0.7679} \approx 2.155 \)
    • For $z_2 = 1.8303$: \( e^{1.8303} \approx 6.236 \)
    • For $z_3 = 2.8926$: \( e^{2.8926} \approx 18.040 \)
    • For $z_4 = 3.9550$: \( e^{3.9550} \approx 52.223 \)
    Step 2: Calculate the sum of all exponentials ($\sum e^{z_j}$) for this row

    Sum \( \approx 2.155 + 6.236 + 18.040 + 52.223 \approx 78.654 \)

    Step 3: Divide each exponential by the sum of exponentials

    • For \( \text{softmax}(z_1) \): \( \frac{2.155}{78.654} \approx 0.0274 \)
    • For \( \text{softmax}(z_2) \): \( \frac{6.236}{78.654} \approx 0.0793 \)
    • For \( \text{softmax}(z_3) \): \( \frac{18.040}{78.654} \approx 0.2294 \)
    • For \( \text{softmax}(z_4) \): \( \frac{52.223}{78.654} \approx 0.6639 \)

    Result for the First Row

    The first row of the Attention Weights matrix is approximately:

    \[ [0.0274, 0.0793, 0.2294, 0.6639] \]

    Notice that these values sum up to approximately 1 (due to rounding).

    6. Weighted Sum of Values

    Multiply the attention weights by the value vectors V:

    Z = Attention Weights V

    This gives new representations z1, z2, z3, z4 which are the final attention outputs.

    LoRA (Low-Rank Adaptation)

    LoRA (Low-Rank Adaptation) doesn't change the fundamental mathematical operations of the attention mechanism itself (like dot product, scaling, or softmax). Instead, it modifies how the weight matrices within the attention mechanism's linear projection layers are adapted during fine-tuning.

    In a Transformer's attention block, Query (Q), Key (K), and Value (V) matrices are derived from the input (\(X\)) by multiplying it with specific weight matrices:

    • \( Q = X W^Q \)
    • \( K = X W^K \)
    • \( V = X W^V \)

    Where \( W^Q \), \( W^K \), and \( W^V \) are the projection weight matrices for Queries, Keys, and Values, respectively. There's also typically an output projection matrix \( W^O \).

    The Math of LoRA Applied to Attention Projections:

    Let's consider one of these weight matrices, for example, the Query projection matrix \( W^Q \).

    • Let the original pre-trained weight matrix be \( W^Q_0 \in \mathbb{R}^{d_{model} \times d_k} \). Here, \( d_{model} \) is the dimension of the input embeddings from \( X \), and \( d_k \) is the dimension of the query vectors.

    When applying LoRA to this layer, the core idea is to freeze the original pre-trained \( W^Q_0 \) and introduce a small, trainable low-rank decomposition \( \Delta W^Q \) such that:

    \[ \Delta W^Q = B^Q A^Q \]

    where \( B^Q \in \mathbb{R}^{d_{model} \times r} \) and \( A^Q \in \mathbb{R}^{r \times d_k} \). The parameter \( r \) is the rank of the update, and it is chosen to be much smaller than both \( d_{model} \) and \( d_k \) (\( r \ll \min(d_{model}, d_k) \)).

    The effective weight matrix for queries during fine-tuning then becomes:

    \[ W^Q = W^Q_0 + B^Q A^Q \]

    This same approach is applied to \( W^K \), \( W^V \), and potentially \( W^O \) as well.

    The main benefit of this mathematical setup is a drastic reduction in trainable parameters for adaptation:

    • Full Fine-tuning: For a single matrix like \( W^Q \), you would train \( d_{model} \times d_k \) parameters.
    • LoRA Fine-tuning: You only train the parameters in \( B^Q \) and \( A^Q \), which amount to \( d_{model} \times r + r \times d_k \) parameters, \( r \) refers to the rank of the low-rank approximation in LoRA. Since \( r \) is typically very small (e.g., 4, 8, 16), this is a significant reduction.

    This means that while the attention mechanism still performs its standard operations (dot product, scaling, softmax, multiplication with V), the specific way the input data is projected into the Q, K, and V spaces (and how the final output is projected) is efficiently modified by LoRA's low-rank updates. The modification happens at the level of the projection matrices that feed into the attention computation, not the attention computation steps themselves.

    Some Python Code for LoRA

    The code was initially generated using OpenAI and Gemini, and I subsequently reviewed and corrected it.

    
    !pip install transformers sentence-transformers torch accelerate peft bitsandbytes
    import torch
    from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
    from sentence_transformers import SentenceTransformer
    from peft import LoraConfig, PeftModel, PeftConfig, TaskType # Added TaskType and LoraConfig directly
    import warnings
    from sklearn.metrics.pairwise import cosine_similarity # Added for clarity
    
    # Suppress some common warnings for cleaner output
    warnings.filterwarnings("ignore", category=FutureWarning)
    warnings.filterwarnings("ignore", category=UserWarning)
    
    # The sentence we'll be working with
    sentence = "The cat is black."
    print(f"Original Sentence: '{sentence}'\n")
    
    print("--- Step 1: Tokenization & Word Embeddings (from a general LLM) ---\n")
    
    # We'll use a smaller pre-trained LLM called 'gpt2' for this demonstration.
    # For more complex or production tasks, you'd typically use larger models like Llama-2 or Mistral.
    model_name_llm = "gpt2"
    tokenizer_llm = AutoTokenizer.from_pretrained(model_name_llm)
    # AutoModelForCausalLM is used for text generation tasks
    model_llm = AutoModelForCausalLM.from_pretrained(model_name_llm)
    
    # Add a pad token if the tokenizer doesn't have one. This is common for GPT-like models
    # and helps when batching sentences of different lengths.
    if tokenizer_llm.pad_token is None:
        tokenizer_llm.pad_token = tokenizer_llm.eos_token # The End-Of-Sequence token is often used as a pad token
        model_llm.config.pad_token_id = tokenizer_llm.eos_token_id
    
    print(f"**Loaded Tokenizer and Model:** {model_name_llm}")
    
    # Tokenize the sentence:
    # `return_tensors="pt"` ensures the output is a PyTorch tensor.
    # `padding=True` and `truncation=True` handle variable sentence lengths (not critical for one sentence).
    inputs = tokenizer_llm(sentence, return_tensors="pt", padding=True, truncation=True)
    print(f"**Token IDs:** {inputs['input_ids'].tolist()}")
    print(f"**Attention Mask:** {inputs['attention_mask'].tolist()}") # 1s indicate actual tokens, 0s for padding
    
    # Decode the token IDs back to human-readable tokens to see the tokenizer's split
    tokens = tokenizer_llm.convert_ids_to_tokens(inputs['input_ids'][0])
    print(f"**Tokens:** {tokens}")
    
    # Get word embeddings from the model:
    # We tell the model to output its hidden states (internal representations).
    # `torch.no_grad()` is used during inference to save memory and computation by not calculating gradients.
    with torch.no_grad():
        outputs = model_llm(**inputs, output_hidden_states=True)
        # The last layer's hidden states are typically considered the contextualized word embeddings
        word_embeddings = outputs.hidden_states[-1]
    
    print(f"**Word Embeddings Shape:** {word_embeddings.shape} (meaning {word_embeddings.shape[1]} tokens, each with a {word_embeddings.shape[2]}-dimensional embedding)")
    # Print the first few dimensions of embeddings for 'cat' and 'black'
    print(f"**Embedding for 'cat' (first 5 dims):** {word_embeddings[0, tokens.index('cat')].cpu().numpy()[:5]}...")
    print(f"**Embedding for 'black' (first 5 dims):** {word_embeddings[0, tokens.index('black')].cpu().numpy()[:5]}...\n")
    
    print("--- Step 2: Sentence Embeddings (using Sentence Transformers) ---\n")
    
    # Sentence Transformers are specially designed to create high-quality sentence-level embeddings.
    # 'all-MiniLM-L6-v2' is a popular choice for its balance of performance and efficiency.
    model_name_st = "all-MiniLM-L6-v2"
    model_st = SentenceTransformer(model_name_st)
    
    print(f"**Loaded Sentence Embedder:** {model_name_st}")
    
    # Encode the sentence into a single vector
    sentence_embedding = model_st.encode(sentence, convert_to_tensor=True)
    
    print(f"**Sentence Embedding Shape:** {sentence_embedding.shape} (a single vector for the entire sentence)")
    print(f"**Sentence Embedding (first 5 dims):** {sentence_embedding.cpu().numpy()[:5]}...\n")
    
    # Let's see how sentence embeddings can be used for semantic similarity:
    sentence2 = "A dark feline is present." # Semantically similar
    sentence3 = "The car is red."         # Semantically dissimilar
    
    embedding2 = model_st.encode(sentence2, convert_to_tensor=True)
    embedding3 = model_st.encode(sentence3, convert_to_tensor=True)
    
    # Cosine similarity measures the angle between two vectors. Closer to 1 means more similar.
    similarity_cat_feline = cosine_similarity(sentence_embedding.reshape(1, -1), embedding2.reshape(1, -1))[0][0]
    similarity_cat_car = cosine_similarity(sentence_embedding.reshape(1, -1), embedding3.reshape(1, -1))[0][0]
    
    print(f"**Similarity ('{sentence}' vs '{sentence2}'):** {similarity_cat_feline:.4f} (High similarity, as expected)")
    print(f"**Similarity ('{sentence}' vs '{sentence3}'):** {similarity_cat_car:.4f} (Low similarity, as expected)\n")
    
    print("--- Step 3: Prompting an LLM (Inference) ---\n")
    
    # We use the 'gpt2' model for text generation.
    # A prompt guides the LLM on what kind of text to generate.
    prompt = f"Given the sentence '{sentence}', complete the following story: The cat"
    
    # Tokenize the prompt for the LLM
    inputs_llm_gen = tokenizer_llm(prompt, return_tensors="pt")
    
    # Generate text:
    # `max_new_tokens` controls how long the generated text can be.
    # `num_return_sequences` specifies how many different outputs to generate.
    # `pad_token_id=tokenizer_llm.eos_token_id` is crucial for GPT-like models.
    # `do_sample=True` enables sampling, allowing for more creative (less deterministic) output.
    # `top_k` and `temperature` control the randomness and diversity of the sampling process.
    generated_ids = model_llm.generate(
        inputs_llm_gen["input_ids"],
        # max_new_tokens=20,
        num_return_sequences=1,
        pad_token_id=tokenizer_llm.eos_token_id,
        do_sample=True,
        top_k=50,
        temperature=0.7
    )
    
    # Decode the generated token IDs back into human-readable text
    generated_text = tokenizer_llm.decode(generated_ids[0], skip_special_tokens=True)
    print(f"Prompt: '{prompt}'")
    print(f"Generated Text: '{generated_text}'\n")
    
    
    print("--- Step 4: LoRA (Low-Rank Adaptation) - Conceptual Understanding ---\n")
    
    
    # This configuration is for illustration and wouldn't be directly executable here without a full fine-tuning setup.
    lora_config_example = LoraConfig(
        r=8,              # The rank of the update matrices (smaller `r` means fewer parameters to train)
        lora_alpha=16,    # A scaling factor for the LoRA weights
        target_modules=["q_proj", "v_proj"], # Common layers to apply LoRA in attention mechanisms
        lora_dropout=0.1, # Dropout applied to the LoRA layers
        bias="none",      # How bias parameters are handled
        task_type=TaskType.CAUSAL_LM # Specifies the type of task (e.g., text generation)
    )
    print(lora_config_example)
    
    
    import torch
    from transformers import AutoTokenizer, AutoModelForCausalLM
    from peft import LoraConfig, get_peft_model, TaskType
    import warnings
    
    # Suppress some common warnings for cleaner output
    warnings.filterwarnings("ignore", category=FutureWarning)
    warnings.filterwarnings("ignore", category=UserWarning)
    
    # --- 0. Setup ---
    sentence = "The cat is black."
    print(f"Original Sentence: '{sentence}'\n")
    
    # Choose a small pre-trained model for demonstration. 'gpt2' is good for causal language modeling.
    base_model_name = "gpt2"
    
    # --- 1. Load the Base LLM and Tokenizer ---
    print(f"--- Step 1: Loading Base LLM ({base_model_name}) ---")
    tokenizer = AutoTokenizer.from_pretrained(base_model_name)
    # AutoModelForCausalLM is for text generation
    base_model = AutoModelForCausalLM.from_pretrained(base_model_name)
    
    # Ensure the tokenizer has a pad token, which is often needed for batching and generation
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        base_model.config.pad_token_id = tokenizer.eos_token_id
    
    print(f"Base Model Parameters (total): {base_model.num_parameters()}")
    print(f"Example: First few weights of a linear layer in base model (e.g., first attention query layer):")
    # Accessing an example weight matrix (for visualization of LoRA's effect)
    # For GPT-2, the attention layers are in model.transformer.h[layer_idx].attn.c_attn
    # c_attn is a Conv1D layer that projects to Q, K, V. Let's pick one part.
    # We'll just take the first layer for simplicity.
    # The weights are usually stored in 'weight' or 'W' depending on the layer type
    # For Conv1D, it's typically 'weight'
    if hasattr(base_model.transformer.h[0].attn.c_attn, 'weight'):
        original_weight_qkv = base_model.transformer.h[0].attn.c_attn.weight.data.clone()
        print(f"Original c_attn.weight shape: {original_weight_qkv.shape}")
        print(f"Original c_attn.weight (first row, first 5 cols): {original_weight_qkv[0, :5].tolist()}")
    else:
        print("Could not directly access 'weight' attribute of c_attn. Skipping detailed weight inspection.")
    print("\n" + "="*80 + "\n")
    
    
    # --- 2. Configure LoRA ---
    print("--- Step 2: Configuring LoRA ---")
    # LoRA configuration parameters:
    # r (rank): The dimension of the low-rank matrices. A smaller 'r' means fewer trainable parameters.
    # lora_alpha: A scaling factor for the LoRA weights.
    # target_modules: The names of the modules (linear layers) in the base model to apply LoRA to.
    #                 Commonly "q_proj", "v_proj" for attention mechanisms. For GPT2, it's often within c_attn.
    #                 PEFT can sometimes infer this, but explicit is clearer.
    # bias: 'none', 'all', or 'lora_only'. Controls how bias parameters are handled.
    # task_type: Specifies the type of task (e.g., CAUSAL_LM for text generation).
    lora_config = LoraConfig(
        r=8, # Rank
        lora_alpha=16, # Scaling factor
        # For GPT-2, 'c_attn' is a Conv1D that projects to Q, K, V. PEFT can automatically handle this.
        # Alternatively, you might target specific sub-modules if they were explicitly named q_proj, v_proj.
        # For GPT2, it's common to target the c_attn and c_proj layers.
        target_modules=["c_attn", "c_proj"],
        lora_dropout=0.05,
        bias="none",
        task_type=TaskType.CAUSAL_LM,
    )
    print("LoRA Configuration:")
    print(lora_config)
    print("\n" + "="*80 + "\n")
    
    
    # --- 3. Apply LoRA to the Base Model ---
    print("--- Step 3: Applying LoRA to the Base Model ---")
    # `get_peft_model` wraps the base model, injecting LoRA adapters.
    # It makes only the LoRA adapter parameters trainable.
    lora_model = get_peft_model(base_model, lora_config)
    
    print("LoRA-enabled Model Structure and Trainable Parameters:")
    lora_model.print_trainable_parameters()
    # This output shows that only a small fraction of parameters are now trainable.
    print("\n" + "="*80 + "\n")
    
    
    # --- 4. Simulate a Small "Update" (Conceptual Training Step) ---
    # In a real scenario, you'd have a training loop, a dataset, and an optimizer.
    # Here, we'll manually change a small LoRA parameter to show it's trainable.
    # This simulates how a LoRA adapter would learn during fine-tuning.
    
    print("--- Step 4: Simulating a Small 'Update' to LoRA Adapters ---")
    # Find a LoRA adapter parameter and change it slightly
    found_lora_param = False
    for name, param in lora_model.named_parameters():
        if "lora" in name and "default.lora_A" in name: # Target 'lora_A' matrix
            print(f"Before 'update': first value of '{name}': {param.data.flatten()[0]:.6f}")
            param.data.fill_(0.01) # Set all values to 0.01 (simulating a learning update)
            print(f"After 'update': first value of '{name}': {param.data.flatten()[0]:.6f}")
            found_lora_param = True
            break # Just update one for demonstration
    if not found_lora_param:
        print("Could not find a 'lora_A' parameter to simulate update.")
    
    # Verify that base model parameters are unchanged
    print("\nVerifying Base Model Parameters are Frozen:")
    if hasattr(base_model.transformer.h[0].attn.c_attn, 'weight'):
        current_weight_qkv = base_model.transformer.h[0].attn.c_attn.weight.data
        # Note: `lora_model.base_model.transformer...` accesses the original frozen weights
        # We compare the original cloned weights to the weights inside the LoRA model's base
        # They should be identical if the base model is frozen.
        are_weights_same = torch.equal(original_weight_qkv, current_weight_qkv)
        print(f"Are original base model weights (c_attn.weight) unchanged? {are_weights_same}")
        print(f"Current c_attn.weight (first row, first 5 cols): {current_weight_qkv[0, :5].tolist()}")
    else:
        print("Skipping base weight verification due to direct access issue.")
    print("\n" + "="*80 + "\n")
    
    
    # --- 5. Perform Inference with the LoRA-enabled Model ---
    print("--- Step 5: Performing Inference with LoRA-enabled Model ---")
    
    prompt = f"The cat is black. It then"
    inputs = tokenizer(prompt, return_tensors="pt")
    
    # Generate text using the LoRA-enabled model
    # Use `no_grad()` to ensure no gradients are computed during inference
    with torch.no_grad():
        lora_generated_ids = lora_model.generate(
            inputs["input_ids"],
            max_new_tokens=20,
            num_return_sequences=1,
            pad_token_id=tokenizer.eos_token_id,
            do_sample=True,
            top_k=50,
            temperature=0.7
        )
    lora_generated_text = tokenizer.decode(lora_generated_ids[0], skip_special_tokens=True)
    print(f"Prompt: '{prompt}'")
    print(f"Generated by LoRA-enabled model: '{lora_generated_text}'")
    
    # For comparison, generate with the original base model (should produce different output if LoRA was effective)
    print("\n--- Comparing with Base Model (without LoRA effects) ---")
    with torch.no_grad():
        base_generated_ids = base_model.generate(
            inputs["input_ids"],
            max_new_tokens=20,
            num_return_sequences=1,
            pad_token_id=tokenizer.eos_token_id,
            do_sample=True,
            top_k=50,
            temperature=0.7
        )
    base_generated_text = tokenizer.decode(base_generated_ids[0], skip_special_tokens=True)
    print(f"Generated by original base model: '{base_generated_text}'")
    print("\n" + "="*80 + "\n")
    
    
    # --- 6. Merging LoRA Weights (for efficient inference) ---
    print("--- Step 6: Merging LoRA Weights for Efficient Inference ---")
    
    # After training, you can merge the LoRA adapters into the base model's weights.
    # This results in a single model that no longer requires the PEFT structure,
    # and performs inference at the same speed as a fully fine-tuned model.
    # The `merge_and_unload()` method detaches the LoRA adapters and updates the base model's weights.
    # It returns the base model with the merged weights.
    try:
        merged_model = lora_model.merge_and_unload()
        print("LoRA adapters merged successfully into the base model.")
    
        # Verify that the model is no longer a PeftModel
        print(f"Is the merged model still a PeftModel? {isinstance(merged_model, PeftModel)}")
    
        # Accessing the same weight matrix again to show it has changed after merge
        if hasattr(merged_model.transformer.h[0].attn.c_attn, 'weight'):
            merged_weight_qkv = merged_model.transformer.h[0].attn.c_attn.weight.data
            # Compare with original weights again
            are_weights_still_same_after_merge = torch.equal(original_weight_qkv, merged_weight_qkv)
            print(f"Are original base model weights (c_attn.weight) the same AFTER MERGE? {are_weights_still_same_after_merge}")
            print(f"Merged c_attn.weight (first row, first 5 cols): {merged_weight_qkv[0, :5].tolist()}")
            # You should see that `are_weights_still_same_after_merge` is False,
            # and `merged_weight_qkv` is different from `original_weight_qkv`.
        else:
            print("Skipping merged weight verification due to direct access issue.")
    
        # You can now save this merged model just like any other Hugging Face model
        # merged_model.save_pretrained("./my_merged_lora_model")
        # tokenizer.save_pretrained("./my_merged_lora_model")
    
    except Exception as e:
        print(f"Error merging LoRA adapters: {e}")
        print("This might happen if the model was already merged or if there's no GPU available for certain operations.")
    
    print("\n")
    print("--- Summary ---")
    print("This example demonstrated: ")
    print("1. Loading a base LLM.")
    print("2. Configuring LoRA parameters for efficient fine-tuning.")
    print("3. Transforming the base model into a LoRA-enabled model, showing drastically fewer trainable parameters.")
    print("4. Conceptually altering a LoRA parameter to show its trainability, while the base model remains frozen.")
    print("5. Performing inference to show how the LoRA adapters (even with simulated changes) can affect output.")
    print("6. Merging LoRA adapters back into the base model for deployment-ready inference without overhead.")
      

    And a bit more...

    
    import torch
    from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
    from sentence_transformers import SentenceTransformer
    from peft import PeftModel, PeftConfig # For LoRA conceptual understanding
    import warnings
    
    # Suppress some common warnings for cleaner output
    warnings.filterwarnings("ignore", category=FutureWarning)
    warnings.filterwarnings("ignore", category=UserWarning)
    
    # --- 1. The Sentence ---
    sentence = "The cat is black."
    print(f"Original Sentence: '{sentence}'\n")
    
    # --- 2. Tokenization & Word Embeddings (from a general LLM) ---
    print("--- Tokenization & Word Embeddings (from a general LLM) ---")
    
    # We'll use a smaller pre-trained LLM like 'gpt2' for demonstration.
    # For actual complex tasks, you'd use models like Llama-2, Mistral, etc.
    # Note: 'gpt2' is a decoder-only model, good for text generation.
    model_name_llm = "gpt2"
    tokenizer_llm = AutoTokenizer.from_pretrained(model_name_llm)
    model_llm = AutoModelForCausalLM.from_pretrained(model_name_llm)
    
    # Add a pad token if the tokenizer doesn't have one (common for GPT-like models)
    if tokenizer_llm.pad_token is None:
        tokenizer_llm.pad_token = tokenizer_llm.eos_token # End-of-sequence token often used as pad
        model_llm.config.pad_token_id = tokenizer_llm.eos_token_id
    
    print(f"Tokenizer: {model_name_llm}")
    
    # Tokenize the sentence
    # `return_tensors="pt"` returns PyTorch tensors
    inputs = tokenizer_llm(sentence, return_tensors="pt", padding=True, truncation=True)
    print(f"Token IDs: {inputs['input_ids'].tolist()}")
    print(f"Attention Mask: {inputs['attention_mask'].tolist()}")
    
    # Decode tokens back to words to see what the tokenizer did
    tokens = tokenizer_llm.convert_ids_to_tokens(inputs['input_ids'][0])
    print(f"Tokens: {tokens}")
    
    # Get word embeddings from the model
    # We get the hidden states (embeddings) from the last layer
    with torch.no_grad(): # Disable gradient calculation for inference
        outputs = model_llm(**inputs, output_hidden_states=True)
        word_embeddings = outputs.hidden_states[-1] # Last layer hidden states are often considered word embeddings
    
    print(f"Word Embeddings Shape (tokens, embedding_dim): {word_embeddings.shape}")
    print(f"Embedding for 'cat': {word_embeddings[0, tokens.index('cat')].cpu().numpy()[:5]}...") # First 5 dims
    print(f"Embedding for 'black': {word_embeddings[0, tokens.index('black')].cpu().numpy()[:5]}...\n")
    
    
    # --- 3. Sentence Embeddings (using Sentence Transformers) ---
    print("--- Sentence Embeddings (using Sentence Transformers) ---")
    
    # Sentence Transformers are specifically designed to produce good sentence-level embeddings.
    # 'all-MiniLM-L6-v2' is a popular and efficient choice.
    model_name_st = "all-MiniLM-L6-v2"
    model_st = SentenceTransformer(model_name_st)
    
    print(f"Sentence Embedder: {model_name_st}")
    
    # Encode the sentence
    sentence_embedding = model_st.encode(sentence, convert_to_tensor=True)
    
    print(f"Sentence Embedding Shape: {sentence_embedding.shape}")
    print(f"Sentence Embedding (first 5 dims): {sentence_embedding.cpu().numpy()[:5]}\n")
    
    # Example of comparing sentence embeddings (semantic similarity)
    sentence2 = "A dark feline is present."
    sentence3 = "The car is red."
    
    embedding2 = model_st.encode(sentence2, convert_to_tensor=True)
    embedding3 = model_st.encode(sentence3, convert_to_tensor=True)
    
    from sklearn.metrics.pairwise import cosine_similarity
    
    similarity_cat_feline = cosine_similarity(sentence_embedding.reshape(1, -1), embedding2.reshape(1, -1))[0][0]
    similarity_cat_car = cosine_similarity(sentence_embedding.reshape(1, -1), embedding3.reshape(1, -1))[0][0]
    
    print(f"Similarity ('{sentence}' vs '{sentence2}'): {similarity_cat_feline:.4f}")
    print(f"Similarity ('{sentence}' vs '{sentence3}'): {similarity_cat_car:.4f}\n")
    
    
    # --- 4. Prompting an LLM (Inference) ---
    print("--- Prompting an LLM (Inference) ---")
    
    # The LLM we loaded earlier (gpt2) can be used for text generation.
    # Let's create a simple prompt.
    prompt = f"Given the sentence '{sentence}', complete the following story: The cat"
    
    # Tokenize the prompt
    inputs_llm_gen = tokenizer_llm(prompt, return_tensors="pt")
    
    # Generate text
    # `max_new_tokens` limits the length of the generated output
    # `do_sample=True` enables sampling, `top_k` and `temperature` control randomness
    generated_ids = model_llm.generate(
        inputs_llm_gen["input_ids"],
        max_new_tokens=20,
        num_return_sequences=1,
        pad_token_id=tokenizer_llm.eos_token_id, # Essential for GPT2
        do_sample=True,
        top_k=50,
        temperature=0.7
    )
    
    generated_text = tokenizer_llm.decode(generated_ids[0], skip_special_tokens=True)
    print(f"Prompt: '{prompt}'")
    print(f"Generated Text: '{generated_text}'\n")
    
    
    # --- 5. LoRA (Low-Rank Adaptation) - Conceptual Understanding ---
    print("--- LoRA (Low-Rank Adaptation) ---")
    
    print("LoRA is a parameter-efficient fine-tuning (PEFT) technique.")
    print("It works by injecting small, trainable low-rank matrices into the existing layers of a pre-trained LLM.")
    print("Instead of fine-tuning *all* the millions/billions of parameters of the base LLM, LoRA only updates these much smaller new matrices.")
    print("This significantly reduces the number of trainable parameters, memory footprint, and training time, while still achieving good performance.")
    print("\nKey benefits:")
    print("- Much faster training.")
    print("- Significantly less VRAM usage.")
    print("- Smaller checkpoint sizes (only the LoRA adapter weights are saved).")
    print("- Easier to swap and combine different LoRA adapters for a single base model.")
    
    print("\nConceptual Steps for using LoRA (not executable in this example due to complexity):")
    print("1. Load a pre-trained base LLM (e.g., Llama-2, Mistral).")
    print("2. Import `LoraConfig` and `get_peft_model` from `peft` library.")
    print("3. Define `LoraConfig`: specifying which layers to target (e.g., query, value matrices), rank, alpha, dropout.")
    print("4. Wrap the base model with `get_peft_model(base_model, lora_config)`. This creates a `PeftModel`.")
    print("5. The `PeftModel` will now only train the newly injected LoRA adapter weights.")
    print("6. Train this `PeftModel` on your specific dataset (e.g., for sentiment analysis, summarization, etc.) using a `Trainer`.")
    print("7. Save only the LoRA adapters: `peft_model.save_pretrained('my_lora_adapters')`.")
    print("8. To use the fine-tuned model for inference: Load the base model, load the LoRA adapters, and then merge them or use the `PeftModel` directly.")
    
    print("\nExample LoRA config (conceptual):")
    from peft import LoraConfig, TaskType
    lora_config_example = LoraConfig(
        r=8, # Rank of the update matrices
        lora_alpha=16, # Scaling factor for LoRA weights
        target_modules=["q_proj", "v_proj"], # Which linear layers to apply LoRA to
        lora_dropout=0.1,
        bias="none",
        task_type=TaskType.CAUSAL_LM # Or TaskType.SEQ_CLS, etc.
    )
    print(lora_config_example)
    
    # This part is conceptual for merging LoRA adapters, not runnable without prior training:
    # try:
    #     # This assumes you have trained and saved LoRA adapters
    #     lora_model_path = "./my_lora_adapters"
    #     peft_config = PeftConfig.from_pretrained(lora_model_path)
    #     base_model_for_merge = AutoModelForCausalLM.from_pretrained(peft_config.base_model_name_or_path)
    #     lora_model_for_merge = PeftModel.from_pretrained(base_model_for_merge, lora_model_path)
    #     merged_model = lora_model_for_merge.merge_and_unload()
    #     print(f"\nSuccessfully conceptually loaded and merged LoRA adapters from {lora_model_path}")
    # except Exception as e:
    #     print(f"\n(Skipping LoRA loading/merging: No pre-trained LoRA adapters found for this example. Error: {e})")
    

     

    Resources to help you better understand LoRA that you should definitely check:

    • LoRA: Low-Rank Adaptation of Large Language Models
      Edward J. Hu*, Yelong Shen*, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen
      Paper Link: https://arxiv.org/abs/2106.09685
    • Video explainer: https://www.youtube.com/watch?v=DhRoTONcyZE
    • Github Repo: https://github.com/microsoft/LoRA?tab=readme-ov-file