Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: add test to compare with late chunking in the API #8

Merged
merged 9 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
name: Run Tests

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
types: [opened, synchronize, reopened]
push:
branches:
- main

env:
JINA_API_TOKEN: ${{ secrets.JINA_API_TOKEN }}

jobs:
test:
Expand Down
100 changes: 100 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import os
import numpy as np
from transformers import AutoModel, AutoTokenizer

from chunked_pooling import chunked_pooling
from chunked_pooling.wrappers import load_model
from chunked_pooling.mteb_chunked_eval import AbsTaskChunkedRetrieval

MODEL_NAME = 'jinaai/jina-embeddings-v3'

# Define Text and Chunk
CHUNKS = ["Organic skincare", "for sensitive skin", "with aloe vera and chamomile"]
FULL_TEXT = ' '.join(CHUNKS)


def load_api_results():
import requests

url = 'https://api.jina.ai/v1/embeddings'
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {os.environ["JINA_API_TOKEN"]}',
}
data = {
"model": "jina-embeddings-v3",
"task": "retrieval.passage",
"dimensions": 1024,
"late_chunking": True,
"embedding_type": "float",
"input": CHUNKS,
}
response = requests.post(url, headers=headers, json=data)
data = response.json()
return [np.array(x['embedding']) for x in data['data']]


def calculate_annotations(model, boundary_cues, model_has_instructions, tokenizer):
if model_has_instructions:
instr = model.get_instructions()[1]
instr_tokens = tokenizer(instr, add_special_tokens=False)
n_instruction_tokens = len(instr_tokens[0])
else:
n_instruction_tokens = 0
chunk_annotations = [
AbsTaskChunkedRetrieval._extend_special_tokens(
annotations,
n_instruction_tokens=n_instruction_tokens,
include_prefix=True,
include_sep=True,
)
for annotations in boundary_cues
]
return chunk_annotations


def test_compare_v3_api_embeddings():
# Load Model
model, has_instr = load_model(MODEL_NAME, use_flash_attn=False)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# Determine Boundary Cues
tokenization = tokenizer(
FULL_TEXT, return_offsets_mapping=True, add_special_tokens=False
)
boundary_cues = []
chunk_i = 0
last_cue = 0
last_end = 0
for i, (start, end) in enumerate(tokenization.offset_mapping):
if end >= (last_end + len(CHUNKS[chunk_i])):
boundary_cues.append((last_cue, i + 1))
chunk_i += 1
last_cue = i + 1
last_end = end
extended_boundary_cues = calculate_annotations(
model, [boundary_cues], has_instr, tokenizer
)

# Append Instruction for Retrieval Task
instr = model.get_instructions()[1]
text_inputs = [instr + FULL_TEXT]
model_inputs = tokenizer(
text_inputs,
return_tensors='pt',
padding=True,
truncation=True,
max_length=8192,
)
model_outputs = model(**model_inputs)

# Apply Late Chunking
output_embs = chunked_pooling(
model_outputs, extended_boundary_cues, max_length=8192
)[0]
api_embs = load_api_results()
for local_emb, api_emb in zip(output_embs, api_embs):
local_emb_norm = local_emb / np.linalg.norm(local_emb)
api_emb_norm = api_emb / np.linalg.norm(api_emb)
assert np.allclose(local_emb_norm, api_emb_norm, rtol=1e-02, atol=1e-02)
assert 1.0 - np.dot(local_emb_norm, api_emb_norm) < 1e-3
Loading