ML 6: Characterizations of Federated Learning

Published on
21 min read––– views

Previously, this post was entitled 'I <3 nips', but it was brought to my attention by someone who has influence over my paycheck that this was 'borderline'.

Introduction

I went to NIPS and all I got was this indomitable desire for knowledge. In this post I wanna dive into this paper1 which aims to characterize the tradeoffs between communication and sample complexity in Federated Reinforcement Learning settings. For additional background about Reinforcement Learning, see any of my other numerous posts about RL.2

RL Summary for the zillionth time

Reinforcement learning describes the decision-making paradigm of an agent taking actions in a (partially) unknown environment. Agents learn based on feedback metered out by the environment in the form of reward rr received for taking an action aAa \in \mathcal A from a state sSs \in \mathcal S of the environment at a discrete time step tt. Usually, rewards are some real-valued quantity on the interval [0,1][0, 1], but that's just convention, and doesn't contribute to complexity.

Challenges

Two of the main challenges towards effective reinforcement learning are:

  1. Data collection – it is often time consuming or expensive due to high-dimensional decision space, hence the sparse knowability of the environment.
  2. Expensive – running/training reinforcement learning agents on non-trivial problems is computationally/temporally expensive.

One common solution to these problems is Federation which aims to distribute the computational load of reinforcement learning across several decentralized agents working on disjoint subset instances of the same problem, sharing policy updates across difference segments of training data. It's analogous to a divide & conquer multi-threading strategy.

Salgia attempts to quantify the upper and lower bounds of the benefits Federated Q-Learning which is governed by the tradeoff that communication becomes an integral part of algorithm design. The general principle is that increased communication between agents improves each of their individual learning efficiencies, and decreases sample complexity.

To characterize this curve, Salgia poses two research questions

  1. lower bound: What is the minimum amount of communication required of a Federated Reinforcement Learning algorithm in order to realize a statistic benefit from collaboration.3
  2. optimality: How can we design such an algorithm that simultaneously offers optimal order sample- and communication-complexities.

MDPs

We typically model RL problems as Markov Decision Processes which are useful for describing sequential interactions between agents and environments.

The process of Markov chain iteration is as follows:

  1. The environment provides a state sts_t to the agent.
  2. The agent chooses an action ata_t according to some policy π\pi (a probabilistic distribution over exploratory/exploitative actions).
  3. The environment rewards the agent with rr.4
  4. Finally, the environment provides the successive state st+1s_{t+1} to the agent according to some stochastic transition dynamics τ\tau (environments are not necessarily deterministic, so taking the same action from the same state is not guaranteed to produce the same successive state).
τ:(S×A)×S[0,1]\tau : (\mathcal S \times \mathcal A) \times \mathcal S \rightarrow [0,1]

Policy Optimization

Policies governing which actions the agent takes at any given state can be optimized a number of ways, perhaps the most fundamental of which is via Q-functions given by the statement:

(s,a)S×A:Qπ(s,a)=E[t=0γtr(st,at)s0=s,a0=a]\forall (s,a) \in \mathcal{S \times A}: \quad Q^\pi(s,a) = \mathbb E \Big[\sum_{t=0}^{\infty}\gamma^tr(s_t,a_t) |s_0 = s, a_0 =a \Big]

Here, QQ measures the long-term reward garnered by a policy operating under the environment's transition dynamics. The γ\gamma term is a discount used to decrease the weight of rewards earned in the distant future, since we're likely to perform frequent policy updates which might lead us to take different actions than those which predicted ata_t for distant tt. In other words, we're more confident about the expected value of rewards to be earned in the near future rather than those rewards earned in states which we might never even reach.

Since each factor in the infinite summation changes by a factor of γ\gamma, we can express the effect horizon of a reward as the closed form of the sum of an infinite geometric series:

tf=11γt_f = \frac{1}{1 - \gamma}

where the process has a (1γ)(1-\gamma) probability of terminating at each time step; an agent can expect to survive 11γ\frac{1}{1-\gamma} steps before termination, so rewards past this state are negligible.

The general goal of Q-Learning is to find an optimal policy π\pi^* which maximizes Qπ(s,a)Q^\pi(s, a). That is, we learn π\pi^* by learning QQ^* (it's just function approximations all the way down):

π(s)=argmaxaAQ(s,a)\pi^*(s) = \arg \max_{a\in\mathcal A} Q^*(s,a)

where QQ is the unique solution to the Bellman equation:

V(Q)=Q;V(Q)(s,a)=r(s,a)+γEs τ(s,a)[maxaAQ(s,a)]\begin{aligned} V(Q^*) &= Q^*; \\ V(Q)(s,a) &= r(s,a) + \gamma \cdot\mathbb E_{s'~\tau(\cdot s,a)}\Big[ \max_{a'\in\mathcal A} Q(s', a') \Big] \end{aligned}

A key characteristic of the Bellman function VV is that it's γ\gamma-contractive in infinity, meaning that

V(Q1)V(Q2)γQ1Q2|| V(Q_1) - V(Q_2)||_\infty \leq \gamma||Q_1 - Q_2||_\infty

