2.5.0
版本发布时间: 2024-02-16 15:16:56
pyg-team/pytorch_geometric最新发布版本:2.5.3(2024-04-19 19:37:44)
We are excited to announce the release of PyG 2.5 🎉🎉🎉
PyG 2.5 is the culmination of work from 38 contributors who have worked on features and bug-fixes for a total of over 360 commits since torch-geometric==2.4.0
.
Highlights
torch_geometric.distributed
We are thrilled to announce the first in-house distributed training solution for PyG via the torch_geometric.distributed
sub-package. Developers and researchers can now take full advantage of distributed training on large-scale datasets which cannot be fully loaded in memory of one machine at the same time. This implementation doesn't require any additional packages to be installed on top of the default PyG stack.
Key Advantages
- Balanced graph partitioning via METIS ensures minimal communication overhead when sampling subgraphs across compute nodes.
- Utilizing DDP for model training in conjunction with RPC for remote sampling and feature fetching routines (with TCP/IP protocol and gloo communication backend) allows for data parallelism with distinct data partitions at each node.
- The implementation via custom
GraphStore
andFeatureStore
APIs provides a flexible and tailored interface for distributing large graph structure information and feature storage. - Distributed neighbor sampling is capable of sampling in both local and remote partitions through RPC communication channels. All advanced functionality of single-node sampling are also applicable for distributed training, e.g., heterogeneous sampling, link-level sampling, temporal sampling, etc.
- Distributed data loaders offer a high-level abstraction for managing sampler processes, ensuring simplicity and seamless integration with standard PyG data loaders.
See here for the accompanying tutorial. In addition, we provide two distributed examples in examples/distributed/pyg
to get started:
-
Distributed node-level classification on
ogbn-products
-
Distributed temporal link prediction on
MovieLens
EdgeIndex
Tensor Representation
torch-geometric==2.5.0
introduces the EdgeIndex
class.
EdgeIndex
is a torch.Tensor
, that holds an edge_index
representation of shape [2, num_edges]
. Edges are given as pairwise source and destination node indices in sparse COO format. While EdgeIndex
sub-classes a general torch.Tensor
, it can hold additional (meta)data, i.e.:
-
sparse_size
: The underlying sparse matrix size -
sort_order
: The sort order (if present), either by row or column -
is_undirected
: Whether edges are bidirectional.
Additionally, EdgeIndex
caches data for fast CSR or CSC conversion in case its representation is sorted (i.e. its rowptr
or colptr
). Caches are filled based on demand (e.g., when calling EdgeIndex.sort_by()
), or when explicitly requested via EdgeIndex.fill_cache_()
, and are maintained and adjusted over its lifespan (e.g., when calling EdgeIndex.flip()
).
from torch_geometric import EdgeIndex
edge_index = EdgeIndex(
[[0, 1, 1, 2],
[1, 0, 2, 1]]
sparse_size=(3, 3),
sort_order='row',
is_undirected=True,
device='cpu',
)
>>> EdgeIndex([[0, 1, 1, 2],
... [1, 0, 2, 1]])
assert edge_index.is_sorted_by_row
assert edge_index.is_undirected
# Flipping order:
edge_index = edge_index.flip(0)
>>> EdgeIndex([[1, 0, 2, 1],
... [0, 1, 1, 2]])
assert edge_index.is_sorted_by_col
assert edge_index.is_undirected
# Filtering:
mask = torch.tensor([True, True, True, False])
edge_index = edge_index[:, mask]
>>> EdgeIndex([[1, 0, 2],
... [0, 1, 1]])
assert edge_index.is_sorted_by_col
assert not edge_index.is_undirected
# Sparse-Dense Matrix Multiplication:
out = edge_index.flip(0) @ torch.randn(3, 16)
assert out.size() == (3, 16)
EdgeIndex
is implemented through extending torch.Tensor
via the __torch_function__
interface (see here for the highly recommended tutorial).
EdgeIndex
ensures for optimal computation in GNN message passing schemes, while preserving the ease-of-use of regular COO-based PyG workflows. EdgeIndex
will fully deprecate the usage of SparseTensor
from torch-sparse
in later releases, leaving us with just a single source of truth for representing graph structure information in PyG.
RecSys Support
Previously, all/most of our link prediction models were trained and evaluated using binary classification metrics. However, this usually requires that we have a set of candidates in advance, from which we can then infer the existence of links. This is not necessarily practical, since in most cases, we want to find the top-k
most likely links from the full set of O(N^2)
pairs.
torch-geometric==2.5.0
brings full support for using GNNs as a recommender system (#8452), including support for
-
Maximum Inner Product Search (MIPS) via
MIPSKNNIndex
-
Retrieval metrics such as
f1@k
,map@k
,precision@k
,recall@k
andndcg@k
, including mini-batch support
mips = MIPSKNNIndex(dst_emb)
for src_batch in src_loader:
src_emb = model(src_batch.x_dict, src_batch.edge_index_dict)
_, pred_index_mat = mips.search(src_emb, k)
for metric in retrieval_metrics:
metric.update(pred_index_mat, edge_label_index)
for metric in retrieval_metrics:
metric.compute()
See here for the accompanying example.
PyTorch 2.2 Support
PyG 2.5 is fully compatible with PyTorch 2.2 (#8857), and supports the following combinations:
PyTorch 2.2 | cpu |
cu118 |
cu121 |
---|---|---|---|
Linux | ✅ | ✅ | ✅ |
macOS | ✅ | ||
Windows | ✅ | ✅ | ✅ |
You can still install PyG 2.5 with an older PyTorch release up to PyTorch 1.12 in case you are not eager to update your PyTorch version.
Native torch.compile(...)
and TorchScript Support
torch-geometric==2.5.0
introduces a full re-implementation of the MessagePassing
interface, which makes it natively applicable to both torch.compile
and TorchScript. As such, torch_geometric.compile
is now fully deprecated in favor of torch.compile
- model = torch_geometric.compile(model)
+ model = torch.compile(model)
and MessagePassing.jittable()
is now a no-op:
- conv = torch.jit.script(conv.jittable())
+ model = torch.jit.script(conv)
In addition, torch.compile
usage has been fixed to not require disabling of extension packages such as torch-scatter
or torch-sparse
.
New Tutorials, Examples, Models and Improvements
-
Tutorials
- Multi-Node Training using SLURM (#8071)
- Point Cloud Processing (#8015)
-
Examples
- Distributed training via
torch_geometric.distributed
(examples/distributed/pyg/
) (#8713) - Edge-level temporal sampling on a heterogeneous graph (
examples/hetero/temporal_link_pred.py
) (#8383) - Edge-level temporal sampling on a heterogeneous graph with distributed training (
examples/distributed/pyg/temporal_link_movielens_cpu.py
) (#8820) - Distributed training on XPU device (
examples/multi_gpu/distributed_sampling_xpu.py
) (#8032) - Multi-node multi-GPU training on
ogbn-papers100M
(examples/multi_gpu/papers100m_gcn_multinode.py
) (#8070) - Naive model parallelism on multiple GPUs (
examples/multi_gpu/model_parallel.py
) (#8309)
- Distributed training via
-
Models
- Added the equivariant
ViSNet
from "ViSNet: an equivariant geometry-enhanced graph neural network with vector-scalar interactive message passing for molecules" (#8287)
- Added the equivariant
-
Improvements
- Enabled multi-GPU evaluation in distributed sampling example (
examples/multi_gpu/distributed_sampling.py
) (#8880)
- Enabled multi-GPU evaluation in distributed sampling example (
Breaking Changes
-
GATConv
now initializes modules differently depending on whether their input is bipartite or non-bipartite (#8397). This will lead to issues when loading model state forGATConv
layers trained on earlier PyG versions.
Deprecations
- Deprecated
torch_geometric.compile
in favor oftorch.compile
(#8780) - Deprecated
torch_geometric.nn.DataParallel
in favor oftorch.nn.parallel.DistributedDataParallel
(#8250) - Deprecated
MessagePassing.jittable
(#8781, #8731) - Deprecated
torch_geometric.data.makedirs
in favor ofos.makedirs
(#8421)
Features
Package-wide Improvements
- Added support for type checking via
mypy
(#8254) - Added
fsspec
as file system backend (#8379, #8426, #8434, #8474) - Added fallback code path for segment-based reductions in case
torch-scatter
is not installed (#8852)
Temporal Graph Support
- Added support for edge-level temporal sampling in
NeighborLoader
andLinkNeighborLoader
(#8372, #8428) - Added
Data.{sort_by_time,is_sorted_by_time,snapshot,up_to}
for temporal graph use-cases (#8454) - Added support for graph partitioning for temporal data in
torch_geometric.distributed
(#8718, #8815)
torch_geometric.datasets
- Added the Risk Commodity Detection Dataset (
RCDD
) from "Datasets and Interfaces for Benchmarking Heterogeneous Graph Neural Networks" (#8196) - Added the
StochasticBlockModelDataset(num_graphs: int)
argument (#8648) - Added support for floating-point average degree numbers in
FakeDataset
andFakeHeteroDataset
(#8404) - Added
InMemoryDataset.to(device)
(#8402) - Added the
force_reload: bool = False
argument toDataset
andInMemoryDataset
in order to enforce re-processing of datasets (#8352, #8357, #8436) - Added the
TreeGraph
andGridMotif
generators (#8736)
torch_geometric.nn
- Added
KNNIndex
exclusion logic (#8573) - Added support for MRR computation in
KGEModel.test()
(#8298) - Added support for
nn.to_hetero_with_bases
on static graphs (#8247) - Addressed graph breaks in
ModuleDict
,ParameterDict
,MultiAggregation
andHeteroConv
for better support fortorch.compile
(#8363, #8345, #8344)
torch_geometric.metrics
- Added support for
f1@k
,map@k
,precision@k
,recall@k
andndcg@k
metrics for link-prediction retrieval tasks (#8499, #8326, #8566, #8647)
torch_geometric.explain
- Enabled skipping explanations of certain message passing layers via
conv.explain = False
(#8216) - Added support for visualizing explanations with node labels via
visualize_graph(node_labels: list[str] | None)
argument (#8816)
torch_geometric.transforms
- Added a faster dense computation code path in
AddRandomWalkPE
(#8431)
Other Improvements
- Added support for returning multi graphs in
utils.to_networkx
(#8575) - Added noise scheduler utilities
utils.noise_scheduler.{get_smld_sigma_schedule,get_diffusion_beta_schedule}
for diffusion-based graph generative models (#8347) - Added a relabel node functionality to
utils.dropout_node
viarelabel_nodes: bool
argument (#8524) - Added support for weighted
utils.cross_entropy.sparse_cross_entropy
(#8340) - Added support for profiling on XPU device via
profile.profileit("xpu")
(#8532) - Added METIS partitioning with CSC/CSR format selection in
ClusterData
(#8438)
Bugfixes
- Fixed dummy value creation of boolean tensors in
HeteroData.to_homogeneous()
(#8858) - Fixed Google Drive download issues (#8804)
- Fixed
InMemoryDataset
to reconstruct the correct data class when apre_transform
has modified it (#8692) - Fixed a bug in which transforms were not applied for
OnDiskDataset
(#8663) - Fixed mini-batch computation in
DMoNPooing
loss function (#8285) - Fixed
NaN
handling inSQLDatabase
(#8479) - Fixed
CaptumExplainer
in case noindex
is passed (#8440) - Fixed
edge_index
construction in theUPFD
dataset (#8413) - Fixed TorchScript support in
AttentionalAggregation
andDeepSetsAggregation
(#8406) - Fixed
GraphMaskExplainer
for GNNs with more than two layers (#8401) - Fixed
input_id
computation inNeighborLoader
in case amask
is given (#8312) - Respect current device when deep-copying
Linear
layers (#8311) - Fixed
Data.subgraph()
/HeteroData.subgraph()
in caseedge_index
is not defined (#8277) - Fixed empty edge handling in
MetaPath2Vec
(#8248) - Fixed
AttentionExplainer
usage withinAttentiveFP
(#8244) - Fixed
load_from_state_dict
in lazyLinear
modules (#8242) - Fixed pre-trained
DimeNet++
performance onQM9
(#8239) - Fixed
GNNExplainer
usage withinAttentiveFP
(#8216) - Fixed
to_networkx(to_undirected=True)
in case the input graph is not undirected (#8204) - Fixed sparse-sparse matrix multiplication support on Windows in
TwoHop
andAddRandomWalkPE
transformations (#8197, #8225) - Fixed mini-batching of
HeteroData
objects converted viaToSparseTensor()
whentorch-sparse
is not installed (#8356)
Changes
- Disallow the usage of
add_self_loops=True
inGCNConv(normalize=False)
(#8210) - Changed the default inference mode for
use_segment_matmul
based on benchmarking results (from a heuristic-based version) (#8615) - Sparse node features in
NELL
andAttributedGraphDataset
are now represented astorch.sparse_csr_tensor
instead oftorch_sparse.SparseTensor
(#8679) - Accelerated mini-batching of
torch.sparse
tensors (#8670) -
ExplainerDataset
will now contain node labels for any motif generator (#8519) - Made
utils.softmax
faster via the in-housepyg_lib.ops.softmax_csr
kernel (#8399) - Made
utils.mask.mask_select
faster (#8369) - Added a warning when calling
Dataset.num_classes
on regression datasets (#8550)
New Contributors
- @stadlmax made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/8207
- @joaquincabezas made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/8215
- @SZiesche made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/8201
- @666even666 made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/8268
- @irustandi made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/8274
- @joao-alex-cunha made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/8223
- @chaous made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/8309
- @songsong0425 made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/8298
- @rachitk made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/8356
- @pmpalang made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/8372
- @flxmr made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/8353
- @GuyAglionby made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/8401
- @plutonium-239 made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/8404
- @asherbondy made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/8379
- @wwang-chcn made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/8285
- @SimonPop made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/8326
- @brovatten made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/8566
- @XJTUNR made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/8287
- @kativenOG made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/8648
- @ilsenatorov made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/8663
- @dependabot made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/8706
- @Sutongtong233 made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/8736
- @m-atalla made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/8755
- @A-LOST-WAPITI made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/8818
- @AtomicVar made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/8825
- @vahanhov made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/8816
- @rraadd88 made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/8842
- @mashaan14 made their first contribution in https://github.com/pyg-team/pytorch_geometric/pull/8884
Full Changelog: https://github.com/pyg-team/pytorch_geometric/compare/2.4.0...2.5.0