v0.4.0
版本发布时间: 2018-04-25 04:49:48
pytorch/pytorch最新发布版本:v2.4.1(2024-09-05 03:59:29)
PyTorch 0.4.0 release notes
Table of Contents
- Major Core Changes
- Tensor / Variable merged
- Zero-dimensional Tensors
- dtypes
- migration guide
- New Features
- Tensors
- Full support for advanced indexing
- Fast Fourier Transforms
- Neural Networks
- Trade-off memory for compute
- bottleneck - a tool to identify hotspots in your code
- torch.distributions
- 24 basic probability distributions
- Added cdf, variance, entropy, perplexity etc.
- Distributed Training
- Launcher utility for ease of use
- NCCL2 backend
- C++ Extensions
- Windows Support
- ONNX Improvements
- RNN support
- Tensors
- Performance improvements
- Bug fixes
Major Core changes
Here is a summary of the updates to the most important core features users will use daily.
Major Changes and Potentially Breaking Changes:
-
Tensors
andVariables
have merged - Some operations now return 0-dimensional (scalar)
Tensors
- Deprecation of the
volatile
flag
Improvements:
-
dtypes
,devices
, and Numpy-styleTensor
creation functions added - Support for writing device-agnostic code
We wrote a migration guide that should help you transition your code to new APIs and style. Please read it if you have code in a previous version of PyTorch that you would like to migrate.
Please read the migration guide if you have code in a previous version of PyTorch that you would like to migrate. Please read the migration guide if you have code in a previous version of PyTorch that you would like to migrate. Please read the migration guide if you have code in a previous version of PyTorch that you would like to migrate.
The contents of this section (Major Core changes) are included in the migration guide.
Merging Tensor
and Variable
classes
torch.autograd.Variable
and torch.Tensor
are now the same class. More precisely, torch.Tensor
is capable of tracking history and behaves like the old Variable
; Variable
wrapping continues to work as before but returns an object of type torch.Tensor
. This means that you don't need the Variable
wrapper everywhere in your code anymore.
The type()
of a Tensor
has changed
Note also that the type()
of a Tensor no longer reflects the data type. Use isinstance()
or x.type()
instead:
>>> x = torch.DoubleTensor([1, 1, 1])
>>> print(type(x)) # was torch.DoubleTensor
<class 'torch.autograd.variable.Variable'>
>>> print(x.type()) # OK: 'torch.DoubleTensor'
'torch.DoubleTensor'
>>> print(isinstance(x, torch.DoubleTensor)) # OK: True
True
When does autograd
start tracking history now?
requires_grad
, the central flag for autograd
, is now an attribute on Tensor
s. Let's see how this change manifests in code.
autograd
uses the same rules previously used for Variable
s. It starts tracking history when any input Tensor
of an operation has requires_grad=True
. For example,
>>> x = torch.ones(1) # create a tensor with requires_grad=False (default)
>>> x.requires_grad
False
>>> y = torch.ones(1) # another tensor with requires_grad=False
>>> z = x + y
>>> # both inputs have requires_grad=False. so does the output
>>> z.requires_grad
False
>>> # then autograd won't track this computation. let's verify!
>>> z.backward()
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
>>>
>>> # now create a tensor with requires_grad=True
>>> w = torch.ones(1, requires_grad=True)
>>> w.requires_grad
True
>>> # add to the previous result that has require_grad=False
>>> total = w + z
>>> # the total sum now requires grad!
>>> total.requires_grad
True
>>> # autograd can compute the gradients as well
>>> total.backward()
>>> w.grad
tensor([ 1.])
>>> # and no computation is wasted to compute gradients for x, y and z, which don't require grad
>>> z.grad == x.grad == y.grad == None
True
Manipulating requires_grad
flag
Other than directly setting the attribute, you can change this flag in-place using my_tensor.requires_grad_(requires_grad=True)
, or, as in the above example, at creation time by passing it in as an argument (default is False
), e.g.,
>>> existing_tensor.requires_grad_()
>>> existing_tensor.requires_grad
True
>>> my_tensor = torch.zeros(3, 4, requires_grad=True)
>>> my_tensor.requires_grad
True
What about .data
?
.data
was the primary way to get the underlying Tensor
from a Variable
. After this merge, calling y = x.data
still has similar semantics. So y
will be a Tensor
that shares the same data with x
, is unrelated with the computation history of x
, and has requires_grad=False
.
However, .data
can be unsafe in some cases. Any changes on x.data
wouldn't be tracked by autograd
, and the computed gradients would be incorrect if x
is needed in a backward pass. A safer alternative is to use x.detach()
, which also returns a Tensor
that shares data with requires_grad=False
, but will have its in-place changes reported by autograd
if x
is needed in backward.
Some operations now return 0-dimensional (scalar) Tensors
Previously, indexing into a Tensor
vector (1-dimensional tensor) gave a Python number but indexing into a Variable
vector gave (incosistently!) a vector of size (1,)
! Similar behavior existed with reduction functions, i.e. tensor.sum()
would return a Python number, but variable.sum()
would retun a vector of size (1,)
.
Fortunately, this release introduces proper scalar (0-dimensional tensor) support in PyTorch! Scalars can be created using the new torch.tensor
function (which will be explained in more detail later; for now just think of it as the PyTorch equivalent of numpy.array
). Now you can do things like:
>>> torch.tensor(3.1416) # create a scalar directly
tensor(3.1416)
>>> torch.tensor(3.1416).size() # scalar is 0-dimensional
torch.Size([])
>>> torch.tensor([3]).size() # compare to a vector of size 1
torch.Size([1])
>>>
>>> vector = torch.arange(2, 6) # this is a vector
>>> vector
tensor([ 2., 3., 4., 5.])
>>> vector.size()
torch.Size([4])
>>> vector[3] # indexing into a vector gives a scalar
tensor(5.)
>>> vector[3].item() # .item() gives the value as a Python number
5.0
>>> sum = torch.tensor([2, 3]).sum()
>>> sum
tensor(5)
>>> sum.size()
torch.Size([])
Accumulating losses
Consider the widely used pattern total_loss += loss.data[0]
before 0.4.0. loss
was a Variable
wrapping a tensor of size (1,)
, but in 0.4.0 loss
is now a scalar and has 0
dimensions. Indexing into a scalar doesn't make sense (it gives a warning now, but will be a hard error in 0.5.0): use loss.item()
to get the Python number from a scalar.
Note that if you don't convert to a Python number when accumulating losses, you may find increased memory usage in your program. This is because the right-hand-side of the above expression used to be a Python float, while it is now a zero-dim Tensor. The total loss is thus accumulating Tensors and their gradient history, which may keep around large autograd graphs for much longer than necessary.
Deprecation of volatile
flag
The volatile
flag is now deprecated and has no effect. Previously, any computation that involves a Variable
with volatile=True
won't be tracked by autograd
. This has now been replaced by a set of more flexible context managers including torch.no_grad()
, torch.set_grad_enabled(grad_mode)
, and others.
>>> x = torch.zeros(1, requires_grad=True)
>>> with torch.no_grad():
... y = x * 2
>>> y.requires_grad
False
>>>
>>> is_train = False
>>> with torch.set_grad_enabled(is_train):
... y = x * 2
>>> y.requires_grad
False
>>> torch.set_grad_enabled(True) # this can also be used as a function
>>> y = x * 2
>>> y.requires_grad
True
>>> torch.set_grad_enabled(False)
>>> y = x * 2
>>> y.requires_grad
False
dtypes
, devices
and NumPy-style creation functions
In previous versions of PyTorch, we used to specify data type (e.g. float vs double), device type (cpu vs cuda) and layout (dense vs sparse) together as a "tensor type". For example, torch.cuda.sparse.DoubleTensor
was the Tensor
type respresentingdouble
data type, living on CUDA devices, and with COO sparse tensor layout.
In this release, we introduce torch.dtype
, torch.device
and torch.layout
classes to allow better management of these properties via NumPy-style creation functions.
torch.dtype
Below is a complete list of available torch.dtype
s (data types) and their corresponding tensor types.
Data type | torch.dtype |
Tensor types |
---|---|---|
32-bit floating point | torch.float32 or torch.float |
torch.*.FloatTensor |
64-bit floating point | torch.float64 or torch.double |
torch.*.DoubleTensor |
16-bit floating point | torch.float16 or torch.half |
torch.*.HalfTensor |
8-bit integer (unsigned) | torch.uint8 |
torch.*.ByteTensor |
8-bit integer (signed) | torch.int8 |
torch.*.CharTensor |
16-bit integer (signed) | torch.int16 or torch.short |
torch.*.ShortTensor |
32-bit integer (signed) | torch.int32 or torch.int |
torch.*.IntTensor |
64-bit integer (signed) | torch.int64 or torch.long |
torch.*.LongTensor |
Use torch.set_default_dtype
and torch.get_default_dtype
to manipulate default dtype
for floating point tensors.
torch.device
A torch.device
contains a device type ('cpu'
or 'cuda'
) and optional device ordinal (id) for the device type. It can be initilized with torch.device('{device_type}')
or torch.device('{device_type}:{device_ordinal}')
.
If the device ordinal is not present, this represents the current device for the device type; e.g., torch.device('cuda')
is equivalent to torch.device('cuda:X')
where X
is the result of torch.cuda.current_device()
.
torch.layout
torch.layout
represents the data layout of a Tensor
. Currentlytorch.strided
(dense tensors) and torch.sparse_coo
(sparse tensors with COO format) are supported.
Creating Tensor
s
Methods that create a Tensor
now also take in dtype
, device
, layout
, and requires_grad
options to specify the desired attributes on the returned Tensor
. For example,
>>> device = torch.device("cuda:1")
>>> x = torch.randn(3, 3, dtype=torch.float64, device=device)
tensor([[-0.6344, 0.8562, -1.2758],
[ 0.8414, 1.7962, 1.0589],
[-0.1369, -1.0462, -0.4373]], dtype=torch.float64, device='cuda:1')
>>> x.requires_grad # default is False
False
>>> x = torch.zeros(3, requires_grad=True)
>>> x.requires_grad
True
torch.tensor
torch.tensor
is one of the newly added tensor creation methods. It takes in array like data of all kinds and copies the contained values into a new Tensor
. As mentioned earlier, torch.tensor
is the PyTorch equivalent of NumPy's numpy.array
constructor. Unlike the torch.*Tensor
methods, you can also create zero-dimensional Tensor
s (aka scalars) this way (a single python number is treated as a Size in thetorch.*Tensor
methods). Moreover, if a dtype
argument isn't given, it will infer the suitable dtype
given the data. It is the recommended way to create a tensor from existing data like a Python list. For example,
>>> cuda = torch.device("cuda")
>>> torch.tensor([[1], [2], [3]], dtype=torch.half, device=cuda)
tensor([[ 1],
[ 2],
[ 3]], device='cuda:0')
>>> torch.tensor(1) # scalar
tensor(1)
>>> torch.tensor([1, 2.3]).dtype # type inferece
torch.float32
>>> torch.tensor([1, 2]).dtype # type inferece
torch.int64
We've also added more tensor creation methods. Some of them have torch.*_like
and/or tensor.new_*
variants.
-
torch.*_like
takes in an inputTensor
instead of a shape. It returns aTensor
with same attributes as the inputTensor
by default unless otherwise specified:>>> x = torch.randn(3, dtype=torch.float64) >>> torch.zeros_like(x) tensor([ 0., 0., 0.], dtype=torch.float64) >>> torch.zeros_like(x, dtype=torch.int) tensor([ 0, 0, 0], dtype=torch.int32)
-
tensor.new_*
can also createTensor
s with same attributes astensor
, but it always takes in a shape argument:>>> x = torch.randn(3, dtype=torch.float64) >>> x.new_ones(2) tensor([ 1., 1.], dtype=torch.float64) >>> x.new_ones(4, dtype=torch.int) tensor([ 1, 1, 1, 1], dtype=torch.int32)
To specify the desired shape, you can either use a tuple (e.g., torch.zeros((2, 3))
) or variable arguments (e.g., torch.zeros(2, 3)
) in most cases.
Name | Returned Tensor |
torch.*_like variant |
tensor.new_* variant |
---|---|---|---|
torch.empty |
unintialized memory | ✔ | ✔ |
torch.zeros |
all zeros | ✔ | ✔ |
torch.ones |
all ones | ✔ | ✔ |
torch.full |
filled with a given value | ✔ | ✔ |
torch.rand |
i.i.d. continuous Uniform[0, 1) |
✔ | |
torch.randn |
i.i.d. Normal(0, 1) |
✔ | |
torch.randint |
i.i.d. discrete Uniform in given range | ✔ | |
torch.randperm |
random permutation of {0, 1, ..., n - 1} |
||
torch.tensor |
copied from existing data (list , NumPy ndarray , etc.) |
✔ | |
torch.from_numpy * |
from NumPy ndarray (sharing storage without copying) |
||
torch.arange , torch.range , and torch.linspace |
uniformly spaced values in a given range | ||
torch.logspace |
logarithmically spaced values in a given range | ||
torch.eye |
identity matrix |
*: torch.from_numpy
only takes in a NumPy ndarray
as its input argument.
Writing device-agnostic code
Previous versions of PyTorch made it difficult to write code that was device agnostic (i.e. that could run on both CUDA-enabled and CPU-only machines without modification).
PyTorch 0.4.0 makes this easier in two ways:
- The
device
attribute of a Tensor gives thetorch.device
for all Tensors (get_device
only works for CUDA tensors) - The
to
method ofTensors
andModules
can be used to easily move objects to different devices (instead of having to callcpu()
orcuda()
based on the context)
We recommend the following pattern:
# at beginning of the script
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
...
# then whenever you get a new Tensor or Module
# this won't copy if they are already on the desired device
input = data.to(device)
model = MyModule(...).to(device)
Tensors
Full support for Advanced indexing
PyTorch now has full support for advanced indexing, following numpy's advanced indexing rules. The following examples are now possible:
a = torch.rand(10, 10, 10, 10)
# the indexing elements can have other shapes than 1
b = a[[[3, 2]], :, [[1, 3]]]
# broadcasting also supported in the indices, as well as lists,
# negative indices, slices, elipses, numbers
c = a[[1, -2], 2:4, :, [1]]
# can also support tensors as indices
index = torch.tensor([2, 4])
d = a[index]
# and the indices can be on the GPU
# or CPU
e = a[index.cuda()]
f = a.cuda()[index]
mask = torch.rand(10) > 0.5
# we can now index with a mask that has fewer
# dimensions than the indexing tensor
c = a[mask, :5]
Fast Fourier Transform
- Add new FFT methods #5856
- Add
torch.stft
(short time Fourier transform) and hann/hamming/bartlett window functions. #4095 - Support arbitrary number of batch dimensions in *FFT #6528
New and updated Torch operators
- Added
torch.log2
andtorch.log10
#6272 - Added
torch.isnan
#5273 - Add
torch.reshape
, which is similar tonumpy.reshape
. It is roughly equivalent totensor.contiguous().view()
, but avoids copying in certain cases #5575 - Add CPU implementation of
torch.unique
, which outputs the unique elements of a Tensor #5503 - Add
torch.det
,torch.logdet
andtorch.slogdet
, for computing the (log-)determinant of square 2D tensors. For negative determinants,torch.logdet
returnsnan
, whiletorch.slogdet
returns the sign of the log-determinant and the log of the absolute value of the determinant. #3816 and #5393 - Add
nn.functional.gumbel_softmax
, which lets you use the reparametrization trick for discrete variables #3341 - Add
torch.take
andTensor.put_
. Those functions are equivalent to numpy.take and numpy.put, and are the base for full support of advanced indexing in PyTorch #3263 - Add
torch.randint
, similar tonumpy.random.randint
#6136 - Add
torch.diagonal
andtorch.diagflat
, similar tonumpy.diagonal
andnumpy.diagflat
. They are meant as a replacement fortorch.diag
, which handled both the cases of constructing a diagonal tensor as well as extracting the diagonal of a matrix #5622 - Add
torch.einsum
, equivalent tonumpy.einsum
. einsum allows you to perform operations using Einstein's notation. #5503
a = torch.arange(0, 9).reshape(3, 3)
# the following transposes a
b = torch.einsum('ij->ji', (a,))
- Add
torch.expm1
, a numerically stableexp(x)-1
for smallx
. #4350 - Allow users to specify individual split sizes with
torch.split
#3837 - Add
torch.where(condition, tensor1, tensor2)
that returns a tensors of elements selected fromtensor1
ortensor2
based oncondition
. #4259, #4259 - Add
Tensor.norm(dim)
for sparse tensors. #4882 - Implement
torch.neg
for all types. #4075 - Implement gradient calculation for
torch.trtrs
. #3972 - Deprecate out-of-place
Tensor.resize
andTensor.resize_as
. These have weird semantics and are hard to use correctly. Please use their in-place variantsTensor.resize_
andTensor.resize_as_
. #4886
Rename async
argument in .cuda()
to non_blocking
The async
keyword argument in conversion calls is now deprecated in PyTorch, and it has been replaced by non_blocking
. This was necessary because async
will be a keyword in Python 3.7
Neural Networks
A new autograd container that lets you trade compute for memory
The new checkpoint
container allows you to only store a subset of the outputs necessary for backpropagation. If an output is missing (to save memory), the checkpoint
container will recompute the intermediate outputs from the closest checkpoint, so that memory usage can be reduced (with an increase in computation time).
Here is an example:
# input
input = torch.rand(1, 10)
# suppose we have a very deep model
layers = [nn.Linear(10, 10) for _ in range(1000)]
model = nn.Sequential(*layers)
output = model(input)
The above model uses a lot of memory, because it needs to keep the intermediate values of every operation for backpropagation. checkpoint
lets your reduce the memory requirements:
# create the input tensors and set the requires_grad=True
# NOTE: the requires_grad=True for the input is a current
# limitation of checkpointing. At least one of the
# model inputs should have requires_grad=True.
# If you don't do it, you might have empty gradients.
input = torch.rand(1, 10, requires_grad=True)
layers = [nn.Linear(10, 10) for _ in range(1000)]
# define function that will define where
# we will checkpoint and store
# intermediate gradients. In this case,
# we will only store one intermediate
# gradient, in the middle of the
# model
def run_first_half(*args):
x = args[0]
for layer in layers[:500]:
x = layer(x)
return x
def run_second_half(*args):
x = args[0]
for layer in layers[500:-1]:
x = layer(x)
return x
# now uses the new checkpoint functionality
from torch.utils.checkpoint import checkpoint
x = checkpoint(run_first_half, input)
x = checkpoint(run_second_half, x)
# last output need to be run without checkpoint
x = layers[-1](x)
x.sum.backward() # works!
For sequential modules (which can have arbitrary blocks inside), a helper function checkpoint_sequential
is provided, which takes care of the most common use-cases:
input = torch.rand(1, 10, requires_grad=True)
layers = [nn.Linear(10, 10) for _ in range(1000)]
model = nn.Sequential(*layers)
from torch.utils.checkpoint import checkpoint_sequential
# split in two blocks
num_segments = 2
x = checkpoint_sequential(model, num_segments, input)
x.sum().backward() # works!
bottleneck - a tool to identify hotspots in your code
torch.utils.bottleneck
(#5216, #6425) is a tool that can be used as an initial step for
debugging bottlenecks in your program. It summarizes runs of your script with
the Python profiler and PyTorch’s autograd profiler. See the bottleneck docs for more details.
reduce=False Losses
As of this release, all of our loss functions support the reduce
keyword. Specifying reduce=False
gives a Tensor per unit of loss instead of a single reduced loss. #4924, #5346, #5646, #4231, #4705, #5680
New modules and module improvements
- Add
DistributedDataParallelCPU
. This is similar toDistributedDataParallel
, but with specific support for models running on the CPU (contrary toDistributedDataParallel
, which targets GPU), and supportsmpi
,gloo
andtcp
backends #5919. - Add Group Normalization (
nn.GroupNorm
), an alternative to batch normalization that doesn't suffer from the same issues asBatchNorm
for small batch sizes - Add Layer Normalization (
nn.LayerNorm
), an alternative for batch normalization often used in NLP tasks. #4922 - Add Local Response Normalization (
nn.LocalResponseNorm
). #4922 -
MaxPool3d
now supports double backwards. MaxPool3d and MaxUnpool3d now use indices consistent with the rest of the pooling layers. #5328 - All loss functions now support a reduce argument to return a batch of losses. #264
- Add util to clip gradient value in torch.nn.utils.clip_grad and add param to He initialization scheme in
torch.nn.init
. #6173 - Renamed
torch.nn.init.*
methods to have an underscore in the end, as they operate in-place, and deprecate the old versions 6093 - Added support for returning dictionaries in
DataParallel
#6113 - Added support for N-D tensors in
torch.nn.Bilinear
#5764 - Add
Embedding.from_pretrained
factory. This allows to initialize an Embedding layer with an existing tensor, bypassing the initial random initialization of its weights. - You can now slice
nn.Sequential
,nn.ModuleList
, andnn.ParameterList
#4491 - Registered
nn.Module
integer parameters and buffers are now immune tomodule.float()
,module.double()
module.half()
calls. #3820
torch.distributions
torch.distributions
has expanded to include 24 basic probability distributions: Bernoulli
, Beta
, Binomial
, Categorical
, Cauchy
, Chi2
, Dirichlet
, Exponential
, FisherSnedecor
, Gamma
, Geometric
, Gumbel
, Laplace
, LogNormal
, Multinomial
, MultivariateNormal
, Normal
, OneHotCategorical
, Pareto
, Poisson
, RelaxedBernoulli
, RelaxedOneHotCategorical
, StudentT
, and Uniform
.
The Distribution
interface has expanded to include many methods including .cdf()
, .icdf()
, .mean()
, .variance()
, .entropy()
, and .perplexity()
. Distributions now split tensor dimensions into sample_shape
+batch_shape
+event_shape
. Most continuous distributions now also implement a differentiable .rsample()
method to compute pathwise derivatives aka the reparameterization trick (check .has_rsample
for availability):
>>> loc = torch.tensor(0., requires_grad=True)
>>> scale = torch.tensor(1., requires_grad=True)
>>> samples = Normal(loc, scale).rsample(sample_shape=(1000,))
>>> loss = (samples - 0.5).pow(4).mean() # average over 1000 monte carlo samples
>>> grad(loss, [loc, scale])
(tensor(-7.5092), tensor(15.2704))
Most discrete distributions implement an .enumerate_support()
method to make it easy to sum over all possible sample values (check .has_enumerate_support
for availability).
kl_divergence
is defined for many pairs of distributions, e.g.
>>> x = torch.tensor(1.0, requires_grad=True)
>>> kl = kl_divergence(Uniform(-x, x), Normal(0., 1.))
>>> grad(kl, [x])[0]
tensor(-0.6667)
Distribution Transforms
New distributions can be created by combining TransformedDistribution
with any number of Transform
objects from the torch.distributions.transforms
library, including: ExpTransform
, PowerTransform
, SigmoidTransform
, AbsTransform
, AffineTransform
, SoftmaxTransform
, StickBreakingTransform
, LowerCholeskyTransform
, and their inverses via the .inv
property.
Distribution Constraints
Distributions provide metadata about the constraints of their .support
and about their arguments (.arg_constraints
). These Constraint
objects are registered with transforms using transform_to()
and biject_to()
. Together constraints and transforms make it easy to specify new distributions in a generic way
>>> scale = torch.tensor(1., requires_grad=True)
>>> p = Normal(0., scale)
>>> assert p.arg_constraints['scale'] == constraints.positive
>>> prior = TransformedDistribution(Normal(0., 1.),
... transform_to(constraints.positive))
Constraints in the torch.distributions.constraints
library include: boolean
, greater_than(lower_bound)
, integer_interval(lower_bound, upper_bound)
, interval(lower_bound, upper_bound)
, lower_cholesky
, lower_triangular
, nonnegative_integer
, positive
, positive_definite
, positive_integer
, real
, real_vector
, simplex
, and unit_interval
.
Distributed
Helper utility for launching Distributed Training jobs
We have added an utility function to help launch jobs on a distributed setup.
In order to launch a script that leverages DistributedDataParallel
on either single-node multiple-nodes, we can make use of torch.distributed launch as follows
python -m torch.distributed.launch my_script.py --arg1 --arg2 --arg3
The script simplifies day to day usability of the distributed
package.
You can read about it's usage here: http://pytorch.org/docs/stable/distributed.html#launch-utility
A new distributed backend based on NCCL 2.0
PyTorch now has a new distributed backend, which leverages NCCL 2.0 for maximum speed. It also provides new APIs for collective operations on multiple GPUs. You can enable the new backend via
torch.distributed.init_process_group("nccl")
Other distributed improvements
- Coalesce many small broadcasts to improve performance #4978
- Add mixed-precision support for distributed training #4891
- Release NCCL distributed backend. Previously it was marked as
experimental
. #4921 - Enable Infiniband support for Gloo data channel with automatic IB device detection #4795
C++ extensions
Previously, the official way of writing extensions using C or CUDA for custom modules was through the cffi extension. The drawback of this method was that it required a separate step for compiling the CUDA kernels, which could be a bit messy.
PyTorch now provides a better system for writing your own C++ / CUDA extensions. Example implementations using this new extension support can be found in the pytorch/cpp_extensions repo.
We provide two compilation modes:
- ahead of time compilation: you write a
setup.py
script using the newCppExtension
orCUDAExtension
, which is an extension ofsetuptools.Extension
module; - just-in-time compilation: you pass the list of C++ / CUDA files that you want to compile to
torch.utils.cpp_extension.load
, and it will compile on the fly and cache the libraries for you. Here is an example illustrating how easy it is to implement an extension:
In C++
// my_implementation.cpp
#include <torch/torch.h>
#include <unordered_set>
// can use templates as well. But let's keep it
// simple
using scalar_t = float;
at::Tensor unique_float(at::Tensor input_) {
// only works for floats
AT_ASSERT(input_.type().scalarType() == at::ScalarType::Float, "input must be a float tensor");
// and CPU tensors
AT_ASSERT(!input_.type().is_cuda(), "input must be a CPU tensor");
// make the input contiguous, to simplify the implementation
at::Tensor input = input_.contiguous();
// get the pointer that holds the data
scalar_t* input_data = input.data<scalar_t>();
// let's use a function from the std library to implement
// the unique function
std::unordered_set<scalar_t> set(input_data, input_data + input.numel());
// create the output tensor, with size set.size()
at::Tensor output = input.type().tensor({static_cast<int64_t>(set.size())});
scalar_t* output_data = output.data<scalar_t>();
// copy the content of the set to the output tensor
std::copy(set.begin(), set.end(), output_data);
return output;
}
// this defines the functions exposed to Python
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("unique_float", &unique_float, "Unique for float tensors");
}
And then in Python
import torch
from torch.utils.cpp_extension import load as load_ext
# pass the source files, they will be compiled on the fly
# and will return a python module
_C = load_ext('my_unique_lib', sources=['my_implementation.cpp'])
# now can use the functions implemented in C++
unique = _C.unique_float
a = torch.tensor([1.0, 2.0, 1.0])
print(unique(a))
# tensor([ 2., 1.])
Windows support
PyTorch now officially supports Windows. We provide pre-compiled Conda binaries and pip wheels for Python 3.5 and 3.6.
PyTorch on Windows doesn't support distributed
training and might be a tad bit slower than Linux / OSX because Visual Studio supports an older version of OpenMP.
As always, you can use the commands at http://pytorch.org to install PyTorch on Windows We have an FAQ that answers most questions you might have around Windows here: http://pytorch.org/docs/stable/notes/windows.html
ONNX Improvements
New ONNX operators
- Support export
torch.max(input, dim)
andtorch.min(input, dim)
#6220 - Add symbolic for
ReLU
to support exporting to ONNX #5759 - Add
sum
,prod
,sqrt
and improvelog_softmax
#4579 - Add ONNX support for
InstanceNorm
#4626 - Add ONNX symbolic for
Elu
#3453 - Add ONNX symbolic for
UpsamplingNearest2d
#3450
Improvements
- Print source location when ONNX export fails for a node #5652
- Export onnx protobuf bindings to python #6651
- Support
output_padding
inConvTranspose
#4583
Better RNN support
PyTorch can now export a subset of RNNs to ONNX #4409
- Add Elman RNN export to ONNX #4613
- Support batch-first in ONNX export of padded sequences #5360
- Bidirectional Elman RNN export to ONNX #5120
- Handle sequence lengths correctly when exporting RNNs to ONNX #4695
- Support GRU export to ONNX #4390
Bugfixes
- Fix a bug in ONNX symbolic of 3d average pooling #6101
- Fix onnx export of replication/reflection pad #4263
Miscellaneous improvements
-
implement
__dir__
for Tensors, so that editors can automatically auto-complete and query for the possible fields in Tensors -
Add
numpy()
andfrom_numpy()
toHalfTensor
-
Enable
TensorDataset
to have any number of input tensors. -
Add
padding_value
totorch.nn.utils.rnn.pad_sequence
-
Add
total_length
option topack_padded_sequence
, which is useful when usingDataParallel
, as we can ensure that we have sequences of the same length. -
Improve numerical precision of
torch.arange
, making it consistent withnumpy.arange
-
torch.load()
andtorch.save()
support arbitrary file-like object -
torch.nn.functional.grid_sample
now supports 2D (spatial) and 3D (volumetric) inputs -
set python random seed in
DataLoader
workers, in order to improve experiment reproducibility -
Add
__delitem__
tonn.Sequential
. Now one can delete arbitrary elements of ann.Sequential
.
For example:
model = nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2))
del model[1] # deletes nn.ReLU
-
ReduceLROnPlateau
is now serializable #5300 -
Add option to flush denormal numbers on CPU. #5294
-
PyTorch now exposes the gradients of conv1d, conv2d and conv3d with respect to the input and the weights #5408
-
Add support for calling
pack_padded_sequence
with either list or with a Tensor #5133
-
Support negative indexing for
padding_idx
innn.Embedding
#4496 -
Implement backward pass for
pack_padded_sequence
#4512 -
Add
nn.utils.rnn.pad_sequence
andnn.utils.rnn.pack_sequence
to pad lists of variable length Tensors with0
and to pack a list of variable length Tensors. -
Add
torch.cuda.memory_cached
,torch.cuda.max_memory_cached
,torch.cuda.memory_allocated
, andtorch.cuda.max_memory_allocated
methods for checking CUDA memory usage #4511 -
Allow viewing on noncontiguous tensors if the new view size is compatible with the tensor's original size and stride. #4062
-
NLLLoss
andCrossEntropyLoss
now support more than 2 dimensions. #4654 -
Add an option to not show
model_zoo
download progress bar #4135 -
You can now assign modules to indices of
nn.Sequential
. #4931 -
You can create tensors with a numpy
np.longlong
array #4367 -
Change the autograd execution order to use good heuristics. This greatly improves memory usage for large models. #4746
-
Add AMSgrad mode to
Adam
andSparseAdam
optmizers. #4034 -
Better
torch.autograd.profiler
support for CUDA profiling using thecudaEvent
API. #3734 -
torch.set_num_threads
also sets the respective MKL option so you won't need to use an environment variable to control it. #4949
Performance improvements
- Speed up CPU
nn.EmbeddingBag
, making training overall 30% faster #5433 - Move
nn.MarginRankingLoss
,nn.CosineEmbeddingLoss
,nn.HingeEmbeddingLoss
, andnn.TripletMarginLoss
from Python to our ATen backend, resulting in some cases up to a 3x performance gains. #5346, #5646, #5080, #5680 - Implement
pin_memory()
as a NativeFunction #4094 - Save
self.numel()
for backward computation instead ofself
to save memory #5747 - Rearrange dimensions for pointwise operations for up to 10x better performance in one case. #4174
- Vectorize
normal_
for a 5-6x speed up in a small case #4312 - Allowing usage of GPU Direct within PyTorch for the Broadcast operation #4183
- Speed-up
nn.Linear
for the 3D input case #5279 - Speed up
Conv3D
on the CPU by parallelizingvol2col
andcol2vol
#4824 - Add AVX2 implementation for sigmoid function, showing around 10x speedup #5010
- Use fast integer division algorithm to avoid division ops inside kernels. #5054
- Improve occupancy for CUDA random number generation #5710
- Add optimization to norm for common norms #5722
- Add a fast fused GLU backward #5782
- Optimize unique sorting by using
std::vector+sort
instead ofstd::set
, giving up to 5x speedup. #5913 - Speed up sum over a dimension #6026
- Enable MKLDNN convolution forward and backward. #6062
- Parallelize non-contiguous point-wise operations with OpenMP #2764
- Add cudnn Tensor Core ops to RNNs for Volta #3409
- Vectorize
exp
,log
,sin
,cos
#6078 - Reuse intermediate results over multiple backwards grad_inputs #3526
Distributed
- DistributedDataParallel: 10% of NCCL backend perf improvements with mixed-precision support #5064
- Slightly improve DistributedDataParallel (single-GPU binding) multi-process distributed training performance #4870
Bug fixes
torch operators
- Improve
torch.digamma
precision near poles #6517 - Fix incorrect behavior of
Tensor.random_
on negative inputs #6463 - Fix undefined behavior in backward pass for
tensor.permute(dims)
with negative dims #5945 - Fix integer overflow in
torch.remainder
operator (it would break with a divisor above2**48
) #5906 - Fix memory leak in
torch.bmm
#5744 - Make dimension checker of
scatter_add_
consistent withscatter_
's #5659 - Fix CPU
torch.multinomial
with noncontiguous probability tensor input (previously, it would overwrite input data)#5093 - Fix CUDA
torch.multinomial
using incorrect strides and being able to select zero-probability events. #5774, #5238 - Support empty index tensor for
index_select
#3429 - Support empty indices tensor in CUDA
Tensor.put_
#4486 - Improve stability of
torch.cat
with empty tensors #3602, #5971, #5819 - Fix
torch.fft
in the case where any of the input dimensions is not aligned #6118 - Improve the CUDA btrifact error message #5644
- Return zeros for eigenvector tensor when not requested in
torch.symeig
#3411 - Fix
torch.btrifact
on tensors. #4318 - Fix
torch.pstrf
on tensors. #4883 - Fix memory leak in
torch.median
6889 - Fix SVD backward on non-square matrices when
some=False
6870
core
- Detect re-initialization of
_C
shared library that would often result in segfaults on exit #6232 - Fix indexing with all zero ByteTensors #3926
- Only allow dense floating-point types as the default tensor type. #5674
- Initialize CUDA before setting CUDA tensor types as default to prevent crash #4788
- Fix a bug where
from_dlpack
fails if CUDA is not initialized. #4182 - Fix crash in creating a CUDA tensor with a numpy array #5850
- Fix broken sharing of empty tensor in multiprocessing on some OSes #6229
autograd
- Restore allow_unused functionality: throw error when differentiated input is unused or unreachable. #6553
- Fix
output_nr
not being incremented correctly. This caused crashes in the backward pass of operations that don'trequires_grad
on some inputs. #4812 - Fix nvprof parsing in the
torch.autograd.profiler
#5840
nn layers
- Support only specifying size in certain dimension for adaptive pooling #3127
- Fix reflection padding boundary checks to not cause invalid memory access #6438
- Improve error messages for
NLLLoss
. #5299, #6072 - Fix
kl_div
backward on CUDA. Previously it would not respectgradOutput
when computinggradInput
. #5814 - Fix incorrect
bias
size assert forLinear
#5992 - Fix incorrect
nn.functional.convNd
andnn.functional.conv_transposeNd
error message #5701 - Check that shape for input and target matches instead of number of elements for some loss functions #5085
- Fix
torch.diag
backward returning square grad with non-square input #4538 - Fix convolution type mismatch error message #5815
- Add
align_corners
option to linearly interpolating upsampling and make the default upsampling behavior more consistent with other frameworks #5927 - Prevent numerical issues with
poisson_nll_loss
when log_input=False #3336
CUDA
- Ensure convolution weights are contiguous to fix CUDA
ConvTranspose
double backward #4543 - Fix CUDA double backwards #4460
sparse
- Fix embedding with
sparse=True
#4686 - Fix sparse embedding backward when input contains only
padding_idx
#6211 - Handle copying empty sparse tensors to/from CPU, GPU. #5361
dataloader
- Add argument checks to the
torch.utils.data.Sampler
classes, fixing a bug whereDataLoader
tries to load the entire dataset on non-integerbatch_size
. #6249 - Set
dataloader.batch_size = None
when batch_sampler is given, fixing a bug whereDataLoader
would reportbatch_size
as1
. #6108 - Improve signal handling in
DataLoader
#4643 - Ignore
FileNotFoundError
when shutting down #5380 - Make preprocessing deterministic #4640
optim
- Cast tensors when loading optimizer state dicts to improve usability #3658
- List model parameters in deterministic order to improve stability of
load_state_dict()
#6031 - Add parameter range checks for all optimizers #6000
- Fix
AMSGrad
mode forSparseAdam
#4314