Component attribution

This notebook illustrate the following techniques:

Requirements:

Code
from ssr.lens import Lens

lens = Lens.from_preset("llama3.2_1b")
Loaded pretrained model meta-llama/Llama-3.2-1B-Instruct into HookedTransformer
Code
import transformer_lens as tl
from accelerate.utils.memory import release_memory
import torch as t
from rich import print
from jaxtyping import Float
from functools import partial
import numpy as np
import matplotlib.pyplot as plt
import tqdm
import einops

from ssr import VIOLET
from ssr.files import load_dataset
from ssr.types import HookList
from reproduce_experiments.plot import imshow

1. Load the mod dataset, and compute the logit difference: “Here” - “I”

To compute the difference of contribution of each component’s output to the final prediction between harmful and harmless sentences, we:

  • Run a forward pass on the harmless sentences, and capture the most frequent predicted token, for Llama 3.2 this is “Here”
  • Run a forward pass on the harmful sentences, and capture the most frequent predicted token, for Llama 3.2 this is “I” (“I cannot…”)
  • Compute the sub dataset of contrastive sentences for which the model will predict “I” or “Here” → here_mask
  • Compute the logit difference between “I” and “Here” → logit_diff_direction
Code
neg_idx = int(lens.model.to_single_token("I"))
pos_idx = int(lens.model.to_single_token("Here"))
logit_diff_directions = lens.model.tokens_to_residual_directions(
    neg_idx
) - lens.model.tokens_to_residual_directions(pos_idx)

n_heads = lens.model.cfg.n_heads
n_layers = lens.model.cfg.n_layers
lens.tokenizer.padding_side = "left"

hf_raw, hl_raw = load_dataset("mod")
hf, hl = lens.process_dataset(hf_raw, hl_raw)

# Create the `here_mask`:
with t.no_grad():
    logits_cuda = lens.model.forward(hl, return_type="logits", prepend_bos=False)

    logits_ = logits_cuda.cpu()
    logits_cuda = release_memory(logits_cuda)

here_mask = logits_[:, -1].argmax(-1) == pos_idx
here_idx = here_mask.long().nonzero().squeeze(-1)

hf_logits, hf_cache = lens.auto_scan(hf[here_idx], pattern=None)
hl_logits, hl_cache = lens.auto_scan(hl[here_idx], pattern=None)
100%|██████████| 1/1 [00:01<00:00,  1.03s/it]
100%|██████████| 1/1 [00:01<00:00,  1.01s/it]

2. Define the refusal metric

The refusal metric is defined such that:

  • If the prediction of the model is “Here”, the score is 0 (the model should refuse)
  • If the prediction of the model is “I”, the score is 1
  • The score is continuous and represent the prediction’s distribution between “Here” and “I”
Code
def refusal_metric(
    logits: Float[t.Tensor, "batch seq d_vocab"],
    init_logits_diffs: Float[t.Tensor, "2 batch"],
):
    patched_logits_diff = logits[:, -1, neg_idx] - logits[:, -1, pos_idx]
    return (patched_logits_diff - init_logits_diffs[1]) / (
        init_logits_diffs[0] - init_logits_diffs[1]
    )

3. Run forward passes with denoising

  • Run a forward pass on harmless sentences and cache the intermediate activation
  • For each component, run a forward pass on the corrupted sentences (harmless), and replace the output of this component by the output of the same component during the harmful pass
  • Compute the logit difference at the end of the forward pass, to catch the effect of denoising each component and find which component is sufficient to trigger the refusal behavior

The forward pass are batched. Each forward pass is performed on a batch of size n_heads, so at each iteration, a whole layer is tested.

Hence in the following line:

activations[head, :, head, :] = patch_activations[:, head, :]

→ the first ‘head’ correspond to the batch_idx

→ the second ‘head’ correspond to the head_idx

Code
def batch_hook_fn(
    activations, hook, patch_activations: Float[t.Tensor, "seq_len n_heads d_heads"]
):
    for head in range(n_heads):
        activations[head, :, head, :] = patch_activations[:, head, :]
    return activations


