A paper presentation I virutally gave to Kilian Weinberger’s group at Cornell about FAIR’s recent Training Large Language Models to Reason in a Continuous Latent Space paper.
Some points from the paper and the discussion:
This really seems more like a distillation technique for a very targeted domain where a “step” is a very concrete idea.
- This would explain the GSM8k result, since that is a much more open-ended reasoning dataset with more emphasis on natural language understanding, and has less of a “nicely constructed graph” feel.
The pretrained model they start with has only been pretrained on text, so at every inference step, it’s had to capture all the embedding manipulations the model does into a single token, and then reason over tokens in a context. This mechanism likely enables it to learn to think in embedding space, but there is an complexity and subtlety in the token -> context step.
- Now through this continuous thought token training, it has a view into the full embedding space. This isn’t something it normally sees, and is potentially a lot more information rich.
- This might explain why they have to do such an extensive multi-stage, many epoch training routine; the model needs to be annealed into understanding the continuous thought “tokens” in the same domain as the text tokens.
- Now through this continuous thought token training, it has a view into the full embedding space. This isn’t something it normally sees, and is potentially a lot more information rich.
Someone in the discussion questioned the effectiveness of merging the token embedding space with the “reasoning” embedding space, but I actually don’t think this is a problem. The model is already doing some variation of these manipulations anyways, but this methodology just removes an artificial constraint.
- I could see a future where this type of model is a lot more efficient, not only because it doesn’t have to output grammatically correct sequences of tokens to explain simple transformations, but also because we are removing a constraint that makes the modeling problem more efficient.
I really wish they further explored the binary classifier for switching between continuous thoughts and text outputs. This seems like a natural way for the model to decide when to display outputs.
- Potentially in a continuous, constantly streaming way, where the model takes in new information from the outside world through something like function-calling, then does some continuous thinking, and then outputs intermediate results through text or other function-calls.
While the paper generally has some good ideas, but is more like a distillation method than outright training. No attempt is made to show how this generalized, lkely because it doesn’t at all.
However, there is is a kernel of something really good in there. The continuous thoughts are probably a step in the right direction and are closer to where this field will likely end up.