Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

onnx_implementation added #43

Open
wants to merge 43 commits into
base: main
Choose a base branch
from

Conversation

polodealvarado
Copy link

@polodealvarado polodealvarado commented Nov 11, 2023

Hello Tom!

Here you have the first draft for the onnx implementation.

I propose the following to-do list to get the ONNX format for SpanMarkerModel mainly based on the Optimum documentation of HuggingFace :

  • Create the SpanMarkerOnnxConfig class
  • Create the SpanMarkerDummyTextInputenerator class
  • Compare results with the original model.
  • Register the ONNX configuration in the TasksManager.
  • Create the SpanMarkerONNXPipeline class
  • Compare execution times for different batch sizes.

Requirements

  • optimum==1.13.2

How to get the onnx model

If you run the script onnx_implementation.py you get the model exported into onnx format.
The model is perfectly generated but we should fix a problem with the input/output variable names.

Error response

There is a problem with the names of the generated input/output nodes because ONNX sets different names for each node. In the forward function the following input names are the same in the outputs when the SpanMarkerOutput object is created: "start_marker_indices", "num_marker_pairs", "num_words", "document_ids", "sentence_ids". By default, ONNX transforms the corresponding input names by adding '.1' at the end (You can check the error response at the bottom of the script).

Reviewing the code, I noticed that within the 'forward' function there are 'post-processing' stages included after the encoder output is obtained. One thing we could do is to split this code and leave the 'forward' solely for the execution of the encoder and then apply the post processing steps and the input/output names are not affected. This way the input/output names are unique.

def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        position_ids: torch.Tensor,
        **kwargs,
    ) -> Dict[str,torch.Tensor]:
        """Forward call of the SpanMarkerModel.
        Args:
            input_ids (~torch.Tensor): Input IDs including start/end markers.
            attention_mask (~torch.Tensor): Attention mask matrix including one-directional attention for markers.
            position_ids (~torch.Tensor): Position IDs including start/end markers.
        None.

        Returns:
            outputs: Encoder outputs
        """
        token_type_ids = torch.zeros_like(input_ids)
        outputs = self.encoder(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
        )
        return outputs

