TensorFlow (TF): A AI tool

Graph neural networks in TensorFlow

Objects and their relationships are ubiquitous in the world, and understanding these relationships can be as important as understanding the objects themselves. Examples include transportation networks, production networks, knowledge graphs, and social networks. Discrete mathematics and computer science have long formalized such networks as graphs, comprising nodes connected by edges in various ways. However, many machine learning (ML) algorithms are designed for regular and uniform relationships between inputs, such as grids of pixels, sequences of words, or no relationships at all.

Graph neural networks (GNNs) have emerged as a powerful technique for leveraging both a graph’s connectivity and the input features on nodes and edges. GNNs can make predictions for entire graphs (e.g., “Does this molecule react in a certain way?”), individual nodes (e.g., “What’s the topic of this document, given its citations?”), or potential edges (e.g., “Is this product likely to be purchased with that product?”). Beyond predictions, GNNs integrate graph-based information into other deep learning systems by encoding a graph’s relational information continuously.

We are excited to announce the release of TensorFlow GNN 1.0 (TF-GNN), a robust library for building GNNs at scale. TF-GNN supports both modeling and training within TensorFlow and can extract input graphs from massive data stores. It is designed for heterogeneous graphs, where different types and relationships are represented by distinct sets of nodes and edges. This focus on heterogeneity makes it ideal for representing real-world objects and their varied relationships.

In TensorFlow, these graphs are represented by the tfgnn.GraphTensor object. This composite tensor type (a collection of tensors within one Python class) is treated as a first-class citizen in tf.data.Dataset, tf.function, and other TensorFlow components. It stores both the graph structure and the features attached to nodes, edges, and the graph as a whole. Trainable transformations of GraphTensors can be implemented as Layer objects in the high-level Keras API or directly using the tfgnn.GraphTensor primitive.

GNNs: Making predictions for an object in context

For illustration, let’s explore a typical application of TF-GNN: predicting a property of a specific type of node within a graph derived from cross-referencing tables in a large database. Consider a citation database of Computer Science (CS) arXiv papers, which includes one-to-many cites and many-to-one cited relationships. Our goal is to predict the subject area of each paper.

Similar to other neural networks, a GNN is trained on a dataset containing many labeled examples (potentially millions). However, each training step involves a smaller batch of training examples (typically hundreds). To handle millions of examples, the GNN is trained on a stream of reasonably small subgraphs extracted from the larger graph. Each subgraph includes enough data from the original graph to compute the GNN result for the labeled node at its center and train the model. This process, known as subgraph sampling, is critical for effective GNN training. Most existing tools perform batch sampling, producing static subgraphs for training. In contrast, TF-GNN improves on this by enabling dynamic and interactive sampling.

moving image illustrating the process of subgraph sampling where small, tractable subgraphs are sampled from a larger graph to create input examples for GNN training.
Pictured, the process of subgraph sampling where small, tractable subgraphs are sampled from a larger graph to create input examples for GNN training.

TF-GNN 1.0 introduces a flexible Python API for configuring dynamic or batch subgraph sampling across various scales. This includes interactive sampling in a Colab notebook, efficient sampling for small datasets stored in the main memory of a single training host, and distributed sampling using Apache Beam for massive datasets stored on a network filesystem, encompassing up to hundreds of millions of nodes and billions of edges. For more information, please refer to our user guides on in-memory and Beam-based sampling.

In these sampled subgraphs, the GNN’s task is to compute a hidden (or latent) state at the root node. This hidden state aggregates and encodes relevant information from the root node’s neighborhood. A classical approach to achieve this is through message-passing neural networks. During each round of message passing, nodes receive messages from their neighbors via incoming edges and update their hidden state accordingly. After 𝑛 rounds, the hidden state of the root node encapsulates aggregate information from all nodes within 𝑛 edges (illustrated below for 𝑛=2). The messages and updated hidden states are calculated by hidden layers of the neural network. In heterogeneous graphs, it is often beneficial to use separately trained hidden layers for different types of nodes and edges.

