Le chat

from causapscal.lens import Lens

lens = Lens.from_preset("gpt")
Loaded pretrained model gpt2 into HookedTransformer
from rich import print
import torch as t

from tp.analysis import plot_attention_patterns, filtered_to_str_tokens
PHRASE = "Quelle est la couleur du chat de Hermione Granger?"
print(lens.model.to_str_tokens(PHRASE))
[
    '<|endoftext|>',
    'Q',
    'uel',
    'le',
    ' est',
    ' la',
    ' cou',
    'le',
    'ur',
    ' du',
    ' chat',
    ' de',
    ' Hermione',
    ' Granger',
    '?'
]
vocab = lens.model.to_string(t.arange(lens.model.cfg.d_vocab).unsqueeze(-1))
print(vocab[:5])
print(len(vocab))
['!', '"', '#', '$', '%']
50257
tokens = lens.model.to_tokens(PHRASE)
print(tokens)
tensor([[50256,    48,  2731,   293,  1556,  8591,  2284,   293,   333,  7043,
          8537,   390, 19959, 46236,    30]], device='cuda:0')
logits, cache = lens.model.run_with_cache(tokens)
print(logits.shape)
torch.Size([1, 15, 50257])
predictions = logits[0, -1, :].topk(k=5)
print(lens.model.to_string(predictions.indices.unsqueeze(-1)))
['\n', ' (', '\n\n', ' I', ' A']
str_tokens = filtered_to_str_tokens(lens.model, [PHRASE])

LAYER = 7
EXAMPLE = 0
for layer in range(lens.model.cfg.n_layers):
    plot_attention_patterns(
        cache["pattern", layer][EXAMPLE], str_tokens[EXAMPLE], layer=layer
    )

Induction

PHRASE = "Ceci est une superbe phrase qui ne sert à - Sckathapscal Gorphineus Quantifilius Artificewick des Vents. Ceci est une superbe phrase qui ne sert à - Sckathapscal Gorphineus Quantifilius Artificewick des Vents."
str_tokens = filtered_to_str_tokens(lens.model, [PHRASE])
tokens = lens.model.to_tokens(PHRASE)
logits, cache = lens.model.run_with_cache(tokens)
# " G" à 21 et 56
pred_tok_21 = logits[0, 21].topk(k=5)
pred_tok_56 = logits[0, 56].topk(k=5)
print(lens.model.to_string(pred_tok_21.indices.unsqueeze(-1)))
['.', '-', 'eb', 'aud', 'mb']
print(lens.model.to_string(pred_tok_56.indices.unsqueeze(-1)))
['orph', 'morph', 'omorph', 'obi', 'omb']
EXAMPLE = 0
for layer in range(lens.model.cfg.n_layers):
    plot_attention_patterns(
        cache["pattern", layer][EXAMPLE], str_tokens[EXAMPLE], layer=layer
    )