diff --git a/bugwarrior/collect.py b/bugwarrior/collect.py index bcccb8f0..6ec2ffd8 100644 --- a/bugwarrior/collect.py +++ b/bugwarrior/collect.py @@ -1,15 +1,16 @@ -from collections.abc import Iterator +from collections.abc import Iterable, Iterator import copy -from functools import cache -from importlib.metadata import entry_points +import json import logging import multiprocessing import time -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, NamedTuple from jinja2 import Template from taskw.task import Task +from bugwarrior.config import get_service + if TYPE_CHECKING: from bugwarrior.config.validation import Config from bugwarrior.services import Issue, Service @@ -21,24 +22,15 @@ SERVICE_FINISHED_ERROR = 1 -@cache -def get_service(service_name: str) -> type["Service"]: - try: - (service,) = entry_points(group='bugwarrior.service', name=service_name) - except ValueError as e: - if service_name in [ - 'activecollab', - 'activecollab2', - 'megaplan', - 'teamlab', - 'versionone', - ]: - log.warning(f"The {service_name} service has been removed.") - raise ValueError( - f"Configured service '{service_name}' not found. " - "Is it installed? Or misspelled?" - ) from e - return service.load() +class CollectedIssue(NamedTuple): + task_data: dict[str, Any] + target: str + identifier: str + + +class CollectionErrorData(NamedTuple): + error_message: str + target: str def get_service_instances(conf: "Config") -> list["Service"]: @@ -81,7 +73,9 @@ def _aggregate_issues(service: "Service", queue: multiprocessing.Queue) -> None: log.info(f"Done with [{target}] in {duration}.") -def aggregate_issues(conf: "Config", debug: bool) -> Iterator[dict | tuple[str, str]]: +def aggregate_issues( + conf: "Config", debug: bool +) -> Iterator[CollectedIssue | CollectionErrorData]: """Return all issues from every target.""" log.info("Starting to aggregate remote issues.") @@ -111,22 +105,32 @@ def aggregate_issues(conf: "Config", debug: bool) -> Iterator[dict | tuple[str, while currently_running > 0: issue = queue.get(True) try: - record = TaskConstructor(issue).get_taskwarrior_record() - record['target'] = issue.config.target - yield record + yield TaskConstructor(issue).get_data_to_sync() except AttributeError: if isinstance(issue, tuple): currently_running -= 1 completion_type, target = issue if completion_type == SERVICE_FINISHED_ERROR: log.error(f"Aborted [{target}] due to critical error.") - yield ('SERVICE FAILED', target) + yield CollectionErrorData('SERVICE FAILED', target) continue raise log.info("Done aggregating remote issues.") +def make_unique_identifier( + unique_keys: Iterable[str], task_data: dict[str, Any] +) -> str: + """For a given issue, make an identifier from its unique keys. + + This is not the same as the taskwarrior uuid, which is assigned + only once the task is created. + """ + subset = {key: task_data[key] for key in unique_keys} + return json.dumps(subset, sort_keys=True) + + class TaskConstructor: """Construct a taskwarrior task from a foreign record.""" @@ -152,6 +156,10 @@ def get_taskwarrior_record(self, refined: bool = True) -> dict[str, Any]: record['tags'] = [] if refined: record['tags'].extend(self.get_added_tags()) + + # Blank priority should mean *no* priority + if record['priority'] == '': + record['priority'] = None return record def get_template_context(self) -> dict[str, Any]: @@ -168,3 +176,11 @@ def refine_record(self, record: dict[str, Any]) -> dict[str, Any]: elif field == 'description': record['description'] = self.issue.get_default_description() return record + + def get_data_to_sync(self) -> CollectedIssue: + task_data = self.get_taskwarrior_record() + return CollectedIssue( + task_data=task_data, + identifier=make_unique_identifier(self.issue.UNIQUE_KEY, task_data), + target=self.issue.config.target, + ) diff --git a/bugwarrior/command.py b/bugwarrior/command.py index ada44111..7eccc59d 100644 --- a/bugwarrior/command.py +++ b/bugwarrior/command.py @@ -11,8 +11,8 @@ from lockfile import LockTimeout from lockfile.pidlockfile import PIDLockFile -from bugwarrior.collect import aggregate_issues, get_service -from bugwarrior.config import get_config_path, get_keyring, load_config +from bugwarrior.collect import aggregate_issues +from bugwarrior.config import get_config_path, get_keyring, get_service, load_config from bugwarrior.db import get_defined_udas_as_strings, synchronize if TYPE_CHECKING: diff --git a/bugwarrior/config/__init__.py b/bugwarrior/config/__init__.py index 83fbdbe6..7c3fd9b0 100644 --- a/bugwarrior/config/__init__.py +++ b/bugwarrior/config/__init__.py @@ -17,6 +17,7 @@ UnsupportedOption, # noqa: F401 ) from .secrets import get_keyring # noqa: F401 +from .validation import get_service # noqa:F401 # NOTE: __all__ determines the stable, public API. __all__ = [BugwarriorData.__name__, MainSectionConfig.__name__, ServiceConfig.__name__] diff --git a/bugwarrior/config/validation.py b/bugwarrior/config/validation.py index eed13c78..35da987c 100644 --- a/bugwarrior/config/validation.py +++ b/bugwarrior/config/validation.py @@ -1,3 +1,5 @@ +from functools import cache +from importlib.metadata import entry_points import logging import sys from typing import TYPE_CHECKING, Annotated, Any, NoReturn, Union @@ -5,12 +7,31 @@ from pydantic import Field, TypeAdapter, ValidationError from pydantic_core import ErrorDetails -from bugwarrior.collect import get_service - from .schema import BaseConfig, Hooks, MainSectionConfig, Notifications, ServiceConfig if TYPE_CHECKING: ServiceConfigType = ServiceConfig + from bugwarrior.services import Service + + +@cache +def get_service(service_name: str) -> type["Service"]: + try: + (service,) = entry_points(group='bugwarrior.service', name=service_name) + except ValueError as e: + if service_name in [ + 'activecollab', + 'activecollab2', + 'megaplan', + 'teamlab', + 'versionone', + ]: + log.warning(f"The {service_name} service has been removed.") + raise ValueError( + f"Configured service '{service_name}' not found. " + "Is it installed? Or misspelled?" + ) from e + return service.load() log = logging.getLogger(__name__) diff --git a/bugwarrior/db.py b/bugwarrior/db.py index 89af8258..ab6675d3 100644 --- a/bugwarrior/db.py +++ b/bugwarrior/db.py @@ -1,6 +1,5 @@ from collections.abc import Collection, Iterable, Iterator import itertools -import json import logging import re import subprocess @@ -9,7 +8,8 @@ from taskw import TaskWarriorShellout from taskw.exceptions import TaskwarriorError -from bugwarrior.collect import get_service +from bugwarrior.collect import CollectedIssue, CollectionErrorData +from bugwarrior.config import get_service from bugwarrior.notifications import send_notification if TYPE_CHECKING: @@ -43,21 +43,6 @@ def get_managed_task_uuids( return expected_task_ids -def make_unique_identifier( - unique_key_sets: Iterable[Collection[str]], issue: dict[str, Any] -) -> str: - """For a given issue, make an identifier from its unique keys. - - This is not the same as the taskwarrior uuid, which is assigned - only once the task is created. - """ - for unique_keys in unique_key_sets: - if all(key in issue for key in unique_keys): - subset = {key: issue[key] for key in unique_keys} - return json.dumps(subset, sort_keys=True) - raise RuntimeError("Could not determine unique identifier for %s" % issue) - - def find_taskwarrior_uuid( tw: TaskWarriorShellout, unique_key_sets: Iterable[Collection[str]], @@ -175,7 +160,7 @@ def run_hooks(pre_import: list[str]) -> None: def synchronize( - issue_generator: Iterable[dict | tuple[str, str]], + issue_generator: Iterator[CollectedIssue | CollectionErrorData], conf: "Config", dry_run: bool = False, ) -> None: @@ -207,27 +192,25 @@ def synchronize( } for issue in issue_generator: - if isinstance(issue, tuple): - assert issue[0] == 'SERVICE FAILED', ( - "'issue' should only be a tuple in case of a failure" - ) - successful_config_map.pop(issue[1]) + if isinstance(issue, CollectionErrorData): + successful_config_map.pop(issue.target) continue # De-duplicate issues coming in - unique_identifier = make_unique_identifier(unique_key_sets, issue) - if unique_identifier in issue_map: - log.debug(f"Merging tags and skipping. Seen {unique_identifier} of {issue}") + if issue.identifier in issue_map: + log.debug(f"Merging tags and skipping. Seen {issue.identifier} of {issue}") # Merge and deduplicate tags. - issue_map[unique_identifier]['tags'] += issue['tags'] - issue_map[unique_identifier]['tags'] = list( - set(issue_map[unique_identifier]['tags']) + new_tags = sorted( + set(issue_map[issue.identifier].task_data['tags']) + | set(issue.task_data['tags']) ) + issue_map[issue.identifier].task_data['tags'] = new_tags + else: - issue_map[unique_identifier] = issue + issue_map[issue.identifier] = issue seen_uuids = set() - for issue in issue_map.values(): + for issue, target, _ in issue_map.values(): # We received this issue from The Internet, but we're not sure what # kind of encoding the service providers may have handed us. Let's try # and decode all byte strings from UTF8 off the bat. If we encounter @@ -240,12 +223,7 @@ def synchronize( except UnicodeDecodeError: log.warning("Failed to interpret %r as utf-8" % key) - # Blank priority should mean *no* priority - if issue['priority'] == '': - issue['priority'] = None - - # Target was only tacked on to pass configuration to this function. - service_config = successful_config_map[issue.pop('target')] + service_config = successful_config_map[target] try: existing_taskwarrior_uuid = find_taskwarrior_uuid( diff --git a/tests/test_db.py b/tests/test_db.py index a4faf954..f790f675 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -4,6 +4,7 @@ import taskw.task from bugwarrior import db +from bugwarrior.collect import CollectedIssue from .base import ConfigTest @@ -58,19 +59,8 @@ def test_handles_missing_tags(self): class TestSynchronize(ConfigTest): - def test_synchronize(self): - def remove_non_deterministic_keys(tasks): - for status in ['pending', 'completed']: - for task in tasks[status]: - del task['modified'] - del task['entry'] - del task['uuid'] - task['tags'] = sorted(task['tags']) - return tasks - - def get_tasks(tw): - return remove_non_deterministic_keys(tw.load_tasks()) - + def setUp(self): + super().setUp() self.config = { 'general': { 'targets': ['my_service'], @@ -84,10 +74,38 @@ def get_tasks(tw): 'token': 'abc123', }, } - bwconfig = self.validate() + self.bwconfig = self.validate() + self.tw = taskw.TaskWarrior(self.taskrc) + + def synchronize(self, issues_data): + + issue_generator = [ + CollectedIssue( + task_data=copy.deepcopy(issue_data), + target="my_service", + identifier="abcd", + ) + for issue_data in issues_data + ] + db.synchronize(iter(issue_generator), self.bwconfig) + + def remove_non_deterministic_keys(self, tasks): + for status in ['pending', 'completed']: + for task in tasks[status]: + del task['modified'] + del task['entry'] + del task['uuid'] + task['tags'] = sorted(task['tags']) - tw = taskw.TaskWarrior(self.taskrc) - self.assertEqual(tw.load_tasks(), {'completed': [], 'pending': []}) + return tasks + + def get_tasks(self): + + return self.remove_non_deterministic_keys(self.tw.load_tasks()) + + def test_synchronize(self): + + self.assertEqual(self.tw.load_tasks(), {'completed': [], 'pending': []}) issue = { 'description': 'Blah blah blah. ☃', @@ -96,7 +114,6 @@ def get_tasks(tw): 'githuburl': 'https://example.com', 'priority': 'M', 'tags': ['foo'], - 'target': 'my_service', } duplicate_issue = copy.deepcopy(issue) duplicate_issue['tags'] = ['bar'] @@ -107,11 +124,10 @@ def get_tasks(tw): # These should be de-duplicated in db.synchronize before # writing out to taskwarrior. # https://github.com/ralphbean/bugwarrior/issues/601 - issue_generator = iter((copy.deepcopy(issue), duplicate_issue)) - db.synchronize(issue_generator, bwconfig) + self.synchronize([issue, duplicate_issue]) self.assertEqual( - get_tasks(tw), + self.get_tasks(), { 'completed': [], 'pending': [ @@ -135,11 +151,10 @@ def get_tasks(tw): # Change static field issue['project'] = 'other_project' - - db.synchronize(iter((copy.deepcopy(issue),)), bwconfig) + self.synchronize([issue]) self.assertEqual( - get_tasks(tw), + self.get_tasks(), { 'completed': [], 'pending': [ @@ -159,11 +174,11 @@ def get_tasks(tw): ) # TEST CLOSED ISSUE. - db.synchronize(iter(()), bwconfig) + self.synchronize([]) - completed_tasks = tw.load_tasks() + completed_tasks = self.tw.load_tasks() - tasks = remove_non_deterministic_keys(copy.deepcopy(completed_tasks)) + tasks = self.remove_non_deterministic_keys(copy.deepcopy(completed_tasks)) del tasks['completed'][0]['end'] self.assertEqual( tasks, @@ -186,14 +201,14 @@ def get_tasks(tw): ) # TEST REOPENED ISSUE - db.synchronize(iter((copy.deepcopy(issue),)), bwconfig) + self.synchronize([issue]) - tasks = tw.load_tasks() + tasks = self.tw.load_tasks() self.assertEqual( completed_tasks['completed'][0]['uuid'], tasks['pending'][0]['uuid'] ) - tasks = remove_non_deterministic_keys(tasks) + tasks = self.remove_non_deterministic_keys(tasks) self.assertEqual( tasks, {