diff --git a/span_marker/trainer.py b/span_marker/trainer.py index 1fca2f1..bb664b2 100644 --- a/span_marker/trainer.py +++ b/span_marker/trainer.py @@ -192,6 +192,8 @@ def preprocess_dataset( if column not in dataset.column_names: raise ValueError(f"The {dataset_name} dataset must contain a {column!r} column.") + dataset_num_proc = self.args.dataloader_num_workers or None + # Drop all unused columns, only keep "tokens", "ner_tags", "document_id", "sentence_id" dataset = dataset.remove_columns( set(dataset.column_names) - set(self.OPTIONAL_COLUMNS) - set(self.REQUIRED_COLUMNS) @@ -203,6 +205,7 @@ def preprocess_dataset( input_columns=("tokens", "ner_tags"), desc=f"Label normalizing the {dataset_name} dataset", batched=True, + num_proc=dataset_num_proc, ) # Setting model card data based on training data @@ -230,6 +233,7 @@ def preprocess_dataset( remove_columns=set(dataset.column_names) - set(self.OPTIONAL_COLUMNS), desc=f"Tokenizing the {dataset_name} dataset", fn_kwargs={"return_num_words": is_evaluate}, + num_proc=dataset_num_proc, ) # If "document_id" AND "sentence_id" exist in the training dataset if {"document_id", "sentence_id"} <= set(dataset.column_names): @@ -265,6 +269,7 @@ def preprocess_dataset( "model_max_length": tokenizer.model_max_length, "marker_max_length": self.model.config.marker_max_length, }, + num_proc=dataset_num_proc, ) new_length = len(dataset) logger.info(