Transformers are Graph Neural Networks

1 Transformer

1.1 Representation Learning for NLP

image.png
  1. NNs build representations of input data as vectors/embeddings, which encode useful statistical and semantic information about the data.
  2. For NLP:
    1. Recurrent Neural Networks (RNNs): sequential manner, i.e., one word at a time.
    2. Transformers: attention mechanism to figure out how important all the other words in the sentence are w.r.t. to the aforementioned word.
      • The weighted sum of linear transformations of the features of all the words

1.2 Breaking down the Transformer

The Transformer architecture. 1. The Transformer architecture 2. image.png

  1. The hidden feature of the -th word in a sentence from layer to layer is updated as:
    • are learnable linear weights (denoting the Query, Key and Value for the attention computation, respectively).
    • Attention mechanism pipeline: image.png
  2. Multi-head attention mechanism:
    • is a down-projection to match the dimensions of and across layers.
    • Motivation: bad random initializations for dot-product attention mechanism can de-stabilize the learning process.

1.3 Scale Issues and the Feed-forward Sub-layer

  1. Motivation: the features for words after the attention mechanism might be at different scales or magnitudes.
    1. Issue 1: Some words having very sharp or very distributed attention weights.
      • Scaling the dot-product attention by the square-root of the feature dimension
    2. Issue 2: Each of multiple attention head outputs values at different scales.
      • LayerNorm: normalizes and learns an abne transformation at the feature level.
  2. Another 'trick' to control the scale issue: a position-wise 2-layer MLP:
  3. Stacking layers: Residual connections between the inputs and outputs of each multi-head attention sub-layer and the feed-forward sub-layer.

2 GNNs

  1. Neighborhood aggregation (or [[message passing]]) image.png
    • Each node gathers features from its neighbors to update its representation of the local graph structure around it.
      1. are learnable weight matrices of the GNN layer.
      2. is a non-linearity such as ReLU.
      3. The summation over the neighborhood nodes can be replaced by other input size-invariant aggregation functions.
        • Mean
        • Max
        • Weighted sum via an attention mechanism image.png
  2. Stacking several GNN layers enables the model to propagate each node's features over the entire graph.

3 Transformers Are GNNs

image.png

3.1 Sentences Are Fully-connected Word Graphs

  1. Consider a sentence as a fully-connected graph, where each word is connected to every other word. image.png

  2. Transformers are GNNs with multi-head attention as the neighborhood aggregation function.

    1. Transformers: entire fully-connected graph.
    2. GNNs: local neighborhood.

3.2 What Can We Learn from Each Other?

3.2.1 Are Fully-connected Graphs the Best Input Format for NLP?

  • Linguistic structure image.png
    1. Syntax trees/graphs.
    2. Tree LSTMs.

3.2.2 How to Learn Long-term Dependencies?

  1. Fully-connected graphs make learning very long-term dependencies between words difficult.
    • In an word sentence, a Transformer/GNN would be doing computations over pairs of words.
  2. Making the attention mechanism sparse or adaptive in terms of input size. image.png
    1. Adding recurrence or compression into each layer.
    2. Locality Sensitive Hashing.

3.2.3 Are Transformers Learning 'neural Syntax'?

image.png
  1. Attention for identifying which pairs are the most interesting enables Transformers to learn something like a task-specific syntax.
  2. Different heads can be considered as different syntactic properties.

3.2.4 Why Multiple Heads of Attention? Why Attention?

  1. The optimization view: having multiple attention heads improves learning and overcomes bad random initializations.
  2. GNNs with simpler aggregation functions such as sum or max do not require multiple aggregation heads for stable training.
  3. ConvNet architecture image.png

3.2.5 Why is Training Transformers so Hard?

  1. Hyper-parameter: Learning rate schedule, warmup strategy and decay settings.
  2. The specific permutation of normalization and residual connections within the architecture.
文章作者: Haowei
文章链接: http://howiehsu0126.github.io/2023/10/18/Transformers are Graph Neural Networks/
版权声明: 本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 Haowei Hub