full_scores_list = []
with t.no_grad():
    for i in tqdm.tqdm(here_idx[:10]):
        prompts = [lens.apply_chat_template(p) for p in [hf_raw[i], hl_raw[i]]]
        tokens = lens.tokenizer(
            prompts, padding=True, add_special_tokens=False, return_tensors="pt"
        ).input_ids

        logits, cache = lens.auto_scan(tokens, pattern="z")
        hook_z = cache.stack_activation("z")
        init_logits_diff = (
            (logits[:, -1, neg_idx] - logits[:, -1, pos_idx])
            .unsqueeze(-1)
            .repeat(1, n_heads)
        )

        scores_list = []

        for layer in range(n_layers):
            fwd_hooks: HookList = [
                (
                    tl.utils.get_act_name("z", layer=layer),
                    partial(batch_hook_fn, patch_activations=hook_z[layer, 0]),
                )
            ]

            with lens.model.hooks(fwd_hooks=fwd_hooks) as hooked_model:
                logits = hooked_model.forward(
                    tokens[1].unsqueeze(0).repeat(n_heads, 1), return_type="logits"
                )

                scores_list.append(refusal_metric(logits.cpu(), init_logits_diff))
                logits = release_memory(logits)

        full_scores_list.append(t.vstack(scores_list))

imshow(
    t.cat([x.unsqueeze(0) for x in full_scores_list], dim=0).mean(dim=0).T,
    xaxis_title="Layers",
    yaxis_title="Heads",
    title=f"Patching Attribution",
    size=(600, 400),
    border=True,
)
100%|██████████| 1/1 [00:00<00:00,  7.09it/s]
100%|██████████| 1/1 [00:00<00:00,  7.43it/s]]
100%|██████████| 1/1 [00:00<00:00,  7.22it/s]]
100%|██████████| 1/1 [00:00<00:00,  7.23it/s]]
100%|██████████| 1/1 [00:00<00:00,  7.61it/s]]
100%|██████████| 1/1 [00:00<00:00,  7.00it/s]]
100%|██████████| 1/1 [00:00<00:00,  7.58it/s]]
100%|██████████| 1/1 [00:00<00:00,  7.71it/s]]
100%|██████████| 1/1 [00:00<00:00,  7.56it/s]]
100%|██████████| 1/1 [00:00<00:00,  7.65it/s]]
100%|██████████| 10/10 [00:38<00:00,  3.82s/it]
Unable to display output for mime type(s): application/vnd.plotly.v1+json

4. DLA

This part’s code is borrowed in large part from the ARENA course: https://arena-chapter1-transformer-interp.streamlit.app/[1.4.1]_Indirect_Object_Identification.

In our case, the “correct” answer is the refusal (“I”), and the incorrect answer is the acceptance (“Here”).

Code
hf_logit_diffs = hf_logits[:, -1, neg_idx] - hf_logits[:, -1, pos_idx]
hf_final = hf_cache["resid_post", -1][:, -1, :]
hl_final = hl_cache["resid_post", -1][:, -1, :]
hf_final_scaled = hf_cache.apply_ln_to_stack(hf_final, layer=-1, pos_slice=-1)
hl_final_scaled = hl_cache.apply_ln_to_stack(hl_final, layer=-1, pos_slice=-1)


def residual_stack_to_logit_diff(
    residual_stack: Float[t.Tensor, "... batch d_model"],
    cache: tl.ActivationCache,
    logit_diff_directions: Float[t.Tensor, "d_model"] = logit_diff_directions,
) -> Float[t.Tensor, "..."]:
    """
    Gets the avg logit difference between the correct and incorrect answer for a given stack of components in the
    residual stream.
    """
    batch_size = residual_stack.size(-2)
    scaled_residual_stack = cache.apply_ln_to_stack(
        residual_stack, layer=-1, pos_slice=-1
    )
    return (
        einops.einsum(
            scaled_residual_stack.cuda().half(),
            logit_diff_directions.unsqueeze(0).repeat(batch_size, 1),
            "... batch d_model, batch d_model -> ...",
        )
        / batch_size
    )


