Posted on Tue 22 December 2020

MuZero Intuition

To celebrate the publication of our MuZero paper in [cached]Nature ([cached]full-text), I've written a high level description of the MuZero algorithm. My focus here is to give you an intuitive understanding and general overview of the algorithm; for the full details please read the paper. Please also see our [cached]official DeepMind blog post, it has great animated versions of the figures!

MuZero is a very exciting step forward - it requires no special knowledge of game rules or environment dynamics, instead learning a model of the environment for itself and using this model to plan. Even though it uses such a learned model, MuZero preserves the full planning performance of AlphaZero - opening the door to applying it to many real world problems!

It's all just statistics

MuZero is a machine learning algorithm, so naturally the first thing to understand is how it uses neural networks. From AlphaGo and AlphaZero, it inherited the use of policy and value networks1:

Schematic illustration of value and policy network mapping from a Go board to a value resp. policy estimate

Both the policy and the value have a very intuitive meaning:

  • The policy, written p(s,a)p(s, a), is a probability distribution over all actions aa that can be taken in state ss. It estimates which action is likely to be the optimal action. The policy is similar to the first guess for a good move that a human player has when quickly glancing at a game.

  • The value v(s)v(s) estimates the probability of winning from the current state ss: averaging over all possible future possibilities, weighted by how likely they are, in what fraction of them would the current player win?

Each of these networks on their own is already very powerful: If you only have a policy network, you could simply always play the move it predicts as most likely and end up with a very decent player. Similarly, given only a value network, you could always choose the move with the highest value. However, combining both estimates leads to even better results.

Planning to Win

Similar to AlphaGo and AlphaZero before it, MuZero uses Monte Carlo Tree Search2, short MCTS, to aggregate neural network predictions and choose actions to apply to the environment.

MCTS is an iterative, best-first tree search procedure. Best-first means expansion of the search tree is guided by the value estimates in the search tree. Compared to classic methods such as breadth-first (expand the entire tree up to a fixed depth before searching deeper) or depth-first (consecutively expand each possible path until the end of the game before trying the next), best-first search can take advantage of heuristic estimates (such as neural networks) to find promising solutions even in very large search spaces.

MCTS has three main phases: simulation, expansion and backpropagation. By repeatedly executing these phases, MCTS incrementally builds a search tree over future action sequences one node at a time. In this tree, each node is a future state, while the edges between nodes represent actions leading from one state to the next.

Before we dive into the details, let me introduce a schematic representation of such a search tree, including the neural network predictions made by MuZero:

diagram of the muzero search tree, and the use of representation, dynamics and prediction function

Circles represent nodes of the tree, which correspond to states in the environment. Lines represent actions, leading from one state to the next. The tree is rooted at the top, at the current state of the environment - represented by a schematic Go board. We will cover the details of representation, prediction and dynamics functions in a later section.

Simulation always starts at the root of the tree (light blue circle at the top of the figure), the current position in the environment or game. At each node (state ss), it uses a scoring function U(s,a)U(s, a) to compare different actions aa and chose the most promising one. The scoring function used in MuZero would combine a prior estimate p(s,a)p(s, a) with the value estimate for v(s,a)v(s, a):

U(s,a)=v(s,a)+cp(s,a) U(s, a) = v(s, a) + c \cdot p(s, a)

where cc is a scaling factor3 that ensures that the influence of the prior diminishes as our value estimate becomes more accurate.

Each time an action is selected, we increment its associated visit count n(s,a)n(s, a), for use in the UCB scaling factor cc and for later action selection.

Simulation proceeds down the tree until it reaches a leaf that has not yet been expanded; at this point the neural network is used to evaluate the node. Evaluation results (prior and value estimates) are stored in the node.

Expansion: Once a node has reached a certain number of evaluations, it is marked as "expanded". Being expanded means that children can be added to a node; this allows the search to proceed deeper. In MuZero, the expansion threshold is 1, i.e. every node is expanded immediately after it is evaluated for the first time. Higher expansion thresholds can be useful to collect more reliable statistics4 before searching deeper.

Backpropagation: Finally, the value estimate from the neural network evaluation is propagated back up the search tree; each node keeps a running mean of all value estimates below it. This averaging process is what allows the UCB formula to make increasingly accurate decisions over time, and so ensures that the MCTS will eventually converge to the best move.