def postprocessing_encoder(
        encoder_outputs: Dict[str,torch.Tensor],
        start_marker_indices: torch.Tensor,
        num_marker_pairs: torch.Tensor,
        labels: Optional[torch.Tensor] = None,
        num_words: Optional[torch.Tensor] = None,
        document_ids: Optional[torch.Tensor] = None,
        sentence_ids: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> SpanMarkerOutput:
        last_hidden_state = encoder_outputs[0]
        last_hidden_state = self.dropout(last_hidden_state)

        batch_size = last_hidden_state.size(0)
        sequence_length = last_hidden_state.size(1)

        # Get the indices where the end markers start
        end_marker_indices = start_marker_indices + num_marker_pairs

        # The start marker embeddings concatenated with the end marker embeddings.
        # This is kind of breaking the cardinal rule of GPU-based ML, as this is processing
        # the batch iteratively per sample, but every sample produces a different shape matrix
        # and this is the most convenient way to recombine them into a matrix.
        embeddings = []
        for i in range(batch_size):
            embeddings.append(
                torch.cat(
                    (
                        last_hidden_state[i, start_marker_indices[i] : end_marker_indices[i]],
                        last_hidden_state[i, end_marker_indices[i] : end_marker_indices[i] + num_marker_pairs[i]],
                    ),
                    dim=-1,
                )
            )
        padded_embeddings = [
            F.pad(embedding, (0, 0, 0, sequence_length // 2 - embedding.shape[0])) for embedding in embeddings
        ]
        feature_vector = torch.stack(padded_embeddings)

        # NOTE: This was wrong in the older tests
        feature_vector = self.dropout(feature_vector)
        logits = self.classifier(feature_vector)

        if labels is not None:
            loss = self.loss_func(logits.view(-1, self.config.num_labels), labels.view(-1))

        return SpanMarkerOutput(
            loss=loss if labels is not None else None,
            logits=logits,
            *encoder_outputs[2:],
            num_marker_pairs=num_marker_pairs,
            num_words=num_words,
            document_ids=document_ids,
            sentence_ids=sentence_ids,
        )

@polodealvarado
Copy link
Author

polodealvarado commented Nov 12, 2023

Now running onnx_implementation.py works.
The ONNX validation is done with a tolerance error of less than 0.0001.

================ Diagnostic Run torch.onnx.export version 2.0.1 ================
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================

Optimum version: 1.13.2
Validating ONNX model spanmarker_model.onnx...
        -[✓] ONNX model output names match reference model (last_hidden_state, pooler_output)
        - Validating ONNX Model output "last_hidden_state":
                -[✓] (2, 512, 768) matches (2, 512, 768)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "pooler_output":
                -[✓] (2, 768) matches (2, 768)
                -[✓] all values close (atol: 0.0001)

I have splitted the forward function into two parts: forward and postprocessing_encoder_output:

   def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        position_ids: torch.Tensor,
        **kwargs,
    ) -> Dict[str, torch.Tensor]:
        """Forward call of the SpanMarkerModel.
        Args:
            input_ids (~torch.Tensor): Input IDs including start/end markers.
            attention_mask (~torch.Tensor): Attention mask matrix including one-directional attention for markers.
            position_ids (~torch.Tensor): Position IDs including start/end markers.
        None.

        Returns:
            outputs: Encoder outputs
        """
        token_type_ids = torch.zeros_like(input_ids)
        outputs = self.encoder(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
        )
        return outputs

    def postprocessing_encoder_output(
        self,
        encoder_outputs: Dict[str, torch.Tensor],
        start_marker_indices: torch.Tensor,
        num_marker_pairs: torch.Tensor,
        labels: Optional[torch.Tensor] = None,
        num_words: Optional[torch.Tensor] = None,
        document_ids: Optional[torch.Tensor] = None,
        sentence_ids: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> SpanMarkerOutput:
        """
        Post-processes the output from the encoder to produce SpanMarkerOutput.
        Args:
            encoder_outputs (Dict[str, torch.Tensor]): Outputs from the encoder, typically including the last hidden state.
            start_marker_indices (torch.Tensor): Indices in the input sequence where start markers are located.
            num_marker_pairs (torch.Tensor): Number of start-end marker pairs in each example in the batch.
            labels (Optional[torch.Tensor]): Labels for the input data, used for supervised learning. Default is None.
            num_words (Optional[torch.Tensor]): Number of words in each input sequence. Default is None.
            document_ids (Optional[torch.Tensor]): Identifiers for the documents in the batch. Default is None.
            sentence_ids (Optional[torch.Tensor]): Identifiers for the sentences in the batch. Default is None.

        Returns:
            SpanMarkerOutput: Processed features and optionally the loss, if labels are provided.
        """
        last_hidden_state = encoder_outputs[0]
        last_hidden_state = self.dropout(last_hidden_state)

        batch_size = last_hidden_state.size(0)
        sequence_length = last_hidden_state.size(1)

        # Get the indices where the end markers start
        end_marker_indices = start_marker_indices + num_marker_pairs

        # The start marker embeddings concatenated with the end marker embeddings.
        # This is kind of breaking the cardinal rule of GPU-based ML, as this is processing
        # the batch iteratively per sample, but every sample produces a different shape matrix
        # and this is the most convenient way to recombine them into a matrix.
        embeddings = []
        for i in range(batch_size):
            embeddings.append(
                torch.cat(
                    (
                        last_hidden_state[i, start_marker_indices[i] : end_marker_indices[i]],
                        last_hidden_state[i, end_marker_indices[i] : end_marker_indices[i] + num_marker_pairs[i]],
                    ),
                    dim=-1,
                )
            )
        padded_embeddings = [
            F.pad(embedding, (0, 0, 0, sequence_length // 2 - embedding.shape[0])) for embedding in embeddings
        ]
        feature_vector = torch.stack(padded_embeddings)

        # NOTE: This was wrong in the older tests
        feature_vector = self.dropout(feature_vector)
        logits = self.classifier(feature_vector)

        if labels is not None:
            loss = self.loss_func(logits.view(-1, self.config.num_labels), labels.view(-1))

        return SpanMarkerOutput(
            loss=loss if labels is not None else None,
            logits=logits,
            *encoder_outputs[2:],
            num_marker_pairs=num_marker_pairs,
            num_words=num_words,
            document_ids=document_ids,
            sentence_ids=sentence_ids,
        )

@tomaarsen
Copy link
Owner

tomaarsen commented Nov 12, 2023

Heya @polodealvarado!

This is looking awesome already! I did some experiments of my own with your code, and I can reproduce your good validation results. Regarding the naming issue - I am also open to e.g. renaming the inputs/outputs, for example prepending them with out_.... What do you think of this? For context, these output variables are not widely used: I only have them for evaluation.py, where e.g. the number of words in the sample is necessary for evaluations. If it means that more of the code is inside of the faster ONNX environment, then I am in favor of that.

I've tested it out briefly, and I get:

Optimum version: 1.13.2
Validating ONNX model spanmarker_model.onnx...
        -[✓] ONNX model output names match reference model (out_num_words, out_num_marker_pairs, logits, out_document_ids, out_sentence_ids)
        - Validating ONNX Model output "logits":
                -[✓] (2, 256, 16) matches (2, 256, 16)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "out_num_marker_pairs":
                -[✓] (2,) matches (2,)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "out_num_words":
                -[✓] (2,) matches (2,)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "out_document_ids":
                -[✓] (2,) matches (2,)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "out_sentence_ids":
                -[✓] (2,) matches (2,)
                -[✓] all values close (atol: 0.0001)

Also, I must admit that I haven't yet figured out how to actually run the produced spanmarker_model.onnx. I'm very curious to see the performance impact!

Thanks for working on this, I appreciate it! This is definitely heading in the right direction.

  • Tom Aarsen

@polodealvarado
Copy link
Author

Given that "num_words", "num_marker_pairs", ... etc are not computed by the onnx graph I think it does not affect to the perfomance, but I will run some tests just in case.

Great! I like your approach. Let's use "out_" for these variables.
I am going to get the SpanMarkerOnnxPipeline ready to have all the spanmarker process.

@polodealvarado
Copy link
Author

The first draft for the SpanMarkerOnnxPipeline is ready.
I am facing issues with the batch size for the input onnx.

For some reason the onnx model only processes one batch when the input batch size is different from the dummy input batch size used during the onnx conversion.

# Load ONNX Pipeline
onnx_path = Path("spanmarker_model.onnx")
repo_id = "lxyuan/span-marker-bert-base-multilingual-uncased-multinerd"
onnx_pipe = SpanMarkerOnnxPipeline(onnx_path=onnx_path, repo_id=repo_id)

sample = ["Pedro is working in Alicante"]  # It works
sample = [
    "Pedro is working in Alicante. Pedro is working in Alicante. Pedro is working in Alicante.Pedro is working in Alicante. Pedro is working in Alicante. Pedro is working in Alicante.Pedro is working in Alicante. Pedro is working in Alicante. Pedro is working in Alicante",
]  # It doesn't work
start_time = time.time()
onnx_pipe(sample)
end_time = time.time()
print(f"Execution time: {end_time-start_time}")

@polodealvarado
Copy link
Author

polodealvarado commented Nov 13, 2023

Mmmm I have found a similar issue here:
https://discuss.pytorch.org/t/dynamic-axes-doesnt-work-for-torch-onnx-export-when-torch-cat-is-present/149501

Actually we have torch.cat implemented into the forward function.

@tomaarsen
Copy link
Owner

Hmm, that is odd. I can also reproduce your findings, i.e. that it works for the short sample, but not for the longer one.

@polodealvarado
Copy link
Author

Hello @tomaarsen.

I have figured out how to make it work. Tomorrow I will share with you how to generate the ONNX models. Now I am facing the challenge of how to improve the inference time with ONNX because it is a bit slower than Torch (at least on my M1 Pro).

@tomaarsen
Copy link
Owner

Awesome! I also noticed when running the still-flawed ONNX model that it was roughly as fast on GPU (On my RTX 3090) and roughly 2x as fast on CPU. I'd be very curious to see your findings, and to observe the ONNX graph. Perhaps there's some model inefficiencies, for example the ugly for-loop that I used in post-processing. I tried to refactor that a few times, but all my torch-only solutions were slightly slower.

Thank you for all of your work on this!

  • Tom Aarsen

@polodealvarado
Copy link
Author

polodealvarado commented Nov 16, 2023

Here I have a first solution, however it doesn't improve the inference time compare with the original torch model.
As we discussed previously about that "ugly" for-loop , we were right. This kind of loop affects to the ONNX performance.

To avoid the use of this for-loop inside the onnx model I proposed to split the whole "forward" process into two onnx models: one just for the encoder and another for the classifier. The postprocessing code between the self.encoder() and self.classfier(), where the torch.cat function is, remains unchanged.

However, the execution of the onnx encoder model still goes slower than the torch. I am still researching about some kind of optimizations.

@tomaarsen
Would you mind reproducing the results of the script 'onnx_implement' and sharing it?

Thank you !

Results

Batch size: 30
Torch time: 12.3975088596344
ONNX time: 14.977856636047363
Results are the same: True

System Info
Platform: MacOS Sonoma 14.0, M1 Pro
Python 3.11.5

@polodealvarado
Copy link
Author

Tomorrow I share with you the other solution I have developed. This one tries to improve the forward code.

@tomaarsen
Copy link
Owner

I'll run the script in about 10 hours tomorrow morning and report back my results! I'll run it both on CPU and GPU if I can.

@tomaarsen
Copy link
Owner

These are my results. I reran everything twice to make sure that there wasn't any random fluke:

CPU:
1st time:

Time results:
Batch size: 30
Torch time: 17.610121965408325
ONNX time: 10.279107809066772
Results are the same: True

2nd time:

Time results:
Batch size: 30
Torch time: 17.337737321853638
ONNX time: 10.79595136642456
Results are the same: True

CUDA:

Time results:
Batch size: 30
Torch time: 1.2919535636901855
ONNX time: 0.9842772483825684
Results are the same: False
Time results:
Batch size: 30
Torch time: 1.0053291320800781
ONNX time: 0.7737681865692139
Results are the same: False

Not the same results :(

Okay, I've narrowed it down. The difference was a slight change in the "score", and nothing more. I've made the following function:

def strip_score_from_results(results):
    return [[{key: value for key, value in ent.items() if key != "score"} for ent in ents] for ents in results]

and then used:

print(f"Results are the same: {strip_score_from_results(torch_result)==strip_score_from_results(onnx_result)}")

And now it says:

Time results:
Batch size: 30
Torch time: 0.9915874004364014
ONNX time: 0.77559494972229
Results are the same: True

Another thing to consider is that the normal model.predict has a batch_size parameter which defaults to 4. Setting it higher (e.g. to 30) will speed up processing. Then for CUDA we get e.g.:

Time results:
Batch size: 30
Torch time: 0.8699021339416504
ONNX time: 0.6862277984619141
Results are the same: True

and for CPU we get:

Time results:
Batch size: 30
Torch time: 17.311928272247314
ONNX time: 9.293296575546265
Results are the same: True

Another interesting quirk, my ONNX is sometimes a bit slower if I set two providers: providers=['CUDAExecutionProvider', 'CPUExecutionProvider'] than if I just set ['CUDAExecutionProvider']. Maybe you also get better performance by e.g. only setting the primary provider that you want to use?

  • Tom Aarsen

@polodealvarado
Copy link
Author

Interesting. I have been reading some issues in the onnxruntime repo and I have found several problems running onnxruntime engine on Apple Silicon.

I will run the same test on my Linux - GPU server.

On the other, based on your result we got an improvement, at least for this first test, of :
~72% for CPU
~28% for CUDA

@polodealvarado
Copy link
Author

polodealvarado commented Nov 17, 2023

@tomaarsen
On my Linux machine, I have gotten results closer to yours.
I have discovered also that document-level context affects to the performance:

Setting batch_size=30

Results

Without document-level context

batch = [
        "Pedro is working in Alicante. Pedro is working in Alicante. Pedro is working in Alicante.Pedro is working in Alicante. Pedro is working in Alicante. Pedro is working in Alicante.Pedro is working in Alicante. Pedro is working in Alicante. Pedro is working in Alicante",
    ] * 30

CPU
Time results:
Batch size: 30
Torch time: 24.0738263130188
ONNX time: 12.592668056488037
Results are the same: True

CUDA
Time results:
Batch size: 30
Torch time: 0.91231443252
ONNX time: 0.59696873378
Results are the same: True

With document-level context

batch = [[
        "Pedro is working in Alicante. Pedro is working in Alicante. Pedro is working in Alicante.Pedro is working in Alicante. Pedro is working in Alicante. Pedro is working in Alicante.Pedro is working in Alicante. Pedro is working in Alicante. Pedro is working in Alicante",
    ]]* 30

CPU
Time results:
Batch size: 30
Torch time: 13.35103726387024
ONNX time: 6.461100339889526
Results are the same: True

CUDA
Time results:
Batch size: 30
Torch time: 0.8129876
ONNX time: 0.5063415
Results are the same: True

System Info

Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 46 bits physical, 48 bits virtual
CPU(s): 12
On-line CPU(s) list: 0-11
Thread(s) per core: 2
Core(s) per socket: 6
Socket(s): 1
NUMA node(s): 1
Vendor ID: GenuineIntel
CPU family: 6
Model: 85
Model name: Intel(R) Xeon(R) CPU @ 2.00GHz
Stepping: 3
CPU MHz: 2000.164
BogoMIPS: 4000.32
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 192 KiB
L1i cache: 192 KiB
L2 cache: 6 MiB
L3 cache: 38.5 MiB
NUMA node0 CPU(s): 0-11

Python libraries

Package Version


aiohttp 3.8.6
aiosignal 1.3.1
async-timeout 4.0.3
attrs 23.1.0
certifi 2023.7.22
charset-normalizer 3.3.2
coloredlogs 15.0.1
datasets 2.15.0
dill 0.3.7
evaluate 0.4.1
filelock 3.13.1
flatbuffers 23.5.26
frozenlist 1.4.0
fsspec 2023.10.0
huggingface-hub 0.19.4
humanfriendly 10.0
idna 3.4
Jinja2 3.1.2
joblib 1.3.2
MarkupSafe 2.1.3
mpmath 1.3.0
multidict 6.0.4
multiprocess 0.70.15
networkx 3.2.1
numpy 1.26.2
nvidia-cublas-cu12 12.1.3.1
nvidia-cuda-cupti-cu12 12.1.105
nvidia-cuda-nvrtc-cu12 12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12 8.9.2.26
nvidia-cufft-cu12 11.0.2.54
nvidia-curand-cu12 10.3.2.106
nvidia-cusolver-cu12 11.4.5.107
nvidia-cusparse-cu12 12.1.0.106
nvidia-nccl-cu12 2.18.1
nvidia-nvjitlink-cu12 12.3.101
nvidia-nvtx-cu12 12.1.105
onnx 1.15.0
onnxruntime-gpu 1.16.2
optimum 1.14.1
packaging 23.2
pandas 2.1.3
pip 23.3
protobuf 4.25.1
pyarrow 14.0.1
pyarrow-hotfix 0.5
python-dateutil 2.8.2
pytz 2023.3.post1
PyYAML 6.0.1
regex 2023.10.3
requests 2.31.0
responses 0.18.0
safetensors 0.4.0
scikit-learn 1.3.2
scipy 1.11.3
sentencepiece 0.1.99
setuptools 68.0.0
six 1.16.0
sympy 1.12
threadpoolctl 3.2.0
tokenizers 0.15.0
torch 2.1.1
tqdm 4.66.1
transformers 4.35.2
triton 2.1.0
typing_extensions 4.8.0
tzdata 2023.3
urllib3 2.1.0
wheel 0.41.2
xxhash 3.4.1
yarl 1.9.2

@tomaarsen
Copy link
Owner

Oh that is very interesting! I would not have expected a change in performance between those two cases.

@tomaarsen
Copy link
Owner

Hello!

No worries! No major updates from my side, though I did inspect the third script that you created. I see that it doesn't use optimum, for example. I also tried to run it, though I get an error with 'Optional[Tensor]' object has no attribute or method 'size'.:. Perhaps you also encounter this.

  • Tom

@polodealvarado
Copy link
Author

Hello Tom. I am back.

Tomorrow I will tell you about the latest progress I made before I was away these past two weeks.

On the other hand, could you run the scripts "onnx_implementation_with_torch.py" and "onnx_implementation_with_optimum.py" and share your results?

  • Antonio Polo

@tomaarsen
Copy link
Owner

tomaarsen commented Dec 4, 2023

Hello!

Here are my outputs when running your scripts, and also when modifying the scripts to use CUDA (My RTX 3090) instead:

  • onnx_implementation_with_torch.py:
Time results:
Batch size: 30
Torch times: [17.401074647903442, 17.729020357131958, 16.87925696372986, 17.491692066192627, 16.47899341583252]
ONNX CPU times: [9.490059852600098, 9.385780572891235, 9.485393285751343, 9.465813636779785, 9.746570348739624]
Avg Torch time: 17.19600749015808
Avg ONNX CPU time: 9.514723539352417
Results are the same: True
Time results:
Batch size: 30
Torch times: [0.8015215396881104, 0.6366519927978516, 0.6443564891815186, 0.6469173431396484, 0.6466293334960938]
ONNX CUDA times: [0.6336193084716797, 0.5893862247467041, 0.6037154197692871, 0.583064079284668, 0.5859053134918213]
Avg Torch time: 0.6752153396606445
Avg ONNX CUDA time: 0.5991380691528321
Results are the same: True
  • onnx_implementation_with_optimum.py:
Time results:
Batch size: 30
Torch times: [16.555065631866455, 16.842355966567993, 16.788350820541382, 16.53908395767212, 16.712769269943237]
ONNX CPU times: [9.623130559921265, 9.403372764587402, 9.554663181304932, 9.52034330368042, 9.548219680786133]
Avg Torch time: 16.68752512931824
Avg ONNX CPU time: 9.529945898056031
Results are the same: True
Time results:
Batch size: 30
Torch times: [0.8601093292236328, 0.638655424118042, 0.6409933567047119, 0.6481106281280518, 0.6381371021270752]
ONNX CUDA times: [0.7780861854553223, 0.5775661468505859, 0.5794253349304199, 0.5771944522857666, 0.5816450119018555]
Avg Torch time: 0.6852011680603027
Avg ONNX CUDA time: 0.61878342628479
Results are the same: True

One notable consideration is that with CUDA, there seems to be some "warmup" necessary: the first inference is always slower than the others. In practice, that means just running inference on some dummy text before doing the measurements.
So, if I ignore the first time, then the averages become:
Torch: 0.6425
ONNX with Torch: 0.59051
ONNX with Optimum: 0.5789

So that's about an 11% speedup at this batch size. And a 80% speedup for CPU at these settings!

  • Tom Aarsen

@polodealvarado
Copy link
Author

polodealvarado commented Dec 6, 2023

Hello @tomaarsen . I just uploaded the SpanMarkerOnnx class ready in the onnx folder, within span_marker.

I've commented on the class and functions, when you can, take a look and let me know.

Here is a brief summary:
The SpanMarkerOnnx class consists of 4 fundamental elements: the onnx_encoder, onnx_classifier, config, and tokenizer.

The tokenizer, as in any onnx model pipeline, is necessary.
The config is needed to load variables that are required in the predict. This last function is almost identical to that of the base model.
The encoder and the classifier are not merged into a single onnx model due to the for loop that we have already discussed previously.

for i in range(batch_size):
            feature_vector[                i, : end_marker_indices[i] - start_marker_indices[i], : last_hidden_state.shape[-1]
            ] = last_hidden_state[i, start_marker_indices[i] : end_marker_indices[i]]
            feature_vector[                i, : end_marker_indices[i] - start_marker_indices[i], last_hidden_state.shape[-1] :
            ] = last_hidden_state[i, end_marker_indices[i] : end_marker_indices[i] + num_marker_pairs[i]]


When exporting the for loop we get the following: Dynamic slicing on data-dependent value is not supported (you can reproduce the error by running the file onnx_with_torchdynamo.py)

I have tried different solutions but none have worked for me. Until we find a solution, I think we could launch this first version with the encoder and classifier separately since the execution times compared to the base model are noticeably better.

I have left a file "test_onnx.py" so you can run several tests comparing results and execution times between an onnx model and a base model (by passing provider=["CUDAExecutionProvider"] you can run it on CUDA)

Another thing I would like to do is running these tests on different operating systems, especially Windows and Linux.

  • Antonio Polo

embeddings = []
sequence_length_last_hidden_state = last_hidden_state.size(2) * 2
# Pre-allocates the necessary space for feature_vector
feature_vector = torch.zeros(batch_size, sequence_length // 2, sequence_length_last_hidden_state)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
feature_vector = torch.zeros(batch_size, sequence_length // 2, sequence_length_last_hidden_state)
feature_vector = torch.zeros(batch_size, sequence_length // 2, sequence_length_last_hidden_state, device=self.device)

This is required for it to run well on non-CPU devices.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps even also dtype=last_hidden_state.dtype, e.g. if someone loads the SpanMarker model in lower precision:

SpanMarkerModel.from_pretrained("...", torch_dtype=torch.float16)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also not sure if it's possible to run the ONNX model in lower precision, e.g. float16. For regular torch, it makes a very big difference. (e.g. 2 or 3x as fast)

@polodealvarado
Copy link
Author

device=self.device done.

On the other hand, it is possible to run onnx in lower precision. Link

I can include a function that applies such transformation.

@polodealvarado
Copy link
Author

polodealvarado commented Dec 12, 2023

@tomaarsen
I have pushed an implementation with fp16.

There are still some problems with the OnnxRuntime when running fp16 models and it mostly depends on the hardware and versions of OnnxRuntime being used.

Could you run some tests with higher batch size and reps values using the test_onnx.py file?
To check inference time with your hardware.

@tomaarsen
Copy link
Owner

I ran some tests:

CUDA

  • fp16, CUDA, batch size 32, reps 3, with "warmup" predictions before the measurements:
Time results:
Batch size: 32
Torch times: [0.1742110252380371, 0.16329622268676758, 0.16433238983154297]
ONNX CPU times: [0.2011733055114746, 0.20147180557250977, 0.2001817226409912]
Avg Torch time: 0.16727987925211588
Avg ONNX CPU time: 0.2009422779083252
Results are the same: True
  • fp16, CUDA, batch size 128, reps 3, with "warmup" predictions before the measurements:
Time results:
Batch size: 128
Torch times: [0.5918703079223633, 0.5955493450164795, 0.5886590480804443]
ONNX CPU times: [0.8146817684173584, 0.7572369575500488, 0.749422550201416]
Avg Torch time: 0.5920262336730957
Avg ONNX CPU time: 0.7737804253896078
Results are the same: True
  • fp16, CUDA, batch size 1, reps 3, with "warmup" predictions before the measurements:
Time results:
Batch size: 1
Torch times: [0.05460929870605469, 0.050606727600097656, 0.05323004722595215]
ONNX CPU times: [0.050538063049316406, 0.04193520545959473, 0.04003095626831055]
Avg Torch time: 0.05281535784403483
Avg ONNX CPU time: 0.04416807492574056
Results are the same: True

CPU

  • fp16, CPU (but fp16 is only with ONNX as fp16 isn't implemented on CPU), batch size 32, reps 3, with "warmup" predictions before the measurements:
Time results:
Batch size: 32
Torch times: [8.036468029022217, 7.972146511077881, 8.011084079742432]
ONNX CPU times: [9.070796966552734, 8.906476736068726, 8.901257753372192]
Avg Torch time: 8.006566206614176
Avg ONNX CPU time: 8.959510485331217
Results are the same: True
  • fp16, CPU (but fp16 is only with ONNX as fp16 isn't implemented on CPU), batch size 8, reps 3, with "warmup" predictions before the measurements:
Time results:
Batch size: 8
Torch times: [2.109138250350952, 2.03529953956604, 2.0052287578582764]
ONNX CPU times: [2.4018442630767822, 2.3435325622558594, 2.2726736068725586]
Avg Torch time: 2.049888849258423
Avg ONNX CPU time: 2.3393501440684
Results are the same: True
  • fp16, CPU (but fp16 is only with ONNX as fp16 isn't implemented on CPU), batch size 1, reps 3, with "warmup" predictions before the measurements:
Time results:
Batch size: 1
Torch times: [0.31539320945739746, 0.31916165351867676, 0.31535863876342773]
ONNX CPU times: [0.2530245780944824, 0.27156901359558105, 0.2717933654785156]
Avg Torch time: 0.3166378339131673
Avg ONNX CPU time: 0.26546231905619305
Results are the same: True

@polodealvarado
Copy link
Author

I have achieved similar results to what I expected, based on the cases I found in the ONNX repository.
I think we could keep this option in the backlog until we get it.

Backlog for future versions:

  1. Merge encoder and classifier onnx models.
  2. Include fp16 dtype option.

What do you think?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants