v1.9.0
版本发布时间: 2021-06-16 00:06:52
pytorch/pytorch最新发布版本:v2.4.1(2024-09-05 03:59:29)
PyTorch 1.9 Release Notes
- Highlights
- Backwards Incompatible Change
- Deprecations
- New Features
- Improvements
- Bug Fixes
- Performance
- Documentation
Highlights
We are excited to announce the release of PyTorch 1.9. The release is composed of more than 3,400 commits since 1.8, made by 398 contributors. Highlights include:
- Major improvements to support scientific computing, including torch.linalg, torch.special, and Complex Autograd
- Major improvements in on-device binary size with Mobile Interpreter
- Native support for elastic-fault tolerance training through the upstreaming of TorchElastic into PyTorch Core
- Major updates to the PyTorch RPC framework to support large scale distributed training with GPU support
- New APIs to optimize performance and packaging for model inference deployment
- Support for Distributed training, GPU utilization and SM efficiency in the PyTorch Profiler
We’d like to thank the community for their support and work on this latest release. We’d especially like to thank Quansight and Microsoft for their contributions.
You can find more details on all the highlighted features in the PyTorch 1.9 Release blogpost.
Backwards Incompatible changes
Python API
-
torch.divide
withrounding_mode='floor'
now returns infinity when a non-zero number is divided by zero (#56893). This fixes therounding_mode='floor'
behavior to return the same non-finite values as other rounding modes when there is a division by zero. Previously it would always result in a NaN value, but a non-zero number divided by zero should return +/- infinity in IEEE floating point arithmetic. Note this does not effecttorch.floor_divide
or the floor division operator, which currently userounding_mode='trunc'
(and are also deprecated for that reason).
1.8.1 | 1.9.0 |
---|---|
>>> a = torch.tensor([-1.0, 0.0, 1.0])
>>> b = torch.tensor([0.0])
>>> torch.divide(a, b, rounding_mode='floor')
tensor([nan, nan, nan])
|
>>> a = torch.tensor([-1.0, 0.0, 1.0])
>>> b = torch.tensor([0.0])
>>> torch.divide(a, b, rounding_mode='floor')
tensor([-inf, nan, inf])
|
-
Legacy tensor constructors and
Tensor.new
no longer support passing bothTensor
anddevice
as inputs (#58108). This fixes a bug in which 1-element integer tensors were misinterpreted as specifying tensor size, yielding an uninitialized tensor. As noted in the error message, use the new-styletorch.tensor(...)
ortorch.as_tensor(...)
to copy or alias an existing tensor. If you want to create an uninitialized tensor, usetorch.empty(...)
.
1.8.1 | 1.9.0 |
---|---|
>>> a = torch.tensor([1])
>>> torch.LongTensor(a, device='cpu') # uninitialized
tensor([7022349217739848992])
>>> a.new(a, device='cpu')
tensor([4294967295]) # uninitialized
|
>>> a = torch.tensor([1])
>>> torch.LongTensor(a, device='cpu')
RuntimeError: Legacy tensor constructor of the form torch.Tensor(tensor, device=device) is
not supported. Use torch.tensor(...) or torch.as_tensor(...) instead.
>>> a.new(a, device='cpu')
RuntimeError: Legacy tensor new of the form tensor.new(tensor, device=device) is not
supported. Use torch.as_tensor(...) instead.
|
-
torch.divide
withrounding_mode='true'
is replaced withrounding_mode=None
(#51988).torch.divide
's undocumentedrounding_mode='true'
option has been removed, and insteadrounding_mode=None
should be passed to indicate no rounding should take place. This is equivalent to omitting the argument entirely.
1.8.1 | 1.9.0 |
---|---|
>>> a, b = torch.full((2,), 4.2), torch.full((2,), 2)
>>> torch.divide(a, b, rounding_mode='true')
tensor([2.1000, 2.1000])
|
>>> a, b = torch.full((2,), 4.2), torch.full((2,), 2)
>>> torch.divide(a, b, rounding_mode=None) # equivalent to torch.divide(a, b, rounding_mode='true') from the prior release
tensor([2.1000, 2.1000])
|
-
import torch.tensor as tensor
is no longer supported (#53424). Instead, usefrom torch import tensor
1.8.1 | 1.9.0 |
---|---|
>>> import torch.tensor as tensor
>>> torch.tensor(1.)
tensor(1.)
|
>>> import torch.tensor as tensor
ModuleNotFoundError: No module named 'torch.tensor'
>>> from torch import tensor
>>> tensor(1.)
tensor(1.)
|
-
binary release:
numpy
is no longer a required dependency If you requirenumpy
(and don't already have it installed) you will need to install it separately.
Autograd
-
torch.autograd.gradcheck.get_numerical_jacobian
andtorch.autograd.gradcheck.get_analytical_jacobian
no longer support functions that return complex valued output as well as any other values ofgrad_out
not equal to 1 (#55692). This change is a part of a refactor ofgradcheck
’s internals. Note thatgradcheck
itself still supports functions with complex output. This new restriction only applies to calls to the two internal helper functions. As a workaround, you can wrap your functions to return either the real or imaginary component of its output before calling these functions. Additionally these internal helpers no longer accept any other value except 1 forgrad_out
for any input function. Note that these helper functions are also being deprecated in this release.
1.8.1:
get_numerical_jacobian(torch.complex, (a, b), grad_out=2.0)
1.9.0:
def wrapped(fn):
def wrapper(*input):
return torch.real(fn(*input))
return wrapper
get_numerical_jacobian(wrapped(torch.complex), (a, b), grad_out=1.0)
-
torch.autograd.gradcheck
now throwsGradcheckError
(#55656). This change is a part of a refactor ofgradcheck
’s internals. All errors that are able to be silenced byraise_exception=False
now raiseGradcheckError
(which inherits fromRuntimeError
). If you explicitly check that the type of the error isRuntimeError
you'll need to update your code to check forGradcheckError
instead. Otherwise if you use something likeexcept
orisinstance
, no changes are necessary.
1.8.1:
# An example of a situation that will now return GradcheckError instead of
# RuntimeError is when there is a jacobian mismatch, which can happen
# for example when you forget to specify float64 for your inputs.
try:
torch.autograd.gradcheck(torch.sin, (torch.ones(1, requires_grad=True),))
except RuntimeError as e:
assert type(e) is RuntimeError # explicitly check type -> NEEDS UPDATE
1.9.0:
try:
torch.autograd.gradcheck(torch.sin, (torch.ones(1, requires_grad=True),)
except RuntimeError as e:
# GradcheckError inherits from RuntimeError so you can still catch this
# with RuntimeError (No change necessary!)
# BUT, if you explicitly check type...
assert type(e) is torch.autograd.GradcheckError
- Finished deprecation cycle for in-place view error checks (#56093). In-place modification of views will now raise an error if that view was created by a custom function or a function that returns multiple views, or if the view was created in no-grad mode. Modifying in-place a view created in the situations above are error-prone and have been deprecated since v1.5.0. Doing these in-place modifications are now forbidden. For more information on how to work around this, see the related sections the release notes linked below:
torch.nn
-
Fixed regression for
nn.MultiheadAttention
to now apply bias flag to both in and out projection layers (#52537). In PyTorch 1.6, a regression was introduced that caused thebias
flag ofnn.MultiheadAttention
only to apply to the input projection layer. This caused the output projection layer to always include abias
parameter, even withbias=False
specified. The regression is now fixed in PyTorch 1.9, making thebias
flag correctly apply to both the input and output projection layers. This fix is BC-breaking for thebias=False
case as it will now result in nobias
parameter for the output projection layer.
v1.6 - v1.8.1: | pre 1.6 & 1.9.0 |
---|---|
>>> mha = torch.nn.MultiheadAttention(4, 2, bias=False)
>>> print(mha.out_proj.bias)
Parameter containing:
tensor([0., 0., 0., 0.], requires_grad=True)
|
>>> mha = torch.nn.MultiheadAttention(4, 2, bias=False)
>>> print(mha.out_proj.bias)
None
|
-
Updated
nn.Module
to fire full backward hooks even when no input requires grad (#56693). Prior to this release, full backward hooks were not fired when no input requires gradients. This has been changed so that full backward hooks will always fire during the backward pass, regardless of whether or not any input requires gradients. If you are using full backward hooks, be aware that they may fire more frequently than pre-1.9 due to this change.
1.8.1: | 1.9.0 |
---|---|
>>> m = torch.nn.Linear(2, 3)
>>> def hook(mod, grad_input, grad_output):
>>> print('hook called:', grad_input, grad_output)
>>> m.register_full_backward_hook(hook)
>>> input_no_grad = torch.rand(1, 2, requires_grad=False)
>>> m(input_no_grad).sum().backward()
>>> input_grad = torch.rand(1, 2, requires_grad=True)
>>> m(input_grad).sum().backward()
hook called: (tensor([[0.1478, 0.6517]]),) (tensor([[1., 1., 1.]]),)
|
>>> m = torch.nn.Linear(2, 3)
>>> def hook(mod, grad_input, grad_output):
>>> print('hook called:', grad_input, grad_output)
>>> m.register_full_backward_hook(hook)
>>> input_no_grad = torch.rand(1, 2, requires_grad=False)
>>> m(input_no_grad).sum().backward()
hook called: (None,) (tensor([[1., 1., 1.]]),)
>>> input_grad = torch.rand(1, 2, requires_grad=True)
>>> m(input_grad).sum().backward()
hook called: (tensor([[0.1478, 0.6517]]),) (tensor([[1., 1., 1.]]),)
|
Dataloader
-
Add Numpy seeding to worker of DataLoader (#56488).
DataLoader
withnum_workers > 0
will now set independent random seed for NumPy random functions on each worker by default. So, users now won’t be required to set random seed for NumPy usingworker_init_fn
to force NumPy random operations deterministic and independent acrossDataLoader
workers. This PR won’t affect users who have already set random seed for NumPy random functions usingworker_init_fn
.
# dataset returns numpy.random.randint(1, 10000)
ctx = mp.get_context('fork')
gen = torch.Generator().manual_seed(0)
dl = DataLoader(dataset, batch_size=2, num_workers=2, multiprocessing_context=ctx, generator=gen)
for epoch in range(2):
print("=" * 4, "Epoch", epoch, "=" * 4)
for batch in dl:
print(batch)
1.8.1: | 1.9.0 |
---|---|
# When using fork, each worker has same random seed for NumPy random functions at each epoch.
========== Epoch 0 ==========
tensor([[ 0, 340],
[ 1, 7512]])
tensor([[ 2, 340],
[ 3, 7512]])
========== Epoch 1 ==========
tensor([[ 0, 340],
[ 1, 7512]])
tensor([[ 2, 340],
[ 3, 7512]])
|
# Random seeds for NumPy are different across `DataLoader` workers in each epoch.
========== Epoch 0 ==========
tensor([[ 0, 8715],
[ 1, 5555]])
tensor([[ 2, 6379],
[ 3, 1432]])
========== Epoch 1 ==========
tensor([[ 0, 1374],
[ 1, 996]])
tensor([[ 2, 143],
[ 3, 3507]])
|
- Added static type checking enforce for DataPipe (#54020).
A new attribute named type
has been introduced for IterableDataset
using the typing annotation at each class declaration. By adding this attribute, we are able to extend IterableDataset
to have type inference and lazy initialization to incorporate the new DataLoader architecture. But, several BC-breaking restrictions are introduced due to this feature.
1.8.1:
# Users can use string to bypass the invalid type annotation without any error.
# And, incorrect type annotations attached to `__iter__` function are ignored.
1.9.0:
# The following scenario will now raise different Exceptions
# 1) The type annotation is required to be valid now. Previous workaround
# like using string to represent the invalid type annotation is not supported now.
# Raises Exception from the evaluation `eval("invalid_type", globals, locals)`
class DS(IterableDataset["invalid_type"]):
...
# Raises TypeError if the return type of __iter__ is not an Iterator
class DS(IterableDataset[str]):
def __iter__(self) -> str:
...
# Raise TypeError if the return type of __iter__ is of the form Iterator[X],
# but the argument type X is not a subtype of the IterableDataset.type attribute.
class DS(IterableDataset[str]):
def __iter__(self) -> Iterator[int]:
...
# IterableDatset now has a metaclass, which will conflict with
# existing user-defined metaclasses on IterableDatasets
class DS(IterableDataset[str], metaclass=MyMeta):
...
Meta API
-
Given Tensor a non-trivial (for now) metaclass _TensorMeta (#56147).
Tensor now has a non-trivial metaclass. This shouldn't be user observable, as Tensor already inherits from a C defined class (and is thus incompatible with other typical metaclasses), but there may be unanticipated interactions with other language features in Python. This PR changes the metaclass of torch.tensor. I.e.
type(type(torch.tensor([1])))
now prints<class 'torch._C._TensorMeta'>
(used to be<class 'type'>
)
C++ API
-
Changed in-place resize functions to return const Tensor& (#55351).
The C++ signature for
resize_
,resize_as_
,resize_as_sparse_
,sparse_resize_
, andsparse_resize_and_clear_
has changed to return aconst Tensor&
instead of aTensor&
. This may break users’ TORCH_LIBRARY operators that called these functions but returned a non-constTensor&
. Ideally, users can change their operators to also consume and returnconst Tensor&
, but simply casting the result of the changed function withconst_cast<Tensor&>
is also an option.
1.8.1:
const at::Tensor a = at::randn({2, 2});
const at::Tensor b = at::ones({1, 4}, at::kInt);
at::Tensor& out = at::resize_as_(a, b); # success
1.9.0:
const at::Tensor b = at::ones({1, 4}, at::kInt);
at::Tensor& out = at::resize_as_(a, b);
# error: binding value of type 'const at::Tensor' to reference to type 'at::Tensor' drops 'const' qualifier
const at::Tensor& out = at::resize_as_(a, b); # Success
-
Some ATen Reduction Ops as well as
kron_out
now throw an error when an undefined tensor is passed as input forout
argument (#53218, #53640).- C++ API for the reductions ops like
sum_out
,nansum_out
,prod_out
,std_var_out
have been changed to require users allocating result Tensor before calling these ops. The C++ APIallocate_reduction_result
has changed toresize_reduction_result
to disallow allocating result Tensor in these reduction ops. - The following code can be compiled, but will raise a
c10::Error
when executed. This code compiled and executed successfully in the prior release.
- C++ API for the reductions ops like
at::Tensor out; # Undefined Tensor
const at::Tensor a = at::randn({2, 2});
at::IntArrayRef dim = {1};
at::sum_out(out, a, dim);
# c10::Error: Expected a Tensor of type Variable but found an undefined Tensor for argument #4 'out'
-
The C++ API utility functions
expand_inplace
andexpand_outplace
now returnc10::MaybeOwned<Tensor>
instead ofstd::tuple<Tensor>
(#55065, #55245). The rationale for this change is to avoid unnecessary Tensor creation, thus improving performance. Functions in ExpandUtils returnc10::MaybeOwned<Tensor>
because expansion may not actually be needed, in which case we can improve efficiency by returningc10::MaybeOwned<Tensor>::borrowed(to_expand)
. However, this means that you need to be careful: the returnedc10::MaybeOwned<Tensor>
must not outlive the originalTensor
object thatto_expand
referred to! The deleted rvalue reference overloads of these functions help with this by preventing trivial use of a temporary resulting from a function call, but it is still possible to make a mistake.
TorchScript
-
Added recursive scripting for class type module attributes (#55124).
- This change is BC-breaking because it will result in class type module attributes being scripted when a module instance is scripted. In previous versions, such attributes were ignored unless their class type was also marked with
@torch.jit.script
. This new feature attempts to script the type, and falls back to the old behaviour of marking the class type attribute as "failed" if scripting fails. However, if the class definition does not have type annotations, the definition of the scripted class can different from users might expect (see code sample). If needed, users can explicitly disable the scripting of a class type attribute by adding its name to the__jit_ignored_attributes__
class attribute of the module being scripted.
- This change is BC-breaking because it will result in class type module attributes being scripted when a module instance is scripted. In previous versions, such attributes were ignored unless their class type was also marked with
1.8.1:
class MyClass:
def __init__(self, a):
self.attr = a
class MyModule(torch.nn.Module):
def __init__(self):
self.attr = MyClass(4)
sm = torch.jit.script(MyModule())
1.9.0:
class MyClass:
def __init__(self, a):
self.attr = a
class MyModule(torch.nn.Module):
def __init__(self):
self.attr = MyClass(4)
# RuntimeError: Could not cast attribute 'attr' to type Tensor: Unable to cast Python instance of type <class 'int'> to C++ type 'at::Tensor'
sm = torch.jit.script(MyModule())
This error occurs because MyClass
is automatically scripted, but self.attr
is inferred to be a Tensor
instead of an int
because a
is not annotated. To fix this, annotate a
with the right type int
, or mark attr
as an attribute that should be ignored by the scripting process and not recursively processed:
class MyModule(torch.nn.Module):
__jit_ignored_attributes__ = ["attr"]
def __init__(self):
self.attr = MyClass(4)
Quantization
-
torch.quantization.quantize_fx.convert_fx
’sdebug
argument has been changed tois_reference
(#52179).
1.8.1: | 1.9.0 |
---|---|
import torch.quantization.quantize_fx as quantize_fx
>>> m = quantize_fx.convert_fx(m, debug=True)
(Runs successfully)
|
>>> m = quantize_fx.convert_fx(m, is_reference=True) # Runs successfully
>>> m = quantize_fx.convert_fx(m, debug=True)
Traceback (most recent call last):
File " |
-
torch.cat
is now quantized totorch.cat
instead oftorch.ops.quantized.cat
(#54924). Previously, we produced torch.ops.quantize.cat which took inputs, dequantized them and requantized them with new qparams. This behavior has been changed to producetorch.cat
directly. torch.cat uses the same observer/fake_quant instance for all inputs and output, assumes all inputs are sharing the same qparam, and produces a quantized Tensor with the same qparam as all inputs. Using torch.cat is expected to be more efficient since it does not introduce extra quant/dequant.- Version 1.8.1:
torch.cat
was quantized totorch.ops.quantized.cat.
- Version 1.9:
torch.cat
is quantized totorch.cat
(torch.cat
works on both floating point and quantized Tensor).
- Version 1.8.1:
Distributed
-
DistributedDataParallel
: Removed support for inter-process device replication in DDP (#54454, #54825, #54826, #55212, #55253,
#55353
).DistributedDataParallel
now errors out when users attempt to use it in single-process multi-device mode, where a module is replicated across more than one device in a single process. This mode had been previously deprecated and is now removed. Use cases should switch to spawning a single process for each device that is used in replication, which is the performant way to useDistributedDataParallel
and supports a variety of newly developed features.
1.8.1:
>>> # Assume the below is ran on 2 ranks in a distributed setting.
>>> rank_to_devices = { 0: [0, 1], 1: [2, 3] }
>>> # Each rank replicates model across 2 GPUs.
>>> model_ddp = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=rank_to_devices[rank]
)
>>> # No error is raised, but below warning is produced.
>>> UserWarning: Single-Process Multi-GPU is not the recommended mode for DDP. In this mode, each DDP instance operates on multiple devices and creates multiple module replicas within one process. The overhead of scatter/gather and GIL contention in every forward pass can slow down training. Please consider using one DDP instance per device or per module replica by explicitly setting device_ids or CUDA_VISIBLE_DEVICES.
1.9.0:
>>> # Assume the below is ran on 2 ranks in a distributed setting.
>>> rank_to_devices = { 0: [0, 1], 1: [2, 3] }
>>> # Each rank replicates model across 2 GPUs.
>>> model_ddp = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=rank_to_devices[rank]
)
>>> # Single process multi-GPU mode now produces an error on initialization.
>>> ValueError: device_ids can only be None or contain a single element.
-
torch.distributed.elastic
: Replacedtorch.distributed.launch
withtorch.distributed.elastic_launch
(#56037,
#56214
). * --logdir → —log_dir. The stdout and stderr log dir arg name and destination changed. The file destination changed from$logdir/node_{}_local_rank_{}_stdout
to$log_dir/$rank/stdout.log
. If users used the—logdir
introduced in 1.8 pytorch version, they need to use—log_dir
parameter now.
1.8.1:
#!/bin/bash
# Assumes training script train.py exists.
python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port="29500" --logdir test_logdir train.py
# Logs are written to $logdir/node_{}_local_rank_{}_stdout
1.9.0:
#!/bin/bash
# Assumes training script train.py exists.
python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port="29500" --log_dir test_logdir train.py
# Logs are written to $log_dir/$rank/stdout.log
Deprecations
Python API
-
torch.floor_divide
has been deprecated in favor oftorch.div(..., rounding_mode=‘floor’)
(#50281).-
torch.floor_divide
incorrectly divides then truncates (rounds towards zero) instead of dividing then flooring (rounds “down”). Use therounding_mode
argument oftorch.div
to indicate if you’d like to continue performing truncation division or floor division, instead, sincetorch.floor_divide
will be removed in a future PyTorch release.
-
-
Older linear algebra operations have been deprecated in favor of their new linalg module counterparts. Namely:
-
torch.{cholesky, qr, symeig, chain_matmul, solve, eig, matrix_rank, lstsq}
have been deprecated in favor oftorch.linalg.{cholesky, qr, symeig, chain_matmul, solve, eig, matrix_rank, lstsq}
(#57725,#57745, #57732,#53453, #57741, #57727, #57734, #57743). -
torch.norm
has been deprecated in favor of the new linalg module norm functions:torch.linalg.vector_norm
,torch.linalg.matrix_norm
, andtorch.linalg.norm
(#57986). - Aliased
torch.det
,torch.slogdet
,torch.matrix_power
,torch.inverse
, andtorch.pinverse
to their linalg module counterparts (#57821).
-
Autograd
-
[cpp] Renamed
AutoNonVariableTypeMode
toAutoDispatchBelowAutograd
and added a warning. (#56422)AutoNonVariableTypeMode
is deprecated and will be removed in 1.10 release. For kernel implementations, please useAutoDispatchBelowAutograd
instead. Check out more details on how to migrate your kernel here. If you are looking for a user-facing API to enable running your inference-only workload, please usec10::InferenceMode
. UsingAutoDispatchBelowAutogradMode
in user code is under risk of producing silently wrong result for some edge cases.
1.8.1:
{
at::AutoNonVariableTypeMode guard(true);
}
1.9.0:
{
c10::AutoDispatchBelowAutograd guard(true); // for kernel implementations
// c10::InferenceMode guard(true); --> consider inference mode if you are looking for a user-facing API
}
-
Removed logic for old style custom autograd
Function
(#57357). Instantiating a custom autograd function is now deprecated and will raise a warning. Users should call.apply()
on the class itself because it is a static method.
1.8.1:
# Instantiating custom function will raise a warning in 1.9
Func().apply
1.9.0:
# You should directly call the `apply` (classmethod) on the class
Func.apply
-
Deprecated
get_analytical_jacobian
andget_numerical_jacobian
(#54378, #54049).torch.autograd.gradcheck.get_analytical_jacobian
andtorch.autograd.gradcheck.get_numerical_jacobian
are internal-facing functions that are not a part of our public API. We’ve refactored some PyTorch internals to work without it and will remove it in a future release. For gradient checking purposes, please usetorch.autograd.gradcheck
.
C++ API
-
Removed the redundant
linalg_
prefix fromtorch::linalg::linalg_det
andtorch::linalg::linalg_norm
C++ API (#57464). C++ code that used to calltorch::linalg::{linalg_det, linalg_norm}
should be updated to calltorch::linalg::{det, norm}
Distributed
-
torch.distributed.rpc
: Added a warning message to retire ProcessGroup RPC backend (#55616)- ProcessGroup RPC backend is being deprecated and 1.9 is the last release which will carry it. The default RPC backend is TensorPipe which is the recommended backend to use over ProcessGroup.
New features
Python API
- Added BFloat16 support for
torch.{ceil, floor, frac, round, trunc, lerp, roll, diag, logaddexp, logaddexp2, nan_to_num, exp2, expm1, rsqrt, erfc, atan2, hypot}
on CUDA (#57910, #57907, #57916, #57908, #58063, #57913, #57905). - Added
torch.pow()
fortorch.{float16, BFloat16}
on CPU (#55280). - Added
torch.{index_select, argmax, argmin, min, max, amin, amax}
fortorch.{float16, BFloat16}
(#53898, #52582, #51244, #52579). - Added
torch.dot
forBFloat16
on CUDA (#57903). - Added support for tensor inputs for
min
andmax
arguments intorch.clamp
(#52695, #56367). - Added a new
torch.special
namespace similar toscipy.special
(#52296). - Added the following new operators in PyTorch similar to those in NumPy:
- Added a new keyword argument
alpha
totorch.index_add
(#54176). - Added
torch.assert_async
(#53086) - Added a new keyword argument
interpolation
totorch.quantile
(#49267). - Add correction parameter to std/var (#50903)
- Added overloads for
torch.{std, var, std_mean, var_mean}
with a correction argument specifying the difference between the sample size and number of degrees of freedom. - Add support for integer type for
torch.
{logit, rad2deg, deg2rad, polygamma}
(#52028, #51853,#57462) - Added support for stable sort algorithm on CPU by a new kwarg
stable
(#51790). - The
torch.linalg
module, analogous to NumPy’s linalg module but with several additional functions, is stable! Addedtorch.linalg.{multi_dot, lstsq, vector_norm, matrix_norm, matrix_power, det, eig, eigvals, svdvals, cholesky_ex, inv_ex}
(#51807, #49093, #51099, #57127, #52608, #53119, #52491, #56684, #56724, #58039). - Added a new
device=meta
API (#53143)- “meta” is a new device, like CPU/CUDA, that doesn’t allocate any memory for data. Operators that are passed meta tensor inputs will perform shape inference, without running the actually kernel computation. For example,
torch.ones(2, device='meta') + torch.ones(1, 2, device='meta')
will return a new meta tensor of size[1, 2]
(performing broadcasting), without allocating memory or running an actual kernel. -
device=meta
API is implemented forupsample_linear1d
(#51917),upsample_bilinear2d
andupsample_bicubic2d
(#52012),upsample_nearest3d
(#52065),sin
(#52277),mul
(#52692),pow
(#53669),sub
(#53679),div
(#53680),copysign
(#55040),atan2
(#55130),sinh
(#55538),acosh
(#55540),cosh
(#55563),cos
(#55564),replication_padding1d
(#55481),replication_padding3d
(#55499),replication_pad1d_backward
(#55537),fractional_max_pool2d
(#55581),reflection_pad1d
(#55531),replication_pad2d
(#55511),addmv
(#55746), all unary float functions (#56082),adaptive_max_pool2d
(#56317),adaptive_max_pool3d
(#56320), all non-float unary operators (andrsqrt
) (#56151),adaptive_max_pool2d_backward
(#56799),adaptive_max_pool3d_backward
(#56800),neg
(#57212),max_pool2d_with_indices
(#56459),trunc
(#57350),floor
(#57587),sign
(#57588),ceil
(#57589),gcd
(#57624),nextafter
(#57625),igamma
andigammac
(#57626),hypot
(#57627),lcm
(#57628),logaddexp
andlogaddexp2
(#57629),maximum
andminimum
(#57630),topk
(#57790),max_pool2d_with_indices_backward
(#57797),threshold
(#57810),addmm
(#57417),heaviside
(#57933),elu
(#57619),softplus
(#57620),leaky_relu
(#57621),hardsigmoid
(#57622),softshrink
(#57623),silu
(#58050),empty_strided
(#53397), non-composite in-place operators (#54901)
- “meta” is a new device, like CPU/CUDA, that doesn’t allocate any memory for data. Operators that are passed meta tensor inputs will perform shape inference, without running the actually kernel computation. For example,
Complex Numbers
- Added complex autograd support for
torch.{masked_fill, polar, cumsum, lerp, prod, rsub, unfold, symeig, index_copy}
(#52483, #52488, #53240, #53689, #48125, #53702, #52999, #55085, #52203). - Added complex support for torch.lerp (#54129) and torch.sigmoid (#55975) on CUDA.
- Added complex support for
torch.index_copy
andtorch.{take}
andtorch.Tensor.put_
on both CPU and CUDA (#52203, #53356). - Added complex support to TorchScript.
- Added logic to teach TorchScript frontend to parse complex literals, and complex lists. (#52881).
- Added TorchScript support for:
- complex constructor and
torch.{add, mul, sub, as_tensor}
(#52881). -
cmath
unary ops:cmath.{phase, log, log10, sqrt, exp, sin, cos, tan, asin, acos, atan, sinh, cosh, tanh, asinh, acosh, atanh}
(#54089). -
cmath.
{infj, nanj}
(#54328). -
cmath.{isinf, isnan, isfinite, rect}
(#54541). - real and imag tensor attributes (
tensor.real/imag
) (#54692).
- complex constructor and
- Fixed
test_variant_consistency_jit_addmm
for complex types (#54917, #57129).
- Added initial operator support for sparse complex tensors (#57125).
- Added complex support for
torch.{sparse_coo_tensor, coalesce, to_dense, to_sparse, sparse_add, sspaddmm, saddmm}
.
- Added complex support for
- Added
torch.Tensor.{cfloat, cdouble}
functions (#58137). - Added complex support for all reductions for
torch.{std, var}
to return a real valued output tensor for complex inputs (#58066) . - Updated autograd formulas for many linear algebra operations support complex tensors:
torch.nn
- New
torch.nn
modules:nn.LazyBatchNorm*d
(#51862),nn.HuberLoss
(#50553),nn.Mish
(#58375). - New parametrization functionality (#33344, #58142, #55456, #57784).
-
nn.Conv*d
: Addedpadding='same'
mode for non-strided convolutions (#45667). -
nn.EmbeddingBag
: Addedpadding_idx
support (#49237, #56065, #56618). - Added mish activation function (#58648).
- [memory format] Added channels last support for
MaxPool2d
(#56361). - Added the option to build PyTorch with DNNL + AMD BLIS path (#54953).
Profiler
- Added
skip_first
parameter to the default schedule (#58025). - Added support for trace metadata (#56575).
- Added
gzip
format support for chrome tracing (#56554). - Added
sequenceNr
andfwdThreadId
to the trace (#57182). - Enabled Kineto in CPU builds (#53174).
Autograd
- Added new inference mode both in C++ (#54403, #53343) and python (#58045, #57480).
- Added
fast_mode
argument toautograd.gradcheck
(#54480). - Added support for non-Tensor inputs and outputs to
torch.utils.checkpoint
functions (#52422).
Dataloader
- Implemented
FilterIterDataPipe
(#51783). - Added context manager for runtime type validation (#55936).
- Added typing enforcement for
DataPipe
at construct-time (#54066). - Added typing Enforcement for
DataPipe
at runtime (#54544). - Implemented
issubtype
forDataLoader
type hints (#54299). - Added type hint for SequentialSampler (#56374).
- Added
ConcatDataPipe
(#53301). - Introduced deterministic context to
DataLoader
(#53271). - Added
ZipIterDataPipe
(#53554). - Added switch to guaranteed determinism & add option to non_deterministic (#53532).
- Added
TransformsIterDataPipe
(#52604). - Renamed Callable to
MapIterDataPipe
(#51879).
CUDA
- Added the following new features to CUDA Graphs:
- Added support for
'max'
reduction fortorch.segment_reduce
(#56704). - Added support for CUDA allocator to handle multiple streams seamlessly (#55860).
C++ API
- Added
torch::nn::functional::huber_loss
(#50553). - Added learning rate schedulers to C++ API (#52268).
- Added
padding='same'
mode totorch::conv{1,2,3}d
(#45667). - Added
padding_idx
argument toEmbeddingBag
(#49237). - Added mish activation function (#58648) (#58940).
TorchScript
- Added reductions to NNC python bindings (#52492).
- Added Python bindings for ExternalCalls. (#52905).
- Added an API to reorder multiple loops (#55568).
- Added NNC support for
pow
on CPU (#56308). - Enabled horizontal fusion of all loops (#56324).
- Added an API for Buffer Compression (#55853).
- Added API to distribute loops (#53865).
- Added
matmul
for NNC lowering/unified dtypes (#56456). - Added a method to compute
conv
without bias (#57512). - Added support for computing
conv
with dynamic shapes (#57514). - Added NNC lowerings for
t
/transpose
/permute
/expand
(#57426). - Updated external functions for mobile build (#56850).
- Added
GELU
To NNC (#57753). - Implemented
GELU
Backward (#58249). - Added a mobile NNC backend skeleton (#56852).
- Added support for
torch.type
(#51904) - Added
dict()
constructor (#51934). - Added a new
torch::deploy
to manage multiple python interpreters in a single process to deploy PyTorch models packaged with torch.package (#51754). - Reintroduced static dispatch (#51957).
- Added TS support for
torch.any
(#52360). - Added a demo backend with compiler (#52603).
- Added MKLDNN fuser (#51600).
- Added a context manager for hiding source ranges (#53188).
- Implemented
embedding_bag
for SR (#52429). - Allowed the use of
AliasDb
in Python (#51336). - Added support for
DictConstruct
(#54438) - Added
sliceHead
/sliceTail
APIs with short parameter list (#55115). - Added logic to infer argument types in TorchScript (#56832).
- Added support for custom Python classes in
CUDAFuture
(#56516). - Added a concat optimization pass (#55474).
- Added initial support for PEP-585 types (#57363).
- Added logic to infer types for arguments of methods not invoked directly by
MonkeyType
(#57202). - Added support for
torch.jit.ignore
as a context manager (#55172). - Implemented
hardswish
/hardsigmoid
on MKLDNN tensors (#55218). - Added
model_dump
tool for model inspection (#56868) - Added static method support for TorchBind (#51177)
- Added TS support for
pow
(#52374) - Added support for default argument values to
TorchBind
(#51253). - Added support for AST rewriting for submodules (#52297).
- Added
optimize_for_inference
API (#58193). - Registered
aten::index_out
(#51742). - Added
PYTORCH_TENSOREXPR_DONT_FUSE
env variable to disable fusion on specified operators (#55650).
torch.package
- Allow TorchScript models to be contained in the package format (#54891,#56299,#54893, #57573, #54894, #57678).
Mobile
- Added 8x1 block sparse kernels for ARM and AArch64 (#51118, #51119, #51120).
- Made NNAPI converter handle binary ops combining NHWC+NCHW in some cases (#48812).
- Improved support for multiple inputs and outputs in NNAPI (#54697).
- Added flexible size support for NNAPI (#54701).
- Added new ops for Metal (concat, mul/sub/div, transpose, view, reshape, mean, chunk, reflection_pad2d) ( #53950, #54107, #54522, #56073, #56074, #58263).
- Added python binding to use mobile cpu allocator (#52323).
- Added lightweight RandomSampler for mobile (#58201).
- Added support for:
- new ops to NNAPI converter (size, unsqueeze, cat, mean) (#52026, #48811).
- multi-dimension tensors in Metal via MPSImage (#54106).
- multiple output tensors in Metal (#56072).
- methods other than forward in optimize_for_mobile (#53314).
- ChannelsLast in TensorImageUtils on Android (#48990).
- loading “extra files” in Java/Android (#55644).
- loading “extra files” in Lite interpreter (#52635).
- querying bytecode version in Lite interpreter and bytecode models (#56948, #56948).
- exporting some older bytecode versions for Lite interpreter (#56802).
- querying available ops (#57570).
- Added SqueezeNet to PyTorch Playground (71d0b5632b).
- Added libtorch lite build (#51419).
Distributed
-
torch.distributed.Store
-
torch.distributed.rpc
-
DistributedDataParallel
- Adds a flag to ddp
join
context manager that enables throwing an error across all ranks when this flag is specified (#56755) - Enable static graph training in DDP (#55248, #54995)
- Log unused parameter names in DDP when crashing due to unused parameters (#55075)
- Introduce
torch.distributed.algorithms.default_hooks.fp16_compress_wrapper
wrapper that can be combined with other communication hooks (#53808) - Support loading a non-DP/DDP model from a DP/DDP state_dict (#53224)
- Enhanced logging in DDP for performance metrics (#52957, #53145, #54647)
- Adds a flag to ddp
-
torch.distributed
- Support
work.result
API for MPI backend (#57168) - Support
work.result
forProcessGroupGloo::AsyncWork
objects (#57565) - Support
work.get_future()
API for ProcessGroupMPI and ProcessGroupGloo (#57818,#57214) - New
torch.distributed.monitored_barrier
API (Gloo-only) (#53773, #53787, #55009, #55010, #55197, #55265, #55989, #55990) - Allow passing
options
field to process group initialization APIs (#53662, #54090, #53663) - Enable profiling for distributed collectives (#51822, , #52004, #52031, #52949, #55204, #56412, #56216, #56427)
- Allow user to specify
TORCH_DISTRIBUTED_DEBUG
environment variable (#52481) - Added
compareSet
method fortorch.distributed.{HashStore, FileStore}
(#53803).
- Support
- Added new
torch.distributed.elastic
module that upstreamspytorch/elastic
- Introduce RendezvousSettings (#56537)
- Introduce a new from_backend static constructor for DynamicRendezvousHandler (#57150)
- Introduce the implementation of DynamicRendezvousHandler (#57151)
- add support for the new error file format (#57084)
- Introduce the delay utility function (#56533)
- Make torchelastic launcher compatible with the caffe2.distributed.launch (#55687)
- Introduce
PeriodicTimer
(#55919) - Introduce
DynamicRendezvousHandler
andRendezvousBackend
. (#55635) - Introduce
C10dRendezvousBackend
. (#55636) - Introduce
EtcdRendezvousBackend
. (#55637) - Added
torch.distributed.elastic.launchers.api
,torch.distributed.elastic.metrics
,torch.distributed.events
,torch.distributed.rendezvous
,torch.distributed.elastic.agent
modules (#55471, #53870, #53574, #53760, #53172, #54343) - Upstreamed timer and multiprocessing classes to
torch.distribute.elastic.timer
andtorch.distributed.elastic.multiprocessing
(#53574)
-
torch.distributed.nn.RemoteModule
: Enable RemoteModule to directly send GPU tensors over the wire on TensorPipe RPC backend if a device map is provided (#57288) -
torch.distributed.optim
:
torch.fx
- Added
torch.fx.Node.format_node()
(#51737). - Added a
Transformer
to normalize args/kwargs oftorch.nn.functional
calls into only kwargs (#51816). - Added submodule manipulation APIs on
GraphModule
(#52358). - Added
Graph.eliminate_dead_code
(#52658). - Added a function to retrieve
inspect.Signature
instances for PyTorch operations (#53830). - Experimental type annotation pass using Python signatures (#53831).
- Added a transformer to normalize
torch
namespace operations (#53832). - Extended
NormalizeArgs
to work ontorch
namespace operations (#54236). - Added FX
optimize_for_inference
for Intel CPUs (#53805, #58293). - Added a metadata dict to
Node
and switch shape-prop to use that (#54926). - Added C-level monkey patching of
torch.randn
to capture it during tracing (#54060). - Added a new API replace_input_with to
Node
(#55887). - Added net splitter and net minimizer utilities (#56201).
- Added PyTree support to FX through
concrete_args
(#55888). - Added support for proxy-able classes (#56737).
ONNX
- Support onnxifi interface for set/get options (#52388).
- Support --onnxifi_min_ops in AOT flow (#52380).
- Redesign onnx pass to enable shape type dependent pattern conversion - cont (#51795) (#53304).
- Support inplace operations on inplace indexing (#52063) (#53306).
- Symbolic shape inference (#51481) (#53307).
- Support repeat_interleave symbolic (#52855) (#53312).
- Support primitive type input/outputs and attributes (#53550) (#54864).
- Support outer export to onnx (#53603) (#54869).
- Support hardsigmoid symbolic in opset 9 #49649 (#54193).
- Support support for hann_window operator (#54587) (#56163).
- Enable tensordot symbolic function (#55654) (#56166).
- Support for prim::min (#55259) (#56168).
- Support mv op (#55470) (#56169).
- Support .item() export & NumberType to tensor conversion (#55697) (#57594).
- Support a new operator for fill_() function (#56859) (#57596).
- Support index_add_ function (#56867) (#57830).
- Support tensor.to(device) (#56857) (#57599).
- Support registering custom export for prim::PythonOp from torch.autograd.Function (#55630) (#57600).
Vulkan
- Added the
hardswish
andhardsigmoid
activation functions (#53362). - Added the
reflection_pad2d
op (#53604). - Added an implementation of Winograd convolutions (#54639).
- Added the
sigmoid
activation function (#57867).
Misc
- Android packages are now published to maven central (#53568).
- Kineto is now supported on Windows (#56323).
- Added a Gloo
TCP_TLS
transport (#56442). - Add ability to collect minidumps after the crash (#59236).
Improvements
Python API
- Added nondeterministic alert for
index_put_
whenaccumulate=False
(#55827). - Added deterministic path for
torch.index_add
on CUDA (#56521). - Added deterministic path for
torch.index_copy
on CPU (#56900). - Removed beta warning for use_deterministic_algorithms (#58074)
- Updated
torch.Tensor.unflatten
to be able to infer size value insizes
from -1 (#51955). - Added a safe cast and copy for
out=
input tensor fortorch.tensordot
(#56286). - Added cross-device check for
out
andinput
tensors fortorch.cat
(#53004). - Modified the order of asserts to correct the error message when nan appears in
torch.multinomial
on CUDA (#53288). - Converted a few more checks for unsupported device to raise
NotImplementedError
(#53610). - Made shared cache thread-safe for
torch.multiprocessing
(#53750). - Added support for
torch.int32
indices intorch.repeat_interleave
(#55102). - Added a check to give a clear error message when a binary function is called for non-complex inputs with complex valued alpha (#54964).
- Propagate error message from
torch_shm_manager
when runningtorch.multiprocessing
(#57307, #57310). - Enabled deterministic path for
index_copy_cud
a with index_put (#58144). - Added support for uppercase letters in
torch.einsum
(#56475). - Added CUDA support for
torch.orgqr
(#51348) andtorch.ormqr
(#57316). - Added support for batched as well as complex inputs for
torch.geqrf
on both CPU and CUDA (#56249, #56251).
Complex Numbers
- Fixed
torch.{linspace, logspace}
to correctly infer complex type and return a complex tensor when thestart
and (or)end
values are complex numbers, and thedtype
value isNone
(#38875).
Autograd
- Added support for single tensor in
inputs
argument for.backward()
(#53827). - Added support for C++ optional arguments in autograd custom functions (#54270).
- Added autograd support to
torch.orgqr
(#52637),torch.segment_reduce
(#56792). - Added deterministic backward for
torch.gather
fordim=1
(#55573). - Make detach return an alias even under inference mode (#59633).
torch.nn
- Add 3D depthwise separable convolution (#51027)
- Make bias in lazy modules lazy and avoid creating empty tensors (#52212).
- BFloat16: enable prepacked weights's inference (#48922).
- Enable mkldnn conv2d backward to support mkldnn tensor input (#48994).
- Add OneDNN pooling backward (#49454).
- Add 64bit indexing support for softmax (#52713).
-
nn.init._calculate_fan_in_and_fan_out
: Support usage with__torch_function__
(#53522). -
nn.Transformer
/nn.MultiheadAttention
: Addbatch_first
argument (#55285). -
nn.Transformer
: Addlayer_norm_eps
arg (#54494). -
nn.AvgPool2d
: Add channels_last support on CPU (#48918). -
clip_grad_norm_
: Adderror_if_nonfinite
flag (#53843, #55169). -
Module.train
: Raise nicer error when called with invalid modes (#58247). -
nn.Linear
: Support 0in_features
(#56505). -
nn.EmbeddingBag
: Support mix of int32 and int64 offsets/indices (#55189). -
xnnpack::linear
: Handle 1D input (#54986). -
nn.Module
: Addallow_duplicate
flag tonamed_modules()
(#54812). -
nn.Module
: Addto_empty()
function for moving to a device without copying storage (#56610). - Make
pad_sequence
callable from C++ API (#57868).
Dataloader
- Added
generate_state
for NumPy seeding (#56797). - Modified construct_time_validation to argument_validation (#55836).
- Added mode to
LoadFilesFromDisk
(#57056). - Added the ability to override reduce_ex function of
DataPipe
(#52858). - Added lambda support to
MapIterDataPipe
(#52856). - Added functional way of stacking DataPipes (#52885).
C++ API
- Suppressed unsigned comparison warning (#52653).
- Fixed constexpr host warning (#52702).
- Introduced a fluent API to construct tensors from external data (#54530).
AMD
- Allow PYTORCH_ROCM_ARCH in cpp_extension (#54341).
- Added support for
torch.half
dtype RNNs with MIOpen (#52475). - Added support for the new
hiprtc
precompiler feature (#54350). - Improved reliability of
hipfft
androcfft
detection for ROCm build (#53408).
CUDA
- Improved warning message when old GPU is detected (#56621)
- Made
torch.cuda.amp.GradScaler
scale updates in-place for better composability with graph capture (#55562). - Add
USE_MAGMA
build flag (#55994). - Change link order for BUILD_SPLIT_CUDA option (#58437).
- Improve CUDA-11.X binary builds (#58459).
- Move CUDA async warning to suffix (#59467).
torch.fx
- Make
torch.fx.map_arg
require a callable (#51907). - Generalize dict key check in
torch.fx.Tracer.create_arg
(#51927). - Customize traceback for calls to symbolically-traced code (#51648).
- Allow
Transformer
to accept output result that is not Proxy (#52473). - Make
TracerBase._find_user_frame
private (#53654). - Improve buffer registration during
GraphModule
init (#53444). - Garbage collect values in
Interpreter
(#54726). - Improve placeholder matching in subgraph rewriter (#54958).
- Record
stride
onNode
duringShapeProp
pass (#55108). - Record
memory_format
onNode
duringShapeProp
pass (#55815). - Put tensor metadata into a
NamedTuple
inShapeProp
(#55930). - Preserve node meta info in
split_module
(#56212). - Make
shape_prop
handle targets with aggregate outputs (#56221). - Make arg normalization a method on
Node
and not a pass (also augment tests to be exhaustive) (#55992). - Allow for args to be left as args in NormalizeArgs (#55995).
- Maintain submodule references during subgraph rewriting (#55463).
- Changes in order to move
PythonKey
out of tree (#57427). - Handle cases in
GraphDrawer
when shape, type or stride are not present (#57845). - Handle the case when output consumes
get_attr
directly insplit_by_tags
(#57844). - Let submodules be collected as
args/kwargs
in symbolic tracing(#57840).
Profiler
- Expanded Kineto platform support (#56323).
- Added profiler fallback (#57612).
- Added CUDA event fallback (#58133).
TorchScript
- Added a flag to enable CPU fusion in benchmarks (#48612).
- Updated fusion to handle loops that have the same bounds as expressions (#55997).
- Updated normalization transformation to be in-place (#56158).
- Added check to only lower float
conv2d
s (#56289). - Added more python bindings for loopnest (#56213).
- Updated
fuseLoops
API to return bool flag and not throw any exceptions (#56353). - Added
unroll
andflatten
APIs which do not require return stmt pointer (#56420). - Updated
Buf
on mutation instead of creating a new one (#57513). - Updated
flatten
transformation to be in-place (#56629). - Added missing python bindings for NNC Stmts (#55570).
- Allowed backend preprocessing to take place outside of the backend interface (#51757)
- Added an error message for the case when
with
item is not an object (#52335). - Enabled
ModuleList
non-literal indexing (#53410). - Added recursive scripting for class type module attributes (#55124).
- Added support for
mypy
ignore annotation with particular rule specified (#51675). - Added support for comparing two bool variables (#51844).
- Added MKLDNN GELU function (#53615).
- Added
hardtanh(0,6)
to the set of MKLDNN fusible ops for mobilenetv2 (#56203). - Captured argument names for traced functions and modules (#51775).
- Improved
has_bf16_support
(#57408). - Walk Python AST to check for unsupported attribute type annotations (#51805).
- Added
out
version for sum (#52225) - Added logic to trace
torch.nn.Linear
asaten::linear
(#51897). - Made
is_tracing
scriptable (#49853). - Added support for builtin
sum
(#52188). - Fused
clip_ranges
andgather_ranges
(#52461). - Added support for features from
to_backend
for the Lite Interpreter (#52870). - Added a filter to remove mutation (#51923).
- Added logic functionalize ops which to be included in MKLDNN group (#51924)
- Extended subgraph utils to cover merging a node following a subgraph (#52513)
- Included max pool in fusion groups (#52613).
- Registered both TupleConstruct and ListConstruct as out variants (#52684).
- Added Alias analysis to Memory Management/Planning (#50060).
- Included max pool in fusion groups (#52613).
- Added property binding in TorchBind (#50670).
- Registered
pow
out variant (#52454). - Made
torch.load()
aware of import path changes (#53139). - Added
aten::to
copy out variant (#52343). - Added more variants to
create_empty_from
(#53333). - Added support for parsing Ellipsis in JIT frontend (#53576).
- Added a bool
is_available()
method to the backend contract (#53068). - Added parallel support for the LLVM backend. (#53243) / Resubmit: Add parallel support for the LLVM backend. (#54122).
- Rewrote
functional.tensordot
to be TorchScript-able (#53672). - Added python bindings for missing loop transformations in
LoopNest
(#54355). - Added support for list insertion for mutation removal (#54271).
- Added support for
torch.bfloat16
in the fuser (#54571). - Added some functions for manipulating MKLDNN tensors to TORCH_API (#56954).
- Merged CUDA Streams and Events (#53902).
- Added python bindings for
TensorExprKernel
(#54450). - Added support for dtype-specific tensor subclasses (e.g. LongTensor) (#54817).
- Added support for tuple
add
operator (#52292). - Disambiguated error message for working with not fully refined tuple types (#55745).
- Allowed unpacking tuple and assigning unpacked values to SELECT-type expressions (#55268).
- Made NoneType
annotation_str
emitNoneType
instead ofNone
(#54746). - Added CUDA device synchronization support in JIT (#55469).
- Added
optimize_graph_output_memory
flag (#55811). - Added support for refinement for
torch.jit.Future
(#56148). - Added implicit conversion from null tensor to
NoneType
(#55823). - Added
aten::matmul
s to TE fuser (#54605). - Put explicit error message on class attribute accesses (#55723).
- Added support for constant tensors in tensorexpr kernel (#56319).
- Added native support for
aten::getitem
(#55310). - Added stricter check for function schemas with varargs (#56509).
- Added graceful failure handling of DataPtr extraction in CUDAFuture (#56511).
- Enabled forward/backward compatibility in TS mobile (#56079).
- Added binding for
aten::clamp_min_out
(#56635),aten::argmin_out
(#56638), andaten::norm_out
(#56636). - Enhanced error message for
Future.setErrorIfNeeded
(#56631). - Added type inference support for
nn.Module
methods using PDT (#57165). - Disabled conv-add-relu fusion for cuDNN7 when model uses
torch.float16
(#56579). - Enabled conv-add-relu fusion as a part of frozen graph optimization (#56580).
- Reduced inline autodiff threshold to enable the capture of smaller fusions (#57062).
- Added static runtime support for
aten::matmul
(#57291). - Added
device()
method toc10::Event
(#57293). - Added support for normalization of
is
op (#57862). - Enabled
cat
without conditionals iff CPU (#58026). - Added
LowerSimpleTuples
for freeze tuples (#57915). - Added support for striding for list slicing (#49352).
- Wrapped
torch::deploy
API functions in safe rethrow macros (#58192). - Added binding for
aten::div_out
(#56653) - Added binding for
aten::sub_out
(#56656). - Supported
clamp.Tensor
(#58191). - Added an out version for
aten::repeat
(#57683). - Added default arguments to CUDA stream and events (#53025).
- Added support for linear in MKLDNN fusion (#51484).
- Handled MKLDNN broadcasting in MKLDNN fuser (#51736).
- Added 0-dim support for binary MKLDNN ops (#51921).
- Added OneDNN relu backward and reshape backward (#49455).
- Added OneDNN batch_norm backward (#50460).
- Added support for
hardshrink
(#57749). - Added non mutator bundled inputs method (#58408).
- Added support to compare devices (#53045).
- Added support for
memory_arg
inaten::clone
(#58100). - Implemented
aten::cat
without conditionals (#53128). - Added external function bindings (#53420).
- Added out variant of
sigrid_transforms_torch_bind
andListUnpack
(#54761).
torch.package
- Added a reliable method for determining if a file is part of Python’s standard library (#51694).
- Made package code more composable with other parts of PyTorch (package GraphModule, load non-code files from package) (#51674, #51976).
- Improved debugging facilities (allow_empty flag, zip file viewer, deny instruction, dependency tracing, query if object is from a package) (#53232,#53233, #52176, #55167, #56190, #56238, #56729).
- Allow save_module to accept module as arg (#55996).
- Follow dependencies created by
__import__
calls (#55153). - Added hooks to exporters’ mock and extern calls to take action when a module is matched (#58000)
- Turn the default behavior of packaging into an ‘intern’ action so that it can be ordered with repeat to mock, extern, and deny actions (#57341).
Quantization
- Added support for keeping output quantized for list and dict (#56391).
- Added
torch.float16
andtorch.float64
support tofake_quantize_per_channel
(#56894). - Support preserving attributes in deepcopy of observed/quantized graphmodule (#56550).
- Added support for packed params in state_dict (#51639).
- Added support for fusing
Conv3d + BatchNorm3d + ReLU
operations (#50003). - Change back to
multiple_outputs_gpu_kernel
for learnable fake per-channel quantization (#52017). - Added
torch.float16
andtorch.float32
support tofake_quantize_per_tensor
(#52612). - Support batched embeddings for 8 Bit embedding bag quantization (#55343).
- Expose nbins and ratio for
quantized::embedding_bag_4bit_prepack
(#50398).
Mobile
- Removed caching of inflated bundled inputs (#55181).
- Improved exception reporting for Lite interpreter (#54284, #55062, #55252).
- Improved forward/backward compatibility in Lite interpreter when adding new optional arguments to ops (#56845).
- Added model size to logged metadata when loading a Lite interpreter model (#53578).
- Benchmarking binary speed_benchmark_torch now supports Lite interpreter (#55402).
Distributed
torch.distributed.Store
- Update
compare_set
for other Store implementations to be the same asTCPStore
. (#57175) -
torch.distributed.Store
: Expose C++compare_set
API to python. (#57191) -
torch.distributed.Store
: Addtimeout
,host
,port
to TCPStore’s python API as accessors. (#52784) - Allow
world_size
andis_master
to be optional when constructing TCPStore. (#51809) - Add
wait_for_worker
param toTCPStore
’s Python API(#52888)
torch.distributed.rpc
- Allow
RRef
to be created with a specified set of CUDA devices (#57085) - Correctness fixes for CUDA support in RPC framework (#54024, )
- Refactor RPC agent to use
Store
to collect and verify name (#53209, #53202)
DistributedDataParallel
- Make unused parameter search show up in profiler output (#57376)
- Update DDP communication hooks to divide by world size before all_reduce to avoid overflow (#57410)
- Stabilize
torch.distributed.GradBucket
interface for gradient compression (#53010, #53098, #53102, #53009, #53099) - Skip CPU to GPU input copy if input is already on the right device. (#55624)
- Record forward pass of
DistributedDataParallel
andDataParallel
in profiler.(#55578) - Make
orthogonalization_epsilon
flag configurable intorch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.PowerSGDState
(#55738) - Set default value of
start_powerSGD_iter
to 1K iterations intorch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.
(#55272) - Add a minimum compression rate threshold parameter for
torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook
(#52541) - Report compression rate for batched PowerSGD hook (#55103)
- Enable gradient compression hook testing on ROCm (#52403)
- Enhance warning for unused parameters in
DistributedDataParallel
. (#52385) - Enhance error messages when crashing with unused parameters in
DistributedDataParallel
. (#52391)
torch.distributed
- Add rank information on NCCL communicator abort (#57974)
- Enhance exception logging in NCCL (#54557, #54558, #54117)
torch.distributed.nn.RemoteModule
- Create a separate remote module template when moving CPU tensors to a cuda device is not enabled (#57413)
- Allow passing
RemoteModule
as an argument over RPC (#57695, #58345) - Support async instantiation of RemoteModule (#58052)
- Place inputs on the appropriate devices in
RemoteModule
(#56943)
torch.futures.Future
- Enable
torch.futures.Future
to be created with CUDA support (#56517) -
torch.futures
: Improve error propagation when usingthen
API (#54475)
torch.nn.SyncBatchNorm
- Migrate
apex.parallel.SyncBatchNorm
channels_last
to PyTorch implementation (#46906) - Fix
SyncBatchNorm
’s forward pass to handle optional weight (#54568)
torch.distributed.pipeline
-
torch.distributed.pipeline
: Merge pipeline partitions that are on the same device. (#55973)
Added new torch.distributed.elastic
module that upstreams pytorch/elastic
- Rename
torch.distributed.elastic_launch
totorch.distributed.run
(#56831) - make process failure init error non-fatal (#56739)
- Reorder type definitions in dynamic_rendezvous.py (#56534)
- Revise the rendezvous handler registry logic. (#55466)
- Set error code in reply file when child process is terminated by signals. (f665a7f8a1)
- Make sure torchelastic mp wait for queue to be drained before finishing the process (#55412)
- Revise the rendezvous exception types. (#54803)
- Expose a
stderr
parameter inEtcdServer
. (#54805) - Improve the implementation of the utility functions and add their unit tests. (#54804)
- Improve the implementation of
RendezvousParameters
and add its unit tests. (#54807)
torch.distributed.optim.ZeroRedundancyOptimizer
- Add an option for buckets to be views of tensors and consolidate public interface (#52987)
- Make state dict for ZeroRedundancyOptimizer world size independent (#52960)
Combine backtrace print into one string to avoid interleaving (#56961). Raise exception rather than crash if GLOO_DEVICE_TRANSPORT is set to unknown value (#58518).
ONNX
- Updated fuseLogSoftmaxNllLoss function to handle autocasting (#51729) (#52349).
- Added support for sequence of tensor mutations in blocks (#51577) (#52347).
- Updated LayerNorm symbolic to handle autocasting (#52199) (#52350).
- Restored fast path in
OnnxifiOp::adjustOutputBatchSize
(#52498). - Improved index_put symbolic to handle singular Bool updates (#53690) (#54863).
- Replaced decomposeLinear pre process pass with a symbolic (#53077) (#54866).
- Improved assign input shape for tuple inputs & primitive type inputs (#54112) (#56164).
- Updated repeat_interleave symbolic (#54312) (#56165).
- Enabled
word_language_model
GRU and LSTM scripting (#54310) (#56170). - Added standardOps match more input type in ORT (#53813) (#56172).
- Redesigned in-place conversion (#55033) (#56173).
- Handled PackedParams inputs for _propagate_and_assign_input_shapes (#56449) (#57079).
- Added a warning for the case when len is used to calculate tensor shape (#55151) (#57595).
- Added special post processing for
onnx::Cast
andonnx::ConstantOfShape
shape type inference (#55962) (#57597). - Handled NoneType in Assign Output Shapes (#54623) (#57602).
- ListUnpack on dynamic tensor list (#56592) (#57603).
- Handled mixed mask, index input for index_put (#57604).
- Handled incorrect format for example_outputs (#55802) (#57829).
- Enabled several script unit tests using new jit passes (#51722) (#53309).
Vulkan
- Enabled broadcasting for arithmetic ops (add, sub, mul, and div) (#52842).
- Reduced size of compiled shaders by using the
-Os
flag when callingglslc
(#57199). - The vulkan optimization JIT pass now adds an
optimized_for_vulkan
attribute to the model (#56414).
Benchmark
Misc
- Auto-detect ccache to speed up developer builds (#49389).
- Catch and ignore tracebacks for compilation errors (#55986).
- Register DefaultBackend implementations for functional/inplace structured operators (#53037).
- Improved support for oneDNN on AArch64 when building from src (#55913).
Bug fixes
Python API
- Updated
torch.lerp
to makeweights
tensor broadcast-able (#52319). - Fixed print for negative torch.int8 tensors on ARM64 (#52616).
- Fixed type annotation for
as_tuple
to clearly determine whattorch.nonzero
will resolve to (#51635). - Fixed
torch.logcumsumexp
to correctly handle infs and nans (#52947). - Fixed
torch.topk
for k=0 on CUDA by skipping the kernel launch in this case (#58086). - Fixed a bug for optimizers to have the hyper parameters be still defined when all parameters have no grad (#52944).
- Fixed type promotion issue for
torch.pow
(#54085). - Fixed
torch.min()
andtorch.max()
to work on a non-empty dimension for tensors with 0 elements (#52565). - Fixed the upper bound computation for
torch.randperm
(#56967). - Allowed
std=0
intorch.normal
, and added checks to consistently error out ifstd<0
(#51317) - Fixed
torch.index_fill
to output 0-dim tensor for a 0-dim input tensor (#52209). - Fixed mul_() to correctly work for Mkldnn tensors (#51758).
- Fixed temp file/bind race condition in torch_shm_manager for
torch.multiprocessing
(#57309). - Fixed tempfile address binding in torch_shm_manager to be destructed correctly for
torch.multiprocessing
(#57566). - Fixed
torch.multinomial
to never select an element with 0 weight fortorch.half
(already works correctly for other datatypes) (#53480). - Fixed a bug in
assertRaises
NotImplemented
handling when no exception is thrown (#54126). - Fixed override for
__iter__
(#54702). - Fixed segmentation fault for
torch.floor_divide
when compiling on ARM64 (#55608). - Fixed
torch.digamma
’s inconsistency with SciPy’s digamma (#56689). - Fixed
torch.cat
to return correct result for non-contiguous tensors (#57177). - Fixed distributions for
torch.distributions.log_prob
which don't properly honorvalidate_args=False
(#53600). - De-prioritized
Dimname
andDimnameList
in python overload resolution (#51350). - Fixed the handling of scalar and zero dimensional inputs as well to
torch.take()
andtorch.Tensor.put_
on both CPU and CUDA (#53356). - Fixed a bug to not rebuild extensions for every import (#56015).
- Fixed error message for
torch.as_strided
(#53198). - Added correct handling for tensor allocation for large tensors when using
torch.resize
on CUDA (#52672). - Fixed an illegal memory access that could happen when computing the inverse of a batch of matrices on CUDA (#53064).
- Fixed a bug where
torch.sparse.addmm
would compute the wrong results for CUDA inputs when beta was not zero or one (#56160). - Fixed a bug where
torch.sparse.sparse_coo_tensor
’s gradient could be calculated incorrectly (#50361). -
pow
: Fixed a bug caused for mixed cpu/cuda input tensors (#53669). -
sub
: Fixed asub.Scalar
bug (#53679). - Fixed
torch.unique
for discontiguous inputs (#59003). - Fixed
torch.randperm
on CUDA (#59352). - Fix
torch.reciprocal
fortorch.float32
on ARMv8 (#59361). - Disable overloading of std::max & std::min for inputs of different types, which could cause accuracy loss (#55638)
Complex Numbers
- Added custom implementation for
sqrt
andacos
to be used iflibc++
is used to reduce numerical error for edge cases. (#52018, #54820, #52287).
Autograd
- Fixed
-
torch.autograd.gradgradcheck
when outputs are independent of the inputs (#58049). -
torch.utils.checkpoint
to behave properly when an error happens during forward (#51746). - autograd’s graph discovery when output is a leaf that requires gradients (#51940).
- some cases where
torch.autograd.gradcheck
did not return the correct value whenraise_exception=False
(#53916) . - thread local state not being properly propagated for some operations during the backward pass (#56174).
-
torch.index_fill_
formula to support duplicate indices (#57101). - derivative of
torch.sinc
aroundx=0
(#56763, #56986). -
torch.cdist
backward formula to correctly support broadcasting (#56605) and empty inputs (#56606). - view creation metadata for functions that return multiple views in
no_grad
or inference mode. (#57842). -
autograd.functional.*
functions to work in no_grad mode (#47543). - rare deadlocks on exit due to autograd worker threads (#53170).
-
torch.nn
-
nn.AdaptiveAveragePooling
: Fix crash for integral inputs (#51443). -
F.normalize
: Fix to make it properly scriptable (#51909). -
nn.parallel.scatter_gather.gather
: Fix to handleNamedTuple
s and moving output to CPU (#51104). -
fractional_max_pool{2/3}d
: Fix segfaults for incorrectkernel_size
andoutput_size
(#51626). -
nn.CosineEmbeddingLoss
: Validate target has correct shape (#53110). - Fix multiprocessing serialization for integer parameters on CUDA (#56529).
-
nn.Softplus
: Fix backwards computation by comparinginput
againstbeta * threshold
(#56484). -
addmm_
: Add check to disallow resizing the input tensor for the in-place variation on CPU (#56452). -
nn.InstanceNorm*d
: Fix to perform correct input size check (#56659). -
nn.CTCLoss
: Fix backward pass regression on cuDNN (#56639). -
nn.ConvTranspose*d
: Fix regression that broke padding with a list of values (#54911). -
F.max_pool3d
: Fix illegal memory access for large inputs on CUDA by doing multiplication inint64
(#52828). -
F.embedding
: Support__torch_function__
(#54478). -
nn.ChannelShuffle
: RemoveNamedTensor
warnings (#55911). -
mkldnn_linear
: Fix incorrect results for non-contiguous inputs (#51713). -
nn.ModuleList
/nn.ModuleDict
: RaiseNotImplementedError
forforward()
(#48785). - Change
maybe_resize_storage_cpu
new_size
arg to unsigned (#52671). -
nn.LSTM
: Fix regression that broke loading older serialized modules (#57558). -
F.reflection_pad2d
: Fix CUDA launch error (#56451). - Fix wrong detection of depthwise convolution on neon (#55794).
- Re-enable fast winograd convolution on IOS (#56021).
-
gaussian_nll_loss
: Fix incorrectreduction=‘none’
behavior (#56469). - Fix misaligned access #56325 (#56403).
- Use native CTC loss for target length 256 (#53557).
-
register_full_backward_hook
: Fix crash when first argument doesn't require a gradient (#57945). - Remove asserts of Tensor type and ignore mypy checks to support
__torch_function__
usage (#57458). - Handle stride > 1 with im2col in CUDA thnn conv2d (#54080).
- Add device id to ConvolutionParams (#50892).
- Enabling OneDNN for group convolution (#54890).
-
nn.AdaptiveAveragePooling3d
: AddAccumulateType
for CUDA (#53607). - Do not use depthwise3x3 conv in grad mode for ARM (#56889).
- Fix type annotations for
state_dict()
override (#55704). - Pass contiguous weight to NNPACK convolution (#56569).
-
nn.EmbeddingBag
: Mark backward as non-deterministic for max mode rather than all reducing modes (#55574). -
nn.EmbeddingBag
: Initializebag_size
output with zeros to make it deterministic (#56661). -
nn.EmbeddingBag
: Support the empty bag case on CPU (#57446). - Fix
nn.MHA
+quantized
scriptability (#58727). - Fixes cuDNN performance on A100 (#58287, #59721, #59744, #59802).
Dataloader
- Fixed type hints of the callable DataLoader arguments (#52924).
- Added a keyword arg to meta and support
abc
for typing (#58450). - Fixed a bug to use
generator
instead ofself.generator
in theRandomSampler
(#52956).
C++ API
- Fixed the lifetime of
PyTensorType
(#51649). - Fixed linker failure with ambiguous namespaces (#45736).
- Fix Scalar output formatting (#53229)
- Fix printing of optional string arguments in schemas (#55196)
AMD
CUDA
- Added
torch.scatter_add
totorch.cuda.amp
promote list (#52133). - Fixed segfault in distributed process group due to IPC (#53080).
- Fixed multinomial CUDA misalignment and non-deterministic behavior (#55364).
- Replaced raw cudaMalloc in
torch.sparse
code (#57083). - [CUDA graphs] Added proper sync after replay (#57556).
- Fixed NVRTC versioning for CUDA 11.X (X>=3), CUDA 12 and later (#57204).
- Fixed a correctness issue of CUDA channels-last
nn.SyncBatchNorm
(#57077). - Fixed CUDA caching allocator when trying to allocate ~2^64 memory (#57571).
- Fixed raw_deleter() bug with PYTORCH_NO_CUDA_MEMORY_CACHING=1 (#54775).
- Fixed undefined symbol for CUDA 11.1 Windows (#52506).
- Automatically set BUILD_SPLIT_CUDA for cpp extensions (#52503).
- Adds grid_sampler to the list of operations that can autocast
torch.float32
(#58679).
Dispatcher
- Fix boxing/unboxing for
Scalar
bool values (#53228) - Fix inaccurate dispatch table for
fill_
(#53611) - Fix inaccurate dispatch tables (#54127)
- Fix issue with dispatch key:
AutogradXPU
(#56336) - Modify
DispatchKeyExtractor
to also work for optional Tensors (#58283) - Extract dispatch keys from optional Tensors (unboxed) (#58296)
torch.fx
- Preserve leaf modules in
Transformer
(#51998). - Fix tuple type annotations in FX codebase (#52010).
- Fix type correctness on
GraphModule.graph
(#54305). - Remove
forward
fromforward.__globals__
to facilitate retracing (#54011). - Fix
ScriptMethod
dispatch on__torch_function__
(#56103). - Fix
type_matches
forOptional[List[int]]
arguments to makeNormalizeArgs
more permissive (#56790). - Fix
NormalizeArgs
issues with lists of tensors (#57004). - Changed parametric type error in
NormalizeArgs
to a warning (#57183). - Make
NormalizeArgs
not save output node in thenode_map
(#58058).
Profiler
- Fixed intermittent CUDA activity flush issue (https://github.com/pytorch/kineto/pull/95).
- Handled empty trace (#58013).
- Added cuda synchronization points (#56651).
- Removed usage of onEachDevice from legacy profiler (#54125).
- Fixed double printing of FLOPs (#56974).
TorchScript
- Fixed
jit.trace
mishandling of InterfaceType (#53052). - Made
reshape
/flatten
deterministic (#54353). - Added logic to use
is_buffer
inBufferPolicy::valid
(#49588). - Updated NNC to sanitize input names (#52786).
- Handled ExternalCalls in LoadStore analysis and Inliner (#52628).
- Fixed output restriding of size-1 dimensions (#58256).
- Handled non literal constant bounds in Unroll (#53029).
- Fixed a case where inlining wouldn't work because dim-size was 1 (#53254).
- Removed cached argv from LLVMCodeGen to fix race condition (#54286).
- Lowered scalar constants as doubles/longs (#54824).
- Added a check to not try to vectorize kernels that use float16 (#55970).
- Added a check to not fuse
torch.float16
on CPU (#56119). - Fixed
float->bool
conversion on CPU (#57798). - Fixed handling of the arguments of
aten::to
(#58028). - Don’t error on 0-dim in convolution (#51922).
- Allow
__exit__
to have a return value (#52336). - Added metacompile of ternary if (#51789).
- Keep alive graph when creating iterators from it (#51951).
- Fixed return value of
IValue::to
for Tensor/String (#51463). - Added function to check for memory leak (#52342).
- Ignore user annotated ignored attributes (#52367).
- Fixed
jit.trace
mishandling of InterfaceType (#53052). - Fixed tracing support for TorchBind (#52884).
- Use correct warning type for tracer warnings (#53460).
- Removed the assumption that
forward
exists in freeze_module (#52918). - Removed notion of "level" from
Module::dump_to_str
(#52539). - Made
IValue::toTensor()
inline-able (#53213). - Consider
normal_
as a special operation in the remove mutation pass (#52175). - Updated
set_stream
API to change the device (#53741). - Only run
ReplaceWithCopy
pass whenenable_out_variant
is true (#54111). - Disable dfusion group that is not supported by XPU device (#54239).
- Don’t require same-sized
src
/dest
inreshape_copy
(#54467). - Fixed
TupleType.annotation_str
to conform totyping
module syntax for empty tuple type (#54641). - Made NoneType
annotation_str
emitNoneType
instead ofNone
(#54642). - Made sure the copy version of the op exists in
ReplaceWithCopy
(#55337). - Included
conv3d
inconv-add-relu
fusion (#54772). - Added
cond-add-relu
matching pattern to cover in-place ops (#55458). - Fixed
TupleType.annotation_str
to conform totyping
module syntax for empty tuple type (#54745). - Fixed
Optional[Tensor]
type in autodiff (#55565). - Raise TypeErrors when
IValue::getSubValues
fails (#56510). - Fixed num args for
to_copy
(#56441) - Fixed error in JIT CUDA on ROCm (#55243).
- Fixed a bug in
emitUse
to drop all values that are marked as drop (#56652). - Fixed default dtype for
randperm
andtriu
/tril_indices
inside TorchScript (#57105). - Don't allow create() on singleton types (#56807).
- Fix GIL mutithreading issue exposed by
torch::jit::toIValue()
(#57688). - Fold
NaiveSyncBatchNorm
when folding batch norm (#57823). - Fix UB in
LoopNest::distribute
(#57883). - Fix a condition when we use a native depthwise
conv2d
lowering (#57906). - Ensure
torch.save()
has deterministic output (#57536) - Fixed
hasattr
support type (#57950) - Return nullptr if the number of input args doesn't match (#58018).
- Added fix for missing ops
aten::sorted.str
(#58339). - Fixed deadlock in
Future
due to lock inversion with GIL (#58382). - Added logic to prevent lock inversions with GIL in
Future
(#58391). - Fixed
MKLDNN_add
in-place behavior (#51687). - Use MKLDNN copy for
copy_ when
self and src are MKLDNN layout (#54248) . - Fixed default to align with documentation in
fuser.py
(#53457). - Fixed upcoming changes that are part of ROCm 4.2 and affect PyTorch JIT (#57400).
- Fix for improper mobile and torch.package serialization (#59642).
torch.package
- Add cpython as a dependency for torch_python_obj (#56740).
- Catch exceptions where dependency resolution gets invalid imports (#58573).
- Simplifications to broken dependency handling (#58572).
Quantization
- Fixed conv packed param serialization in
state_dict
(#52787). - Fixed
torch.float16
dynamic quant for functional linear (#52369). - Fixed prepacking for
F.conv1d
(#55311). - MHA tensor assignment fix (#53031).
- Fixed
conv
transpose withqconfig == None
(#52844). - Quant norm layers: move scale + zp to buffers (#52861).
- Handled the case when observed node has no users (#53210).
- Only insert observers for fixed qparam ops (#53330).
- Fixed a condition check for
CopyNode
(#53585). - Fix for
x.ndim
followed bysub
(#53120). - Fixed using size of quant layer in
torch._assert
(#53187). - Fixed fx quant for
quant_layer -> stack -> sum
(#53196). - Fixed
deepcopy
on quantizedConvNd
(#56154) - Fixed
getitem
for unmatched nodes (#57173). - Made quantizeable MHA work with
torch.jit.script
(#57774). - Fixed
quantize_per_tensor
on CUDA (#57703). - Fixed a bug to handle bias in rowwise quantization of FC (#58022).
- Skipped inserting observer for boolean Tensors (#57375).
- Fixed
torch.float16
reference patterns for linear (#55727). - FX Quant:
- Fixed overflow issue in quantized instance_norm/layer_norm/group_norm (#54872).
- Fixed zero_point rounding for _fake_quantize_learnable_per_channel_affine (#52290).
- Bug fix to update requantization and zp parameters of input (#52797).
- Fix embedding bag bug accessing unaligned memory (#53300).
- Fix out variant for 4bit embedding bag (#55096).
- Avoid tensor refcount bumps on embedding bag (#55023).
Mobile
- Fixed some bugs in the implementation of various functions on iOS GPU:
- Removed duplication of constant tensors in model when using Lite interpreter (#58182, #56002).
- Banned mutating operators in mobile GPU models (#56070).
- Use lite interpreter as default and bump model version (#58630)
Distributed
torch.distributed.Store
- Fix flag specifying whether there is more data for
TCPStore
delete key (#53886) - Properly enforce timeout for
PrefixStore
. (#53928) - Fix
TCPStore
wait
hang when key is previously set (#53860) - Properly order
TCPStore
’scompare_set
parameters in Python API (#52696) - Fix resource leak bug in TCPStore constructor (#52860)
torch.distributed.rpc
- Several fixes for CUDA support in the RPC framework (#57926, #57432, #57394, #57443, #57487, #58384, #51820, #57792, #56895, #54932)
- Fix possible reference cycle by passing reference to parent future in RPC callbacks (#57635)
- Fix RPC
get_worker_info
for rank 0 (#52804) - Fix crash when TensorPipe agent tries to double-set errors. (#52837)
torch.distributed
- Fix path handling on Windows during rendezvous process (#57000)
- Fix and re-enable
ProcessGroupMPITest
(#56709)
DistributedDataParallel
- Correct the usage of min_compression_rate in gradient compression communication hooks (#52979)
- Fix mapping of parameter to parameter names when certain parameters don’t require gradient (#57771)
- Skip rebuild buckets in
DistributedDataParallel
when running underno_grad
mode. (#54159) - Fix a race condition in
DistributedDataParallel
when all parameters are used but running withfind_unused_parameters=True
. (#53160) - In
DistributedDataParallel
, pass inprocess_group
argument intodist.get_rank
calls (#53793) - Fix
DistributedDataParallel
’s process for verifying model consistency during initialization. (#52887)
torch.distributed
- Check vector boundaries in
torch::cuda::scatter
(#53057) - Release GIL before destructing ProcessGroup classes (#56381)
torch.distributed.pipeline
- Fix hang in
pipeline
destructor by removingjoin_workers
(#53433)
torch.distributed.elastic
- Resolve bug around incorrect rendezvous handler resolution (#56386)
torch.nn.SyncBatchNorm
- Ensure
SyncBatchNorm
behaves like a regularBatchNorm
layer in eval mode. (#56982)
torch.distributed.optim.ZeroRedundancyOptimizer
- Typing fixes(#53165)
Fix monitored_barrier with wait_all_ranks (#58702).
ONNX
- Removed the last Cast in pow symbolic_opset9 (#52646) (#53305).
- Fixed export of
copy_
operator (#53046) (#53310) (#51938) (#54870). - Fixed export of embedding with
padding_idx
(#53053) (#53530). - Fixed onnx warning message (#54371).
- Improved error message during Glow ONNXIFI (#58069).
- Fixed if output shape mismatch error & graph input directly used as output (#53219) (#54865).
- Fixed ComputeShapeFromReshape when
input_shape_size < reshape_size
(#56171). - Fixed -Wrange-loop-construct in onnx_exporter.cc (#56759).
- Print
onnxifi
failed status code in readable format (#53648).
Vulkan
- Fixed kernel registration errors in Vulkan test and benchmark binaries by adding
nonVarTypeModeGuard
(#52535). - Fixed the
glslc
path in CMake for desktop builds (#56507). - Fixed build failures caused by
warnings-treated-as-error
for Linux builds. (#52781). - Remove constant duplication for Vulkan optimize_for_mobile (#59276).
Benchmark
- Fix timer overflow on small, fast snippets (#55200)
Misc
- [memory format] Fixed channels last bug in upsample kernels to now correctly pass
memory_format
information from the input to the output tensors (#53535). - [memory format] Fixed silent correctness bug for CUDA upsample kernels to correctly handle
torch.channels_last
contiguous tensors (#54744). - Workaround intermittent gcc-7.5 ICE in cpp tests (#57016).
- Improve build quality on Windows (#52729, #53562, #54132, #55275).
- Search for static OpenBLAS compiled with OpenMP (#59428).
Performance
Python API
- Optimized memory usage for
out=
version oftorch
.logsumexp
(#51239). - Added vectorization for
torch.floor_divide
(#55380). - Reimplemented
torch.flip()
using advanced indexing (#56713). - Improved performance for
torch.take()
andtorch.Tensor.put_
on both CPU and CUDA (#53356) - Generic performance improvement for operations performed on non-contiguous 2-dimensional tensors (#53613).
- Added vectorization for
torch.copysign
on CPU (#51792). - Improved performance for bilinear interpolation on CPU (#51653).
- Improved performance for backward computations on
torch.cumsum
andtorch.cumprod
on both CPU and CUDA (#53711). - Improved performance for
torch.Tensor.copy_
when performing copies between small tensors oftorch.float
andtorch.half
data types (#53800). - Enabled vectorization for
torch.Tensor.copy_
andtorch.cat
for BFloat16 tensors (#54671, #54674). - Added a fast path for a common case for
torch.addmm
on CUDA (#55026). - In collaboration with NVIDIA, the CUDA performance of many linear algebra operations has been improved by increasing use of the cuSOLVER and cuBLAS libraries
- Added cuBLAS support for
torch.triangular_solve
(#53147) and batchedtorch.geqrf
(#56253). - Added cuSOLVER support for
torch.linalg.eigh/eigvalsh
(#53040),torch.cholesky_solve
(#54315),torch.cholesky_inverse
(#54676), andtorch.linalg.q
r (#56256). - Added cuBLAS and cuSOLVER support for
torch.linalg.lstsq
(#57317).
- Added cuBLAS support for
- Improved performance for
torch.nonzero
(#58468). - Removed device check from a few indexing methods (#58800).
Complex Numbers
- Added a faster path for
torch.is_complex()
by skipping unnecessary dispatch (#50054).
Autograd
- Sped up autograd’s graph discovery algorithm by skipping some nodes using sequence number (#52180, #52057).
- Added a new fast gradcheck (#54480).
torch.nn
-
Module.forward
: Add fast path for the case of no hooks (#52576). - Fix
mkldnn
heuristic for multithreaded convolution (#52909). -
linear
: Remove one refcount bump (#54936). - Improve
native_batch_norm_backward
performance on CUDA (#58240). -
nll_loss
: Use cascade summation on CPU (#55841). -
nn.BatchNorm1d
: Improve training performance on CPU (#57033). - Simplify convolution double backward gradInput formulas (#54840).
- Move RNN cell size check to cpp (#51964).
- Remove syncs in
one_hot
(#57902). - Enable and enhance bf16 threshold (#54384).
-
nn.Conv3d
: Enablechannels_last_3d
for cuDNN (#48430). - Increase token count threshold for calling thrust sort in embedding backward (#49913).
- CPU convolution benchmark harness for some popular models (#56455).
- Improved performance for
torch.nn.BatchNorm1d
on both CPU and CUDA (#57033, #57786). - Added optimized generic interpolation for
torch.nn.functional.{upsample_nearest
,upsample_bicubic}
and speed up for channels first and last cases (#54500). - Added shape documentation for CosineEmbeddingLoss (#58403).
C++ API
- Fixed nest openmp performance bug in
thnn_conv2d
(#52577). - Added c10::MaybeOwned and Tensor::expect_contiguous (#53317)
- Added DimVector variant of infer_size (#54882)
- Added logic to use
DimVector
for inputs toas_strided
that don't grow dim (#55016). - Reduce ref-counting by borrowing in/out Tensors in TensorIterator (#55690).
- Reduce ref-counting by migrating add operators to borrow Tensors in TensorIteratorBase (#55691).
- Reduce ref-counting by migrating copy_ operators to borrow input/output Tensors (#56031).
- Added logic to use
expect_contiguous
inlayer_norm
(#58067).
CUDA
- Construct only necessary elements in OffsetCalculator (#55107).
- Migrated
torch.index_put
to use cub instead of thrust (#55693). - Added cuSOLVER
potrf
andpotrfBatched
to the backend oftorch.cholesky_decomposition
(#53104). - Implemented
torch.sort
with cub::DeviceSegmentedRadixSort (#56821). - Added cuSOLVER path for
torch.geqrf
(#56252). - Enabled cuSOLVER
torch.potrf
batched for Cholesky decomposition when CUDA >= 11.3 (#57788). - Fewer CUDA sync in unique by using cub instead of thrust (#57323).
- Removed sync for
randperm
on small tensors (#54113). - Simplify convolution double backward gradInput formulas (#54840).
Composability
- We’ve landed lots of performance optimizations for 1.9, both large and small. See individual PRs for details:
- Inline
tensor.device()
(#50848) - Skip a second call to
shouldUseRecordFunction
for BackendSelect ops (#50891) - Re-order
TensorImpl
fields to save a word (#50920) - Devirtualize
TensorImpl::storage()
(#51050) - Reduce template expansion in
call_functor_with_args_from_stack
(build time) (#51313) - Eliminate
WrapFunctionIntoRuntimeFunctor
use in CppFunction constructors (#51315) - Remove
reference_cast
inmake_boxed_from_unboxed_functor
(build time) (#51319) - Debug-gate
static_assert
inKernelFunction::makeFromUnboxedFunctor
(build time) (#51367) - Use real
if constexpr
behind macro in hot template (build time) (#51368, #52420) - Outline
DispatchStub::get_call_ptr()
(#51908) - Use
torchCheckFail
inTORCH_INTERNAL_ASSERT
(#52086) - Add
Storage::set_data_ptr_noswap
and use where possible (#52244) - Make shared empty string static instead of thread_local (#52220)
- Avoid
std::string
inTORCH_CHECK
when possible (#52221) - Make
c10::str(const char*)
returnconst char*
(#52222) - Sync
TORCH_INTERNAL_ASSERT
optimizations withTORCH_CHECK
(#52226) - Save a single add instruction in the dispatcher (#52543)
- Inline
TensorIteratorConfig
setters (#52661) - Use
DimVector
for sizes and strides inview
(#53001) - Avoid TLS in
has_names
(#53003) - Don't inline
Dispatcher::call
on mobile (binary size) (#53197) - Skip dispatch for
is_floating_point
(#53242) - Move non-template part of
TensorImpl::Resize
to cpp (binary size, build time) (#53388) - Don't copy vector arguments to
Tensor::Resize
(#53389) - Skip dispatch trip for CPU in
resize_
(#53575) - Pass
Scalar
by reference (#53583) - Don't use static for template declarations in headers (binary size) (#53602)
- Boxing logic forwards arguments to stack (#53624)
-
Speed up Tensor::data_ptr by using static item size (
#53723
)
-
Skip dispatch for is_signed (
#53847
)
- Allow inlining of more Tensor methods (#53905)
-
Tensor::register_hook
: Avoid wrapping hook in two levels ofstd::function
(#53917) - Take advantage of string literals in
TORCH_WARN
(#54032) - Inline
Tensor
keyset-checking methods & similar getters (#54806) -
TensorIterator::output
returns const reference (#54811) - Avoid refcount bump in
TensorArg
(#54934) - Move
Tensor::has_names
inline (#54965) -
OperandInfo
ctor should take rvalue reference (#54972) - Don't bother with
SmallVector
inTensorMaker
(#55125) - Eliminate device guard in generic dispatch key kernel wrappers (#55131)
- Move logic to skip a redispatch directly inside of
resize_output
(#55162) - Use
infer_size_dimvector
inExpandUtils
(#55180) - Don't create intermediate Tensor for
at::result_type
w/Scalar (#55232) - Use
sizes()[x]
instead ofsize(x)
inaddr
(#55247) - Add & use
inferExpandGeometry_dimvector
(#55316) - Mark borrowed case as
C10_LIKELY
inMaybeOwned
(#55553) - Avoid double indirection in
MaybeOwned
's borrowed state (#55685) - Make
VariableVersion::DISABLED
the default constructor forVariableVersion
. (#55572) - Don't set
version_counter
on inference tensor forunsafe_
ops. (#55819) - Add & document
borrow_from_optional_tensor
(#56647) - Migrate hacky wrapper removal to
borrow_from_optional_tensor
(#56648) - Optimize
at::repeat
(#56994) - Optimize
intrusive_ptr(TTarget*)
ctor (pybind
) (#57053)
- Inline
torch.fx
- Use precompiled regex in graph name processing (#52853).
- Optimize module path finding in
Tracer
(#52990). - Speed up
_Namespace.create_name
(#55580).
Profiler
- Sped up post processing (#58021).
TorchScript
- Generate arithmetic vs logical right shift as appropriate (#51749)
- Introduced likely/unlikely
CompareSelect
hint (#51751). - Implemented log approximation using the VML approach (#51752).
- Updated
TensorExpr
to useLLVM
as the default backend (#52314). - Added support for
aten::hardtanh
(a hot operation in mobilenet v2/v3) (#52394) - Implemented
hardtanh
(#57750). - Add
aten::batch_norm
into fuser when in inference mode (#54204). - NNC
- Added a new API to perform loop fusion (#54461).
- Implemented depthwise
conv2d
(#54920). - Integrated NNC
conv2d
with fuser (#55213). - Added logic to use NNC to generate
logit
,relu
andtanh
(#52322). - Use VML-inspired logarithm with NNC, tweak scheduling (#52423).
- Generate
sigmoid
with NNC (#52424). - Enabled CPU fusion only when
num_threads == 1
(#56120). - Use NNC's
call_raw
API to reduce call overheads. (#57553). - Started codegen’ing some external calls (#58118).
- Reduce memory use for inference path in
OneDNN MaxPooling
(#52728). - Removed redundant
gather_ranges
when fusing (#53323). - Optimized
sigrid_hash
(#53065). - Updated
create_empty_from
to directly use the native version ofat::empty
(#53216). - Added a minimum fusion group size (#50217).
- Added CUDNN
Conv-Add-Relu
fusion for Frozen Model Optimization (#52102). - Avoid dispatch overhead in call to MKLDNN convolution (#52614).
- Added re-inplacing to MKLDNN subgraphs (#53908).
- Set
requires_gradient
to help autodiff prune unneeded gradients (#54374). - Use type cache in erasing shape information (#55828).
- Added heuristic to avoid perf incompatible MKLDNN formats for binary ops (#56089)
- Added
adaptive_avgpool2d
to the set of fusible ops (#56180). - Lazily initialize
AliasDb
inremove_mutation
opt (#55949) - Made DataPtr extraction in CUDAFuture faster for Python values (#56918).
- Lazily initialize
AliasDb
in DCE (#56649). - Add explicit checks for in-place ops in
ReplaceWithCopy
(#54657).
Quantization
- Optimized quantized
torch.cat
(#54813).
Mobile
- Enabled
QNNPACK
for Apple Silicon builds (#52308). - Sped up model loading for per-channel quantized models using
QNNPACK
(#53726). - Added
XNNPACK
implementations for various operationss (hardswish, global average pool
) (#56714, #56715, #55791). - Made various performance improvements for iOS GPU (Metal) (#57664, #57665, #57666, #57667, #57668).
Distributed
torch.distributed
- Avoid 2 extra copies when reducing sparse tensors (#57822)
Vulkan
- Switched to a more performant implementation of matrix multiplication (#49609).
- Updated the version of Vulkan Memory Allocator used (#52938).
- Increased the command buffer submission rate (#57196).
- Updated the Vulkan tensors to use 2D textures whenever possible, instead of always using 3D textures (#57198).
- Updated convolution shaders to receive the bias tensor as a texture as opposed to a buffer (#57201).
Docs
Python API
- Added
torch.testing
docs (#57247). - Updated docs to mention CUDA support for Future (#50048).
- Included
memory_format
, an already accepted argument, intorch.empty
doc (#54664). - Improved the documentation for torch.matrix_exp() (#55626).
- Updated use_deterministic_algorithms docs (#55413).
- Added the
generator
argument totorch.rand
andtorch.randn
docs (#56242). - Added an example to show how to use learning rate schedulers in Optimizers (#56705).
- Corrected the torch.ceil formula in docs (#55039)
- Fixed docs to use autosummary on tensors.rst (#55042)
- Improved testing documentation in
CONTRIBUTING.md
(#54904) - Updated
torch.fft
docs to includeout=
argument (#56732). - Updated rounding_mode documentation to remove
"true"
(#52202). - Added a note about error handling for non-chained futures (#53212).
- Updated
torch.stft
documentation to clarify output shape (#54877). - Added an example for
torch.is_tensor
andtorch.is_storage
(#55052).
Autograd
- Added a note describing gradcheck internals (#55966).
- Split up autograd documentation into separate pages (#55672).
-
torch.utils.checkpoint
: Updated docs to state thatinput
flag in.backward()
is disallowed when checkpointing (#51746). - Added section in autograd mechanics note describing how to use inference/no_grad (#58513).
- Added doc string for
torch.is_inference_mode_enabled
andtorch.is_grad_enabled
(#59047). - Added no-grad inference mode note (#58513).
- Add docstring for is_inference_mode_enabled (#59047).
torch.nn
-
nn.TripletMarginLoss
/torch.reciprocal
: Fix formatting in docs (#51650) -
nn.FractionalMaxPool3d
: Add to pooling layer docs (#52556) -
F.fractional_max_pool
: Add tonn.functional
docs (#52557) -
Module.share_memory
: Add link toTensor.share_memory_
in docs (#52561) -
nn.SiLU
: Mention alternative name of Swish within docs (#53239) - Remove redundant hardsigmoid() in docstring to show up
inplace
parameter (#52559) - Clarify docs for lazy modules (#53495)
-
torch.nn
: Grammatically update docs (#54370) -
nn.Sequential
: Expand docs, including comparison withnn.ModuleList
(#53380) -
F.embedding_bag
: Fix formatting in docs (#54666) -
F.group_norm
: Add to docs (#54673) - Add separate autosummary for flatten layer docs (#54663)
-
LazyModuleMixin
: Add missing attr in docs to improve formatting (#53363) -
conv1d
: Fix example error in docs (#57356) -
nn.functional
: Split docs into a table-of-contents page and a sub-page per function (#55038) -
nn.LSTM
/nn.RNN
/nn.GRU
: Clarifybatch_first
behavior (#58809) -
nn.CosineEmbeddingLoss
: Add shape info to docs (#58403) - Add doc warnings for default SELU gain (#54057).
- Clarify batch_first behavior for
nn.LSTM, nn.RNN, and nn.GRU
(#58809). - Add UninitializedBuffer to nn docs ( #59021).
- Document factory_kwargs in nn.Quantize + remove Attributes section (#59025).
Dataloader
- Added DataPipes Typing Doc (#54773).
- Added docs to document the default NumPy seed for DataLoader workers (#56528).
AMD
- Added HIP semantics doc (#57871).
CUDA
torch.fx
- Make some modifications to limitation section (#51928)
- Added docstring for concrete_args on
Tracer.trace
(#53151). - Change Dynamic Control Flow example to a more dynamic version (#53250).
- Render inherited methods in fx.Tracer API reference (#53630).
- Add docs for
ShapeProp
(#54554). - Hide module paths leaking in the documentation. (#54585).
Profiler
- Updated profiler recipe doc (https://github.com/pytorch/tutorials/pull/1528).
TorchScript
- Added NNC IR specification (#52912).
- Added starter content for new TorchScript language reference (#53837).
- Added documentation for
torch.jit.Attribute
andtorch.jit.annotate
(#54485). - Updated TorchScript language reference section for types (#53673).
- Documented the TorchScript type system (#53244).
- Added language reference for Python builtin functions, statements, and values in TorchScript (#52847, #52830).
- Added
torch.*
API section for TorchScript language reference (#53236). - Added “Conditionals in TE” doc (#56949).
torch.package
- Added API reference (#55812, #56547).
- Add explanation, tutorial, and preamble sections for
torch.package
(#59833, #59503, #59499, #59491, #59842, #59843, #59602). - Add pickle security warning to package docs (#59959).
Quantization
- Added docs for storage and tensors for quantized Tensor (#51817).
- Fixed FX Graph Mode Quantization tutorial link (#54715).
- Added fx graph mode quant api doc (#55306).
- FX Graph Mode Quantization - fixed preamble (#52192).
- Fixed broken link to fx graph quant guide in quantization.rst (#56776).
Mobile
- Added doc string for lite interpreter related API in Android (#53136).
- Improved
export_opnames
Documentation (#52333).
Distributed
torch.distributed.Store
- Documentation for TCPStore’s
compare_set
API (#57203)
torch.distributed.optim
- Update distributed optimizer documentation (#58084)
- Update and expose ZeroRedundancyOptimizer docs (#53112, #53113)
torch.distributed.elastic
- Upstream
torchelastic
documentation to PyTorch. (#56811) - Revise the note section of RendezvousHandler doc (#57723)
- Update the rendezvous documentation (#57973)
DistributedDataParallel
- Add register_comm_hook API to DDP communication hooks documentation page (#51846,#51986)
- Enhance documentation around
DistributedDataParallel
uneven input support (#57448) - Enhance communication hook documentation (#58170, #58168, #53253, #53855, #53596,#53955, #54052. #55031)
torch.distributed.rpc
- Add a disclaimer about limited CUDA support in RPC (#58023)
-
torch.distributed.rpc
: Add a link to the tutorial in RemoteModule docstring (#57875) -
torch.distributed.rpc
: MentionedRemoteModule
in RPC documentation (#57876)
torch.distributed.nn.RemoteModule
- Add RemoteModule to master RPC docs. (#53084)
- Add
remote_parameters
andget_module_rref
to RemoteModule docs. (#54645)
torch.distributed.pipeline
- Enhance Pipe docs to explicitly mention RPC initialization. (#55187)
- Add tutorials to pipeline docs. (#55209)
torch.distributed
- Update documentation for
get_future
support (#58107) - Mention distributed profiling in documentation (#58286)
- Update distributed doc table for
alltoall
(#54277) - fix docstring signature in
all_reduce_multigpu
(#54665) -
torch.distributed
: Improve dist.new_group doc (#55660)