moving image illustrating the process of subgraph sampling where small, tractable subgraphs are sampled from a larger graph to create input examples for GNN training.
Pictured, a simple message-passing neural network where, at each step, the node state is propagated from outer to inner nodes where it is pooled to compute new node states. Once the root node is reached, a final prediction can be made.

The training setup is completed by adding an output layer on top of the GNN’s hidden state for the labeled nodes. This setup involves computing the loss to measure prediction error and updating the model weights through backpropagation, as is standard in neural network training.

In addition to supervised training (minimizing a loss defined by labels), GNNs can also be trained unsupervised (without labels). This allows the computation of a continuous representation (or embedding) of the discrete graph structure and node features. These embeddings can then be used in other ML systems, enabling the inclusion of discrete, relational information encoded by a graph in more typical neural network applications. TF-GNN supports detailed specification of unsupervised objectives for heterogeneous graphs.

Building GNN architectures

The TF-GNN library supports building and training GNNs at various levels of abstraction.

At the highest level, users can utilize any of the predefined models included with the library, which are implemented as Keras layers. In addition to a small collection of models from the research literature, TF-GNN offers a highly configurable model template. This template provides a curated selection of modeling options that have proven to be strong baselines for many of our in-house problems. These templates implement GNN layers, so users only need to initialize the Keras layers

import tensorflow_gnn as tfgnn
from tensorflow_gnn.models import mt_albis

def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
  """Builds a GNN as a Keras model."""
  graph = inputs = tf.keras.Input(type_spec=graph_tensor_spec)

  # Encode input features (callback omitted for brevity).
  graph = tfgnn.keras.layers.MapFeatures(

  # For each round of message passing...
  for _ in range(2):
    # ... create and apply a Keras layer.
    graph = mt_albis.MtAlbisGraphUpdate(
        units=128, message_dim=64,
        attention_type="none", simple_conv_reduce_type="mean",
        normalization_type="layer", next_state_type="residual",
        state_dropout_rate=0.2, l2_regularization=1e-5,

  return tf.keras.Model(inputs, graph)

At the lowest level, users can write a GNN model from scratch using primitives for passing data around the graph, such as broadcasting data from a node to all its outgoing edges or pooling data into a node from all its incoming edges (e.g., computing the sum of incoming messages). TF-GNN’s graph data model treats nodes, edges, and entire input graphs equally regarding features or hidden states, making it straightforward to express not only node-centric models like the MPNN discussed above but also more general forms of GraphNets. This can be done with or without using Keras as a modeling framework on top of core TensorFlow. For more details, and information on intermediate levels of modeling, refer to the TF-GNN user guide and model collection.

Training orchestration

While advanced users can create custom model training setups, the TF-GNN Runner offers a streamlined way to orchestrate the training of Keras models for common use cases. A simple invocation might look like this:

from tensorflow_gnn import runner

   task=runner.RootNodeBinaryClassification("papers", ...),
   trainer=runner.KerasTrainer(tf.distribute.MirroredStrategy(), model_dir="/tmp/model"),

The Runner provides ready-to-use solutions for ML pains like distributed training and tfgnn.GraphTensor padding for fixed shapes on Cloud TPUs. Beyond training on a single task (as shown above), it supports joint training on multiple (two or more) tasks in concert. For example, unsupervised tasks can be mixed with supervised ones to inform a final continuous representation (or embedding) with application specific inductive biases. Callers only need substitute the task argument with a mapping of tasks:

from tensorflow_gnn import runner
from tensorflow_gnn.models import contrastive_losses

        "classification": runner.RootNodeBinaryClassification("papers", ...),
        "dgi": contrastive_losses.DeepGraphInfomaxTask("papers"),


In summary, we aim for TF-GNN to facilitate the widespread application of GNNs in TensorFlow, enabling scalable solutions and fostering innovation in the field. If you’re interested in learning more, we encourage you to try our Colab demo featuring the popular OGBN-MAG benchmark, explore additional user guides and Colabs, or delve into our paper for further insights. TF-GNN is designed to empower researchers and practitioners to push the boundaries of graph neural networks within the TensorFlow ecosystem.

Leave a Comment

Your email address will not be published. Required fields are marked *