diff --git a/deeppavlov/core/data/utils.py b/deeppavlov/core/data/utils.py index 6d4eb88661..318f6e6a53 100644 --- a/deeppavlov/core/data/utils.py +++ b/deeppavlov/core/data/utils.py @@ -465,7 +465,7 @@ def flatten_str_batch(batch: Union[str, Iterable]) -> Union[list, chain]: ['a', 'b', 'c', 'd'] """ - if isinstance(batch, str): + if isinstance(batch, str) or isinstance(batch, int) or isinstance(batch, float): return [batch] else: return chain(*[flatten_str_batch(sample) for sample in batch]) diff --git a/deeppavlov/core/trainers/nn_trainer.py b/deeppavlov/core/trainers/nn_trainer.py index a749a67933..49dcc07fb9 100644 --- a/deeppavlov/core/trainers/nn_trainer.py +++ b/deeppavlov/core/trainers/nn_trainer.py @@ -20,6 +20,8 @@ from pathlib import Path from typing import List, Tuple, Union, Optional, Iterable +from tqdm import tqdm + from deeppavlov.core.common.errors import ConfigError from deeppavlov.core.common.registry import register from deeppavlov.core.data.data_learning_iterator import DataLearningIterator @@ -279,7 +281,8 @@ def train_on_batches(self, iterator: DataLearningIterator) -> None: while True: impatient = False self._send_event(event_name='before_train') - for x, y_true in iterator.gen_batches(self.batch_size, data_type='train'): + log.info('The model training started') + for x, y_true in tqdm(iterator.gen_batches(self.batch_size, data_type='train')): self.last_result = self._chainer.train_on_batch(x, y_true) if self.last_result is None: self.last_result = {} diff --git a/deeppavlov/dataset_readers/basic_classification_reader.py b/deeppavlov/dataset_readers/basic_classification_reader.py index c354d2dc11..10e301b611 100644 --- a/deeppavlov/dataset_readers/basic_classification_reader.py +++ b/deeppavlov/dataset_readers/basic_classification_reader.py @@ -35,6 +35,7 @@ class BasicClassificationDatasetReader(DatasetReader): @overrides def read(self, data_path: str, url: str = None, format: str = "csv", class_sep: str = None, + float_labels: bool = False, *args, **kwargs) -> dict: """ Read dataset from data_path directory. @@ -48,6 +49,8 @@ def read(self, data_path: str, url: str = None, format: extension of files. Set of Values: ``"csv", "json"`` class_sep: string separator of labels in column with labels sep (str): delimeter for ``"csv"`` files. Default: None -> only one class per sample + float_labels (boolean): if True and class_sep is not None, we treat all classes as float + quotechar (str): what char we consider as quote in the dataset header (int): row number to use as the column names names (array): list of column names to use orient (str): indication of expected JSON string format @@ -80,7 +83,7 @@ def read(self, data_path: str, url: str = None, file = Path(data_path).joinpath(file_name) if file.exists(): if format == 'csv': - keys = ('sep', 'header', 'names') + keys = ('sep', 'header', 'names', 'quotechar') options = {k: kwargs[k] for k in keys if k in kwargs} df = pd.read_csv(file, **options) elif format == 'json': @@ -92,22 +95,27 @@ def read(self, data_path: str, url: str = None, x = kwargs.get("x", "text") y = kwargs.get('y', 'labels') - if isinstance(x, list): - if class_sep is None: - # each sample is a tuple ("text", "label") - data[data_type] = [([row[x_] for x_ in x], str(row[y])) - for _, row in df.iterrows()] - else: - # each sample is a tuple ("text", ["label", "label", ...]) - data[data_type] = [([row[x_] for x_ in x], str(row[y]).split(class_sep)) - for _, row in df.iterrows()] - else: - if class_sep is None: - # each sample is a tuple ("text", "label") - data[data_type] = [(row[x], str(row[y])) for _, row in df.iterrows()] - else: - # each sample is a tuple ("text", ["label", "label", ...]) - data[data_type] = [(row[x], str(row[y]).split(class_sep)) for _, row in df.iterrows()] + data[data_type] = [] + i = 0 + prev_n_classes = 0 # to capture samples with different n_classes + for _, row in df.iterrows(): + if isinstance(x, list): + sample = [row[x_] for x_ in x] + else: + sample = row[x] + label = str(row[y]) + if class_sep: + label = str(row[y]).split(class_sep) + if prev_n_classes == 0: + prev_n_classes = len(label) + assert len(label) == prev_n_classes, f"Wrong class number at {i} row" + if float_labels: + label = [float(k) for k in label] + if sample == sample and label == label: # not NAN + data[data_type].append((sample, label)) + else: + log.warning(f'Skipping NAN received in file {file} at {i} row') + i += 1 else: log.warning("Cannot find {} file".format(file))