accumulated_residual, labels = hf_cache.accumulated_resid(
    layer=-1, incl_mid=True, pos_slice=-1, return_labels=True
)
logit_lens_logit_diffs: Float[t.Tensor, "component"] = residual_stack_to_logit_diff(
    accumulated_residual, hf_cache
)
Code
plt.rcParams.update(
    {
        "font.size": 15,
        "axes.titlesize": 17,
        "axes.labelsize": 15,
        "xtick.labelsize": 13,
        "ytick.labelsize": 13,
        "legend.fontsize": 13,
        "figure.titlesize": 18,
        "lines.linewidth": 3,
    }
)

support = np.arange(logit_lens_logit_diffs.shape[0]) / 2

plt.figure(figsize=(12, 6))
fig, ax = plt.subplots(figsize=(12, 6))
plt.plot(
    support,
    logit_lens_logit_diffs.detach().cpu().numpy(),
    color=VIOLET,
    label="Accumulated contribution",
)

intervention_layers = [14]
for layer in intervention_layers:
    plt.axvline(x=layer, color="gray", linestyle=":", alpha=0.7)


plt.xlabel("Layers")
plt.xticks(support, labels, rotation=45, ha="right")  # type: ignore
plt.ylabel("Logit difference")
plt.title("Layer Attribution")
plt.legend()
fig.patch.set_alpha(0)
ax.set_facecolor("none")

plt.show()
<Figure size 1200x600 with 0 Axes>

5. Compute DLA per heads

Instead of computing the contribution of each layer, we can take the contribution to each individual head and verify that the result coincide.

The next part relies heavily on the Transformer Lens library, that has utilities for this.

Code
scores = t.zeros(n_layers, n_heads)

with t.no_grad():
    for i in range(len(hf[here_idx])):
        _, hf_smal_cache = lens.auto_scan(hf[here_idx][0], pattern=None)
        hf_smal_cache = hf_smal_cache.to("cuda")
        per_head_residual, labels = hf_smal_cache.stack_head_results(
            layer=-1, pos_slice=-1, return_labels=True
        )
        per_head_residual = einops.rearrange(
            per_head_residual,
            "(layer head) ... -> layer head ...",
            layer=lens.model.cfg.n_layers,
        )
        per_head_logit_diffs = residual_stack_to_logit_diff(
            per_head_residual, hf_smal_cache
        )
        scores += per_head_logit_diffs.cpu()

        per_head_logit_diffs, per_head_residual, hf_smal_cache = release_memory(
            per_head_logit_diffs, per_head_residual, hf_smal_cache
        )

imshow(
    scores.T,
    color_continuous_midpoint=0.5,
    xaxis_title="Layers",
    yaxis_title="Heads",
    title=f"Direct Logit Attribution",
    size=(600, 400),
    border=True,
)
100%|██████████| 1/1 [00:00<00:00,  6.05it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  6.13it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  6.43it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  6.33it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  6.34it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  6.23it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  6.19it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  6.11it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  6.35it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  6.32it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  6.34it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  6.00it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  6.27it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  6.05it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  6.23it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  5.97it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  5.90it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  6.15it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  6.20it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  6.23it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  6.15it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  6.24it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  6.07it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  5.83it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  5.34it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  5.20it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  5.79it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  5.84it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  5.47it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  5.81it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  5.91it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  5.45it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  5.88it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  5.88it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  5.95it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  5.52it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  5.44it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  5.77it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  5.78it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  5.81it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  5.82it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  5.77it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  5.77it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  5.78it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  5.67it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  5.32it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  5.90it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  6.18it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  6.26it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  6.27it/s]
Tried to stack head results when they weren't cached. Computing head results now
100%|██████████| 1/1 [00:00<00:00,  6.30it/s]
Tried to stack head results when they weren't cached. Computing head results now
Unable to display output for mime type(s): application/vnd.plotly.v1+json