2.4.0
版本发布时间: 2023-10-12 16:28:59
pyg-team/pytorch_geometric最新发布版本:2.5.3(2024-04-19 19:37:44)
We are excited to announce the release of PyG 2.4 🎉🎉🎉
PyG 2.4 is the culmination of work from 62 contributors who have worked on features and bug-fixes for a total of over 500 commits since torch-geometric==2.3.1
.
Highlights
PyTorch 2.1 and torch.compile(dynamic=True)
support
The long wait has an end! With the release of PyTorch 2.1, PyG 2.4 now brings full support for torch.compile
to graphs of varying size via the dynamic=True
option, which is especially useful for use-cases that involve the usage of DataLoader
or NeighborLoader
. Examples and tutorials have been updated to reflect this support accordingly (#8134), and models and layers in torch_geometric.nn
have been tested to produce zero graph breaks:
import torch_geometric
model = torch_geometric.compile(model, dynamic=True)
When enabling the dynamic=True
option, PyTorch will up-front attempt to generate a kernel that is as dynamic as possible to avoid recompilations when sizes change across mini-batches changes. As such, you should only ever not specify dynamic=True
when graph sizes are guaranteed to never change. Note that dynamic=True
requires PyTorch >= 2.1.0 to be installed.
PyG 2.4 is fully compatible with PyTorch 2.1, and supports the following combinations:
PyTorch 2.1 | cpu |
cu118 |
cu121 |
---|---|---|---|
Linux | ✅ | ✅ | ✅ |
macOS | ✅ | ||
Windows | ✅ | ✅ | ✅ |
You can still install PyG 2.4 on older PyTorch releases up to PyTorch 1.11 in case you are not eager to update your PyTorch version.
OnDiskDataset
Interface
We added the OnDiskDataset
base class for creating large graph datasets (e.g., molecular databases with billions of graphs), which do not easily fit into CPU memory at once (#8028, #8044, #8046, #8051, #8052, #8054, #8057, #8058, #8066, #8088, #8092, #8106). OnDiskDataset
leverages our newly introduced Database
backend (sqlite3
by default) for on-disk storage and access of graphs, supports DataLoader
out-of-the-box, and is optimized for maximum performance.
OnDiskDataset
utilizes a user-specified schema to store data as efficient as possible (instead of Python pickling). The schema can take int
, float
str
, object
or a dictionary with dtype
and size
keys (for specifying tensor data) as input, and can be nested as a dictionary. For example,
dataset = OnDiskDataset(root, schema={
'x': dict(dtype=torch.float, size=(-1, 16)),
'edge_index': dict(dtype=torch.long, size=(2, -1)),
'y': float,
})
creates a database with three columns, where x
and edge_index
are stored as binary data, and y
is stored as a float.
Afterwards, you can append data to the OnDiskDataset
and retrieve data from it via dataset.append()
/dataset.extend()
, and dataset.get()
/dataset.multi_get()
, respectively. We added a fully working example on how to set up your own OnDiskDataset
here (#8102). You can also convert in-memory dataset instances to an OnDiskDataset
instance by running InMemoryDataset.to_on_disk_dataset()
(#8116).
Neighbor Sampling Improvements
Hierarchical Sampling
One drawback of NeighborLoader
is that it computes a representations for all sampled nodes at all depths of the network. However, nodes sampled in later hops no longer contribute to the node representations of seed nodes in later GNN layers, thus performing useless computation. NeighborLoader
will be marginally slower since we are computing node embeddings for nodes we no longer need. This is a trade-off we have made to obtain a clean, modular and experimental-friendly GNN design, which does not tie the definition of the model to its utilized data loader routine.
With PyG 2.4, we introduced the option to eliminate this overhead and speed-up training and inference in mini-batch GNNs further, which we call "Hierarchical Neighborhood Sampling" (see here for the full tutorial) (#6661, #7089, #7244, #7425, #7594, #7942). Its main idea is to progressively trim the adjacency matrix of the returned subgraph before inputting it to each GNN layer, and works seamlessly across several models, both in the homogeneous and heterogeneous graph setting. To support this trimming and implement it effectively, the NeighborLoader
implementation in PyG and in pyg-lib
additionally return the number of nodes and edges sampled in each hop, which are then used on a per-layer basis to trim the adjacency matrix and the various feature matrices to only maintain the required amount (see the trim_to_layer
method):
class GNN(torch.nn.Module):
def __init__(self, in_channels: int, out_channels: int, num_layers: int):
super().__init__()
self.convs = ModuleList([SAGEConv(in_channels, 64)])
for _ in range(num_layers - 1):
self.convs.append(SAGEConv(hidden_channels, hidden_channels))
self.lin = Linear(hidden_channels, out_channels)
def forward(
self,
x: Tensor,
edge_index: Tensor,
num_sampled_nodes_per_hop: List[int],
num_sampled_edges_per_hop: List[int],
) -> Tensor:
for i, conv in enumerate(self.convs):
# Trim edge and node information to the current layer `i`.
x, edge_index, _ = trim_to_layer(
i, num_sampled_nodes_per_hop, num_sampled_edges_per_hop,
x, edge_index)
x = conv(x, edge_index).relu()
return self.lin(x)
Corresponding examples can be found here and here.
Biased Sampling
Additionally, we added support for weighted/biased sampling in NeighborLoader
/LinkNeighborLoader
scenarios. For this, simply specify your edge_weight
attribute during NeighborLoader
initialization, and PyG will pick up these weights to perform weighted/biased sampling (#8038):
data = Data(num_nodes=5, edge_index=edge_index, edge_weight=edge_weight)
loader = NeighborLoader(
data,
num_neighbors=[10, 10],
weight_attr='edge_weight',
)
batch = next(iter(loader))
New models, datasets, examples & tutorials
As part of our algorithm and documentation sprints (#7892), we have added:
-
Model components:
-
MixHopConv
: “MixHop: Higher-Order Graph Convolutional Architecturesvia Sparsified Neighborhood Mixing” (examples/mixhop.py
) (#8025) -
LCMAggregation
: “Learnable Commutative Monoids for Graph Neural Networks” (examples/lcm_aggr_2nd_min.py
) (#7976, #8020, #8023, #8026, #8075) -
DirGNNConv
: “Edge Directionality Improves Learning on Heterophilic Graphs” (examples/dir_gnn.py
) (#7458) - Support for
Performer
inGPSConv
: “Recipe for a General, Powerful, Scalable Graph Transformer” (examples/graph_gps.py
) (#7465) -
PMLP
: “Graph Neural Networks are Inherently Good Generalizers: Insights by Bridging GNNs and MLPs” (examples/pmlp.py
) (#7470, #7543) -
RotateE
: “RotatE: Knowledge Graph Embedding by Relational Rotation in Complex Space” (examples/kge_fb15k_237.py
) (#7026) -
NeuralFingerprint
: “Convolutional Networks on Graphs for Learning Molecular Fingerprints” (#7919)
-
-
Datasets:
HM
(#7515),BrcaTcga
(#7994),MyketDataset
(#7959),Wikidata5M
(#7864),OSE_GVCS
(#7811),MovieLens1M
(#7479),AmazonBook
(#7483),GDELTLite
(#7442),IGMCDataset
(#7441),MovieLens100K
(#7398),EllipticBitcoinTemporalDataset
(#7011),NeuroGraphDataset
(#8112),PCQM4Mv2
(#8102) - Tutorials:
-
Examples:
- Heterogeneous link-level GNN explanations via
CaptumExplainer
(examples/captum_explainer_hetero_link.py
) (#7096) - Training
LightGCN
onAmazonBook
for recommendation (examples/lightgcn.py
) (#7603) - Using the Kùzu remote backend as
FeatureStore
(examples/kuzu
) (#7298) - Multi-GPU training on
ogbn-papers100M
(examples/papers100m_multigpu.py
) (#7921) - The
OGC
model onCora
(examples/ogc.py
) (#8168) - Distributed training via
graphlearn-for-pytorch
(examples/distributed/graphlearn_for_pytorch
) (#7402)
- Heterogeneous link-level GNN explanations via
Join our Slack here if you're interested in joining community sprints in the future!
Breaking Changes
-
Data.keys()
is now a method instead of a property (#7629):<=2.3 2.4 data = Data(x=x, edge_index=edge_index) print(data.keys) # ['x', 'edge_index']
data = Data(x=x, edge_index=edge_index) print(data.keys()) # ['x', 'edge_index']
- Dropped Python 3.7 support (#7939)
- Removed
FastHGTConv
in favor ofHGTConv
(#7117) - Removed the
layer_type
argument fromGraphMaskExplainer
(#7445) - Renamed
dest
argument todst
inutils.geodesic_distance
(#7708)
Deprecations
- Deprecated
contrib.explain.GraphMaskExplainer
in favor ofexplain.algorithm.GraphMaskExplainer
(#7779)
Features
Data
and HeteroData
improvements
- Added a warning for isolated/non-existing node types in
HeteroData.validate()
(#7995) - Added
HeteroData
support into_networkx
(#7713) - Added
Data.sort()
andHeteroData.sort()
(#7649) - Added padding capabilities to
HeteroData.to_homogeneous()
in case feature dimensionalities do not match (#7374) - Added
torch.nested_tensor
support inData
andBatch
(#7643, #7647) - Added
keep_inter_cluster_edges
option toClusterData
to support inter-subgraph edge connections when doing graph partitioning (#7326)
Data-loading improvements
- Added support for floating-point slicing in
Dataset
, e.g.,dataset[:0.9]
(#7915) - Added
save
andload
methods toInMemoryDataset
(#7250, #7413) - Beta: Added
IBMBNodeLoader
andIBMBBatchLoader
data loaders (#6230) - Beta: Added
HyperGraphData
to support hypergraphs (#7611) - Added
CachedLoader
(#7896, #7897) - Allowed GPU tensors as input to
NodeLoader
andLinkLoader
(#7572) - Added
PrefetchLoader
capabilities (#7376, #7378, #7383) - Added manual sampling interface to
NodeLoader
andLinkLoader
(#7197)
Better support for sparse tensors
- Added
SparseTensor
support toWLConvContinuous
,GeneralConv
,PDNConv
andARMAConv
(#8013) - Change
torch_sparse.SparseTensor
logic to utilizetorch.sparse_csr
instead (#7041) - Added support for
torch.sparse.Tensor
inDataLoader
(#7252) - Added support for
torch.jit.script
withinMessagePassing
layers withouttorch_sparse
being installed (#7061, #7062) - Added unbatching logic for
torch.sparse.Tensor
(#7037) - Added support for
Data.num_edges
for nativetorch.sparse.Tensor
adjacency matrices (#7104) - Accelerated sparse tensor conversion routines (#7042, #7043)
- Added a sparse
cross_entropy
implementation (#7447, #7466)
Integration with 3rd-party libraries
torch_geometric.transforms
- All transforms are now immutable, i.e. they perform a shallow-copy of the data and therefore do not longer modify data in-place (#7429)
- Added the
HalfHop
graph upsampling augmentation (#7827) - Added interval argument to
Cartesian
,LocalCartesian
andDistance
transformations (#7533, #7614, #7700) - Added an optional
add_pad_mask
argument to thePad
transform (#7339) - Added
NodePropertySplit
transformation for creating node-level splits using structural node properties (#6894) - Added a
AddRemainingSelfLoops
transformation (#7192)
Bugfixes
- Fixed
HeteroConv
for layers that have a non-default argument order, e.g.,GCN2Conv
(#8166) - Handle reserved keywords as keys in
ModuleDict
andParameterDict
(#8163) - Fixed
DynamicBatchSampler.__len__
to raise an error in casenum_steps
is undefined (#8137) - Enabled pickling of
DimeNet
models (#8019) - Fixed a bug in which
batch.e_id
was not correctly computed on unsorted graph inputs (#7953) - Fixed
from_networkx
conversion fromnx.stochastic_block_model
graphs (#7941) - Fixed the usage of
bias_initializer
inHeteroLinear
(#7923) - Fixed broken URLs in
HGBDataset
(#7907) - Fixed an issue where
SetTransformerAggregation
produced NaN values for isolates nodes (#7902) - Fixed
summary
on modules with uninitialized parameters (#7884) - Fixed tracing of
add_self_loops
for a dynamic number of nodes (#7330) - Fixed device issue in
PNAConv.get_degree_histogram
(#7830) - Fixed the shape of
edge_label_time
when using temporal sampling on homogeneous graphs (#7807) - Fixed
edge_label_index
computation inLinkNeighborLoader
for the homogeneous+disjoint mode (#7791) - Fixed
CaptumExplainer
for binary classification tasks (#7787) - Raise error when collecting non-existing attributes in
HeteroData
(#7714) - Fixed
get_mesh_laplacian
fornormalization="sym"
(#7544) - Use
dim_size
to initialize output size of theEquilibriumAggregation
layer (#7530) - Fixed empty edge indices handling in
SparseTensor
(#7519) - Move the
scaler
tensor inGeneralConv
to the correct device (#7484) - Fixed
HeteroLinear
bug when used via mixed precision (#7473) - Fixed gradient computation of edge weights in
utils.spmm
(#7428) - Fixed an index-out-of-range bug in
QuantileAggregation
whendim_size
is passed (#7407) - Fixed a bug in
LightGCN.recommendation_loss()
to only use the embeddings of the nodes involved in the current mini-batch (#7384) - Fixed a bug in which inputs where modified in-place in
to_hetero_with_bases
(#7363) - Do not load
node_default
andedge_default
attributes infrom_networkx
(#7348) - Fixed
HGTConv
utility function_construct_src_node_feat
(#7194) - Fixed
subgraph
on unordered inputs (#7187) - Allow missing node types in
HeteroDictLinear
(#7185) - Fix
numpy
incompatiblity when reading files forPlanetoid
datasets (#7141) - Fixed crash of heterogeneous data loaders if node or edge types are missing (#7060, #7087)
- Allowed
CaptumExplainer
to be called multiple times in a row (#7391)
Changes
- Enabled dense eigenvalue computation in
AddLaplacianEigenvectorPE
for small-scale graphs (#8143) - Accelerated and simplified
top_k
computation inTopKPooling
(#7737) - Updated
GIN
implementation in benchmarks to apply sequential batch normalization (#7955) - Updated
QM9
data pre-processing to include the SMILES string (#7867) - Warn user when using the
training
flag into_hetero
modules (#7772) - Changed
add_random_edge
to only add true negative edges (#7654) - Allowed the usage of
BasicGNN
models inDeepGraphInfomax
(#7648) - Added a
num_edges
parameter to the forward method ofHypergraphConv
(#7560) - Added a
max_num_elements
parameter to the forward method ofGraphMultisetTransformer
,GRUAggregation
,LSTMAggregation
,SetTransformerAggregation
andSortAggregation
(#7529, #7367) - Re-factored
ClusterLoader
to integratepyg-lib
METIS routine (#7416) - The
filter_per_worker
option will not get automatically inferred by default based on the device of the underlying data (#7399) - Added the option to pass
fill_value
as atorch.tensor
toutils.to_dense_batch
(#7367) - Updated examples to use
NeighborLoader
instead ofNeighborSampler
(#7152) - Extend dataset summary to create stats for each node/edge type (#7203)
- Added an optional
batch_size
argument toavg_pool_x
andmax_pool_x
(#7216) - Optimized
from_networkx
memory footprint by reducing unnecessary copies (#7119) - Added an optional
batch_size
argument toLayerNorm
,GraphNorm
,InstanceNorm
,GraphSizeNorm
andPairNorm
(#7135) - Accelerated attention-based
MultiAggregation
(#7077) - Edges in
HeterophilousGraphDataset
are now undirected by default (#7065) - Added an optional
batch_size
andmax_num_nodes
arguments toMemPooling
layer (#7239)
Full Changelog
Full Changelog: https://github.com/pyg-team/pytorch_geometric/compare/2.3.0...2.4.0
New Contributors
- @zoryzhang made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7027
- @DomInvivo made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7037
- @OlegPlatonov made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7065
- @hbenedek made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7053
- @rishiagarwal2000 made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7011
- @sisaman made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7104
- @amorehead made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7110
- @EulerPascal404 made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7093
- @Looong01 made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7143
- @kamil-andrzejewski made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7135
- @andreazanetti made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7089
- @akihironitta made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7195
- @kjkozlowski made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7216
- @vstenby made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7221
- @piotrchmiel made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7239
- @vedal made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7272
- @gvbazhenov made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/6894
- @Saydemr made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7313
- @HaoyuLu1022 made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7325
- @Vuenc made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7330
- @mewim made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7298
- @volltin made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7355
- @kasper-piskorski made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7377
- @happykygo made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7384
- @ThomasKLY made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7398
- @sky-2002 made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7421
- @denadai2 made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7456
- @chrisgo-gc made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7484
- @furkanakkurt1335 made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7507
- @mzamini92 made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7497
- @n-patricia made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7543
- @SalvishGoomanee made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7573
- @emalgorithm made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7458
- @marshka made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7595
- @djm93dev made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7598
- @NripeshN made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7770
- @ATheCoder made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7774
- @ebrahimpichka made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7775
- @kaidic made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7814
- @Wesxdz made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7811
- @daviddavo made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7888
- @frinkleko made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7907
- @chendiqian made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7917
- @rajveer43 made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7885
- @erfanloghmani made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7959
- @xnuohz made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7937
- @Favourj-bit made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7905
- @apfelsinecode made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7996
- @ArchieGertsman made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7976
- @bkmi made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/8019
- @harshit5674 made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7919
- @erikhuck made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/8024
- @jay-bhambhani made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/8028
- @Barcavin made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/8049
- @royvelich made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/8048
- @CodeTal made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/7611
- @filipekstrm made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/8117
- @Anwar-Said made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/8122
- @xYix made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/8168