Challenge 1: Poison [1/2]

Author

Le magicien quantique

Published

May 12, 2024

!pip install -r "../requirements.txt";
from fl.preprocessing import load_mnist, data_to_client
from fl.model import NN, train_and_test
from fl.utils import plot_train_and_test, weights_to_json
from fl.federated_learning import federated;

These two challenges aim to introduce the technique of federated learning and the potential dangers to consider.

du_poison.jpg

This series of challenges is accompanied by utility functions from the fl module. Everything is in the provided .zip file mentioned in the statement.

1 Federated Learning

Sometimes, instead of performing all the training at once from a single database, it’s preferable to train multiple versions of the model on varied and potentially decentralized data. This is the case, for example, with recommendation algorithms that train directly on users’ machines, and then aggregate on a common server.

1.1 Example of the scenario we are considering

The central server creates a base model, saves its weights, and sends them to all clients (the base weights are available here: weights/base_fl.weights.h5). In our case, there are five clients, and you are one of the five clients. Each client trains the model on their side with their own data (represented in our simulation by x_clients, y_clients), then sends the results, the weights, to the server. The server then aggregates the weights by averaging across all clients. This produces a new version of the common model, which it can then redistribute, and so on.

Let’s imagine that the base model has the weights: \[ M_1 = \{W_1, b_1, W_2, b_2, W_3, b_3, W_4, b_4\} \] (here, I’ve used the commonly used notations, \(W\) for weights* and \(b\) for biases, everything is considered a “weight” when coding)*

Next, each client trains the model on their own dataset, which will update the local weights. For all \(i \in [1, \mathrm{nb_client}]\), the client \(i\)’s model is defined by: \[ M_1^{(i)} = \{W_1 + \delta W_1^{(i)}, b_1 + \delta b_1^{(i)},..., W_4 + \delta W_4^{(i)}, b_4 + \delta b_4^{(i)}\} = \{W_1^{(i)}, b_1^{(i)}, ..., W_4^{(i)}, b_4^{(i)}\} \]

The clients send their weights to the server, which aggregates them to create a new version of the common model: \[ M_2 = \left\{\sum_{i=1}^{\mathrm{nb_clients}} W_1^{(i)}, \sum_{i=1}^{\mathrm{nb_clients}} b_1^{(i)}, \sum_{i=1}^{\mathrm{nb_clients}} W_2^{(i)}, \sum_{i=1}^{\mathrm{nb_clients}} b_2^{(i)}, \sum_{i=1}^{\mathrm{nb_clients}} W_3^{(i)}, \sum_{i=1}^{\mathrm{nb_clients}} b_3^{(i)}, \sum_{i=1}^{\mathrm{nb_clients}} W_4^{(i)}, \sum_{i=1}^{\mathrm{nb_clients}} b_4^{(i)}\right\} \] and so on.

To visualize all this and potentially run your tests, I’ve provided the server-side function that handles weight aggregation. It’s called federated, and in fact, it performs the entire simulation. It trains the five clients on their respective sides and then aggregates their weights. Here’s an example:

We start by retrieving the data:

fl_iterations = 5
client_epochs = 1
nb_clients = 5

x_train, y_train, x_test, y_test = load_mnist()
x_clients, y_clients = data_to_client(x_train, y_train, nb_clients=nb_clients)      # Simule le fait que les clients ont des jeux de données différents 

The entire federated learning process is contained within the federated function:

federated_learning = federated(
    x_clients, 
    y_clients, 
    x_test,                             # The server validates the results on a single and consistent test set
    y_test, 
    fl_iterations=fl_iterations,        # We simulate only one iteration of federated learning (M_1 -> M_2).
    client_epochs=client_epochs                 
)

We can then display the results:

history = federated_learning["history_acc"]
plot_train_and_test([history], ["FL"], fl_iterations)

2 On Your Side

On your side, you don’t need to simulate the other 4 clients; you only need to handle your version of the common model.

First, you need to retrieve the weights and re-establish the structure of the common model \(\left(W_1, b_1, W_2, b_2, W_3, b_3\right)\):

model_base = NN()
model_base.load_weights("../weights/base_fl.weights.h5")

Next, you can improve the model locally:

local_epochs = 5

local_results = train_and_test(
    model_base, 
    x_train,        # You can train your local model on all the data, or on whatever you choose—this is precisely the principle.
    y_train, 
    x_test, 
    y_test, 
    epochs=local_epochs
)
plot_train_and_test([local_results["history"].history["val_accuracy"]], ["Entraînement local"], epochs=local_epochs)
63/63 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step
Accuracy of the model: 0.942

Votre modèle local entraîné est disponible ici :

type(local_results["model"])
fl.model.NN

Your trained local model is available here:

print(f"""
Number of layers: {len(local_results["weights"])}
Size of W1: {local_results["weights"][0].shape}
Size of b1: {local_results["weights"][1].shape}
Size of W2: {local_results["weights"][2].shape}
Size of b2: {local_results["weights"][3].shape}
Size of W3: {local_results["weights"][4].shape}
Size of b3: {local_results["weights"][5].shape}
Size of W4: {local_results["weights"][6].shape}
Size of b4: {local_results["weights"][7].shape}
""")

Nombre de couches : 8
Taille de W1 : (784, 1000)
Taille de b1 : (1000,)
Taille de W2 : (1000, 700)
Taille de b2 : (700,)
Taille de W3 : (700, 500)
Taille de b3 : (500,)
Taille de W4 : (500, 10)
Taille de b4 : (10,)

You then simply need to send your weights back to the server:

import requests as rq

URL = "https://du-poison.challenges.404ctf.fr"
rq.get(URL + "/healthcheck").json()
{'message': 'Statut : en pleine forme !'}
d = weights_to_json(local_results["weights"])
# d = weights_to_json(model.get_weights())
rq.post(URL + "/challenges/1", json=d).json()
{'message': "Raté ! Le score de l'apprentissage fédéré est de 0.946. Il faut l'empoisonner pour qu'il passe en dessous de 0.5"}

3 Your Turn!

You represent one of the 5 clients, find a way to poison the model to bring the accuracy down as low as possible!

The model structure must remain the same, and you don’t have access to the epochs or fl_iterations parameters coded directly on the server side, so you must act directly on your only area of influence: the weights!