MyGit

v1.13.0

pytorch/pytorch

版本发布时间: 2022-10-29 00:54:00

pytorch/pytorch最新发布版本:v2.3.0(2024-04-25 00:12:17)

Pytorch 1.13 Release Notes

Highlights

We are excited to announce the release of PyTorch 1.13! This includes stable versions of BetterTransformer. We deprecated CUDA 10.2 and 11.3 and completed migration of CUDA 11.6 and 11.7. Beta includes improved support for Apple M1 chips and functorch, a library that offers composable vmap (vectorization) and autodiff transforms, being included in-tree with the PyTorch release. This release is composed of over 3,749 commits and 467 contributors since 1.12.1. We want to sincerely thank our dedicated community for your contributions.

Summary:

Stable Beta Prototype
  • Better Transformer
  • CUDA 10.2 and 11.3 CI/CD Deprecation
  • Enable Intel® VTune™ Profiler's Instrumentation and Tracing Technology APIs
  • Extend NNC to support channels last and bf16
  • Functorch now in PyTorch Core Library
  • Beta Support for M1 devices
  • Arm® Compute Library backend support for AWS Graviton
  • CUDA Sanitizer

You can check the blogpost that shows the new features here.

Backwards Incompatible changes

Python API

uint8 and all integer dtype masks are no longer allowed in Transformer (#87106)

Prior to 1.13, key_padding_mask could be set to uint8 or other integer dtypes in TransformerEncoder and MultiheadAttention, which might generate unexpected results. In this release, these dtypes are not allowed for the mask anymore. Please convert them to torch.bool before using.

1.12.1

>>> layer = nn.TransformerEncoderLayer(2, 4, 2)
>>> encoder = nn.TransformerEncoder(layer, 2)
>>> pad_mask = torch.tensor([[1, 1, 0, 0]], dtype=torch.uint8)
>>> inputs = torch.cat([torch.randn(1, 2, 2), torch.zeros(1, 2, 2)], dim=1)
# works before 1.13
>>> outputs = encoder(inputs, src_key_padding_mask=pad_mask)

1.13

>>> layer = nn.TransformerEncoderLayer(2, 4, 2)
>>> encoder = nn.TransformerEncoder(layer, 2)
>>> pad_mask = torch.tensor([[1, 1, 0, 0]], dtype=torch.bool)
>>> inputs = torch.cat([torch.randn(1, 2, 2), torch.zeros(1, 2, 2)], dim=1)
>>> outputs = encoder(inputs, src_key_padding_mask=pad_mask)

Updated torch.floor_divide to perform floor division (#78411)

Prior to 1.13, torch.floor_divide erroneously performed truncation division (i.e. truncated the quotients). In this release, it has been fixed to perform floor division. To replicate the old behavior, use torch.div with rounding_mode='trunc'.

1.12.1

>>> a = torch.tensor([4.0, -3.0])
>>> b = torch.tensor([2.0, 2.0])
>>> torch.floor_divide(a, b)
tensor([ 2., -1.])

1.13

>>> a = torch.tensor([4.0, -3.0])
>>> b = torch.tensor([2.0, 2.0])
>>> torch.floor_divide(a, b)
tensor([ 2., -2.])
# Old behavior can be replicated using torch.div with rounding_mode='trunc'
>>> torch.div(a, b, rounding_mode='trunc')
tensor([ 2., -1.])

Fixed torch.index_select on CPU to error that index is out of bounds when the source tensor is empty (#77881)

Prior to 1.13, torch.index_select would return an appropriately sized tensor filled with random values on CPU if the source tensor was empty. In this release, we have fixed this bug so that it errors out. A consequence of this is that torch.nn.Embedding which utilizes index_select will error out rather than returning an empty tensor when embedding_dim=0 and input contains indices which are out of bounds. The old behavior cannot be reproduced with torch.nn.Embedding, however since an Embedding layer with embedding_dim=0 is a corner case this behavior is unlikely to be relied upon.

1.12.1

>>> t = torch.tensor([4], dtype=torch.long)
>>> embedding = torch.nn.Embedding(3, 0)
>>> embedding(t)
tensor([], size=(1, 0), grad_fn=<EmbeddingBackward0>)

1.13

>>> t = torch.tensor([4], dtype=torch.long)
>>> embedding = torch.nn.Embedding(3, 0)
>>> embedding(t)
RuntimeError: INDICES element is out of DATA bounds, id=4 axis_dim=3

Disallow overflows when tensors are constructed from scalars (#82329)

Prior to this PR, overflows during tensor construction from scalars would not throw an error. In 1.13, such cases will error.

1.12.1

>>> torch.tensor(1000, dtype=torch.int8)
tensor(-24, dtype=torch.int8)

1.13

>>> torch.tensor(1000, dtype=torch.int8)
RuntimeError: value cannnot be converted to type int8 without overflow

Error on indexing a cpu tensor with non-cpu indices (#69607)

Prior to 1.13, cpu_tensor[cuda_indices] was a valid program that would return a cpu tensor. The original use case for mixed device indexing was for non_cpu_tensor[cpu_indices], and allowing the opposite was unintentional (cpu_tensor[non_cpu_indices]). This behavior appears to be rarely used, and a refactor of our indexing kernels made it difficult to represent an op that takes in (cpu_tensor, non_cpu_tensor) and returns another cpu_tensor, so it is now an error.

To replicate the old behavior for base[indices], you can ensure that either indices lives on the CPU device, or base and indices both live on the same device.

1.12.1

>>> a = torch.tensor([1.0, 2.0, 3.0])
>>> b = torch.tensor([0, 2], device='cuda')
>>> a[b]
tensor([1., 3.])

1.13

>>> a = torch.tensor([1.0, 2.0, 3.0])
>>> b = torch.tensor([0, 2], device='cuda')
>>> a[b]
RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)
# Old behavior can be replicated by moving b to CPU, or a to CUDA
>>> a[b.cpu()]
tensor([1., 3.])
>>> a.cuda()[b]
tensor([1., 3.], device='cuda:0')

Remove deprecated torch.eig, torch.matrix_rank, torch.lstsq (#70982, #70981, #70980)

The deprecation cycle for the above functions has been completed and they have been removed in the 1.13 release.

torch.nn

Enforce that the bias has the same dtype as input and weight for convolutions on CPU (#83686)

To align with the implementation on other devices, the CPU implementation for convolutions was updated to enforce that the dtype of the bias matches the dtype of the input and weight.

1.12.1

# input and weight are dtype torch.int64
# bias is torch.float32
>>> out = torch.nn.functional.conv2d(input, weight, bias, ...)

1.13

# input and weight are dtype torch.int64
# bias is torch.float32
>>> with assertRaisesError():
>>>    out = torch.nn.functional.conv2d(input, weight, bias, ...)

# Updated code to avoid the error
>>> out = torch.nn.functional.conv2d(input, weight, bias.to(input.dtype), ...)

Autograd

Disallow setting the .data of a tensor that requires_grad=True with an integer tensor (#78436)

Setting the .data of a tensor that requires_grad with an integer tensor now raises an error.

1.12.1

>>> x = torch.randn(2, requires_grad=True)
>>> x.data = torch.randint(1, (2,))
>>> x
tensor([0, 0], requires_grad=True)

1.13

>>> x = torch.randn(2, requires_grad=True)
>>> x.data = torch.randint(1, (2,))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: data set to a tensor that requires gradients must be floating point or complex dtype

Added variable_list support to ExtractVariables struct (#84583)

Prior to this change, C++ custom autograd Function considers tensors passed in TensorList to not be tensors for the purposes of recording the backward graph. After this change, custom Functions that receive TensorList must modify their backward functions to also compute gradients for these additional tensor inputs. Note that this behavior now differs from that of custom autograd Functions in Python.

1.12.1

struct MyFunction : public Function<MyFunction> {
    static Variable forward(AutogradContext* ctx, at::Tensor t, at::TensorList tensors) {
      return 2 * tensors[0] + 3 * t;
    }

    static variable_list backward(
        AutogradContext* ctx,
        variable_list grad_output) {
      return {3 * grad_output[0]};
    }
};

1.13

struct MyFunction : public Function<MyFunction> {
    static Variable forward(AutogradContext* ctx, at::Tensor t, at::TensorList tensors) {
      return 2 * tensors[0] + 3 * t;
    }

    static variable_list backward(
        AutogradContext* ctx,
        variable_list grad_output) {
      return {3 * grad_output[0], 2 * grad_output[0]};
    }
};

Don't detach when making views; force kernel to detach (#84893)

View operations registered as CompositeExplicitAutograd kernels are no longer allowed to return input tensors as-is. You must explicitly create a new tensor (e.g., using .alias()).

1.12.1

torch::Tensor view_op(const torch::Tensor& self) {
  return self;
}

1.13

torch::Tensor view_op(const torch::Tensor& self) {
  return self.alias();
}

ONNX

torch.onnx.register_custom_op_symbolic now only registers the symbolic function at the specified opset version (#85636)

This updates register_custom_op_symbolic's behavior to only register the symbolic function at a single version. This is more aligned with the semantics of the API signature. Previously the API registers a symbolic function to all versions up to the specified version. As a result of this change, users will need to register a symbolic function to the exact version when they want to override an existing symbolic function. Users are not affected if (1) an implementation does not exist for the op, or (2) the symbolic function is already registering to the exact version for export.

1.12.1

# Assuming an implemented symbolic function `custom_op_function`
torch.onnx.register_custom_op_symbolic("aten::foo", custom_op_function, 16)

1.13

# Assuming an implemented symbolic function `custom_op_function`
for opset in range(1, 17):
    torch.onnx.register_custom_op_symbolic("aten::foo", custom_op_function, opset)

Default ONNX opset is updated to 14 (#83284)

The update is done in regularly to ensure we are in sync with the onnx updates. Users can specify opset_version in torch.onnx.export to maintain opset version 13.

torch.onnx.symbolic_registry is removed (#84382)

We removed the symbolic_registry module and hid it as an internal implementation detail. Users previously relying on the register_op function to register custom symbolic functions should move to use the torch.onnx.register_custom_op_symbolic API.

ScalarType and global variables in torch.onnx.symbolic_helper are removed (#82995)

The ScalarType class in torch.onnx.symbolic_helper, along with the global variables cast_pytorch_to_onnx, pytorch_name_to_type, scalar_name_to_pytorch, scalar_type_to_onnx and scalar_type_to_pytorch_type are removed from the module. Users previously using these global variables for PyTorch JIT-ONNX type conversion in symbolic functions should move to use the torch.onnx.JitScalarType class.

1.12.1

# 1
torch.onnx.symbolic_helper.scalar_type_to_onnx[
    symbolic_helper.scalar_type_to_pytorch_type.index(x.dtype)
].value

# 2
torch.onnx.symbolic_helper.scalar_name_to_pytorch[element_type] in cast_pytorch_to_onnx.keys()

# 3
torch.onnx.symbolic_helper.cast_pytorch_to_onnx["Long"]

# 4
torch.onnx.symbolic_helper.cast_pytorch_to_onnx[tensor.type().scalarType()]

1.13

# 1
torch.onnx.JitScalarType.from_dtype(x.dtype).onnx_type()

# 2
torch.onnx.JitScalarType.from_name(element_type).onnx_compatible()

# 3
torch.onnx.TensorProtoDataType.INT64

# 4
torch.onnx.JitScalarType.from_name(tensor.type().scalarType()).onnx_type()

Distributed

In c10d collectives, input tensors dtype must now be the same (#84664)

We added a check to validate all dtype across all input tensors. Previously, users were allowed to pass in tensors with diferent dtypes for c10d collectives. Now, passing in tensors with different dtypes will throw a RuntimeError with the following message: “Invalid usage of tensors with different dtypes Found torch.float and torch.half”. Users can use tensor.to(dtype={some_dtype}) to fix this.

1.12.1

# users could pass inputs having different dtypes
>>> tensor = torch.ones(2, 2) * 7
>>> tensor_h = tensor.half()
>>> tensor_list = [torch.zeros(2, 2) for _ in range(4)] # Assume world_size = 4
# Both cases work.
>>> dist.all_gather(tensor_list, tensor)
>>> dist.all_gather(tensor_list, tensor_h)
...

1.13

# all inputs of c10d collectives need to have the same dtype
>>> tensor = torch.ones(2, 2) * 7
>>> tensor_h = tensor.half()
>>> tensor_list = [torch.zeros(2, 2) for _ in range(4)] # Assume world_size = 4
# Only allow same dtype for all input tensors.
>>> dist.all_gather(tensor_list, tensor) # RuntimeError thrown
...

Users doing wildcard imports of torch.distributed.distributed_c10d will no longer get non-public symbols (#84872)

We limit the usage of c10d APIs to public APIs, so if a user does a wildcard import and calls an internal API, it will fail. Please see the example below:

1.12.1

# users could import both public and non-public symbols:
from torch.distributed.distributed_c10d import *
>>> is_nccl_available() # public API
>>> _check_single_tensor(...) # Non-public API
...

1.13

# users can only import public symbols
from torch.distributed.distributed_c10d import *
is_nccl_available() # public API
_check_single_tensor(...) # Non-public API, this will fail now
...

Process Group C++ extensions must use absolute path when importing ProcessGroup.hpp (#86257), ProcessGroup::Work object moved out of work to its own Work class (#83680):

Details of the changes and the updated tutorial can be found in the PyTorch tutorial PR #2099

1.12.1

// users use relative path to import C++ headers and Work resides in ProcessGroup class
#include <c10d/ProcessGroup.hpp>
#include <c10d/Store.hpp>
#include <c10d/Types.hpp>
#include <c10d/Utils.hpp>
...
class WorkDummy : public ProcessGroup::Work {
    ...
}

1.13

// users must use absolute path of import C++ files and Work is its own class
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/c10d/Store.hpp>
#include <torch/csrc/distributed/c10d/Types.hpp>
#include <torch/csrc/distributed/c10d/Utils.hpp>
...
#include <torch/csrc/distributed/c10d/Work.hpp>
class WorkDummy : public Work {
    ...
}

Quantization

Add required example_args argument to prepare_fx and prepare_qat_fx (#249) (#77608)

We added an additional required example_inputs argument to prepare_fx and prepare_qat_fx APIs, this can be used to do type inference to figure out the type information for each of the fx Node in the graph.

1.12.1

m = resnet18(...)
m = prepare_fx(m, qconfig_dict)
# or
m = prepare_qat_fx(m, qconfig_dict)

1.13

m = resnet18(...)
m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 224, 224),))
# or
m = prepare_qat_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 224, 224),))

Stop moving models to CPU in quantization convert (#80555)

Previously, we automatically moved the model to CPU in torch.ao.quantization.fx.convert to work around the issue where certain functions called by convert expect CPU arguments. This commit pushes this responsibility to the caller since it is the user's decision of which device to use.

1.12.1

model = resnet18(...)
model = prepare_fx(model, qconfig_mapping, example_inputs)
# calibrate
model = convert_fx(model)

1.13

model = resnet18(...)
model.cpu()  # if needed
model = prepare_fx(model, qconfig_mapping, example_inputs)
# calibrate
model = convert_fx(model)

Replace the is_reference flag of the torch.ao.quantize_fx.convert_fx function with the convert_to_reference function (#80091, #81326)

This PR removes the is_reference flag from the existing convert_fx API and replaces it with a new convert_to_reference function. This separates (1) converting the prepared model to a reference model from (2) lowering the reference model to a quantized model, enabling users to call their custom lowering function for custom backends.

1.12.1

from torch.ao.quantization.quantize_fx import (
    prepare_fx,
    convert_to_reference,
)

prepared = prepare_fx(model, ...)
reference = convert_to_reference(prepared, ...)

1.13

from torch.ao.quantization.quantize_fx import (
    prepare_fx,
    convert_to_reference_fx,
)

prepared = prepare_fx(model, ...)
reference = convert_to_reference_fx(prepared, ...)

Add default configs for fixed qparams ops (#80184)

This commit adds qconfigs with special observers for fixed qparams ops (operators whose corresponding quantized version has fixed quantized parameters for output) like sigmoid in get_default_qconfig_mapping and get_default_qat_qconfig_mapping. For correctness, we also require users to use these special observers if we detect these fixed qparams ops in prepare.

1.12.1 (fails after this PR):

from torch.ao.quantization.quantize_fx import prepare_fx

model = ModelWithFixedQParamsOps()
qconfig_mapping = QConfigMapping()
example_inputs = ...
prepare_fx(model, qconfig_mapping, example_inputs)

1.13

from torch.ao.quantization import get_default_qconfig_mapping
from torch.ao.quantization.quantize_fx import prepare_fx

model = ModelWithFixedQParamsOps()
qconfig_mapping = get_default_qconfig_mapping()
example_inputs = ...
prepare_fx(model, qconfig_mapping, example_inputs)

Replace qconfig_dict with a typed QConfigMapping object (#78452, #79618)

Previously, FX graph mode quantization configurations were specified through a dictionary of qconfigs. However, this API was not in line with other core APIs in PyTorch. This commit replaces this dictionary with a config object that users will create and pass to prepare and convert. This leads to better type safety and better user experience in notebook settings due to improved auto completion.

1.12.1 (deprecated)

from torch.ao.quantization.quantize_fx import prepare_fx

qconfig_dict = {
    "": qconfig,
    "object_type": [
        (torch.nn.Linear, qconfig),
    ],
    "module_name_regex": [
        ("foo.*bar", qconfig),
    ],
    "module_name": [
        ("mod", qconfig),
    ],
}

prepare_fx(model, qconfig_dict)

1.13

from torch.ao.quantization import QConfigMapping
from torch.ao.quantization.quantize_fx import prepare_fx

qconfig_mapping = QConfigMapping()
    .set_global(qconfig)
    .set_object_type(torch.nn.Linear, qconfig)
    .set_module_name_regex("foo.*bar", qconfig)
    .set_module_name("mod", qconfig)

prepare_fx(model, qconfig_mapping)

Replace *custom_config_dict with typed config objects (#79066)

This commit replaces the following config dicts with python objects:

This leads to better type safety and better user experience in notebook settings due to improved auto completion. 1.12.1

from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx

prepare_custom_config_dict = {
  "float_to_observed_custom_module_class": {
     "static": {
         FloatClass: ObservedClass
     }
  },
  "non_traceable_module_name": ["mod1", "mod2"],
  "non_traceable_module_class": [class1, class2],
  "input_quantized_idxs": [0, 1],
  "output_quantized_idxs": [0],
  "preserved_attributes": ["attr1", "attr2"],
}

convert_custom_config_dict = {
  "observed_to_quantized_custom_module_class": {
     "static": {
         FloatClass: ObservedClass
     }
  },
  "preserved_attributes": ["attr1", "attr2"],
}

model = prepare_fx(
    model,
    qconfig_mapping,
    example_inputs,
    prepare_custom_config_dict=prepare_custom_config_dict)

model(data)

model = convert_fx(model, convert_custom_config_dict=convert_custom_config_dict)

1.13

from torch.ao.quantization.fx.custom_config import (
    PrepareCustomConfig,
    ConvertCustomConfig,
)
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx

prepare_custom_config = PrepareCustomConfig() \
    .set_float_to_observed_mapping(float_class, observed_class) \
    .set_non_traceable_module_names(["mod1", "mod2"]) \
    .set_non_traceable_module_classes([class1, class2]) \
    .set_input_quantized_indexes([0, 1]) \
    .set_output_quantized_indexes([0]) \
    .set_preserved_attributes(["attr1", "attr2"])

convert_custom_config = ConvertCustomConfig() \
    .set_observed_to_quantized_mapping(observed_class, quantized_class) \
    .set_preserved_attributes(["attr1", "attr2"])

model = prepare_fx(
    model,
    qconfig_mapping,
    example_inputs,
    prepare_custom_config=prepare_custom_config)

model(data)

model = convert_fx(model, convert_custom_config=convert_custom_config)

Remove remove_quant_dequant_pairs and fix tests (#84203)

This PR removed some passes in convert_fx, and also fixes the way we quantize layer_norm operator, so the qconfig for layer_norm op needs to be updated as well.

1.12.1

import torch
from torch.ao.quantization.qconfig_mapping import QConfigMapping, QConfig
from torch.ao.quantization.observer import default_weight_observer
from torch.ao.quantization.backend_config import (
    DTypeConfig,
    ObservationType,
)
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx

qconfig = QConfig(activation=qconfig.activation, weight=default_weight_observer)
qconfig_mapping = QConfigMapping().set_object_type(torch.nn.LayerNorm, q_config) \
.set_object_type(torch.nn.functional.layer_norm, q_config)

# assuming mymodel contains a LayerNorm layer or torch.nn.functional.layer_norm
m = MyModel()
example_inputs = (torch.rand(3, 3),)
m = prepare_fx(m, qconfig_mapping, example_inputs)

1.13

import torch
from torch.ao.quantization.qconfig_mapping import QConfigMapping, QConfig
from torch.ao.quantization.observer import default_placeholder_observer
from torch.ao.quantization.backend_config import (
    DTypeConfig,
    ObservationType,
)
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx

qconfig = QConfig(activation=qconfig.activation, weight=default_placeholder_observer)
qconfig_mapping = QConfigMapping().set_object_type(torch.nn.LayerNorm, q_config) \
.set_object_type(torch.nn.functional.layer_norm, q_config)

# assuming mymodel contains a LayerNorm layer or torch.nn.functional.layer_norm
m = MyModel()
example_inputs = (torch.rand(3, 3),)
m = prepare_fx(m, qconfig_mapping, example_inputs)

Align observer dtype with reference model spec (#85345)

Before this PR, the dtype attribute of observers was not clearly defined. It originally meant interface_dtype in the eager mode workflow, which is how the codebase before this PR is using it. In the new reference model spec, dtype attribute of an observer represents the dtype value which needs to be passed into a quantize function in the reference model spec. This PR aligns the codebase to this definition of dtype.

1.12.1

dynamic_quant_observer = PlaceholderObserver.with_args(
    dtype=torch.float, compute_dtype=torch.quint8)

1.13

dynamic_quant_observer = PlaceholderObserver.with_args(
    dtype=torch.quint8, compute_dtype=torch.quint8)

Composability

Changed the backend C++ kernel representation for some operators that take in lists of tensors (#73350)

If an operator in ATen takes in a list of tensors, and is marked as “structured” in native_functions.yaml (example), then previously, TensorList was represented as at::TensorList, or c10::ArrayRef<at::Tensor>. Now, it is represented as a more efficient type: const ITensorListRef&.

1.12.1

at::Tensor cat_kernel(at::TensorList tensors,int64_t dim) {
    ...
}
TORCH_LIBRARY_IMPL(aten, dispatch_key, m) {
    ...
    m.impl("cat", &cat_kernel);
}

1.13

at::Tensor cat_kernel(const at::ITensorListRef& tensors,int64_t dim) {
    ...
}
TORCH_LIBRARY_IMPL(aten, dispatch_key, m) {
    ...
    m.impl("cat", &cat_kernel);
}

C++ API

Lowered randint default dtype to the C++ API (#81410)

Prior to 1.13, the default for the dtype argument of torch.randint, torch.long, was set via manual python binding. However, in the C++ API, torch::randint would default to the global default data type, which is usually float. In 1.13 we changed the default for dtype in the C++ API to int64 in order to match the python API. To reproduce the old behavior, one can set the dtype argument.

1.12.1

torch::randint(/*low=*/0, /*high=*/10, {2, 3});

1.13

// assuming default dtype is float
torch::randint(/*low=*/0, /*high=*/10, {2, 3}, torch::kFloat);

Enabled dim=None for torch.{std, var, std_mean, var_mean} (#81845, #82765, #82912)

Prior to 1.13, a C++ API call that has argument types torch::{std, var, std_mean, var_mean}(Tensor, OptionalIntArrayRef, int64_t, bool) used to resolve to the {std, var, std_mean, var_mean}.correction overload. In this release, it resolves to the {std, var, std_mean, var_mean}.dim overload. With the .correction overload, the third argument of type int64_t could be used to pass a correction δN other than 1. In order to call the {std, var, std_mean, var_mean}.correction overload in 1.13, the old int64_t argument can be wrapped in a c10::optional.

1.12.1

// using std as an example
int64_t correction = 2;
torch::std(t, /*dim=*/dim, /*correction=*/correction, /*keepdim=*/True);

1.13

// To replicate in 1.13 using std as an example
auto correction = c10::make_optional<int64_t>(2);
torch::std(t, /*dim=*/dim, /*correction=*/correction, /*keepdim=*/True);

Deprecations

Distributed

We are deprecating the following APIs of c10d: *_coalesced APIs (#85959), *_multigpu APIs (#85961) and ProcessGroupRoundRobin (#85158)

We added warnings when users call c10d’s *_coalesced, *_multigpu and ProcessGroupRoundRobin APIs. Previously, users can use these APIs without any warnings but now they will see warnings like “torch.distributed.all_reduce_coalesced will be deprecated. If you must use it, please revisit our documentation later at https://pytorch.org/docs/master/distributed.html#collective-functions”. There are still workarounds for *_coalesced APIs but no workarounds will be provided for the other two.

1.12.1

# users could use the following APIs with no warnings:
all_reduce_coalesced(...)
all_gather_coalesced(...)
broadcast_multigpu(...)
all_reduce_multigpu(...)
reduce_multigpu(...)
all_gather_multigpu(...)
reduce_scatter_multigpu(...)
...

1.13

# users can still use these APIs but it will come with warnings:
all_reduce_coalesced(...)
# Warnings:
# torch.distributed.all_reduce_coalesced will be deprecated. If you must
# use it, please revisit our documentation later at
# https://pytorch.org/docs/master/distributed.html#collective-functions"

# Potential workaround:
reqs = []
with dist._coalescing_manager(group, reqs):
    reqs.append(dist.all_reduce(tensor1, async_op=True))
    reqs.append(dist.all_reduce(tensor2, async_op=True))
for req in reqs:
    req.wait()
...

We are deprecating passing optim_input into the FSDP optimizer state checkpointing APIs. The user can simply not pass the optim_input argument, and all behavior is preserved. No fix is needed from users side for now.

1.12.1

# the user can use the following APIs with no warnings
full_optim_state_dict(...)
sharded_optim_state_dict(...)
shard_full_optim_state_dict(...)
flatten_sharded_optim_state_dict(...)
scatter_full_optim_state_dict(...)
rekey_optim_state_dict(...)

1.13

# users can still use these APIs, but they will come with warnings
# The `optim_input` argument is deprecated and will be removed after PyTorch 1.13.
# You may remove it from your code without changing its functionality.

LinAlg

Deprecate torch.lu in favor of linalg.lu_factor (#77636)

The new operation has a cleaner API and better docs. The update rule is as follows:

1.12.1

LU2, pivots2, info = torch.lu(A, compute_pivots, get_infos=True)
LU1, pivots1, info = torch.lu(A, compute_pivots)

1.13

LU2, pivots2, info = torch.linalg.lu_factor_ex(A, compute_pivots)
LU1, pivots1 = torch.linalg.lu_factor(A, compute_pivots)

Deprecate torch.lu_solve in favor of linalg.lu_solve(#77637)

The new operation has a notation consistent with linalg.solve, and has an extra parameter adjoint=False. The update rule is as follows:

1.12.1

X = torch.lu_solve(B, LU, pivots)

1.13

X = linalg.lu_solve(LU, pivots, B)

ONNX

Monkey patched convenience method on torch._C.Graph, torch._C.Block and torch._C.Node are deprecated. (#83006)

Deprecated methods include Graph.op(), Graph.constant(), Graph.at(), Block.op(), and Node.__getitem__(). Previously, these methods are patched into the classes above when users call torch.onnx.export() and are typically used in custom symbolic functions. Users can continue to expect g.op() and g.at() in symbolic functions to work. The g parameter has been substituted by the GraphContext object (#84728). The methods are now exposed by the GraphContext class with APIs unchanged. Users should not rely on the Graph.op(), Graph.constant(), Graph.at(), Block.op(), Node.__getitem__() methods when they are directly interacting with the C classes. Users should use only the op() and at() methods of the GraphContext object, as other fields in the class will change in future releases.

New features

Python API

Build

Complex

torch.nn

torch.optim

BetterTransformer

ForEach

LinAlg

Sparse

torch.fx

JIT

ONNX

AMD

CUDA

Intel

MPS

Profiler

Vulkan

Mobile

Distributed

Distributed Checkpointing (Prototyping)

Distributed(c10d)

DistributedDataParallel

FullyShardedDataParallel

torch.distributed.elastic

Activation Memory Management (Prototyping)

Infra (RelEng)

Improvements

Python API

C++ API

Autograd

Build

torch.nn

torch.optim

Composability

Dataloader

Functorch

LinAlg

Sparse

torch.fx

JIT

Quantization

ONNX

AMD

CUDA

Intel

MPS

Profiler

Vulkan

Mobile

Distributed

Distributed(c10d)

Distributed Optimizer

DistributedDataParallel

FullyShardedDataParallel

torch.distributed.elastic

Infra (RelEng)

Bug fixes

Python API

C++ API

Autograd

Build

Complex

torch.nn

torch.optim

BetterTransformer

Composability

Dataloader

Functorch

LinAlg

Sparse

torch.fx

JIT

Quantization

ONNX

AMD

CUDA

Intel

MPS

Package

Profiler

Visualization

Vulkan

Mobile

Distributed

Distributed(c10d)

DistributedDataParallel

FullyShardedDataParallel

torch.distributed.elastic

torch.distributed.rpc

Infra (RelEng)

Performance

Python API

Autograd

torch.nn

BetterTransformer

Composability

Dataloader

LinAlg

Sparse

JIT

Quantization

CUDA

Intel

MPS

Profiler

Vulkan

Mobile

Documentation

Python API

Autograd

Complex

torch.nn

torch.optim

Composability

Functorch

LinAlg

Sparse

torch.fx

Quantization

ONNX

CUDA

MPS

Package

Distributed

Distributed(c10d)

DistributedDataParallel

FullyShardedDataParallel

torch.distributed.rpc

Infra (RelEng)

Developers

Autograd

Build

Composability

torch.fx

Quantization

ONNX

Intel

MPS

Profiler

Vulkan

Distributed

torch.distributed

torch.distributed.elastic

Distributed(c10d)

Infra (RelEng)

相关地址:原始地址 下载(tar) 下载(zip)

1、 pytorch-v1.13.0.tar.gz 223.32MB

查看:2022-10-29发行的版本