Federated Learning (FL) is a new machine learning technique where there is no need to collect raw data at a single data center before training from each device. In FL, training is performed distributedly on each mobile device separately, and locally computed updates on the current global model are then shared. This helps protect users’ privacy as only updates (and not the data) are shared for a short time, and information on the source of updates is not needed for the optimization process in FL.
FL, as takes updates from each mobile, few things differ from traditional machine learning and distributed computing.
-
Non-IID Data : Devices used for training data can vary in hardware capibilities, sensors, operating systems. Similarly, user can vary based on geograhic and demographic, and consequently their usage pattern and preferences. Therefore, any user’s local dataset cannot be representative of population distribution.
-
Unbalanced Data : Even two similar users can varying amount of training data based on heavy or light usages.
- Massively Distributed : Data can be distributed among tens of thousand of users, where average number of data generated per user can be less than total number of users.
- Limited Communication : Devices can be frequently switched off or on restricted/poor internet connection. Therefore, each client tends to participate for only a small number of update round each day.
- Communication Cost : Unlike distributed systems, computation cost at data center is essential free in FL, but communication cost of for each round of updates is manifold.
Federated Learning Optimization
In FL, we assume a iterative process of synchronous updates in each round of communication.
If there is a fixed set of $K$ clients, each with a fixed local dataset. At the begining of each round, a random fraction of \(C\) clients is selected, and server sends the current global algorithm state to each of \(C\) clients. Each selected client then performs local computation based on global state and its own dataset, and sends an update back to server. The server then applies these updates to its global state, and the process repeats.
FL objective function can be stated as :
\[\underset{w ∈ \mathbb{R}^d}{\min} f(w) \notag\]where,
\[f(w) \dot{=} \frac{1}{n} \sum\limits_{i=1}^n f_{i}(w)\]Here, as it is inherently a machine learning problem, \(f_{i}(w) = l(x_{i}, y_{i},w)\) i.e, loss of prediction on data \((x_{i},y_{i})\) with model parameter $w$
Also, as we have \(K\) clients over which data is partitioned as \(P_{k}\) and set of indexes of data points on each client \(k\) as \(n_{k} = \lvert P_{k} \rvert\) , then we can re-write eq(1) as:
\[f_{w} = \sum\limits_{k=1}^K \frac{n_{k}}{n} F_{k}(w)\notag\]where,
\[F(w) = \frac{1}{n_{k}}\sum\limits_{i ∈ P_{k}} f_{i}(w)\notag\]Federated Learning Variants
Federated SGD
In Federated Stochastic Gradient Descent (FedSGD), Stochastic Gradient Descent is used for FL optimization. FedSGD uses \(C\) random subset of nodes, but all the local data on those nodes. We first initialize a model parameter \(w_{t}\). Then, on client side, each client $k$ computes its local average gradient \(g_{k} = ∇ F_{k}(w_{t})\) using all its local data. Then, on central server all local gradients are received and aggregated before global model update where model updated parameter \(w_{t+1}\) with learning rate \(\eta\) is computed as:
\[w_{t+1} ← w_{t} - η \sum\limits_{k=1}^K \frac{n_{k}}{n} g_{k}\]Federated Averaging
Federated Averaging FedAVG is similar to FedSGD, but it averages weights on each client state instead of updating global model with client’s local parameters on server side.
If \(\eta\) is learning rate, then updated weight on client state is computed as \(w_{t+1}^k ← w_{t} - η g_{k}\) and then server computes weighted average as :
\[w_{t+1} ← \sum\limits_{k=1}^K \frac{n_{k}}{n} w_{t+1}^k\]References
- McMahan, B., Moore, E., Ramage, D., Hampson, S. and y Arcas, B.A., 2017, April. Communication-efficient learning of deep networks from decentralized data. In Artificial intelligence and statistics (pp. 1273-1282). PMLR.