MyGit

0.4.3

kyegomez/LongNet

版本发布时间: 2023-08-11 02:30:43

kyegomez/LongNet最新发布版本:0.4.8(2023-08-11 03:04:14)

Changelog:

  1. Tensor Shape Adjustments:

    • Ensured the consistent shape of tensors across all operations.

    • Squeezed a_indices to 2D to match dimensions of att_denom_sums.

      a_indices = a_indices[:, :, 0].squeeze(-1).squeeze(-1)
      
    • Sliced a_indices to the unpadded sequence length before scattering.

      a_indices = a_indices[:, :unpadded_seq_len]
      
  2. Scatter and Gather Operations:

    • Scatter with squeezed 2D a_indices and gather sparse sums with these indices.

      att_denom_sums.scatter_add_(1, a_indices, a_denoms)
      sparse_att_denom_sum = torch.gather(att_denom_sums, 1, a_indices)
      
  3. DataType Handling:

    • Converted the 'sparse indices' tensors to torch.int64 (or torch.long) to ensure compatibility with PyTorch's indexing operations.
    • Retained the torch.float16 dtype for the 'X' tensor to make it memory-efficient.
  4. Code Cleaning:

    • Removed repeated lines that print the shape and datatype of "sparse indices" to declutter the code.
    • Standardized debug print statements to have a consistent format.
    • Print shapes of tensors before scattering to verify dimensions match.
    • Added comments explaining dimension squeezing, slicing, and other adjustments for clarity.
  5. Validation Checks:

    • Added checks to ensure tensors are on the same device (either all on CPU or all on CUDA).
    • Checked whether the size of the tensor 'X' matches the expected shape before operations.
  6. Enhanced Error Messages:

    • Improved the debug error messages to be more descriptive.
  7. Optimizations:

    • Removed unnecessary tensor operations that don't contribute to the final result.
    • Optimized tensor slicing and indexing operations to be more memory efficient.
  8. Edge Case Handling:

    • Handled the edge case of negative head_idx.
  9. Other Minor Fixes:

    • Ensured that the code uses math or memory-efficient attention only if the input tensor is on CUDA and a non-A100 GPU is detected.
    • Made sure tensor operations are consistent with PyTorch best practices.
  10. Documentation:

相关地址:原始地址 下载(tar) 下载(zip)

查看:2023-08-11发行的版本