Skip to content

Commit

Permalink
Merge branch 'main' into v1.0-release
Browse files Browse the repository at this point in the history
  • Loading branch information
tomaarsen committed Jan 11, 2024
2 parents f4d9f5e + f387387 commit 9729aa0
Show file tree
Hide file tree
Showing 10 changed files with 287 additions and 28 deletions.
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<img src="https://raw.githubusercontent.com/huggingface/setfit/main/assets/setfit.png">

<p align="center">
🤗 <a href="https://huggingface.co/setfit" target="_blank">Models & Datasets</a> | 📕 <a href="https://huggingface.co/docs/setfit" target="_blank">Documentation</a> | 📖 <a href="https://huggingface.co/blog/setfit" target="_blank">Blog</a> | 📃 <a href="https://arxiv.org/abs/2209.11055" target="_blank">Paper</a>
🤗 <a href="https://huggingface.co/models?library=setfit" target="_blank">Models</a> | 📊 <a href="https://huggingface.co/setfit" target="_blank">Datasets</a> | 📕 <a href="https://huggingface.co/docs/setfit" target="_blank">Documentation</a> | 📖 <a href="https://huggingface.co/blog/setfit" target="_blank">Blog</a> | 📃 <a href="https://arxiv.org/abs/2209.11055" target="_blank">Paper</a>
</p>

# SetFit - Efficient Few-shot Learning with Sentence Transformers
Expand Down Expand Up @@ -61,7 +61,10 @@ eval_dataset = dataset["validation"].select(range(100))
test_dataset = dataset["validation"].select(range(100, len(dataset["validation"])))

# Load a SetFit model from Hub
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")
model = SetFitModel.from_pretrained(
"sentence-transformers/paraphrase-mpnet-base-v2",
labels=["negative", "positive"],
)

