ML 6: Characterizations of Federated Learning
- Published on
- ∘ 21 min read ∘ ––– views
Previous Article
Next Article
Previously, this post was entitled 'I <3 nips', but it was brought to my attention by someone who has influence over my paycheck that this was 'borderline'.
Introduction
I went to NIPS and all I got was this indomitable desire for knowledge. In this post I wanna dive into this paper1 which aims to characterize the tradeoffs between communication and sample complexity in Federated Reinforcement Learning settings. For additional background about Reinforcement Learning, see any of my other numerous posts about RL.2
RL Summary for the zillionth time
Reinforcement learning describes the decision-making paradigm of an agent taking actions in a (partially) unknown environment. Agents learn based on feedback metered out by the environment in the form of reward received for taking an action from a state of the environment at a discrete time step . Usually, rewards are some real-valued quantity on the interval , but that's just convention, and doesn't contribute to complexity.
Challenges
Two of the main challenges towards effective reinforcement learning are:
- Data collection – it is often time consuming or expensive due to high-dimensional decision space, hence the sparse knowability of the environment.
- Expensive – running/training reinforcement learning agents on non-trivial problems is computationally/temporally expensive.
One common solution to these problems is Federation which aims to distribute the computational load of reinforcement learning across several decentralized agents working on disjoint subset instances of the same problem, sharing policy updates across difference segments of training data. It's analogous to a divide & conquer multi-threading strategy.
Salgia attempts to quantify the upper and lower bounds of the benefits Federated Q-Learning which is governed by the tradeoff that communication becomes an integral part of algorithm design. The general principle is that increased communication between agents improves each of their individual learning efficiencies, and decreases sample complexity.
To characterize this curve, Salgia poses two research questions
- lower bound: What is the minimum amount of communication required of a Federated Reinforcement Learning algorithm in order to realize a statistic benefit from collaboration.3
- optimality: How can we design such an algorithm that simultaneously offers optimal order sample- and communication-complexities.
MDPs
We typically model RL problems as Markov Decision Processes which are useful for describing sequential interactions between agents and environments.
The process of Markov chain iteration is as follows:
- The environment provides a state to the agent.
- The agent chooses an action according to some policy (a probabilistic distribution over exploratory/exploitative actions).
- The environment rewards the agent with .4
- Finally, the environment provides the successive state to the agent according to some stochastic transition dynamics (environments are not necessarily deterministic, so taking the same action from the same state is not guaranteed to produce the same successive state).
Policy Optimization
Policies governing which actions the agent takes at any given state can be optimized a number of ways, perhaps the most fundamental of which is via Q-functions given by the statement:
Here, measures the long-term reward garnered by a policy operating under the environment's transition dynamics. The term is a discount used to decrease the weight of rewards earned in the distant future, since we're likely to perform frequent policy updates which might lead us to take different actions than those which predicted for distant . In other words, we're more confident about the expected value of rewards to be earned in the near future rather than those rewards earned in states which we might never even reach.
Since each factor in the infinite summation changes by a factor of , we can express the effect horizon of a reward as the closed form of the sum of an infinite geometric series:
where the process has a probability of terminating at each time step; an agent can expect to survive steps before termination, so rewards past this state are negligible.
The general goal of Q-Learning is to find an optimal policy which maximizes . That is, we learn by learning (it's just function approximations all the way down):
where is the unique solution to the Bellman equation:
A key characteristic of the Bellman function is that it's -contractive in infinity, meaning that
so we can find by finding the fixed point of a contractive operator via fixed point iteration, which is precisely what Q-Learning accomplishes:
where is an algorithm's step size at any given time step . This value can be understood as the learning rate.
In practice, however, the value of a policy at a given time given by the Bellman equation is unknowable (since the domain of is infinite), so we must instead use a stochastic approximation the value function :
We denote the fidelity of the approximation as , defined by the number of samples drawn from .
Federated Reinforcement Learning
The general setup involves a single server connecting agents solving a shared problem5 expressed as an MDP. Each agent performs rounds of local Q-Learning updates and periodically sends these back to the server:
The server then averages these updates and sends the aggregate policy update(s) back to the agents:
This process is repeated for distributed policy synchronizations. In real life, we might use more sophisticated methods of aggregation so as to discount stupid agents stuck in the trenches of reward minima, as well as executing synchronization aperiodically – but the general form of Federated Reinforcement Learning is as above.
FRL Metrics
Under such a construction, we can define the following metrics to quantify statistical benefits of collaboration vs. complexity:
- Sample complexity: .
- Error rate:
- This measures the worst case error relative to an optimal policy approximated by
- Sample Complexity:
- The smallest number of sample from each agent to achieve a sufficiently small error
- Communication complexity:
- Round trip cost:
- Measures the number of times that an agent sends a message to the server
- bit cost:
- Measures the number of bits sent by to the server
- Round trip cost:
With these metrics established, we can return to the research questions introduced earlier.
Lower Bound
The first theorem introduced by Salgia states that, for any intermittent communication algorithm with constant or linearly-rescaled step size, if6
then, for all communication schedule, batch size, and step size hyper parameters, we have an error rate lower bounded by some constant parameterized by over the square root of the number of samples:
Irrespective of the number of the agents participating in the federated effort, the error rate will not decrease meaning that we do not gain any statistical benefit from adding more agents. In other words, without sufficient communication, it's as if we're just using a single agent.
This result is derived by two decompositions. The first decomposition breaks down the error rate of our Q functions with respect to the optimum into two terms measuring bias and variance:
It is a classically accepted result that variance exhibits linear increase upon averaging across multiple agents, but that bias does not. So, if out bias term dominates the variance term, the result will be that we gain no benefit from collaboration offered by the variance term.
The second decomposition we're concerned with illustrates that Q-Learning induces a positive bias:
Scheduling Update Synchronizations
These two expressions underscore the importance of understanding the evolution of bias throughout the process of federated learning.
Communication schedules between agents can be denoted
consisting of rounds of updates where:
If we consider the following plot of bias over time, we can see that for each of of communication round , bias increases:
However, since we're averaging updates across all agents, overall bias is independent of . For smaller values of corresponding to more frequently scheduled update communications between agents, we can see that overall bias grows more slowly:
Similarly, larger values of result in larger growths in bias:
Thus, bias is directly proportional to which is inversely proportional to the communication complexity:
So, if , or (the umber of communication rounds is larger than the effective horizon), then the bias term dominates the variance term and we'll achieve zero collaborative gains.
Now, with a tight understanding of the lower bound, we can approach the second research question.
Achievability
How to design an algorithm that offers optimal order sample and communication complexities? Salgia introduces a new algorithm called Federated Doubly Variance Reduced Q-Learning (Fed-DVR-Q).
Notably, their algorithm operates at the optimal frontier of the sample-communication complexity boundary. It's the first algorithm to achieve optimal order sample complexity w.r.t. all salient hyperparameters, namely the effective horizon, and it also employs quantization of policy updates to achieve optimal communication costs.
The algorithm is relatively straightforward, operating in terms of epochs. At each epoch , the agents collaboratively update the policy estimate to such that the error between and the optimal estimate decreases by a factor of 2:
If we repeat this for epochs, we can guarantee an -optimal Q function. The novelty of the algorithm lies in the scheduling strategy. Whereas most existing approaches opt for larger and smaller sample sizes (achieving a lot of local, albeit noisy steps). This results in a lot of bias accumulation which must be accounted for via frequent averaging during the update aggregation.
Salgia inverts this relation instead preferring a larger value of resulting in higher-fidelity samples, allowing for less frequent communication. This is the key to decreasing communication complexity.
The second theorem introduced in the paper expresses the benefits of combining this approach with variance-reduced updates, stating that, for an implementation of the algorithm with the prescribed choice of parameters, and for some :
Note that the sample complexity now depends on and exhibits the optimal cubic dependence w.r.t. the event horizon. Additionally, the communication complexity matches the lower bound derived above.
Footnotes
Footnotes
Salgia, Sudeep and Yuejie Chi. "The Sample-Communication Complexity Trade-off in Federated Q-Learning." arXiv, 2024. ↩
https://www.murphyandhislaw.com/tags/reinforcement-learning ↩
where statistical benefit would be a measurable decrease in sample complexity as the number of agents working on a problem increased. ↩
alternatively, reward functions can be computed by the agent according to the favorability of the subsequent state based on the agent's context. I.e. the reward can be agent-dependent rather than environment-dependent. ↩
meaning that the transition kernel and reward functions are the same for each agent. ↩
Here and throughout, I opt to use a loose inequality rather than introduce constants to make the inequalities proper. This is the same as just referring to the order of the complexity quantities. ↩