Coconut

Training Large Language Models to Reason in a Continuous Latent Space

Andrew Dinhobl

General Concept


- 2 modes: language model and latent mode
- use <bot> and <eot> tokens to demarcate thoughts



Equations

  • \(x = (x_1, ..., x_T)\) some input sequence of tokens to time \(T\):
  • \(e(\cdot)\) is token embedding function
  • \(E_t = [e(x_1), e(x_2), ..., e(x_t)]\) sequence of token embeddings up to position \(t\)
  • \(H_t = \text{Transformer}(E_t)\); \(H_t \in \mathbb{R}^{t \times d}\)
  • \(h_t\) is last hidden state at position \(t\); \(h_t=H_t[t, :]\)
  • Transformer model \(M(x_{t+1}\mid x_{\leq t}) = \text{softmax}(Wh_t)\)
  • \(W\) is the parameter of the language model head

\(E_t=[e(x_1), e(x_2), ..., e(<bot>), h_i, h_{i+1}, ..., h_{j-1}, e(<eot>), ..., e(x_t)]\)


Interpretability

  • Interpretability 🐘: Is this good?
    • Performance ↑, intepretability ↓
    • CoT / R1: “Wow! I can see what the model is thinking!”
    • Interpretability was already questionable

Datasets

  • 3 datasets: augmented GSM8k + 2 constructed reasoning datasets
  • Note that a step is a reasoning step composed of many tokens


  • GSM8k is augmented with additonal synthetic data and NL instructions


How do they train it?

  • pre-trained GPT2 models
  • “multi-stage training curriculum”


- \(n + 1\) forward passes on each stage
- no training loss for latent thoughts, so hard to parallelize training?
- decomposes training into easier objectives


Steps

  1. Data: (question, CoT for \(k\) steps, answer)
  2. In initial stage, model trained on all steps of data, like normal
  3. At \(k\)th stage, the first \(k\) steps are replaced with \(k \times c\) continuous thoughts
    • \(c\) is hyperparam controlling the continuous-steps-to-text-steps ratio
    • they insert the <bot> and <eot> tokens
  4. reset the optimizer state between each training stage
  5. Optimize negative log-likelihood loss; mask questions and latent thoughts


Inference

  • essentially normal inference procedure, slight modified
  • What triggers the <eot> token? Two options
    • train binary classifier on latent thoughts to let the model decide
    • “pad latent thoughts to constant length” (==constant number of thoughts?)
  • The use the latter after saying they are comparable

Results

Math Reasoning
  • \(c = 2\)
  • They train for stages 0, 1, 2, 3 and additional stage
    • additional stage uses \(3 \times c\) continuous thoughts, but removes rest of language steps
    • stage 0 for 6 epochs, 3 epochs for other stages
Logical Reasoning
  • \(c = 1\)
  • They train for stages 0 - 6, 5 epochs per
Both
  • Train at final stage for 50 epochs, use validation loss to select

  • iCoT?

    • somewhat similar, almost like ablating out the continuous thoughts by training gradually removing earlier reasoning tokens during training. Just predicting answer during inference
  • What are pause tokens?

    • additional filler tokens, like <pause> or ..., that improve performance on some tasks
  • don’t really speculate on GSM8k result

  • on “highly branching” and random ProsQA, CoT doesn’t really improve over No-CoT -> latent reasoning good



  • batch size 1
  • they note that clock time \(\propto\) number of newly generated tokens
  • used transformers for inference

Understanding Distributed Reasoning

  • “tree” vs “chain”
  • they train probe to decode latent thought in \(c = 1\) setting
  • \(3 * 3 = 9\) and \(3 * 60 = 180\) \(\rightarrow 55\%\) probability mass



- you’ll notice there are many dead ends



  • Test time inference with `\(k \in {0,1,2,3,4,5,6}\) thoughts, output the rest in NL
  • Labels and Paths depend on length of continuous and NL thoughts
  • Training for this stage mixes data from other stages
  • “As more reasoning in continuous thoughts, Correct Path and Correct Label increase” \(\rightarrow\) better latent space reasoning
    • allows model to delay hard-decisions, “allow model to eliminate incorrect options”


  • after k latent thoughts, they analyze next step probabilities for child or grandchild nodes
  • probability is calculated as conditioned token probs for each node
  • they interpret this as the model’s “implicit value function”



- x-axis is percentile of test cases
- “can encode several potential reasoning steps simultaneously and progressively eliminate incorrect paths”
- they refer to another paper laying the groundwork for distributional reasoning

Thoughts

  • Lots of room to improve training procedure. RL?
  • Still requires natural language reasoning chains
  • Thought granularity: training is based on NL steps
  • how well does this training procedure generalize to new problem spaces?