args = TrainingArguments(
batch_size=16,
Expand Down Expand Up @@ -94,7 +97,7 @@ model = SetFitModel.from_pretrained("tomaarsen/setfit-paraphrase-mpnet-base-v2-s
# Run inference
preds = model.predict(["i loved the spiderman movie!", "pineapple on pizza is the worst 🤮"])
print(preds)
# tensor([1, 0], dtype=torch.int32)
# ["positive", "negative"]
```


Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"evaluate>=0.3.0",
"huggingface_hub>=0.13.0",
"scikit-learn",
"packaging",
]
ABSA_REQUIRE = ["spacy"]
QUALITY_REQUIRE = ["black", "flake8", "isort", "tabulate"]
Expand Down Expand Up @@ -53,7 +54,7 @@ def combine_requirements(base_keys):

setup(
name="setfit",
version="1.0.1",
version="1.0.2",
description="Efficient few-shot learning with Sentence Transformers",
long_description=README_TEXT,
long_description_content_type="text/markdown",
Expand Down
2 changes: 1 addition & 1 deletion src/setfit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "1.0.1"
__version__ = "1.0.2"

import importlib
import os
Expand Down
57 changes: 51 additions & 6 deletions src/setfit/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
import torch
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
from huggingface_hub.utils import validate_hf_hub_args
from sentence_transformers import SentenceTransformer, models
from packaging.version import Version, parse
from sentence_transformers import SentenceTransformer
from sentence_transformers import __version__ as sentence_transformers_version
from sentence_transformers import models
from sklearn.linear_model import LogisticRegression
from sklearn.multiclass import OneVsRestClassifier
from sklearn.multioutput import ClassifierChain, MultiOutputClassifier
Expand Down Expand Up @@ -215,6 +218,7 @@ class SetFitModel(PyTorchModelHubMixin):
normalize_embeddings: bool = False
labels: Optional[List[str]] = None
model_card_data: Optional[SetFitModelCardData] = field(default_factory=SetFitModelCardData)
sentence_transformers_kwargs: Dict = field(default_factory=dict, repr=False)

attributes_to_save: Set[str] = field(
init=False, repr=False, default_factory=lambda: {"normalize_embeddings", "labels"}
Expand Down Expand Up @@ -501,6 +505,11 @@ def predict_proba(
inputs = [inputs]
embeddings = self.encode(inputs, batch_size=batch_size, show_progress_bar=show_progress_bar)
probs = self.model_head.predict_proba(embeddings)
if isinstance(probs, list):
if self.has_differentiable_head:
probs = torch.stack(probs, axis=1)
else:
probs = np.stack(probs, axis=1)
outputs = self._output_type_conversion(probs, as_numpy=as_numpy)
return outputs[0] if is_singular else outputs

Expand Down Expand Up @@ -600,6 +609,9 @@ def device(self) -> torch.device:
Returns:
torch.device: The device that the model is on.
"""
# SentenceTransformers.device is reliable from 2.3.0 onwards
if parse(sentence_transformers_version) >= Version("2.3.0"):
return self.model_body.device
return self.model_body._target_device

def to(self, device: Union[str, torch.device]) -> "SetFitModel":
Expand All @@ -617,9 +629,10 @@ def to(self, device: Union[str, torch.device]) -> "SetFitModel":
Returns:
SetFitModel: Returns the original model, but now on the desired device.
"""
# Note that we must also set _target_device, or any SentenceTransformer.fit() call will reset
# the body location
self.model_body._target_device = device if isinstance(device, torch.device) else torch.device(device)
# Note that we must also set _target_device with sentence-transformers <2.3.0,
# or any SentenceTransformer.fit() call will reset the body location
if parse(sentence_transformers_version) < Version("2.3.0"):
self.model_body._target_device = device if isinstance(device, torch.device) else torch.device(device)
self.model_body = self.model_body.to(device)

if self.has_differentiable_head:
Expand Down Expand Up @@ -696,10 +709,37 @@ def _from_pretrained(
multi_target_strategy: Optional[str] = None,
use_differentiable_head: bool = False,
device: Optional[Union[torch.device, str]] = None,
trust_remote_code: bool = False,
**model_kwargs,
) -> "SetFitModel":
model_body = SentenceTransformer(model_id, cache_folder=cache_dir, use_auth_token=token, device=device)
device = model_body._target_device
sentence_transformers_kwargs = {
"cache_folder": cache_dir,
"use_auth_token": token,
"device": device,
"trust_remote_code": trust_remote_code,
}
if parse(sentence_transformers_version) >= Version("2.3.0"):
sentence_transformers_kwargs = {
"cache_folder": cache_dir,
"token": token,
"device": device,
"trust_remote_code": trust_remote_code,
}
else:
if trust_remote_code:
raise ValueError(
"The `trust_remote_code` argument is only supported for `sentence-transformers` >= 2.3.0."
)
sentence_transformers_kwargs = {
"cache_folder": cache_dir,
"use_auth_token": token,
"device": device,
}
model_body = SentenceTransformer(model_id, **sentence_transformers_kwargs)
if parse(sentence_transformers_version) >= Version("2.3.0"):
device = model_body.device
else:
device = model_body._target_device
model_body.to(device) # put `model_body` on the target device

# Try to load a SetFit config file
Expand Down Expand Up @@ -822,6 +862,7 @@ def _from_pretrained(
model_head=model_head,
multi_target_strategy=multi_target_strategy,
model_card_data=model_card_data,
sentence_transformers_kwargs=sentence_transformers_kwargs,
**model_kwargs,
)

Expand All @@ -846,6 +887,10 @@ def _from_pretrained(
Whether to apply normalization on the embeddings produced by the Sentence Transformer body.
device (`Union[torch.device, str]`, *optional*):
The device on which to load the SetFit model, e.g. `"cuda:0"`, `"mps"` or `torch.device("cuda")`.
trust_remote_code (`bool`, defaults to `False`): Whether or not to allow for custom Sentence Transformers
models defined on the Hub in their own modeling files. This option should only be set to True for
repositories you trust and in which you have read the code, as it will execute code present on
the Hub on your local machine. Defaults to False.
Example::
Expand Down
97 changes: 96 additions & 1 deletion src/setfit/span/modeling.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import copy
import os
import re
import tempfile
import types
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union

import torch
from datasets import Dataset
from huggingface_hub.utils import SoftTemporaryDirectory

from setfit.utils import set_docstring
Expand Down Expand Up @@ -148,7 +151,99 @@ class AbsaModel:
aspect_model: AspectModel
polarity_model: PolarityModel

def predict(self, inputs: Union[str, List[str]]) -> List[Dict[str, Any]]:
def gold_aspect_spans_to_aspects_list(self, inputs: Dataset) -> List[List[slice]]:
# First group inputs by text
grouped_data = defaultdict(list)
for sample in inputs:
text = sample.pop("text")
grouped_data[text].append(sample)

# Get the spaCy docs
docs, _ = self.aspect_extractor(grouped_data.keys())

# Get the aspect spans for each doc by matching gold spans to the spaCy tokens
aspects_list = []
index = -1
skipped_indices = []
for doc, samples in zip(docs, grouped_data.values()):
aspects_list.append([])
for sample in samples:
index += 1
match_objects = re.finditer(re.escape(sample["span"]), doc.text)
for i, match in enumerate(match_objects):
if i == sample["ordinal"]:
char_idx_start = match.start()
char_idx_end = match.end()
span = doc.char_span(char_idx_start, char_idx_end)
if span is None:
logger.warning(
f"Aspect term {sample['span']!r} with ordinal {sample['ordinal']}, isn't a token in {doc.text!r} according to spaCy. "
"Skipping this sample."
)
skipped_indices.append(index)
continue
aspects_list[-1].append(slice(span.start, span.end))
return docs, aspects_list, skipped_indices

def predict_dataset(self, inputs: Dataset) -> Dataset:
if set(inputs.column_names) >= {"text", "span", "ordinal"}:
pass
elif set(inputs.column_names) >= {"text", "span"}:
inputs = inputs.add_column("ordinal", [0] * len(inputs))
else:
raise ValueError(
"`inputs` must be either a `str`, a `List[str]`, or a `datasets.Dataset` with columns `text` and `span` and optionally `ordinal`. "
f"Found a dataset with these columns: {inputs.column_names}."
)
if "pred_polarity" in inputs.column_names:
raise ValueError(
"`predict_dataset` wants to add a `pred_polarity` column, but the input dataset already contains that column."
)
docs, aspects_list, skipped_indices = self.gold_aspect_spans_to_aspects_list(inputs)
polarity_list = sum(self.polarity_model(docs, aspects_list), [])
for index in skipped_indices:
polarity_list.insert(index, None)
return inputs.add_column("pred_polarity", polarity_list)

def predict(self, inputs: Union[str, List[str], Dataset]) -> Union[List[Dict[str, Any]], Dataset]:
"""Predicts aspects & their polarities of the given inputs.
Example::
>>> from setfit import AbsaModel
>>> model = AbsaModel.from_pretrained(
... "tomaarsen/setfit-absa-bge-small-en-v1.5-restaurants-aspect",
... "tomaarsen/setfit-absa-bge-small-en-v1.5-restaurants-polarity",
... )
>>> model.predict("The food and wine are just exquisite.")
[{'span': 'food', 'polarity': 'positive'}, {'span': 'wine', 'polarity': 'positive'}]
>>> from setfit import AbsaModel
>>> from datasets import load_dataset
>>> model = AbsaModel.from_pretrained(
... "tomaarsen/setfit-absa-bge-small-en-v1.5-restaurants-aspect",
... "tomaarsen/setfit-absa-bge-small-en-v1.5-restaurants-polarity",
... )
>>> dataset = load_dataset("tomaarsen/setfit-absa-semeval-restaurants", split="train")
>>> model.predict(dataset)
Dataset({
features: ['text', 'span', 'label', 'ordinal', 'pred_polarity'],
num_rows: 3693
})
Args:
inputs (Union[str, List[str], Dataset]): Either a sentence, a list of sentences,
or a dataset with columns `text` and `span` and optionally `ordinal`. This dataset
contains gold aspects, and we only predict the polarities for them.
Returns:
Union[List[Dict[str, Any]], Dataset]: Either a list of dictionaries with keys `span`
and `polarity` if the input was a sentence or a list of sentences, or a dataset with
columns `text`, `span`, `ordinal`, and `pred_polarity`.
"""
if isinstance(inputs, Dataset):
return self.predict_dataset(inputs)

is_str = isinstance(inputs, str)
inputs_list = [inputs] if is_str else inputs
docs, aspects_list = self.aspect_extractor(inputs_list)
Expand Down
22 changes: 11 additions & 11 deletions src/setfit/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,10 +507,8 @@ def get_dataloader(
args.sampling_strategy,
max_pairs=max_pairs,
)
# shuffle_sampler = True can be dropped in for further 'randomising'
shuffle_sampler = True if args.sampling_strategy == "unique" else False
batch_size = min(args.embedding_batch_size, len(data_sampler))
dataloader = DataLoader(data_sampler, batch_size=batch_size, shuffle=shuffle_sampler, drop_last=False)
dataloader = DataLoader(data_sampler, batch_size=batch_size, drop_last=False)
loss = args.loss(self.model.model_body)

return dataloader, loss, batch_size
Expand Down Expand Up @@ -576,8 +574,8 @@ def _train_sentence_transformer(
if args.use_amp:
scaler = torch.cuda.amp.GradScaler()

model_body.to(model_body._target_device)
loss_func.to(model_body._target_device)
model_body.to(self.model.device)
loss_func.to(self.model.device)

# Use smart batching
train_dataloader.collate_fn = model_body.smart_batching_collate
Expand Down Expand Up @@ -625,8 +623,8 @@ def _train_sentence_transformer(
data = next(data_iterator)

features, labels = data
labels = labels.to(model_body._target_device)
features = list(map(lambda batch: batch_to_device(batch, model_body._target_device), features))
labels = labels.to(self.model.device)
features = list(map(lambda batch: batch_to_device(batch, self.model.device), features))

if args.use_amp:
with autocast():
Expand Down Expand Up @@ -673,10 +671,12 @@ def _train_sentence_transformer(
step_to_load = dir_name[5:]
logger.info(f"Loading best SentenceTransformer model from step {step_to_load}.")
self.model.model_card_data.set_best_model_step(int(step_to_load))
sentence_transformer_kwargs = self.model.sentence_transformers_kwargs
sentence_transformer_kwargs["device"] = self.model.device
self.model.model_body = SentenceTransformer(
self.state.best_model_checkpoint, device=model_body._target_device
self.state.best_model_checkpoint, **sentence_transformer_kwargs
)
self.model.model_body.to(model_body._target_device)
self.model.model_body.to(self.model.device)

# Ensure logging the speed metrics
num_train_samples = self.state.max_steps * args.embedding_batch_size # * args.gradient_accumulation_steps
Expand Down Expand Up @@ -736,8 +736,8 @@ def _evaluate_with_loss(
tqdm(iter(eval_dataloader), total=eval_steps, leave=False, disable=not args.show_progress_bar), start=1
):
features, labels = data
labels = labels.to(model_body._target_device)
features = list(map(lambda batch: batch_to_device(batch, model_body._target_device), features))
labels = labels.to(self.model.device)
features = list(map(lambda batch: batch_to_device(batch, self.model.device), features))

if args.use_amp:
with autocast():
Expand Down
4 changes: 1 addition & 3 deletions src/setfit/trainer_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,8 @@ def get_dataloader(
data_sampler = ContrastiveDistillationDataset(
input_data, cos_sim_matrix, args.num_iterations, args.sampling_strategy, max_pairs=max_pairs
)
# shuffle_sampler = True can be dropped in for further 'randomising'
shuffle_sampler = True if args.sampling_strategy == "unique" else False
batch_size = min(args.embedding_batch_size, len(data_sampler))
dataloader = DataLoader(data_sampler, batch_size=batch_size, shuffle=shuffle_sampler, drop_last=False)
dataloader = DataLoader(data_sampler, batch_size=batch_size, drop_last=False)
loss = args.loss(self.model.model_body)
return dataloader, loss, batch_size

Expand Down
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ def absa_model() -> AbsaModel:
return AbsaModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2", spacy_model="en_core_web_sm")


@pytest.fixture()
def trained_absa_model() -> AbsaModel:
return AbsaModel.from_pretrained(
"tomaarsen/setfit-absa-bge-small-en-v1.5-restaurants-aspect",
"tomaarsen/setfit-absa-bge-small-en-v1.5-restaurants-polarity",
)


@pytest.fixture()
def absa_dataset() -> Dataset:
texts = [
Expand Down
Loading

0 comments on commit 9729aa0

Please sign in to comment.