Machine learning engineer, Stanford University, UK, London
GRAPH TRANSFORMERS
ABSTRACT
This article provides a comprehensive overview of the evolution from traditional neural network architectures—such as fully connected networks, convolutional neural networks (CNNs), and recurrent neural networks (RNNs)—to the transformative paradigm of transformers. It highlights how transformers revolutionized deep learning by introducing the attention mechanism, enabling efficient parallel processing and capturing long-range dependencies, which earlier architectures struggled to handle. The paper then focuses on applying these principles to graph-structured data. It explores how Graph Neural Networks (GNNs), particularly Graph Attention Networks (GATs), integrate attention mechanisms and positional encodings to effectively model complex relationships among nodes. Emphasis is placed on the practical utility of GATs in diverse domains, from recommendation systems and drug discovery to fraud detection and time-series anomaly detection. Through this synthesis, the article underscores the growing relevance of attention-based methods for handling intricate, interconnected datasets and outlines the ongoing research directions that push the field of graph transformers forward.
АННОТАЦИЯ
Статья предлагает всеобъемлющий обзор эволюции нейронных сетей — от полносвязных слоёв, сверточных нейронных сетей (CNN) и рекуррентных нейронных сетей (RNN) до преобразующей парадигмы трансформеров. В работе подчеркивается, как трансформеры произвели переворот в глубоком обучении благодаря механизму внимания, обеспечивающему эффективную параллельную обработку и улавливание дальнодействующих зависимостей, недоступных предыдущим архитектурам. Далее внимание сосредоточено на применении данных принципов к графовым структурам. Рассматривается, как графовые нейронные сети (GNN), особенно графовые сети с механизмом внимания (GAT), используют механизмы внимания и позиционные кодировки для эффективного моделирования сложных отношений между узлами. Особое внимание уделяется практическому применению GAT в различных областях — от рекомендательных систем и поиска лекарственных соединений до выявления мошенничества и обнаружения аномалий во временных рядах. Таким образом, статья подчеркивает возрастающее значение методов на основе внимания для анализа сложных, взаимосвязанных наборов данных и намечает направления дальнейших исследований в области графовых трансформеров.
Keywords: transformers, attention mechanism, graph neural networks (GNNs), graph attention networks (GATs), positional encoding, long-range dependencies, graph-structured data, recommendation systems, drug discovery, fraud detection, anomaly detection
Ключевые слова: трансформеры, механизм внимания, графовые нейронные сети (GNN), графовые сети с вниманием (GAT), позиционное кодирование, дальнодействующие зависимости, графовые данные, рекомендательные системы, поиск лекарственных соединений, обнаружение мошенничества, обнаружение аномалий
1. Introduction
In the last decade, deep learning has evolved from simple architectures to complex, data-driven systems. The most revolutionizing development has been the emergence of transformer architectures, which have reshaped domains like NLP, computer vision, and graph analytics.
The journey to transformers began with foundational architectures: fully connected layers provided the basic building blocks of neural networks, convolutional neural networks (CNNs) excelled at spatial pattern recognition, and recurrent neural networks (RNNs) captured sequential dependencies. While each architecture marked significant progress in specific domains, they also faced fundamental limitations in handling complex, interconnected data structures.
Graph-structured data, representing relationships between entities by nature, was particularly problematic for those architectures. Transformers, introducing the attention mechanism and parallelization in processing, finally opened new possibilities for processing graph data, bringing rapid progress in fields such as molecular modeling, social network analysis, and recommendation systems.
2. Before Transformers
2.1 Fully Connected Layers: Understanding the Foundation and Limitations
Fully connected, or dense, layers were among the very first methods to process input data for classification purposes. In a fully connected layer, every input neuron is mapped to every output neuron-a matrix multiplication in essence.
This architecture works by doing matrix multiplication where every input neuron is connected to each output neuron through learnable weights. Considering the image processing example, an image of 28x28 pixels would have to be flattened into a 784-dimensional vector, which would then undergo linear transformations followed by nonlinear activation functions.
While this works well for simple tasks, such as handwritten digit recognition on the MNIST dataset, it comes with major drawbacks when dealing with complex data like images or sequential text. The fundamental limitation lies in the flattening process: the act of turning structured data into a single vector inherently discards spatial relationships-critical information about how pixels relate to their neighbors.
Figure 1. Example
Consider an image of a panda: while humans naturally recognize patterns through spatial relationships, such as the black patches around the eyes or the texture of fur, a fully connected layer processes this as a mere sequence of numbers, blind to these crucial spatial patterns. This architectural blindness to local structure becomes particularly problematic with high-dimensional inputs, where the number of parameters grows exponentially with input size, leading to computational inefficiency and increased risk of overfitting.
2.2 Convolutional Neural Networks: Exploiting Spatial Structure
In order to handle the loss of the spatial structure, the convolutional neural networks were introduced. CNNs retain the local patterns in data by applying convolutional filters, also known as kernels, across the input image, allowing it to learn spatial hierarchies. Instead of flattening the image, CNNs move small, learnable filters, usually 3x3 or 5x5, across the input in a sliding window fashion, computing dot products between the filter and local patches to detect specific patterns. These filters are like pattern detectors, learning to identify features from simple to complex levels. It is a set of feature maps that describes local patterns such as edges, textures, and shapes.
Figure 2. Example
CNNs offer two crucial advantages over fully connected layers: they preserve spatial relationships, and they are parameter-efficient. In a CNN, by sharing parameters, on account of applying the same filter at different positions in an image, one can significantly reduce the number of learnable parameters while achieving translation invariance. That is, a 3x3 filter would just require 9 parameters to detect a given pattern in any part of an image. This efficient architecture led to breakthroughs in a whole bunch of computer vision tasks, from image classification to object detection and semantic segmentation.
Still, dealing with more complex relationships in the data is a challenging task for CNNs. The architectures of these networks are really good at catching local patterns but are not so effective in modeling long-range dependencies between distant elements in an image. Their hierarchical nature further makes it hard to grab the global context without layer piling. For example, relationships between objects at opposite corners of an image (via multiple convolutional layers) are understood, and hence important information may be lost during this process.
2.3 Recurrent Neural Networks: Handling Sequential Data
Recurrent neural networks were developed to deal with the sequential nature in many real-world tasks, such as language modeling, machine translation, and speech recognition. Unlike CNNs, which deal with data in fixed-size chunks, RNNs process input sequences one element at a time and maintain a hidden state that evolves with each new element. This makes RNNs particularly effective for tasks where order matters, like predicting the next word in a sentence.
The core of RNNs is recurrence: the idea that at every time-step, the hidden state gets updated based on the current input and the previous hidden state. This allows RNNs to pick up temporal relationships in sequential data. Actually, the basic RNN model is very simple, and at each time step, it updates the hidden state based on a linear transformation followed by a nonlinear activation.
Figure 3. Example
While RNNs have proved good at many tasks, like machine translation, they have some serious drawbacks. Among these is the vanishing gradient problem. The gradients of the loss with respect to the weights become exponentially smaller as they are propagated backwards through the timesteps. The problem arises because the same weights are reused in every recurrent computation; if those weights are less than 1, repeated multiplication will result in vanishingly small gradients. As a result, the network finds it difficult to capture long-range dependencies. For example, it may not be able to remember that “capybara” is the subject when predicting the verb in the sentence “The capybara, who was peacefully swimming in the crystal-clear river near the lush rainforest vegetation, _____”. This demonstrates that RNNs struggle with learning relationships that depend on several time steps, thus less effective in tasks involving long sequences or complex dependencies.
Figure 4. Example
Moreover, RNNs are inherently sequential; hence, computation is done step by step. That makes them not very friendly for parallel processing, something to be taken seriously today with the scaling of models to fit large datasets on distributed systems.
3. Transformers
Transformers were introduced in the seminal 2017 paper Attention is All You Need by Vaswani et al., building on top of the attention mechanism first introduced by Bahdanau et al. in 2014 for neural machine translation. The key insight of transformers was that instead of depending on sequential operations like those in RNNs, it could use the attention mechanism to process entire input sequences in parallel. This new architecture allowed transformers to deal with long-range dependencies effectively while simultaneously allowing for significant improvements in both training and inference by means of parallelization.
Figure 5. Example
At the core of transformers is self-attention, a mechanism that calculates the relationships between all pairs of elements in an input sequence. This allows transformers to process data not just one element at a time, but to compute attention weights that measure the relevance of each element to every other element in the sequence.
Unlike RNNs, which require information to flow sequentially through the network, transformers can directly model relationships between any two positions in the sequence, no matter how far apart they are. Additionally, the attention mechanism is quite flexible, enabling the model to concentrate on different parts of the input sequence depending on the task at hand. This flexibility is crucial for tasks like machine translation, where certain words in a sentence carry more significance than others.
3. 1 Understanding the Attention Mechanism
The attention mechanism can be thought of as a ‘soft’ version of a lookup table. In a traditional lookup table, you would use an exact key to find a specific value. In contrast, attention performs this lookup in a ‘soft’ way—rather than requiring exact matches, it assesses how similar each key is to your query and returns a weighted mixture of values based on these similarities. For instance, if you were to look up the meaning of the word ‘happy’, instead of providing just one definition, it would offer a weighted combination of definitions from related words like ‘joyful’, ‘content’, and ‘pleased’, with the weights reflecting how closely each word relates to ‘happy’.
Figure 6. Example
The key innovation of the attention mechanism is its capacity to dynamically assess relationships between elements in a sequence through learnable parameters. Unlike earlier methods that depended on fixed similarity metrics, contemporary attention mechanisms utilize trainable transformations, enabling the model to discover the best ways to connect various elements.
3.2 Query, Key, and Value: The Foundation of Self-Attention
The attention mechanism operates using three key components: queries (Q), keys (K), and values (V). Each input element is transformed into these three representations through learnable linear transformations: WQ, WK, and WV. Here’s how the process works:
1. Query-Key Interaction: For each position i in the sequence, its query vector qi interacts with all key vectors kj through dot products. In mathematical terms, the similarity Eij is calculated as:
Eij = qiT * kj
2. Softmax Normalization: To ensure that the weights are normalized and sum to one , we apply the softmax function to the similarity scores: 𝞪ij=softmax(Eij) = exp(Eij) / Σkexp(Eik)
3. Value Aggregation: The final output for position i is computed as a weighted sum of values, where the weights are determined by the normalized attention scores: oi = Σj 𝞪ij vj
3.3 The Challenges With Simple Attention and Their Solutions
While the mathematical approach described above is straightforward and effective in terms of calculation, it suffers from several challenges:
Lack of Sequence Awareness
The vanilla attention mechanism treats input sequences as unordered sets of elements, ignoring crucial positional information. For instance, in the sentences "The park is where I am going." and “The I is where park is going” the model would compute identical attention patterns despite their different meanings.
To address the lack of sequence awareness, positional encoding is added to input embeddings. Several approaches exist:
● Fixed Sinusoidal Encodings
Uses sin and cos functions with different frequencies
● Learnable Positional Encoding
Treats position vectors as trainable parameters, allowing the model to optimize position representations for the specific task. While more flexible, these encodings may not generalize well to sequences longer than those seen during training
No nonlinearityThe attention mechanism is purely linear (dot products and weighted sums), limiting its ability to learn complex patterns.
Feed-forward neural networks (FFNs) are introduced between attention layers to address this challenge. These FFNs contain ReLU activations, adding necessary nonlinear transformations to the model.
We cannot “look at the future”.
Figure 7. Example
In tasks like machine translation, we need to prevent the model from “looking at the future” - words that come after the current position. This requires masking future tokens, adding complexity to the attention mechanism.
To address this challenge, masked self-attention is implemented in the decoder, where a lower triangular mask ensures each position can only attend to previous positions. This maintains causality during generation.
3.4 The Transformer Architecture: Bringing It All Together
The transformer architecture combines multiple innovations to create an efficient and powerful sequence processing model:
Figure 8. Example
● Encoder-Decoder Architecture: The encoder-decoder architecture consists of an encoder that processes the input sequence to create representations, and a decoder that utilizes these representations to produce an output sequence. This architecture can be adapted for various tasks, such as translation (using both encoder and decoder), text generation (using only the decoder), or encoding tasks (using only the encoder).
● Multi-Head Attention
Although a single attention mechanism can be effective, employing multiple attention heads in parallel enables the model to simultaneously capture various types of relationships, thereby enhancing its representational capabilities.
● Multi-Layer Processing: Each layer integrates attention and feed-forward networks, supported by two essential stabilizing mechanisms:
○ Residual Connections
■ Combats vanishing gradients by adding input to the transformed output: x + F(x)
○ Layer Normalization
■ Normalizes layer outputs using mean and variance, stabilizing training by controlling the scale of activations
4. Graph Neural Networks (GNNs) and Graph Attention Networks (GATs)
In this section, we focus on Graph Neural Networks (GNNs), a robust class of models specifically designed for data organized as graphs. Unlike traditional neural networks that typically process sequences or grids, GNNs are adept at capturing the intricate relationships and dependencies among entities represented as nodes in a graph. Previous architectures like CNNs and RNNs often struggle with graph-structured data, where relationships can be complex and varied. In contrast, GNNs, especially Graph Convolutional Networks (GCNs) and Graph Attention Networks (GATs), effectively manage long-range dependencies and dynamic attention. For instance, in a social network, a GNN can model the flow of information between friends, their friends, and across entire communities, irrespective of the network’s structure.
4.1 GCNs and Graph Embeddings
At the core of any GNN is the concept of graph embeddings. In a graph, both nodes and edges can have embeddings—a vector representation that encapsulates the features of a node and its relationships. Node embeddings reflect the characteristics of entities, while edge embeddings convey the nature of the relationships between them. This is similar to how language models utilize word embeddings for individual words and position embeddings for the relationships among those words.
Figure 9. Example
In each layer of a Graph Neural Network (GNN), embeddings are processed through a message-passing mechanism, where information is adjusted as it travels along the edges and then combined at each node. To derive node embeddings, we usually aggregate the features from a node’s neighbors. In a Graph Convolutional Network (GCN), this aggregation occurs through a convolution-like operation, where the representation of each node is updated by calculating a weighted sum of its neighbors’ embeddings. This process is often simplified mathematically to mean aggregation, where a node’s new embedding is the average of its current embedding and those of its neighbors. For example, in a molecular graph, the representation of each atom would be updated based on its chemical bonds (edges) with adjacent atoms.
While this method is effective for basic graph structures, GCNs struggle to assign varying levels of importance to different neighbors. Each neighbor contributes equally to the updated node embedding, which may not be ideal, particularly in graphs where some relationships hold more significance than others. In our molecular example, different types of chemical bonds should exert varying degrees of influence on an atom’s representation.
4.2 Graph Attention Networks (GATs): Introducing Dynamic Weights
Graph Attention Networks (GATs) improve upon GCNs by incorporating attention mechanisms into the graph aggregation process. Rather than treating all neighboring nodes the same, GATs calculate dynamic attention weights for each node. The attention weight between nodes i and j is determined using a weight matrix, with “||” indicating concatenation. The resulting scores are normalized through softmax across the neighbors.
In practice, this allows nodes to learn to emphasize certain connections. For instance, in a citation network, a paper might give more weight to closely related papers instead of treating all citations uniformly.
4.3 Positional Encoding and Graphs
Unlike sequences in NLP, graphs lack inherent order, raising the question: How can we incorporate positional information into GNNs?
Two main approaches have emerged:
1. Feature-based positional encoding
a. Shortest Path Distances: Encode node positions using lengths of paths between nodes
b. Spectral Embeddings: Use eigenvalue decomposition of the graph matrices (like the Laplacian)
c. Powers of Adjacency Matrix: Capture k-hop neighborhood information
2. Learnable positional encodings
a. Similar to transformers, the model learns optimal node position representations during training through backpropagation.
These encodings are injected into node embeddings before attention computation, enabling the model to leverage both structural and positional relationships.
4.4 Applications of GNNs and GATs
Graph neural networks, particularly GATs, have a diverse array of applications across various fields. Some of the key applications include:
- Recommendation Systems: Graphs effectively represent user-item interactions as bipartite networks, with users and items as distinct node types and interactions as edges. GATs enhance the accuracy of product recommendations by learning which connections are most indicative of user preferences, allowing them to focus on the most relevant product relationships.
- Drug Discovery: Molecules can be modeled as graphs, where atoms serve as nodes and bonds act as edges. GATs assist in predicting molecular properties and identifying potential drug candidates by discerning which atomic interactions are crucial for specific characteristics.
- Fraud Detection: In the financial sector, GNNs are increasingly utilized for detecting fraud. Fraudulent actors often create hidden networks where nodes (like accounts or transactions) are linked in atypical patterns. GNNs can effectively uncover these anomalous structures, enabling more accurate fraud detection.
- Malware Detection: The detection of malware in computer networks can be approached by modeling the network as a graph of interconnected devices. GNNs can recognize malicious patterns within these networks, aiding cybersecurity systems in identifying malware spread.
- Person Re-identification: In the realms of surveillance and security, GNNs have been employed to tackle the challenge of person re-identification. Even when individuals are captured in different locations or by various cameras, GNNs can connect their images based on the underlying graph structure of the environments they inhabit.
- Time-Series Anomaly Detection: Time-series data can be represented as graphs by linking time points or variables according to their relationships (such as correlation or temporal proximity). GNNs identify anomalies by detecting nodes or subgraphs that diverge from established normal patterns.
5. Conclusion
The integration of attention mechanisms with graph-based structures has established Graph Attention Networks (GATs) as a formidable asset in machine learning. By effectively identifying and focusing on the most pertinent neighbors within a graph, GATs can manage more intricate and dynamic relationships compared to traditional graph convolution models. Additionally, the use of positional encodings in Graph Neural Networks (GNNs) enhances their understanding of structure, enabling these models to capture detailed structural nuances.
The range of potential applications for GNNs is extensive, spanning from recommendation systems to drug discovery and fraud detection. As research progresses and these models are refined, the capacity to represent complex relationships in data is expected to grow, paving the way for new avenues of innovation across various sectors.
References:
- Bahdanau, Dzmitry, Kyunghyun Cho, and Yoshua Bengio. “Neural Machine Translation by Jointly Learning to Align and Translate.” International Conference on Learning Representations (ICLR), 2015. arXiv:1409.0473. https://doi.org/10.48550/arXiv.1409.0473.
- Vaswani, Ashish, Noam Shazeer, Niki Parmar, Llion Jones, Aidan N. Gomez, Łukasz Kaiser, and Illia Polosukhin. “Attention Is All You Need.” In Advances in Neural Information Processing Systems 30 (NIPS 2017), Long Beach, CA, USA, 2017. https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf.
- Kipf, Thomas N., and Max Welling. “Semi-Supervised Classification with Graph Convolutional Networks.” International Conference on Learning Representations (ICLR), 2017. arXiv:1609.02907. https://doi.org/10.48550/arXiv.1609.02907.
- Veličković, Petar, Guillem Cucurull, Arantxa Casanova, Adriana Romero, Pietro Liò, and Yoshua Bengio. “Graph Attention Networks.” International Conference on Learning Representations (ICLR), 2018. arXiv:1710.10903. https://doi.org/10.48550/arXiv.1710.10903.
- Ma, Liheng, Reihaneh Rabbany, and Adriana Romero-Soriano. “Graph Attention Networks with Positional Embeddings.” In Advances in Knowledge Discovery and Data Mining, 514–527. Lecture Notes in Computer Science. May 2021. https://doi.org/10.1007/978-3-030-75762-5_41.
- Brüel-Gabrielsson, Rickard. “Rewiring with Positional Encodings for Graph Neural Networks.” arXiv preprint, January 2022. https://doi.org/10.48550/arXiv.2201.12674.
- Park, Wonpyo, Woonggi Chang, Donggeon Lee, Juntae Kim, and Seung-won Hwang. “GRPE: Relative Positional Encoding for Graph Transformer.” arXiv preprint, 2022. arXiv:2201.12787. https://doi.org/10.48550/arXiv.2201.12787.
- Shehzad, Ahsan, Feng Xia, Shagufta Abid, Ciyuan Peng, Shuo Yu, Dongyu Zhang, and Karin Verspoor. “Graph Transformers: A Survey.” arXiv preprint, 2024. arXiv:2407.09777. https://doi.org/10.48550/arXiv.2407.09777.