so we can find QQ^* by finding the fixed point of a contractive operator via fixed point iteration, which is precisely what Q-Learning accomplishes:

Qt+1(1ηt)Qt+ηtV(Qt)Q_{t+1} \leftarrow (1-\eta_t)Q_t + \eta_t\color{red}V(Q_t)\color{black}

where ηt\eta_t is an algorithm's step size at any given time step tt. This value can be understood as the learning rate.

In practice, however, the value of a policy at a given time given by the Bellman equation V(Qt)\color{red}V(Q_t)\color{black} is unknowable (since the domain of TT is infinite), so we must instead use a stochastic approximation the value function V~\tilde{V}:

Qt+1(1ηt)Qt+ηtbV~(Qt)Q_{t+1} \leftarrow (1-\eta_t)Q_t + \eta_tb\tilde{V}(Q_t)

We denote the fidelity of the approximation as bb, defined by the number of samples drawn from V~\tilde V.

Federated Reinforcement Learning

The general setup involves a single server connecting MM agents solving a shared problem5 expressed as an MDP. Each agent mm performs kk rounds of local Q-Learning updates and periodically sends these back to the server:

Qt+1m(1ηt)Qtm+ηtbV~tm(Qtm)Q^m_{t+1} \leftarrow (1 - \eta_t)Q^m_t + \eta_t b\tilde V^m_t(Q^m_t)

The server then averages these updates and sends the aggregate policy update(s) back to the agents:

Qt=1Mm=1MQtmQ_t = \frac{1}{M}\sum_{m=1}^M Q_t^m

This process is repeated for TT distributed policy synchronizations. In real life, we might use more sophisticated methods of aggregation so as to discount stupid agents stuck in the trenches of reward minima, as well as executing synchronization aperiodically – but the general form of Federated Reinforcement Learning is as above.

FRL Metrics

Under such a construction, we can define the following metrics to quantify statistical benefits of collaboration vs. complexity:

  • Sample complexity: N=bTN = bT.
  • Error rate: E(N,M)=supτ,rE[Q^NQ]E(N, M) = \sup \limits_{\tau, r} \mathbb E[ ||\hat Q_N - Q^* ||_\infty]
    • This measures the worst case error relative to an optimal policy approximated by QQ
  • Sample Complexity: C(ε,M)=inf{NE(N,M)ε}\mathcal C_\thicksim(\varepsilon, M) = \inf \{ N | E(N, M) \leq \varepsilon \}
    • The smallest number of sample from each agent to achieve a sufficiently small error ϵ\epsilon
  • Communication complexity:
    • Round trip cost: C=1Mm=1MCm\mathcal C_{\circlearrowleft} = \frac{1}{M}\sum_{m=1}^M\mathcal C_{\circlearrowleft}^m
      • Measures the number of times that an agent sends a message to the server
    • bit cost: Cb=1Mm=1MCbm\mathcal C_{\vec{b}} = \frac{1}{M}\sum_{m=1}^MC_{\vec{b}}^m
      • Measures the number of bits sent by mm to the server

With these metrics established, we can return to the research questions introduced earlier.

Lower Bound

The first theorem introduced by Salgia states that, for any intermittent communication algorithm with constant or linearly-rescaled step size, if6

C11γorCbSA1γsz of decision space\mathcal C_{\circlearrowleft} \lesssim \frac{1}{1-\gamma} \quad \text{or} \quad \mathcal C_{\vec{b}} \lesssim \overbrace{\frac{|\mathcal S||\mathcal A|}{1-\gamma}}^{\text{sz of decision space}}

then, for all communication schedule, batch size, and step size hyper parameters, we have an error rate lower bounded by some constant parameterized by γ\gamma over the square root of the number of samples:

E(N,M)cγNE(N, M) \gtrsim \frac{c_\gamma}{\sqrt{N}}

Irrespective of the number of the agents participating in the federated effort, the error rate will not decrease meaning that we do not gain any statistical benefit from adding more agents. In other words, without sufficient communication, it's as if we're just using a single agent.

This result is derived by two decompositions. The first decomposition breaks down the error rate of our Q functions with respect to the optimum into two terms measuring bias and variance:

E[(Q^Q)2]=E[(E[Q^]Q)2]bias+E[(Q^E[Q^])2]variance\mathbb E \big[(\hat Q - Q^*)^2\big] = \underbrace{\mathbb E\big[(\mathbb E[\hat Q\big] - Q^*)^2]}_{\text{bias}} + \underbrace{\mathbb E[(\hat Q - \mathbb E[\hat Q])^2]}_{\text{variance}}

It is a classically accepted result that variance exhibits linear increase upon averaging across multiple agents, but that bias does not. So, if out bias term dominates the variance term, the result will be that we gain no benefit from collaboration offered by the variance term.

The second decomposition we're concerned with illustrates that Q-Learning induces a positive bias:

