From 0a55e745f41e551ab860826dadf45f91d1fb529c Mon Sep 17 00:00:00 2001 From: alanakbik Date: Mon, 24 Mar 2025 12:47:02 +0900 Subject: [PATCH 1/3] GH-3400: add implementation for whitespace-preserving transformer word embeddings --- flair/embeddings/transformer.py | 113 +++++++++++++++++- .../test_transformer_word_embeddings.py | 51 ++++++++ 2 files changed, 163 insertions(+), 1 deletion(-) diff --git a/flair/embeddings/transformer.py b/flair/embeddings/transformer.py index 46b4a8881d..8c908609da 100644 --- a/flair/embeddings/transformer.py +++ b/flair/embeddings/transformer.py @@ -221,6 +221,58 @@ def document_max_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: return result +def map_tokens_to_subtokens(subtoken_offsets, token_offsets, verbose: bool = False, subtokens=None, tokens=None): + + mapping = [] + for subtoken_id, subtoken in enumerate(subtoken_offsets): + + # subtokens of length 0 should not be mapped to anything + if subtoken[0] == subtoken[1]: + mapping.append(None) + continue + + mapping_found = False + + if verbose and subtokens: + print(f"trying to match {subtokens[subtoken_id]} ({subtoken})") + + # check if the subtoken is wholly contained within a token. If so, it should be mapped to this token + for token_id, token in enumerate(token_offsets): + + if verbose and tokens: + print(f" ... does {tokens[token_id]} (#{token_id}, {token}) match?") + + if token[0] - 1 <= subtoken[0] and token[1] >= subtoken[1]: + if verbose: + print(" ... yes!") + mapping.append(token_id) + mapping_found = True + break + + if mapping_found: + continue + + # if the subtoken is not wholly contained within a token, it may be partially contained + # in this case, take the first token in which it is partially contained + for token_id, token in enumerate(token_offsets): + if verbose and tokens: + print(f" ... does {tokens[token_id]} (#{token_id}, {token}) partially match?") + if token[0] >= subtoken[0]: + if verbose: + print(" ... yes!") + mapping.append(token_id) + mapping_found = True + break + + if mapping_found: + continue + + # if a subtoken cannot be mapped, the mapping is None + mapping.append(None) + + return mapping + + def _legacy_reconstruct_word_ids( embedding: "TransformerBaseEmbeddings", flair_tokens: list[list[str]] ) -> list[list[Optional[int]]]: @@ -353,6 +405,8 @@ def __init__( feature_extractor: Optional[FeatureExtractionMixin] = None, needs_manual_ocr: Optional[bool] = None, use_context_separator: bool = True, + use_raw_text_as_input: bool = False, + **kwargs, ) -> None: self.name = name super().__init__() @@ -373,6 +427,7 @@ def __init__( self.feature_extractor = feature_extractor self.use_context_separator = use_context_separator self.cls_pooling = cls_pooling + self.use_raw_text_as_input = use_raw_text_as_input tokenizer_params = list(inspect.signature(self.tokenizer.__call__).parameters.keys()) self.tokenizer_needs_ocr_boxes = "boxes" in tokenizer_params @@ -416,6 +471,7 @@ def to_args(self): "feature_extractor": self.feature_extractor, "use_context_separator": self.use_context_separator, "cls_pooling": self.cls_pooling, + "use_raw_text_as_input": self.use_raw_text_as_input, } if hasattr(self, "needs_manual_ocr"): args["needs_manual_ocr"] = self.needs_manual_ocr @@ -567,6 +623,17 @@ def __build_transformer_model_inputs( tokenizer_kwargs["is_split_into_words"] = True tokenizer_kwargs["text"] = [[t.text for t in tokens] for tokens in flair_tokens] + # if we use raw text as input #TODO: explain + if self.use_raw_text_as_input: + tokenizer_kwargs["is_split_into_words"] = False + tokenizer_kwargs["return_offsets_mapping"] = True + + # reconstruct text of sentences and preserve whitespace_after information + tokenizer_kwargs["text"] = [ + "".join([t.text if t.whitespace_after == 0 else t.text + " " * t.whitespace_after for t in tokens]) + for tokens in flair_tokens + ] + batch_encoding = self.tokenizer( **tokenizer_kwargs, stride=self.stride, @@ -626,12 +693,54 @@ def __build_transformer_model_inputs( if "bbox" in batch_encoding: model_kwargs["bbox"] = batch_encoding["bbox"].to(device, non_blocking=True) + # If we need a token-level embedding, we need to derive mappings between subtokens and flair tokens if self.token_embedding or self.needs_manual_ocr: assert sentence_lengths is not None # for type checking model_kwargs["token_lengths"] = torch.tensor(sentence_lengths, device=device) if self.tokenizer.is_fast: - word_ids_list = [batch_encoding.word_ids(i) for i in range(input_ids.size()[0])] + + if self.use_raw_text_as_input: + word_ids_list = [] + for sentence_no, sentence_tokens in enumerate(flair_tokens): + tokens = [t.text for t in sentence_tokens] + + subtoken_offsets = batch_encoding["offset_mapping"][sentence_no] + + offset = 0 + token_offsets = [] + for token in sentence_tokens: + token_offsets.append((offset, offset + len(token.text))) + offset += len(token.text) + token.whitespace_after + + verbose = True + subtokens = [] + if verbose: + subtokens = self.tokenizer.convert_ids_to_tokens(input_ids.tolist()[sentence_no]) + # old_mapping = batch_encoding.word_ids(sentence_no) + + mapping = map_tokens_to_subtokens( + subtoken_offsets=subtoken_offsets, + token_offsets=token_offsets, + subtokens=subtokens, + tokens=tokens, + verbose=verbose, + ) + + word_ids_list.append(mapping) + + # we need to find other causes of divergence - other causes? + if mapping.count(None) > 2 or True: + print("---") + print("tokens = ", tokens) + print("subtokens = ", subtokens) + print("subtoken_offsets = ", subtoken_offsets) + print("token_offsets = ", token_offsets) + print(mapping) + # print(old_mapping) # why is the old mapping incorrect? + print("---") + else: + word_ids_list = [batch_encoding.word_ids(i) for i in range(input_ids.size()[0])] else: word_ids_list = _legacy_reconstruct_word_ids( self, @@ -1050,6 +1159,7 @@ def __init__( transformers_model_kwargs: dict[str, Any] = {}, peft_config=None, peft_gradient_checkpointing_kwargs: Optional[dict[str, Any]] = {}, + use_raw_text_as_input: bool = False, **kwargs, ) -> None: """Instantiate transformers embeddings. @@ -1096,6 +1206,7 @@ def __init__( logging.set_verbosity_error() self.tokenizer: PreTrainedTokenizer + self.use_raw_text_as_input = use_raw_text_as_input self.feature_extractor: Optional[FeatureExtractionMixin] if tokenizer_data is None: diff --git a/tests/embeddings/test_transformer_word_embeddings.py b/tests/embeddings/test_transformer_word_embeddings.py index a2ca3716a5..ed328d5654 100644 --- a/tests/embeddings/test_transformer_word_embeddings.py +++ b/tests/embeddings/test_transformer_word_embeddings.py @@ -4,10 +4,12 @@ import pytest import torch from PIL import Image +from torch import tensor from transformers.utils import is_detectron2_available from flair.data import BoundingBox, Dictionary, Sentence from flair.embeddings import TransformerJitWordEmbeddings, TransformerWordEmbeddings +from flair.embeddings.transformer import map_tokens_to_subtokens from flair.models import SequenceTagger from tests.embedding_test_utils import BaseEmbeddingsTest @@ -323,3 +325,52 @@ def test_onnx_export_works(self, results_base_path): for sent_a, sent_b in zip(normal_sentences, onnx_sentences): for token_a, token_b in zip(sent_a, sent_b): assert torch.isclose(token_a.get_embedding(), token_b.get_embedding(), atol=1e-6).all() + + def test_token_subtoken_mapping(self): + ### Test Case 1: Normal text + # text = "BEST DENTIST EVER -" + + # Token and subtoken offsets + # tokens = ["[FLERT]", "BEST", "DENTIST", "EVER", "-", "[FLERT]"] + token_offsets = [(0, 7), (8, 12), (13, 20), (21, 25), (26, 27), (27, 34)] + + # subtokens = ["[CLS]", "[FLERT]", "▁BEST", "▁D", "ENT", "IST", "▁EVER", "▁-", "[FLERT]", "[SEP]", ] + subtoken_offsets = tensor( + [[0, 0], [0, 7], [8, 12], [12, 14], [14, 17], [17, 20], [20, 25], [25, 27], [27, 34], [0, 0]] + ) + + mapping = map_tokens_to_subtokens(subtoken_offsets=subtoken_offsets, token_offsets=token_offsets) + + assert [None, 0, 1, 2, 2, 2, 3, 4, 5, None] == mapping + + ### Test Case 2: Differing tokenizations + # text = "So don't be afraid" + + # Token and subtoken offsets + # tokens = ["[FLERT]", "So", "do", "n't", "be", "afraid", "[FLERT]"] + token_offsets = [(0, 7), (8, 10), (11, 13), (13, 16), (17, 19), (20, 26), (26, 33)] + + # subtokens = ["[CLS]", "[FLERT]", "▁So", "▁don", "'", "t", "▁be", "▁afraid", "[FLERT]", "[SEP]"] + subtoken_offsets = tensor( + [[0, 0], [0, 7], [8, 10], [10, 14], [14, 15], [15, 16], [16, 19], [19, 26], [26, 33], [0, 0]] + ) + + mapping = map_tokens_to_subtokens(subtoken_offsets=subtoken_offsets, token_offsets=token_offsets) + + assert [None, 0, 1, 2, 3, 3, 4, 5, 6, None] == mapping + + ### Test Case 3: Text with punctuation and no whitespaces + # text = "this and/or that," + + # Token and subtoken offsets + # tokens = ["[FLERT]", "this", "and", "/", "or", "that", ",", "[FLERT]"] + token_offsets = [(0, 7), (8, 12), (13, 16), (16, 17), (17, 19), (20, 24), (24, 25), (25, 32)] + + # subtokens = ["[CLS]", "[FLERT]", "▁this", "▁and", "/", "or", "▁that", ",", "[FLERT]", "[SEP]"] + subtoken_offsets = tensor( + [[0, 0], [0, 7], [8, 12], [12, 16], [16, 17], [17, 19], [19, 24], [24, 25], [25, 32], [0, 0]] + ) + + mapping = map_tokens_to_subtokens(subtoken_offsets=subtoken_offsets, token_offsets=token_offsets) + + assert [None, 0, 1, 2, 3, 4, 5, 6, 7, None] == mapping From 47faa2425a20a66e509523b0549dc73a3585c0db Mon Sep 17 00:00:00 2001 From: alanakbik Date: Mon, 24 Mar 2025 13:51:10 +0900 Subject: [PATCH 2/3] GH-3400: take out debug code and add more unit tests --- flair/embeddings/transformer.py | 22 ++-------------- .../test_transformer_word_embeddings.py | 25 +++++++++++++++++++ 2 files changed, 27 insertions(+), 20 deletions(-) diff --git a/flair/embeddings/transformer.py b/flair/embeddings/transformer.py index 8c908609da..e85a1f2952 100644 --- a/flair/embeddings/transformer.py +++ b/flair/embeddings/transformer.py @@ -703,7 +703,6 @@ def __build_transformer_model_inputs( if self.use_raw_text_as_input: word_ids_list = [] for sentence_no, sentence_tokens in enumerate(flair_tokens): - tokens = [t.text for t in sentence_tokens] subtoken_offsets = batch_encoding["offset_mapping"][sentence_no] @@ -713,32 +712,13 @@ def __build_transformer_model_inputs( token_offsets.append((offset, offset + len(token.text))) offset += len(token.text) + token.whitespace_after - verbose = True - subtokens = [] - if verbose: - subtokens = self.tokenizer.convert_ids_to_tokens(input_ids.tolist()[sentence_no]) - # old_mapping = batch_encoding.word_ids(sentence_no) - mapping = map_tokens_to_subtokens( subtoken_offsets=subtoken_offsets, token_offsets=token_offsets, - subtokens=subtokens, - tokens=tokens, - verbose=verbose, ) word_ids_list.append(mapping) - # we need to find other causes of divergence - other causes? - if mapping.count(None) > 2 or True: - print("---") - print("tokens = ", tokens) - print("subtokens = ", subtokens) - print("subtoken_offsets = ", subtoken_offsets) - print("token_offsets = ", token_offsets) - print(mapping) - # print(old_mapping) # why is the old mapping incorrect? - print("---") else: word_ids_list = [batch_encoding.word_ids(i) for i in range(input_ids.size()[0])] else: @@ -749,6 +729,8 @@ def __build_transformer_model_inputs( # word_ids is only supported for fast rust tokenizers. Some models like "xlm-mlm-ende-1024" do not have # a fast tokenizer implementation, hence we need to fall back to our own reconstruction of word_ids. + # print(word_ids_list) + if self.token_embedding: assert offsets is not None # for type checking if self.allow_long_sentences: diff --git a/tests/embeddings/test_transformer_word_embeddings.py b/tests/embeddings/test_transformer_word_embeddings.py index ed328d5654..604d957de3 100644 --- a/tests/embeddings/test_transformer_word_embeddings.py +++ b/tests/embeddings/test_transformer_word_embeddings.py @@ -374,3 +374,28 @@ def test_token_subtoken_mapping(self): mapping = map_tokens_to_subtokens(subtoken_offsets=subtoken_offsets, token_offsets=token_offsets) assert [None, 0, 1, 2, 3, 4, 5, 6, 7, None] == mapping + + ### Test Case 4: Suboptimal tokenization caused by limited vocabulary without whitespace + # text = "number of public-diplomacy officers" + + # Token and subtoken offsets + # tokens = ['number', 'of', 'public', '-', 'diplomacy', 'officers'] + token_offsets = [(0, 6), (7, 9), (10, 16), (16, 17), (17, 26), (27, 35)] + + # new_subtokens = ['[CLS]', '▁number', '▁of', '▁public', '-', 'diploma', 'cy', '▁officers', '[SEP]'] + # old_subtokens = ['[CLS]', '▁number', '▁of', '▁public', '▁-', '▁diplomacy', '▁officers', '[SEP]'] + subtoken_offsets = tensor([[0, 0], [0, 6], [6, 9], [9, 16], [16, 17], [17, 24], [24, 26], [26, 35], [0, 0]]) + + assert [None, 0, 1, 2, 3, 4, 5, 6, 7, None] == mapping + + ### Test Case 5: Suboptimal tokenization in which two tokenizer words become one subtoken ("wan" "na" -> "wanna") + # text = "I gotta have it" + + # Token and subtoken offsets + # tokens = ['I', 'got', 'ta', 'have', 'it'] + token_offsets = [(0, 1), (2, 5), (5, 7), (8, 12), (13, 15)] + + # new subtokens = ['[CLS]', '▁I', '▁gotta', '▁have', '▁it', '[SEP]'] + # old subtokens = ['[CLS]', '▁I', '▁got', '▁ta', '▁have', '▁it', '[SEP]'] + subtoken_offsets = tensor([[0, 0], [0, 1], [1, 7], [7, 12], [12, 15], [0, 0]]) + assert [None, 0, 1, 2, 3, 4, 5, 6, 7, None] == mapping From b26307a038fd5632b6a222fe9189f51ac38c5104 Mon Sep 17 00:00:00 2001 From: alanakbik Date: Mon, 24 Mar 2025 23:03:47 +0900 Subject: [PATCH 3/3] GH-3400: fix mypy errors --- flair/embeddings/transformer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/flair/embeddings/transformer.py b/flair/embeddings/transformer.py index e85a1f2952..19ec900661 100644 --- a/flair/embeddings/transformer.py +++ b/flair/embeddings/transformer.py @@ -223,7 +223,7 @@ def document_max_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: def map_tokens_to_subtokens(subtoken_offsets, token_offsets, verbose: bool = False, subtokens=None, tokens=None): - mapping = [] + mapping: list[Optional[int]] = [] for subtoken_id, subtoken in enumerate(subtoken_offsets): # subtokens of length 0 should not be mapped to anything @@ -702,6 +702,7 @@ def __build_transformer_model_inputs( if self.use_raw_text_as_input: word_ids_list = [] + assert flair_tokens # assert that this is not None for mypy type checking for sentence_no, sentence_tokens in enumerate(flair_tokens): subtoken_offsets = batch_encoding["offset_mapping"][sentence_no] @@ -729,8 +730,6 @@ def __build_transformer_model_inputs( # word_ids is only supported for fast rust tokenizers. Some models like "xlm-mlm-ende-1024" do not have # a fast tokenizer implementation, hence we need to fall back to our own reconstruction of word_ids. - # print(word_ids_list) - if self.token_embedding: assert offsets is not None # for type checking if self.allow_long_sentences: