From d067bb5070cf7cf4c8424191a0d4b3bc2880dfe0 Mon Sep 17 00:00:00 2001 From: inakiLakunza Date: Mon, 1 Jun 2026 10:19:07 +0200 Subject: [PATCH] [Benchmark] Add support for MaRVL, xGQA and ALM-Bench. Register the three multilingual benchmarks and resolve generic config entries to their dataset-specific implementations so they can be evaluated through the standard VLMEvalKit flow. Co-authored-by: inigopm <61738961+inigopm@users.noreply.github.com> --- run.py | 12 ++ vlmeval/dataset/__init__.py | 6 + vlmeval/dataset/almbench.py | 212 ++++++++++++++++++++++++++++++++++++ vlmeval/dataset/marvl.py | 155 ++++++++++++++++++++++++++ vlmeval/dataset/xgqa.py | 124 +++++++++++++++++++++ 5 files changed, 509 insertions(+) create mode 100644 vlmeval/dataset/almbench.py create mode 100644 vlmeval/dataset/marvl.py create mode 100644 vlmeval/dataset/xgqa.py diff --git a/run.py b/run.py index 01f3a6869..f8d629f87 100644 --- a/run.py +++ b/run.py @@ -175,6 +175,18 @@ def build_dataset_from_config(cfg, dataset_name): cls = getattr(vlmeval.dataset, cls_name) sig = inspect.signature(cls.__init__) valid_params = {k: v for k, v in config.items() if k in sig.parameters} + dataset_id = valid_params.get('dataset') + generic_dataset_classes = { + 'ImageMCQDataset', + 'ImageVQADataset', + 'ImageYORNDataset', + 'OCRBench', + } + if dataset_id is not None and cls_name in generic_dataset_classes: + dataset_kwargs = {k: v for k, v in valid_params.items() if k != 'dataset'} + resolved = build_dataset(dataset_id, **dataset_kwargs) + if resolved is not None: + return resolved if cls.MODALITY == 'VIDEO': if valid_params.get('fps', 0) > 0 and valid_params.get('nframe', 0) > 0: raise ValueError('fps and nframe should not be set at the same time') diff --git a/vlmeval/dataset/__init__.py b/vlmeval/dataset/__init__.py index 33ae51f98..295ae5eee 100644 --- a/vlmeval/dataset/__init__.py +++ b/vlmeval/dataset/__init__.py @@ -6,6 +6,7 @@ import pandas as pd from vlmeval.smp import LMUDataRoot, dump, get_intermediate_file_path, load, localize_df, toliststr +from .almbench import ALMBenchDataset from .asclepius import Asclepius from .av_speakerbench import AVSpeakerBench from .CGAVCounting.cg_av_counting import CGAVCounting @@ -67,6 +68,7 @@ from .m3oralbench import M3oralBenchDataset from .m4bench import M4Bench from .macbench import MaCBench +from .marvl import MaRVL, MaRVL_id, MaRVL_sw, MaRVL_ta, MaRVL_tr, MaRVL_zh from .matbench import MATBench from .medq_deg_bench import MedQDEGBenchDataset from .medqbench_caption import MedqbenchCaptionDataset @@ -159,6 +161,7 @@ from .wildvision import WildVision from .worldsense import WorldSense from .worldvqa import WorldVQA +from .xgqa import xGQA, xGQA_bn, xGQA_de, xGQA_en, xGQA_id, xGQA_ko, xGQA_pt, xGQA_ru, xGQA_zh from .xstest import XSTestDataset from .video_dataset_config import supported_video_datasets # isort: skip @@ -289,6 +292,9 @@ def evaluate(self, eval_file, **judge_kwargs): ZEROBench, SCAM, Omni3DBench, TallyQA, _3DSRBench, BMMR, AffordanceDataset, MMEReasoning, GOBenchDataset, SFE, ChartMimic, MMVMBench, XLRSBench, OmniEarthMCQBench, VisFactor, OSTDataset, OCRBench_v2, TreeBench, CVQA, M4Bench, + MaRVL, MaRVL_id, MaRVL_sw, MaRVL_ta, MaRVL_tr, MaRVL_zh, + xGQA, xGQA_bn, xGQA_de, xGQA_en, xGQA_id, xGQA_ko, xGQA_pt, xGQA_ru, xGQA_zh, + ALMBenchDataset, AyaVisionBench, TopViewRS, VLMBias, MMHELIX, MedqbenchMCQDataset, MathCanvas, MMReason, MedqbenchPairedDescriptionDataset, MedqbenchCaptionDataset, MedQDEGBenchDataset, ChartMuseum, ChartQAPro, ReasonMap_Plus, diff --git a/vlmeval/dataset/almbench.py b/vlmeval/dataset/almbench.py new file mode 100644 index 000000000..63bdd5b17 --- /dev/null +++ b/vlmeval/dataset/almbench.py @@ -0,0 +1,212 @@ +""" +VLMEvalKit dataset class for ALM-Bench. +""" + +import re +import string + +import pandas as pd + +from ..smp import load +from .image_base import ImageBaseDataset + +LANGUAGES = [ + 'Afrikaans', 'Albanian', 'Amharic', 'Armenian', 'Assamese', 'Azerbaijani', + 'Basque', 'Belarusian', 'Bengali', 'Bhojpuri', 'Bosnian', 'Bulgarian', + 'Catalan', 'Cebuano', 'Chinese_Simplified', 'Chinese_Traditional', 'Croatian', + 'Czech', 'Danish', 'Dutch', 'Egyptian_Arabic', 'Emirati_Arabic', 'English', + 'Estonian', 'Filipino', 'Finnish', 'French', 'Galician', 'Georgian', + 'German', 'Greek', 'Gujarati', 'Hausa', 'Hawaiian', 'Hebrew', 'Hindi', + 'Hungarian', 'Icelandic', 'Igbo', 'Indonesian', 'Irish', 'Italian', + 'Japanese', 'Javanese', 'Kannada', 'Kazakh', 'Kinyarwanda', 'Korean', + 'Kurdish', 'Kyrgyz', 'Lao', 'Latin', 'Latvian', 'Lithuanian', + 'Luxembourgish', 'Macedonian', 'Malagasy', 'Malay', 'Malayalam', 'Maltese', + 'Marathi', 'Mongolian', 'Myanmar_Burmese', 'Nepali', 'Norwegian', + 'Odia_Oriya', 'Pashto', 'Persian', 'Polish', 'Portuguese', 'Punjabi', + 'Romanian', 'Russian', 'Sanskrit', 'Saudi_Arabic', 'Scots_Gaelic', + 'Serbian', 'Shona', 'Sindhi', 'Sinhala', 'Slovak', 'Slovenian', 'Somali', + 'Spanish', 'Sundanese', 'Swahili', 'Swedish', 'Tajik', 'Tamil', 'Telugu', + 'Thai', 'Turkish', 'Ukrainian', 'Urdu', 'Uyghur', 'Uzbek', 'Vietnamese', + 'Welsh', 'Yiddish', 'Yoruba', +] + + +def _make_url_dicts(): + names = ['ALMBench'] + [f'ALMBench_{lang}' for lang in LANGUAGES] + return {name: '' for name in names}, {name: None for name in names} + + +DATASET_URL, DATASET_MD5 = _make_url_dicts() + + +def _normalise(text: str) -> str: + """Lowercase, strip punctuation and extra whitespace.""" + text = str(text).lower().strip() + text = text.translate(str.maketrans('', '', string.punctuation)) + text = re.sub(r'\s+', ' ', text).strip() + return text + + +def _question_family(question_type: str) -> str: + qtype = _normalise(question_type) + if qtype in ('t/f', 'true/false', 'tf', 'true false question'): + return 'tf' + if qtype in ('mcqs', 'mcq', 'multiple choice', 'multiple choice questions'): + return 'mcq' + if qtype in ('svqas', 'svqa', 'short questions', 'short'): + return 'short' + if qtype in ('lvqas', 'lvqa', 'long question', 'long questions', 'long'): + return 'long' + return 'open' + + +def _extract_tf(text: str): + """Extract True/False from a model prediction.""" + norm = _normalise(text) + if re.search(r'\btrue\b|\byes\b|\bcorrect\b', norm): + return 'true' + if re.search(r'\bfalse\b|\bno\b|\bincorrect\b', norm): + return 'false' + return None + + +def _extract_mcq_answer(answer: str) -> str: + text = str(answer).strip() + for delimiter in (' (', '\n('): + if delimiter in text: + return text.split(delimiter, 1)[0].strip() + return text + + +def _soft_exact_match(prediction: str, answer: str) -> bool: + return _normalise(prediction) == _normalise(answer) + + +def _tf_match(prediction: str, answer: str, english_answer: str = '') -> bool: + pred_label = _extract_tf(prediction) + ans_label = _extract_tf(english_answer) if str(english_answer).strip() else None + if ans_label is None: + ans_label = _extract_tf(answer) + if pred_label is None or ans_label is None: + if english_answer and _soft_exact_match(prediction, english_answer): + return True + return _soft_exact_match(prediction, answer) + return pred_label == ans_label + + +def _accuracy(df: pd.DataFrame) -> float: + if len(df) == 0: + return 0.0 + return round(df['correct'].sum() / len(df) * 100, 2) + + +def _evaluate_row(row) -> bool: + qtype = _question_family(str(row.get('question_type', ''))) + prediction = str(row['prediction']) + answer = str(row['answer']) + english_answer = str(row.get('english_answer', '')) + + if qtype == 'tf': + return _tf_match(prediction, answer, english_answer) + if qtype == 'mcq': + return _soft_exact_match(prediction, _extract_mcq_answer(answer)) + return _soft_exact_match(prediction, answer) + + +class ALMBenchDataset(ImageBaseDataset): + TYPE = 'VQA' + MODALITY = 'IMAGE' + DATASET_URL = DATASET_URL + DATASET_MD5 = DATASET_MD5 + + def build_prompt(self, line): + if isinstance(line, int): + line = self.data.iloc[line] + + img_paths = self.dump_image(line) + if not isinstance(img_paths, list): + img_paths = [img_paths] + + question = str(line['question']) + family = _question_family(str(line.get('question_type', '')).strip().lower()) + + if family == 'tf': + instruction = 'Answer with True or False only.' + elif family == 'mcq': + instruction = 'Answer using only the text of the correct option.' + elif family == 'short': + instruction = 'Answer the question using a single word or short phrase.' + else: + instruction = 'Answer the question as accurately as possible.' + + prompt = f'{question}\n{instruction}' + msgs = [dict(type='image', value=p) for p in img_paths] + msgs.append(dict(type='text', value=prompt)) + return msgs + + def evaluate(self, eval_file, **judge_kwargs): + data = load(eval_file) + data['correct'] = data.apply(_evaluate_row, axis=1) + + rows = [] + + def add_rows(col_name, split_label): + if col_name not in data.columns: + return + for value in sorted(data[col_name].dropna().unique()): + sub = data[data[col_name] == value] + rows.append({ + 'dataset': self.dataset_name, + 'split_by': split_label, + 'value': value, + 'total': len(sub), + 'correct': int(sub['correct'].sum()), + 'accuracy (%)': _accuracy(sub), + }) + + add_rows('language', 'language') + add_rows('category', 'category') + add_rows('question_type', 'question_type') + rows.append({ + 'dataset': self.dataset_name, + 'split_by': 'overall', + 'value': 'all', + 'total': len(data), + 'correct': int(data['correct'].sum()), + 'accuracy (%)': _accuracy(data), + }) + + result_df = pd.DataFrame(rows) + result_path = eval_file.replace('.xlsx', '_ALMBench_results.csv') + if result_path == eval_file: + result_path = eval_file + '_ALMBench_results.csv' + result_df.to_csv(result_path, index=False) + print(f'\nALM-Bench results -> {result_path}') + print(result_df.to_string(index=False)) + return result_df + + +def _make_lang_class(lang: str): + name = f'ALMBench_{lang}' + return type( + name, + (ALMBenchDataset,), + { + '__doc__': f'ALM-Bench - language: {lang}', + 'DATASET_URL': {name: DATASET_URL.get(name, '')}, + 'DATASET_MD5': {name: DATASET_MD5.get(name)}, + }, + ) + + +for _lang in LANGUAGES: + globals()[f'ALMBench_{_lang}'] = _make_lang_class(_lang) + + +class ALMBench(ALMBenchDataset): + DATASET_URL = {'ALMBench': DATASET_URL.get('ALMBench', '')} + DATASET_MD5 = {'ALMBench': DATASET_MD5.get('ALMBench')} + + +ALM_LANGUAGES = list(LANGUAGES) +ALM_DATASETS = ['ALMBench'] + [f'ALMBench_{lang}' for lang in LANGUAGES] diff --git a/vlmeval/dataset/marvl.py b/vlmeval/dataset/marvl.py new file mode 100644 index 000000000..d1f97ca63 --- /dev/null +++ b/vlmeval/dataset/marvl.py @@ -0,0 +1,155 @@ +""" +VLMEvalKit dataset class for MaRVL. +""" + +import re +import string + +import pandas as pd + +from ..smp import load +from .image_base import ImageBaseDataset + +LANGUAGES = ['id', 'sw', 'ta', 'tr', 'zh'] + +DATASET_URL = { + 'MaRVL': '', + 'MaRVL_id': '', + 'MaRVL_sw': '', + 'MaRVL_ta': '', + 'MaRVL_tr': '', + 'MaRVL_zh': '', +} + +DATASET_MD5 = { + 'MaRVL': None, + 'MaRVL_id': None, + 'MaRVL_sw': None, + 'MaRVL_ta': None, + 'MaRVL_tr': None, + 'MaRVL_zh': None, +} + +_TRUE_TOKENS = {'true', 'yes', 'correct', 'right', '1'} +_FALSE_TOKENS = {'false', 'no', 'wrong', 'incorrect', '0'} + + +def _extract_answer(prediction: str) -> str: + """Parse a free-form model prediction into 'True' or 'False'.""" + clean = str(prediction).strip().strip(string.punctuation).lower() + for tok in _TRUE_TOKENS: + if re.search(rf'\b{tok}\b', clean): + return 'True' + for tok in _FALSE_TOKENS: + if re.search(rf'\b{tok}\b', clean): + return 'False' + first = clean.split()[0] if clean.split() else clean + return first.capitalize() + + +def _normalise_binary_label(value) -> str: + """Normalise saved booleans / strings to canonical labels.""" + text = str(value).strip() + if text in {'True', 'False'}: + return text + return _extract_answer(text) + + +def _accuracy(df: pd.DataFrame) -> float: + if len(df) == 0: + return 0.0 + return round(df['correct'].sum() / len(df) * 100, 2) + + +class MaRVLDataset(ImageBaseDataset): + TYPE = 'VQA' + DATASET_URL = DATASET_URL + DATASET_MD5 = DATASET_MD5 + + def build_prompt(self, line): + if isinstance(line, int): + line = self.data.iloc[line] + + img_paths = self.dump_image(line) + if not isinstance(img_paths, list): + img_paths = [img_paths] + + question = str(line['question']) + hint = str(line.get('hint', '')) if 'hint' in line.index else '' + + prompt = ( + 'You are shown two images placed side by side.\n' + f'Hypothesis: {question}\n' + ) + if hint and hint.lower() not in ('', 'nan', 'none'): + prompt += f'(English translation: {hint})\n' + + prompt += ( + '\nBased on the two images, is the hypothesis TRUE or FALSE?\n' + 'Answer with a single word: True or False.' + ) + + msgs = [dict(type='image', value=p) for p in img_paths] + msgs.append(dict(type='text', value=prompt)) + return msgs + + def evaluate(self, eval_file, **judge_kwargs): + data = load(eval_file) + data['prediction_normalized'] = data['prediction'].apply(_normalise_binary_label) + data['answer_normalized'] = data['answer'].apply(_normalise_binary_label) + data['correct'] = data['prediction_normalized'] == data['answer_normalized'] + + rows = [] + if 'category' in data.columns: + for lang in sorted(data['category'].unique()): + sub = data[data['category'] == lang] + rows.append({ + 'dataset': self.dataset_name, + 'lang': lang, + 'total': len(sub), + 'correct': int(sub['correct'].sum()), + 'accuracy (%)': _accuracy(sub), + }) + + rows.append({ + 'dataset': self.dataset_name, + 'lang': 'overall', + 'total': len(data), + 'correct': int(data['correct'].sum()), + 'accuracy (%)': _accuracy(data), + }) + + result_df = pd.DataFrame(rows) + result_path = eval_file.replace('.xlsx', '_MaRVL_results.csv') + result_df.to_csv(result_path, index=False) + print(f'\nMaRVL results -> {result_path}') + print(result_df.to_string(index=False)) + return result_df + + +def _make_lang_class(lang: str): + name = f'MaRVL_{lang}' + return type( + name, + (MaRVLDataset,), + { + '__doc__': f'MaRVL benchmark - language: {lang}', + 'DATASET_URL': {name: DATASET_URL.get(name, '')}, + 'DATASET_MD5': {name: DATASET_MD5.get(name)}, + }, + ) + + +MaRVL_id = _make_lang_class('id') +MaRVL_sw = _make_lang_class('sw') +MaRVL_ta = _make_lang_class('ta') +MaRVL_tr = _make_lang_class('tr') +MaRVL_zh = _make_lang_class('zh') + + +class MaRVL(MaRVLDataset): + DATASET_URL = {'MaRVL': DATASET_URL.get('MaRVL', '')} + DATASET_MD5 = {'MaRVL': DATASET_MD5.get('MaRVL')} + + +MARVL_DATASETS = ['MaRVL'] + [f'MaRVL_{lang}' for lang in LANGUAGES] diff --git a/vlmeval/dataset/xgqa.py b/vlmeval/dataset/xgqa.py new file mode 100644 index 000000000..b3b1ee1ee --- /dev/null +++ b/vlmeval/dataset/xgqa.py @@ -0,0 +1,124 @@ +""" +VLMEvalKit dataset class for xGQA. +""" + +import re +import string + +import pandas as pd + +from ..smp import load +from .image_base import ImageBaseDataset + +LANGUAGES = ['bn', 'de', 'en', 'id', 'ko', 'pt', 'ru', 'zh'] + +DATASET_URL = {k: '' for k in ['xGQA'] + [f'xGQA_{lang_code}' for lang_code in LANGUAGES]} +DATASET_MD5 = {k: None for k in ['xGQA'] + [f'xGQA_{lang_code}' for lang_code in LANGUAGES]} + + +def _normalise(text: str) -> str: + """Lowercase, strip punctuation and extra whitespace.""" + text = str(text).lower().strip() + text = text.translate(str.maketrans('', '', string.punctuation)) + text = re.sub(r'\s+', ' ', text).strip() + return text + + +def _exact_match(prediction: str, answer: str) -> bool: + return _normalise(prediction) == _normalise(answer) + + +def _accuracy(df: pd.DataFrame) -> float: + if len(df) == 0: + return 0.0 + return round(df['correct'].sum() / len(df) * 100, 2) + + +class xGQADataset(ImageBaseDataset): + TYPE = 'VQA' + MODALITY = 'IMAGE' + DATASET_URL = DATASET_URL + DATASET_MD5 = DATASET_MD5 + + def build_prompt(self, line): + if isinstance(line, int): + line = self.data.iloc[line] + + img_paths = self.dump_image(line) + if not isinstance(img_paths, list): + img_paths = [img_paths] + + question = str(line['question']) + prompt = ( + f'{question}\n' + 'Answer the question using a single word or short phrase.' + ) + + msgs = [dict(type='image', value=p) for p in img_paths] + msgs.append(dict(type='text', value=prompt)) + return msgs + + def evaluate(self, eval_file, **judge_kwargs): + data = load(eval_file) + data['correct'] = data.apply( + lambda row: _exact_match(row['prediction'], row['answer']), + axis=1, + ) + + rows = [] + if 'category' in data.columns: + for lang in sorted(data['category'].unique()): + sub = data[data['category'] == lang] + rows.append({ + 'dataset': self.dataset_name, + 'lang': lang, + 'total': len(sub), + 'correct': int(sub['correct'].sum()), + 'accuracy (%)': _accuracy(sub), + }) + + rows.append({ + 'dataset': self.dataset_name, + 'lang': 'overall', + 'total': len(data), + 'correct': int(data['correct'].sum()), + 'accuracy (%)': _accuracy(data), + }) + + result_df = pd.DataFrame(rows) + result_path = eval_file.replace('.xlsx', '_xGQA_results.csv') + result_df.to_csv(result_path, index=False) + print(f'\nxGQA results -> {result_path}') + print(result_df.to_string(index=False)) + return result_df + + +def _make_lang_class(lang: str): + name = f'xGQA_{lang}' + return type( + name, + (xGQADataset,), + { + '__doc__': f'xGQA benchmark - language: {lang}', + 'DATASET_URL': {name: DATASET_URL.get(name, '')}, + 'DATASET_MD5': {name: DATASET_MD5.get(name)}, + }, + ) + + +xGQA_bn = _make_lang_class('bn') +xGQA_de = _make_lang_class('de') +xGQA_en = _make_lang_class('en') +xGQA_id = _make_lang_class('id') +xGQA_ko = _make_lang_class('ko') +xGQA_pt = _make_lang_class('pt') +xGQA_ru = _make_lang_class('ru') +xGQA_zh = _make_lang_class('zh') + + +class xGQA(xGQADataset): + DATASET_URL = {'xGQA': DATASET_URL.get('xGQA', '')} + DATASET_MD5 = {'xGQA': DATASET_MD5.get('xGQA')} + + +XGQA_DATASETS = ['xGQA'] + [f'xGQA_{lang}' for lang in LANGUAGES]