E[Qt+1]=(1ηt)E[Qt]+ηtV(E[Qt])unbiased update+E[V^(E[Qt])V(E[Qt)]bias 0\mathbb E[Q_{t+1}] = \underbrace{(1-\eta_t)\mathbb E[Q_t] + \eta_t V(\mathbb E[Q_t])}_{\text{unbiased update}} + \underbrace{\mathbb E[\hat V(\mathbb E[Q_t]) - V(\mathbb E[Q_t)]}_{\text{bias } \geq 0}

Scheduling Update Synchronizations

These two expressions underscore the importance of understanding the evolution of bias throughout the process of federated learning.

Communication schedules between agents can be denoted

σ={tr}r=1R\sigma = \{t_r\}^R_{r=1} consisting of RR rounds of updates where:

Qtm={1Mj=1MQt12j if tσQt12m  o.w.Q^m_t = \begin{cases} \frac{1}{M}\sum^M_{j=1} Q^j_{t-\frac{1}{2}} \text{ if } t \in \sigma\\ Q^m_{t-\frac{1}{2}} \; \text{o.w.} \end{cases}

If we consider the following plot of bias over time, we can see that for each of of communication round ζ\zeta, bias increases:

However, since we're averaging updates across all agents, overall bias is independent of MM. For smaller values of ζ\zeta corresponding to more frequently scheduled update communications between agents, we can see that overall bias grows more slowly:

Similarly, larger values of ζ\zeta result in larger growths in bias:

Thus, bias is directly proportional to ζ\zeta which is inversely proportional to the communication complexity:

biasζ1communication complexity\text{bias} \propto \zeta \propto \frac{1}{\text{communication complexity}}

So, if ζ=O((1γ)T)\zeta = \mathcal O((1-\gamma)T), or C11γ\mathcal C_\circlearrowleft \lesssim \frac{1}{1-\gamma} (the umber of communication rounds is larger than the effective horizon), then the bias term dominates the variance term and we'll achieve zero collaborative gains.

Now, with a tight understanding of the lower bound, we can approach the second research question.

Achievability

How to design an algorithm that offers optimal order sample and communication complexities? Salgia introduces a new algorithm called Federated Doubly Variance Reduced Q-Learning (Fed-DVR-Q).

Notably, their algorithm operates at the optimal frontier of the sample-communication complexity boundary. It's the first algorithm to achieve optimal order sample complexity w.r.t. all salient hyperparameters, namely the effective horizon, and it also employs quantization of policy updates to achieve optimal communication costs.

The algorithm is relatively straightforward, operating in terms of epochs. At each epoch kk, the agents collaboratively update the policy estimate Qk1Q^{k-1} to QkQ^{k} such that the error between QkQ^k and the optimal estimate QQ^* decreases by a factor of 2:

QkQ12Qk1Q||Q^k - Q^*||_{\infty} \leq \frac{1}{2} ||Q^{k-1} - Q^*||_{\infty}

If we repeat this for log1ε\log \frac{1}{\varepsilon} epochs, we can guarantee an ε\varepsilon-optimal Q function. The novelty of the algorithm lies in the scheduling strategy. Whereas most existing approaches opt for larger ζ\zeta and smaller VV^* sample sizes bb (achieving a lot of local, albeit noisy steps). This results in a lot of bias accumulation which must be accounted for via frequent averaging during the update aggregation.

Salgia inverts this relation instead preferring a larger value of bb resulting in higher-fidelity samples, allowing for less frequent communication. This is the key to decreasing communication complexity.

The second theorem introduced in the paper expresses the benefits of combining this approach with variance-reduced updates, stating that, for an implementation of the algorithm with the prescribed choice of parameters, and for some ε(0,1]\varepsilon \in (0, 1]:

C(ε,M)1Mε2(1γ)3C11γCbSA1γ\begin{aligned} \mathcal C_\thicksim(\varepsilon, M) &\lesssim \frac{1}{M\varepsilon^2(1-\gamma)^3} \\ \\ \mathcal C_{\circlearrowleft} &\lesssim \frac{1}{1-\gamma} \\ \\ \mathcal C_{\vec{b}} &\lesssim \frac{|\mathcal S||\mathcal A|}{1-\gamma} \\ \\ \end{aligned}

Note that the sample complexity now depends on MM and exhibits the optimal cubic dependence w.r.t. the event horizon. Additionally, the communication complexity matches the lower bound derived above.

Footnotes

Footnotes

  1. Salgia, Sudeep and Yuejie Chi. "The Sample-Communication Complexity Trade-off in Federated Q-Learning." arXiv, 2024.

  2. https://www.murphyandhislaw.com/tags/reinforcement-learning

  3. where statistical benefit would be a measurable decrease in sample complexity as the number of agents working on a problem increased.

  4. alternatively, reward functions can be computed by the agent according to the favorability of the subsequent state based on the agent's context. I.e. the reward can be agent-dependent rather than environment-dependent.

  5. meaning that the transition kernel and reward functions are the same for each agent.

  6. Here and throughout, I opt to use a loose inequality \lesssim rather than introduce constants to make the inequalities proper. This is the same as just referring to the order of the complexity quantities.