Intermediate Rewards

The astute reader may have noticed that the figure above also includes the prediction of a quantity rr. Some domains, such as board games, only provide feedback at the end of an episode (e.g. a win/loss result); they can be modeled purely through value estimates. Other domains however provide more frequent feedback, in the general case a reward rr is observed after every transition from one state to the next.

Directly modeling this reward through a neural network prediction and using it in the search is advantageous. It only requires a slight modification to the UCB formula:

U(s,a)=r(s,a)+γv(s)+cp(s,a) U(s, a) = r(s, a) + \gamma \cdot v(s') + c \cdot p(s, a)

where r(s,a)r(s, a) is the reward observed in transitioning from state ss by choosing action aa, and γ\gamma is a discount factor that describes how much we care about future rewards.

Since in general rewards can have arbitrary scale, we further normalize the combined reward/value estimate to lie in the interval [0,1][0, 1] before combining it with the prior:

U(s,a)=r(s,a)+γv(s)qminqmaxqmin+cp(s,a) U(s, a) = \frac{r(s, a) + \gamma \cdot v(s') - q_{min}}{q_{max} - q_{min}} + c \cdot p(s, a)

where qminq_{min} and qmaxq_{max} are the minimum and maximum r(s,a)+γv(s)r(s, a) + \gamma \cdot v(s') estimates observed across the search tree.

Episode Generation

The MCTS procedure described above can be applied repeatedly to play entire episodes:

  • Run a search in the current state sts_t of the environment.
  • Select an action at+1a_{t+1} according to the statistics πt\pi_t of the search.
  • Apply the action to the environment to advance to the next state st+1s_{t+1} and observe reward ut+1u_{t+1}.
  • Repeat until the environment terminates.

generation of episodes by running MCTS in each state, selecting an action and advancing the environment

Action selection can either be greedy - select the action with the most visits - or exploratory: sample action aa proportional to its visit count n(s,a)n(s, a), potentially after applying some temperature tt to control the degree of exploration:

p(a)=(n(s,a)bn(s,b))1/t p(a) = {\left( \frac{n(s, a)}{\sum_b n(s, b)} \right)}^{1/t}

For t=0t = 0, we recover greedy action selection; t=inft = \inf is equivalent to sampling actions uniformly.

Training

Now that we know how to run MCTS to select actions, interact with the environment and generate episodes, we can turn towards training the MuZero model.

We start by sampling a trajectory and a position within it from our dataset, then we unroll the MuZero model alongside the trajectory:

training unrolls the muzero model along the trajectory

You can see the three parts of the MuZero algorithm in action:

  • the representation function hh maps from a set of observations (the schematic Go board) to the hidden state ss used by the neural network
  • the dynamics function gg maps from a state sts_t to the next state st+1s_{t+1} based on an action at+1a_{t+1}. It also estimates the reward rtr_t observed in this transition. This is what allows the learned model to be rolled forward inside the search.
  • the prediction function ff makes estimates for policy ptp_t and value vtv_t based on a state sts_t. These are the estimates used by the UCB formula and aggregated in the MCTS.

The observations and actions used as input to the network are taken from this trajectory; similarly the prediction targets for policy, value and reward are the ones stored with the trajectory when it was generated.

You can see this alignment between episode generation (B) and training (C) even more clearly in the full figure:

the three previous figures in a single picture

Specifically, the training losses for the three quantities estimated by MuZero are:

  • policy: cross-entropy between MCTS visit count statistics and policy logits from the prediction function.
  • value: cross-entropy or mean squared error between discounted sum of N rewards + stored search value or target network estimate and value from the prediction function5.
  • reward: cross-entropy between reward observed in the trajectory and dynamics function estimate.

Reanalyse

Having examined the core MuZero training, we are ready to take a look at the technique that allows us to leverage the search to achieve massive data-efficiency improvements: Reanalyse.

In the course of normal training, we generate many trajectories (interactions with the environment) and store them in our replay buffer for training. Can we get more mileage out of this data?

sequence of states representing an episode

Unfortunately, since this is stored data, we cannot change the states, actions or received rewards - this would require resetting the environment to an arbitrary state and continuing from there. Possible in The Matrix, but not in the real world.

Luckily, it turns out that we don't need to - using existing inputs with fresh, improved labels is enough for continued learning. Thanks to MuZero's learned model and the MCTS, this is exactly what we can do:

sequence of states with new MCTS trees at each state

We keep the saved trajectory (observations, actions and rewards) as is and instead only re-run the MCTS. This generates fresh search statistics, providing us with new targets for the policy and value prediction.

In the same way that searching with an improved network results in better search statistics when interacting with the environment directly, re-running the search with an improved network on saved trajectories also results in better search statistics, allowing for repeated improvements using the same trajectory data.

Reanalyse fits naturally into the MuZero training loop. Let's start with the normal training loop:

diagram of actors and learners exchanging data during training

We have two sets of jobs that communicate with each other asynchronously:

  • a learner that receives the latest trajectories, keeps the most recent of these in a replay buffer and uses them to perform the training algorithm described above.
  • multiple actors which periodically fetch the latest network checkpoint from the learner, use the network in MCTS to select actions and interact with the environment to generate trajectories.

To implement reanalyse, we introduce two jobs:

previous diagram extended with reanalyse actors

  • a reanalyse buffer that receives all trajectories generated by the actors and keeps the most recent ones.
  • multiple reanalyse actors 6 that sample stored trajectories from the reanalyse buffer, re-run MCTS using the latest network checkpoints from the learner and send the resulting trajectories with updated search statistics to the learner.

For the learner, "fresh" and reanalysed trajectories are indistinguishable; this makes it very simple to vary the proportion of fresh vs reanalysed trajectories.

What's in a name?

MuZero's name is of course based on AlphaZero - keeping the Zero to indicate that it was trained without imitating human data, and replacing Alpha with Mu to signify that it now uses a learned model to plan.

Digging a little deeper, we find that Mu is rich in meaning:

  • [cached], which can be read as mu in Japanese, means 'dream' - just like MuZero uses the learned model to imagine future scenarios.
  • the greek letter μ, pronounced mu, can also stand for the learned model.
  • [cached], pronunced mu in Japanese, means 'nothing' - doubling down on the notion of learning from scratch: not just no human data to imitate, but not even provided with the rules.

Final Words

I hope this summary of MuZero was useful!

If you are interested in more details, start with the [cached]full paper. I also gave talks about MuZero at [cached]NeurIPS ([cached]poster) and most [cached]recently at ICAPS.

Let me finish by linking some articles, blog posts and GitHub projects from other researchers that I found interesting:


  1. For simplicity, in MuZero both of these predictions are made by a single network, the prediction function. 

  2. Introduced by Rémi Coulom in Efficient Selectivity and Backup Operators in Monte-Carlo Tree Search, 2006, MCTS lead to a major improvement in the playing strength of all Go playing programs. "Monte Carlo" in MCTS refers to [cached]random playouts used in Go playing programs at the time, estimating the chance of winning in a particular position by playing random moves until the end of the game. 

  3. The exact scaling used in MuZero is bn(s,b)1+n(s,a)(c1+log(bn(s,b)+c2+1c2))\frac{\sqrt{\sum_b n(s, b)}}{1 + n(s, a)} \cdot (c_1 + \log(\frac{\sum_b n(s, b) + c_2 + 1}{c_2})), where n(s,a)n(s, a) are the number of visits for action aa from state ss, and c1=1.25c_1 = 1.25 and c2=19652c_2 = 19652 are constants to influence the important of the prior relative to the value estimate. Note that for c2nc_2 \gg n, the exact value of c2c_2 is not important and the loglog term becomes 0. In this case, the formula simplifies to c1bn(s,b)1+n(s,a)c_1 \cdot \frac{\sqrt{\sum_b n(s, b)}}{1 + n(s, a)} 

  4. This is most useful when using stochastic evaluation functions such as random rollouts as used by many Go programs before AlphaGo. If the evaluation function is deterministic (such as a standard neural network), evaluating the same nodes multiple times is less useful. 

  5. For board games, the discount γ\gamma is 1 and the number of TD steps infinite, so this is just prediction of the Monte Carlo return (winner of the game). 

  6. In our implementation of MuZero, there is no separate set of actors for reanalyse: We have a single set of actors which decide at the start of each episode whether to start a fresh trajectory interacting with the environment or to reanalyse a stored trajectory. 

Tags: ai, muzero, rl

M ↓   Markdown
P
pnorridge
1 point
10 months ago

Hi,

You don’t mention it here but I’ve been trying to understand why a categorical representation of the value function helps. Is there some theory for that?

Thanks! Paul

J
Julian Schrittwieser
0 points
10 months ago

The categorical representation, or more specifically the cross-entropy loss, ensures that the scale of the gradients is independent of the magnitude of the value to be predicted. This is especially important when you have values varying by many orders of magnitude, eg max episode return of 21 in pong and over a million in atlantis, and makes learning much more robust.

See also the C51 (https://arxiv.org/abs/1707.06887) and IQN papers (https://arxiv.org/abs/1806.06923).

F
Ferran Alet
0 points
9 months ago

Hi Julian,

Congrats on the very nice paper! At NeurIPS someone from DeepMind said they reproduced MuZero in JAX and you also recommend it somewhere else in your website. Do you know if this implementation is open-sourced or whether there are plans to do so?

Thanks!

J
Julian Schrittwieser
0 points
9 months ago

Yeah, we are using JAX for our current MuZero implementation and research - we've been using JAX for about a year now and it's very pleasant to use :)

I'm not aware of an open source implementation yet, but you should be able to use some of the existing codebases and replace the learning part with JAX.

F
Ferran Alet
0 points
9 months ago

The NeurIPS talk mentioned something about using JAX to parallelize (some parts of?) MCTS as well. Did I understand this incorrectly or you meant that replacing only the learning with JAX and would be good enough and correct, even though slower?

Thanks!

Ferran

E
Eljas Hyyrynen
0 points
9 months ago

Hi Julian!

Congratulations for your good work! I am lucky to have found your amazing blog.

I am implementing muzero for the game of go. I didn't find in the paper exactly some details about the representation network. To be more precise:

  • What size is the hidden state outputted by the representation network? Is is the same as the input or is the third dimension shrunk? Meaning, it gets 19x19x17 in and outputs maybe 19x19x1? It must be smaller as we compress the essential information.

These are kind of minor questions as the beef of the research is the MCTS and using the three networks jointly. But nonetheless this minor detail kind of blocks me from going further in my implementation.

Actually I am not able to train the network for 19x19 board at my home so I'm trying to train it for 9x9 board instead. I don't know if I will ever succeed but at least I will learn a lot in the process :)

Danke Schön, Eljas Hyyrynen

J
Julian Schrittwieser
0 points
9 months ago

Hi Eljas,

We use 256 hidden planes, you can find more on this in the methods section of the paper. For your experiments at home I'd recommend starting with something smaller so it's faster to train, maybe 128 or 64 hidden planes.

Viel Spaß! Julian

M
Mbwenga Maliti
0 points
6 months ago

Hi Julian,

Thanks for the contribution. It’s results like these that made me excited about getting into AI. I have a few questions if you have the time to answer them:

  1. What is the intuition behind the policy loss? Aren’t both terms in one way or another generated by the network and not the environment?

  2. Is MuZero only able to select among moves in a game that are legal(e.g it can’t place a pawn outside the board)? If so, do you have any ideas on how to tackle games where the legal moves can’t be described neatly, such as NLP?

  3. Have you seen any use of MuZero in NLP?

Thanks for reading, Mbwenga

J
Julian Schrittwieser
1 point
6 months ago

The intuition behind the policy loss is that the combination of many value estimates inside MCTS (guided by the policy) leads to an improved policy, which we can then try to train towards. You are right that this involves the network on both sides!

MuZero can work with our without illegal moves; if it is permitted to perform illegal moves then it will learn that such moves lead to a low value and avoid them.

C
Christoph Heindl
0 points
10 months ago

Hey Julian,

great post - thanks! In the simulation paragraph, there is a typo (I believe) in that it should read v(s) not v(s').

Best, Christoph

J
Julian Schrittwieser
0 points
10 months ago

Hi Christoph,

thanks for pointing this out. What I wanted to indicate with v(s)v(s') was the the value of the child that is reached from state s when applying the action a under consideration; I've now changed this to v(s,a)v(s, a) which is hopefully clearer.

© Julian Schrittwieser. Built using Pelican. Theme by Giulio Fidente on github. .