v2.4.0
版本发布时间: 2024-02-23 21:12:22
UKPLab/sentence-transformers最新发布版本:v3.0.1(2024-06-07 21:01:30)
This release introduces numerous notable features that are well worth learning about!
Install this version with
pip install sentence-transformers==2.4.0
MatryoshkaLoss (#2485)
Dense embedding models typically produce embeddings with a fixed size, such as 768 or 1024. All further computations (clustering, classification, semantic search, retrieval, reranking, etc.) must then be done on these full embeddings. Matryoshka Representation Learning revisits this idea, and proposes a solution to train embedding models whose embeddings are still useful after truncation to much smaller sizes. This allows for considerably faster (bulk) processing.
Training
Training using Matryoshka Representation Learning (MRL) is quite elementary: rather than applying some loss function on only the full-size embeddings, we also apply that same loss function on truncated portions of the embeddings. For example, if a model has an embedding dimension of 768 by default, it can now be trained on 768, 512, 256, 128, 64 and 32. Each of these losses will be added together, optionally with some weight:
from sentence_transformers import SentenceTransformer
from sentence_transformers.losses import CoSENTLoss, MatryoshkaLoss
model = SentenceTransformer("microsoft/mpnet-base")
base_loss = CoSENTLoss(model=model)
loss = MatryoshkaLoss(model=model, loss=base_loss, matryoshka_dims=[768, 512, 256, 128, 64])
-
Reference:
MatryoshkaLoss
Inference
Inference
After a model has been trained using a Matryoshka loss, you can then run inference with it using SentenceTransformers.encode
. You must then truncate the resulting embeddings, and it is recommended to renormalize the embeddings.
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
import torch.nn.functional as F
model = SentenceTransformer("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True)
matryoshka_dim = 64
embeddings = model.encode(
[
"search_query: What is TSNE?",
"search_document: t-distributed stochastic neighbor embedding (t-SNE) is a statistical method for visualizing high-dimensional data by giving each datapoint a location in a two or three-dimensional map.",
"search_document: Amelia Mary Earhart was an American aviation pioneer and writer.",
]
)
embeddings = embeddings[..., :matryoshka_dim] # Shrink the embedding dimensions
similarities = cos_sim(embeddings[0], embeddings[1:])
# => tensor([[0.7839, 0.4933]])
As you can see, the similarity between the search query and the correct document is much higher than that of an unrelated document, despite the very small matryoshka dimension applied. Feel free to copy this script locally, modify the matryoshka_dim
, and observe the difference in similarities.
Note: Despite the embeddings being smaller, training and inference of a Matryoshka model is not faster, not more memory-efficient, and not smaller. Only the processing and storage of the resulting embeddings will be faster and cheaper.
Extra information:
Example training scripts:
CoSENTLoss (#2454)
CoSENTLoss was introduced by Jianlin Su, 2022 as a drop-in replacement of CosineSimilarityLoss. Experiments have shown that it produces a stronger learning signal than CosineSimilarityLoss
.
from sentence_transformers import SentenceTransformer, losses
from sentence_transformers.readers import InputExample
model = SentenceTransformer('bert-base-uncased')
train_examples = [
InputExample(texts=['My first sentence', 'My second sentence'], label=1.0),
InputExample(texts=['My third sentence', 'Unrelated sentence'], label=0.3)
]
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=train_batch_size)
train_loss = losses.CoSENTLoss(model=model)
You can update training_stsbenchmark.py by replacing CosineSimilarityLoss
with CoSENTLoss
& you can observe the improved performance.
AnglELoss (#2471)
AnglELoss was introduced in Li and Li, 2023. It is an adaptation of the CoSENTLoss, and also acts as a strong drop-in replacement of CosineSimilarityLoss
. Compared to CoSENTLoss
, AnglELoss
uses a different similarity function which aims to avoid vanishing gradients.
Like with CoSENTLoss
, you can use it just like CosineSimilarityLoss
.
from sentence_transformers import SentenceTransformer, losses
from sentence_transformers.readers import InputExample
model = SentenceTransformer('bert-base-uncased')
train_examples = [
InputExample(texts=['My first sentence', 'My second sentence'], label=1.0),
InputExample(texts=['My third sentence', 'Unrelated sentence'], label=0.3)
]
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=train_batch_size)
train_loss = losses.AnglELoss(model=model)
You can update training_stsbenchmark.py by replacing CosineSimilarityLoss
with AnglELoss
& you can observe the improved performance.
Prompt Templates (#2477)
Some models require using specific text prompts to achieve optimal performance. For example, with intfloat/multilingual-e5-large you should prefix all queries with query: and all passages with passage: . Another example is BAAI/bge-large-en-v1.5, which performs best for retrieval when the input texts are prefixed with Represent this sentence for searching relevant passages: .
Sentence Transformer models can now be initialized with prompts
and default_prompt_name
parameters:
-
prompts
is an optional argument that accepts a dictionary of prompts with prompt names to prompt texts. The prompt will be prepended to the input text during inference. For example,model = SentenceTransformer( "intfloat/multilingual-e5-large", prompts={ "classification": "Classify the following text: ", "retrieval": "Retrieve semantically similar text: ", "clustering": "Identify the topic or theme based on the text: ", }, ) # or model.prompts = { "classification": "Classify the following text: ", "retrieval": "Retrieve semantically similar text: ", "clustering": "Identify the topic or theme based on the text: ", }
-
default_prompt_name
is an optional argument that determines the default prompt to be used. It has to correspond with a prompt name fromprompts
. IfNone
, then no prompt is used by default. For example,model = SentenceTransformer( "intfloat/multilingual-e5-large", prompts={ "classification": "Classify the following text: ", "retrieval": "Retrieve semantically similar text: ", "clustering": "Identify the topic or theme based on the text: ", }, default_prompt_name="retrieval", ) # or model.default_prompt_name="retrieval"
Both of these parameters can also be specified in the config_sentence_transformers.json
file of a saved model. That way, you won't have to specify these options manually when loading. When you save a Sentence Transformer model, these options will be automatically saved as well.
During inference, prompts can be applied in a few different ways. All of these scenarios result in identical texts being embedded:
- Explicitly using the
prompt
option inSentenceTransformer.encode
:embeddings = model.encode("How to bake a strawberry cake", prompt="Retrieve semantically similar text: ")
- Explicitly using the
prompt_name
option inSentenceTransformer.encode
by relying on the prompts loaded from a) initialization or b) the model config.embeddings = model.encode("How to bake a strawberry cake", prompt_name="retrieval")
- If
prompt
norprompt_name
are specified inSentenceTransformer.encode
, then the prompt specified bydefault_prompt_name
will be applied. If it isNone
, then no prompt will be applied.embeddings = model.encode("How to bake a strawberry cake")
Instructor support (#2477)
Some INSTRUCTOR models, such as hkunlp/instructor-large, are natively supported in Sentence Transformers. These models are special, as they are trained with instructions in mind. Notably, the primary difference between normal Sentence Transformer models and Instructor models is that the latter do not include the instructions themselves in the pooling step.
The following models work out of the box:
You can use these models like so:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("hkunlp/instructor-large")
embeddings = model.encode(
[
"Dynamical Scalar Degree of Freedom in Horava-Lifshitz Gravity",
"Comparison of Atmospheric Neutrino Flux Calculations at Low Energies",
"Fermion Bags in the Massive Gross-Neveu Model",
"QCD corrections to Associated t-tbar-H production at the Tevatron",
],
prompt="Represent the Medicine sentence for clustering: ",
)
print(embeddings.shape)
# => (4, 768)
Information Retrieval usage
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
model = SentenceTransformer("hkunlp/instructor-large")
query = "where is the food stored in a yam plant"
query_instruction = (
"Represent the Wikipedia question for retrieving supporting documents: "
)
corpus = [
'Yams are perennial herbaceous vines native to Africa, Asia, and the Americas and cultivated for the consumption of their starchy tubers in many temperate and tropical regions. The tubers themselves, also called "yams", come in a variety of forms owing to numerous cultivars and related species.',
"The disparate impact theory is especially controversial under the Fair Housing Act because the Act regulates many activities relating to housing, insurance, and mortgage loans—and some scholars have argued that the theory's use under the Fair Housing Act, combined with extensions of the Community Reinvestment Act, contributed to rise of sub-prime lending and the crash of the U.S. housing market and ensuing global economic recession",
"Disparate impact in United States labor law refers to practices in employment, housing, and other areas that adversely affect one group of people of a protected characteristic more than another, even though rules applied by employers or landlords are formally neutral. Although the protected classes vary by statute, most federal civil rights laws protect based on race, color, religion, national origin, and sex as protected traits, and some laws include disability status and other traits as well.",
]
corpus_instruction = "Represent the Wikipedia document for retrieval: "
query_embedding = model.encode(query, prompt=query_instruction)
corpus_embeddings = model.encode(corpus, prompt=corpus_instruction)
similarities = cos_sim(query_embedding, corpus_embeddings)
print(similarities)
# => tensor([[0.8835, 0.7037, 0.6970]])
All other Instructor models either 1) will not load as they refer to InstructorEmbedding
in their modules.json
or 2) require calling model.set_pooling_include_prompt(include_prompt=False)
after loading.
Reduce Dependencies (#2476)
Sentence Transformers now no longer uses sentencepiece
or nltk
as mandatory dependencies. This should make Sentence Transformers 1) lighter to install, 2) quicker to import and 3) less likely to result in dependency issues.
New Loss Overview documentation (#2447, #2496)
The documentation has been upgraded with a Loss Overview: https://sbert.net/docs/training/loss_overview.html This section contains tables with loss functions and their required data formats. This should help you narrow down which loss functions might suit your use cases:
Additionally, each loss function now has extended documentation, including references, requirements, relations to other loss functions, inputs, a code snippet with example usage, and/or links to other documentation/examples using that loss function.
All changes
- [
hotfix
] Don't require loading files for Normalize by @tomaarsen in https://github.com/UKPLab/sentence-transformers/pull/2460 - Created CoSENTLoss.py by @johneckberg in https://github.com/UKPLab/sentence-transformers/pull/2454
- [
fix
] Avoid sets and dicts inBinaryClassificationEvaluator
when sentences are not hashable by @tomaarsen in https://github.com/UKPLab/sentence-transformers/pull/2462 - [
docs
] Refactor & improve the loss documentation by @ir2718 in https://github.com/UKPLab/sentence-transformers/pull/2447 - CrossEncoder device by @milistu in https://github.com/UKPLab/sentence-transformers/pull/2463
- Update README.md by @akshaydevml in https://github.com/UKPLab/sentence-transformers/pull/2464
- Add support for Ascend NPU by @statelesshz in https://github.com/UKPLab/sentence-transformers/pull/2466
- Add revision parameter to CrossEncoder. by @fkdosilovic in https://github.com/UKPLab/sentence-transformers/pull/2479
- Add NDCG metric to CERerankingEvaluator by @milistu in https://github.com/UKPLab/sentence-transformers/pull/2478
- [
deps
] Remove the sentencepiece & nltk dependencies by @tomaarsen in https://github.com/UKPLab/sentence-transformers/pull/2476 - AnglE loss by @johneckberg in https://github.com/UKPLab/sentence-transformers/pull/2471
- [
ci
] On Ubuntu CI runner, use temporary directories as cache folders for some models by @tomaarsen in https://github.com/UKPLab/sentence-transformers/pull/2481 - [
docs
] Slight improvements to docs phrasing by @tomaarsen in https://github.com/UKPLab/sentence-transformers/pull/2486 - Ensure dtype consistency in Pooling forward method by @EliasKassapis in https://github.com/UKPLab/sentence-transformers/pull/2492
- [
feat
] Add Matryoshka loss + examples + docs by @tomaarsen in https://github.com/UKPLab/sentence-transformers/pull/2485 - [
feat
] Add prompt templates by @tomaarsen in https://github.com/UKPLab/sentence-transformers/pull/2477 - [
docs
] Move loss overview to "main" documentation by @tomaarsen in https://github.com/UKPLab/sentence-transformers/pull/2496 - [
feat
] Allow saving a model to the Hub without providing a user + Upload Matryoshka models after training by @tomaarsen in https://github.com/UKPLab/sentence-transformers/pull/2497 - Add F1 score evaluator for CrossEncoder. by @fkdosilovic in https://github.com/UKPLab/sentence-transformers/pull/2493
- [
docs
] Address some small mistakes by @tomaarsen in https://github.com/UKPLab/sentence-transformers/pull/2498
New Contributors
- @johneckberg made their first contribution in https://github.com/UKPLab/sentence-transformers/pull/2454
- @ir2718 made their first contribution in https://github.com/UKPLab/sentence-transformers/pull/2447
- @akshaydevml made their first contribution in https://github.com/UKPLab/sentence-transformers/pull/2464
- @fkdosilovic made their first contribution in https://github.com/UKPLab/sentence-transformers/pull/2479
- @EliasKassapis made their first contribution in https://github.com/UKPLab/sentence-transformers/pull/2492
I especially want to thank @ir2718, @johneckberg & @SeanLee97 for their valuable contributions in this release, and @fkdosilovic and @milistu for their valuable improvements to the CrossEncoder.
Full Changelog: https://github.com/UKPLab/sentence-transformers/compare/v2.3.1...v2.4.0