diff --git a/electrum/__init__.py b/electrum/__init__.py index e94ac78ac9d9..48309db6617a 100644 --- a/electrum/__init__.py +++ b/electrum/__init__.py @@ -17,7 +17,7 @@ class GuiImportError(ImportError): from .version import ELECTRUM_VERSION from .util import format_satoshis from .wallet import Wallet -from .storage import WalletStorage +from .stored_dict import DictStorage from .coinchooser import COIN_CHOOSERS from .network import Network, pick_random_server from .interface import Interface diff --git a/electrum/commands.py b/electrum/commands.py index 37686f9c9a49..fadc9a98d1da 100644 --- a/electrum/commands.py +++ b/electrum/commands.py @@ -360,7 +360,6 @@ async def password(self, password=None, new_password=None, encrypt_file=None, wa else: encrypt_file = wallet.storage.is_encrypted() wallet.update_password(password, new_password, encrypt_storage=encrypt_file) - wallet.save_db() return {'password': wallet.has_password()} @command('w') @@ -1553,7 +1552,6 @@ async def addtransaction(self, tx, wallet: Abstract_Wallet = None): tx = Transaction(tx) if not wallet.adb.add_transaction(tx): return False - wallet.save_db() return tx.txid() @command('w') @@ -1667,7 +1665,6 @@ async def removelocaltx(self, txid, wallet: Abstract_Wallet = None): f'Only local transactions can be removed. ' f'This tx has height: {height} != {TX_HEIGHT_LOCAL}') wallet.adb.remove_transaction(txid) - wallet.save_db() @command('wn') async def get_tx_status(self, txid, wallet: Abstract_Wallet = None): diff --git a/electrum/crypto.py b/electrum/crypto.py index ad46e38b3783..1f77834edcbb 100644 --- a/electrum/crypto.py +++ b/electrum/crypto.py @@ -134,9 +134,10 @@ def strip_PKCS7_padding(data: bytes) -> bytes: return data[0:-padlen] -def aes_encrypt_with_iv(key: bytes, iv: bytes, data: bytes) -> bytes: +def aes_encrypt_with_iv(key: bytes, iv: bytes, data: bytes, append_pkcs7=True) -> bytes: assert_bytes(key, iv, data) - data = append_PKCS7_padding(data) + if append_pkcs7: + data = append_PKCS7_padding(data) if HAS_CRYPTODOME: e = CD_AES.new(key, CD_AES.MODE_CBC, iv).encrypt(data) elif HAS_CRYPTOGRAPHY: @@ -152,7 +153,7 @@ def aes_encrypt_with_iv(key: bytes, iv: bytes, data: bytes) -> bytes: return e -def aes_decrypt_with_iv(key: bytes, iv: bytes, data: bytes) -> bytes: +def aes_decrypt_with_iv(key: bytes, iv: bytes, data: bytes, strip_pkcs7=True) -> bytes: assert_bytes(key, iv, data) if HAS_CRYPTODOME: cipher = CD_AES.new(key, CD_AES.MODE_CBC, iv) @@ -168,9 +169,11 @@ def aes_decrypt_with_iv(key: bytes, iv: bytes, data: bytes) -> bytes: else: raise Exception("no AES backend found") try: - return strip_PKCS7_padding(data) + if strip_pkcs7: + data = strip_PKCS7_padding(data) except InvalidPadding: raise InvalidPassword() + return data def EncodeAES_bytes(secret: bytes, msg: bytes) -> bytes: diff --git a/electrum/daemon.py b/electrum/daemon.py index 0a5273129912..7dfb6ef0ed7c 100644 --- a/electrum/daemon.py +++ b/electrum/daemon.py @@ -47,7 +47,7 @@ log_exceptions, randrange, OldTaskGroup, UserFacingException, JsonRPCError, os_chmod ) from .wallet import Wallet, Abstract_Wallet -from .storage import WalletStorage +from .stored_dict import DictStorage from .wallet_db import WalletDB, WalletUnfinished from .commands import known_commands, Commands from .simple_config import SimpleConfig @@ -544,15 +544,14 @@ def _load_wallet( force_check_password: bool = False, # if set, always validate password ) -> Optional[Abstract_Wallet]: path = standardize_path(path) - storage = WalletStorage(path, allow_partial_writes=config.WALLET_PARTIAL_WRITES) - if not storage.file_exists(): + if not os.path.exists(path): raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), path) + storage = DictStorage(path, allow_partial_writes=config.WALLET_PARTIAL_WRITES) if storage.is_encrypted(): if not password: raise InvalidPassword('No password given') storage.decrypt(password) - # read data, pass it to db - db = WalletDB(storage.read(), storage=storage, upgrade=upgrade) + db = WalletDB(storage, upgrade=upgrade) if db.get_action(): raise WalletUnfinished(db) wallet = Wallet(db, config=config) diff --git a/electrum/gui/qml/qedaemon.py b/electrum/gui/qml/qedaemon.py index 3b3ea756253b..bf299ff78196 100644 --- a/electrum/gui/qml/qedaemon.py +++ b/electrum/gui/qml/qedaemon.py @@ -13,7 +13,7 @@ from electrum.lnchannel import ChannelState from electrum.bitcoin import is_address from electrum.bitcoin import verify_usermessage_with_address -from electrum.storage import StorageReadWriteError, WalletStorage +from electrum.stored_dict import StorageReadWriteError, DictStorage from .auth import AuthMixin, auth_protect from .qefx import QEFX @@ -333,7 +333,7 @@ def isValidWalletName(self, wallet_name: str) -> bool: wallet_path = self.wallet_path_from_wallet_name(wallet_name) # validate that the path looks sane to the filesystem: try: - temp_storage = WalletStorage(wallet_path) + temp_storage = DictStorage(wallet_path, init_db=False) except (StorageReadWriteError, WalletFileException): return False except Exception: diff --git a/electrum/gui/qml/qetxdetails.py b/electrum/gui/qml/qetxdetails.py index a5e4dabe03fd..7ea99f658e38 100644 --- a/electrum/gui/qml/qetxdetails.py +++ b/electrum/gui/qml/qetxdetails.py @@ -498,7 +498,6 @@ def removeLocalTx(self, confirm=False): return self._wallet.wallet.adb.remove_transaction(txid) - self._wallet.wallet.save_db() # NOTE: from here, the tx/txid is unknown and all properties are invalid. # UI should close TxDetails and avoid interacting with this qetxdetails instance. diff --git a/electrum/gui/qml/qewallet.py b/electrum/gui/qml/qewallet.py index eeabf5261b43..eba0e510fadf 100644 --- a/electrum/gui/qml/qewallet.py +++ b/electrum/gui/qml/qewallet.py @@ -637,7 +637,6 @@ def save_tx(self, tx: 'PartialTransaction') -> bool: self.saveTxError.emit(tx.txid(), 'conflict', _("Transaction could not be saved.") + "\n" + _("It conflicts with current history.")) return False - self.wallet.save_db() self.saveTxSuccess.emit(tx.txid()) self.historyModel.initModel(True) return True @@ -754,7 +753,8 @@ def setPassword(self, password): try: self._logger.info('setting new password') - self.wallet.update_password(current_password, password, encrypt_storage=True) + encrypt_storage = self.wallet.storage.supports_file_encryption() + self.wallet.update_password(current_password, password, encrypt_storage=encrypt_storage) # restore the invariant that all loaded wallets in qml must be unlocked: self.wallet.unlock(password) return True diff --git a/electrum/gui/qt/__init__.py b/electrum/gui/qt/__init__.py index 32ac61d18cbb..6308a2174636 100644 --- a/electrum/gui/qt/__init__.py +++ b/electrum/gui/qt/__init__.py @@ -507,7 +507,7 @@ def _start_wizard_to_select_or_create_wallet(self, path) -> Optional[Abstract_Wa self.logger.info('wizard dialog cancelled by user') return db.put('x3', wizard.get_wizard_data()['x3']) - db.write_and_force_consolidation() # TODO API for db is a bit weird: there should be a close method + db.storage.write() # TODO API for db is a bit weird: there should be a close method wallet = self.daemon.load_wallet(wallet_file, password, upgrade=True) return wallet diff --git a/electrum/gui/qt/history_list.py b/electrum/gui/qt/history_list.py index c1d44fff8a05..53524d4bdc4b 100644 --- a/electrum/gui/qt/history_list.py +++ b/electrum/gui/qt/history_list.py @@ -810,7 +810,6 @@ def remove_local_tx(self, tx_hash: str): if not self.main_window.question(msg=question, title=_("Please confirm")): return self.wallet.adb.remove_transaction(tx_hash) - self.wallet.save_db() # need to update at least: history_list, utxo_list, address_list self.main_window.need_update.set() diff --git a/electrum/gui/qt/invoice_list.py b/electrum/gui/qt/invoice_list.py index dd8bfb35d298..2c2ccc876ced 100644 --- a/electrum/gui/qt/invoice_list.py +++ b/electrum/gui/qt/invoice_list.py @@ -213,6 +213,5 @@ def show_log(self, key, log: Sequence[HtlcLog]): def delete_invoices(self, keys): for key in keys: - self.wallet.delete_invoice(key, write_to_disk=False) + self.wallet.delete_invoice(key) self.delete_item(key) - self.wallet.save_db() diff --git a/electrum/gui/qt/main_window.py b/electrum/gui/qt/main_window.py index 81b5ca9b3b79..03a9282b06bd 100644 --- a/electrum/gui/qt/main_window.py +++ b/electrum/gui/qt/main_window.py @@ -51,6 +51,7 @@ from electrum.gui import messages from electrum import (keystore, constants, util, bitcoin, commands, lnutil) +from electrum.stored_dict import PasswordType from electrum.bitcoin import COIN, is_address, DummyAddress from electrum.plugin import run_hook from electrum.i18n import _ @@ -698,6 +699,26 @@ def select_backup_dir(self, b): self.config.WALLET_BACKUP_DIRECTORY = dirname self.backup_dir_e.setText(dirname) + def get_storage_password(self): + if self.wallet.has_storage_encryption(): + if self.wallet.storage.is_encrypted_with_hw_device(): + password = self.wallet.keystore.get_password_for_storage_encryption() + password_type = PasswordType.XPUB + else: + password_type = PasswordType.USER + while True: + password = self.password_dialog(parent=self, msg='') + if not password: + raise UserCancelled + try: + self.wallet.storage.check_password(password) + except InvalidPassword: + continue + break + else: + password, password_type = None, None + return password, password_type + def backup_wallet(self): d = WindowModalDialog(self, _("File Backup")) vbox = QVBoxLayout(d) @@ -725,8 +746,16 @@ def backup_wallet(self): if backup_dir is None: self.show_message(_("You need to configure a backup directory in your preferences"), title=_("Backup not configured")) return + new_path = os.path.join(backup_dir, self.wallet.basename() + '.backup') + if os.path.exists(new_path): + self.show_message(f'File already exists: {new_path}') + return + try: + password, password_type = self.get_storage_password() + except UserCancelled: + return try: - new_path = self.wallet.save_backup(backup_dir) + self.wallet.save_backup(new_path, password, password_type) except BaseException as reason: self.show_critical(_("Electrum was unable to copy your wallet file to the specified location.") + "\n" + str(reason), title=_("Unable to create backup")) return @@ -1748,7 +1777,7 @@ def save_notes_text(self): def update_console(self): console = self.console - console.history = self.wallet.db.get_stored_item("qt-console-history", []) + console.history = self.wallet.db.get_list("qt-console-history") console.history_index = len(console.history) console.updateNamespace({ @@ -1922,14 +1951,13 @@ def update_buttons_on_seed(self): self.password_button.setVisible(self.wallet.may_have_password()) def change_password_dialog(self): - from electrum.stored_dict import StorageEncryptionVersion - if StorageEncryptionVersion.XPUB_PASSWORD in self.wallet.get_available_storage_encryption_versions(): + if self.wallet.is_hw_encryption_available(): from .password_dialog import ChangePasswordDialogForHW d = ChangePasswordDialogForHW(self, self.wallet) ok, old_password, new_password, encrypt_with_xpub = d.run() if not ok: return - has_xpub_encryption = self.wallet.storage.get_encryption_version() == StorageEncryptionVersion.XPUB_PASSWORD + has_xpub_encryption = self.wallet.storage.is_encrypted_with_hw_device() def on_password(hw_dev_pw): self._update_wallet_password( old_password = hw_dev_pw if has_xpub_encryption else old_password, @@ -1950,8 +1978,12 @@ def on_password(hw_dev_pw): self.update_lock_menu() def _update_wallet_password(self, *, old_password, new_password, xpub_encrypt=False): + encrypt_storage = self.wallet.storage.supports_file_encryption() try: - self.wallet.update_password(old_password, new_password, encrypt_storage=True, xpub_encrypt=xpub_encrypt) + self.wallet.update_password( + old_password, new_password, + encrypt_storage=encrypt_storage, + xpub_encrypt=xpub_encrypt) except InvalidPassword as e: self.show_error(str(e)) return @@ -2947,7 +2979,6 @@ def save_transaction_into_wallet(self, tx: Transaction): win.show_error(e) return False else: - self.wallet.save_db() # need to update at least: history_list, utxo_list, address_list self.need_update.set() msg = (_("Transaction added to wallet history.") + '\n\n' + diff --git a/electrum/gui/qt/wizard/wallet.py b/electrum/gui/qt/wizard/wallet.py index 8397b7ba31d4..85ae54a603ed 100644 --- a/electrum/gui/qt/wizard/wallet.py +++ b/electrum/gui/qt/wizard/wallet.py @@ -15,13 +15,13 @@ from electrum.i18n import _ from electrum.keystore import bip44_derivation, bip39_to_seed, purpose48_derivation, ScriptTypeNotSupported from electrum.plugin import run_hook, HardwarePluginLibraryUnavailable -from electrum.storage import StorageReadWriteError +from electrum.storage import StorageReadWriteError, StorageException from electrum.util import WalletFileException, get_new_wallet_name, UserFacingException, InvalidPassword from electrum.util import is_subpath, ChoiceItem, multisig_type, UserCancelled, standardize_path from electrum.wallet import wallet_types from .wizard import QEAbstractWizard, WizardComponent from electrum.logging import get_logger, Logger -from electrum import WalletStorage, mnemonic, keystore +from electrum import DictStorage, mnemonic, keystore from electrum.wallet_db import WalletDB from electrum.wizard import NewWalletWizard, KeystoreWizard, WizardViewState @@ -176,7 +176,7 @@ def is_finalized(self, wizard_data: dict) -> bool: wallet_file = wizard_data['wallet_name'] - storage = WalletStorage(wallet_file) + storage = DictStorage(wallet_file) assert storage.file_exists(), f"file {wallet_file!r} does not exist" if not storage.is_encrypted_with_user_pw() and not storage.is_encrypted_with_hw_device(): return True @@ -280,7 +280,7 @@ def __init__(self, parent, wizard): self.layout().addLayout(hbox2) self.layout().addStretch(1) - temp_storage = None # type: Optional[WalletStorage] + temp_storage = None # type: Optional[DictStorage] datadir_wallet_folder = self.wizard.config.get_datadir_wallet_path() def relative_path(path): @@ -313,12 +313,12 @@ def on_filename(filename_or_path): wallet_from_memory = self.wizard._daemon.get_wallet(_path) try: if wallet_from_memory: - temp_storage = wallet_from_memory.storage # type: Optional[WalletStorage] + temp_storage = wallet_from_memory.storage # type: Optional[DictStorage] self.wallet_is_open = True else: - temp_storage = WalletStorage(_path) + temp_storage = DictStorage(_path, init_db=False) self.wallet_exists = temp_storage.file_exists() - except (StorageReadWriteError, WalletFileException) as e: + except (StorageReadWriteError, StorageException) as e: msg = _('Cannot read file') + f'\n{repr(e)}' except Exception as e: self.logger.exception('') @@ -327,7 +327,7 @@ def on_filename(filename_or_path): msg = "" self.valid = temp_storage is not None user_needs_to_enter_password = False - if temp_storage: + if temp_storage is not None: if not temp_storage.file_exists(): msg = _("This file does not exist.") + '\n' \ + _("Press 'Next' to create this wallet, or choose another file.") @@ -1376,7 +1376,7 @@ def validate(self): def check_hw_decrypt(self): wallet_file = self.wizard_data['wallet_name'] - storage = WalletStorage(wallet_file) + storage = DictStorage(wallet_file) if not storage.is_encrypted_with_hw_device(): return True diff --git a/electrum/gui/stdio.py b/electrum/gui/stdio.py index 34eab29c3323..908c9bbf79e8 100644 --- a/electrum/gui/stdio.py +++ b/electrum/gui/stdio.py @@ -6,7 +6,7 @@ from electrum.gui import BaseElectrumGui from electrum import util -from electrum import WalletStorage, Wallet +from electrum import DictStorage, Wallet from electrum.wallet import Abstract_Wallet from electrum.wallet_db import WalletDB from electrum.util import format_satoshis, EventListener, event_listener @@ -26,7 +26,7 @@ class ElectrumGui(BaseElectrumGui, EventListener): def __init__(self, *, config, daemon, plugins): BaseElectrumGui.__init__(self, config=config, daemon=daemon, plugins=plugins) self.network = daemon.network - storage = WalletStorage(config.get_wallet_path()) + storage = DictStorage(config.get_wallet_path()) password = None if not storage.file_exists(): print("Wallet not found. try 'electrum create'") diff --git a/electrum/gui/text.py b/electrum/gui/text.py index 1c86193100f5..42d203b399ad 100644 --- a/electrum/gui/text.py +++ b/electrum/gui/text.py @@ -21,7 +21,7 @@ from electrum.transaction import PartialTxOutput from electrum.wallet import Wallet, Abstract_Wallet from electrum.wallet_db import WalletDB -from electrum.storage import WalletStorage +from electrum.stored_dict import DictStorage from electrum.network import NetworkParameters, TxBroadcastError, BestEffortRequestFailed, ProxySettings from electrum.interface import ServerAddr from electrum.invoices import Invoice @@ -62,7 +62,7 @@ class ElectrumGui(BaseElectrumGui, EventListener): def __init__(self, *, config: 'SimpleConfig', daemon: 'Daemon', plugins: 'Plugins'): BaseElectrumGui.__init__(self, config=config, daemon=daemon, plugins=plugins) self.network = daemon.network - storage = WalletStorage(config.get_wallet_path()) + storage = DictStorage(config.get_wallet_path()) password = None if not storage.file_exists(): print("Wallet not found. try 'electrum create'") diff --git a/electrum/invoices.py b/electrum/invoices.py index 55ad4fc5f0b9..8431713adb36 100644 --- a/electrum/invoices.py +++ b/electrum/invoices.py @@ -236,7 +236,7 @@ def get_id(self) -> str: else: # on-chain return get_id_from_onchain_outputs(outputs=self.get_outputs(), timestamp=self.time) - def as_dict(self, status): + def export(self, status): d = { 'is_lightning': self.is_lightning(), 'amount_BTC': format_satoshis(self.get_amount_sat()), @@ -298,7 +298,7 @@ def can_be_paid_onchain(self) -> bool: return True def to_debug_json(self) -> Dict[str, Any]: - d = self.to_json() + d = self.as_dict() d["lnaddr"] = self._lnaddr.to_debug_json() return d diff --git a/electrum/json_db.py b/electrum/json_db.py index dbf3c3dd861e..9c720882e698 100644 --- a/electrum/json_db.py +++ b/electrum/json_db.py @@ -26,19 +26,17 @@ import copy import json from typing import TYPE_CHECKING, Optional, Sequence, List, Union, Dict, Any +from contextlib import contextmanager import jsonpatch import jsonpointer -from . import util -from .util import WalletFileException, profiler, sticky_property +from .util import profiler, sticky_property from .logging import Logger -from .stored_dict import StoredDict, _FLEX_KEY, registered_names, registered_keys, _convert_dict_key, _convert_dict_value +from .stored_dict import _FLEX_KEY, BaseDB, StorageException +from .storage import FileStorage -if TYPE_CHECKING: - from .storage import WalletStorage - # We monkeypatch exceptions in the jsonpatch package to ensure they do not contain secrets from the DB. # We often log exceptions and offer to send them to the crash reporter, so they must not contain secrets. @@ -69,12 +67,6 @@ def to_str(x: _FLEX_KEY) -> str: return '/'.join(items) -def modifier(func): - def wrapper(self, *args, **kwargs): - with self.lock: - self._modified = True - return func(self, *args, **kwargs) - return wrapper def locked(func): def wrapper(self, *args, **kwargs): @@ -84,36 +76,132 @@ def wrapper(self, *args, **kwargs): - -class JsonDB(Logger): +class JsonDB(BaseDB): def __init__( - self, - s: str, - *, - storage: Optional['WalletStorage'] = None, - encoder=None, - upgrader=None, + self, + path: Optional[str], + *, + allow_partial_writes = True, + init_db = True, ): - Logger.__init__(self) + BaseDB.__init__(self, path) + self._is_closed = True self.lock = threading.RLock() - self.storage = storage - self.encoder = encoder self.pending_changes = [] # type: List[str] - self._modified = False - # load data - data = self.load_data(s) - if upgrader: - data, was_upgraded = upgrader(data) - self._modified |= was_upgraded - # convert json to python objects - data = self._convert_dict([], data) - # convert dict to StoredDict - self.data = StoredDict(data, self) - self.data.set_parent(key='', parent=None) + self._write_batch = False + if self.path: + self.storage = FileStorage(path, allow_partial_writes=allow_partial_writes) + if init_db and not self.is_encrypted(): + # open DB if file is not encrypted + # otherwise, this will be called in self.decrypt + self.init_db() + else: + self.storage = None + self.json_data = {} + self._is_closed = False + + def set_data(self, json_str): + self.json_data = self.load_data(json_str) + + def init_db(self): + if self.storage.is_encrypted(): + assert self.storage.is_past_initial_decryption() + json_str = self.storage.read() + self.json_data = self.load_data(json_str) + self._is_closed = False # write file in case there was a db upgrade - if self.storage and self.storage.file_exists(): - self.write_and_force_consolidation() + self.write(force_consolidation=True) + + def decrypt(self, password: str): + self.storage.decrypt(password) + json_str = self.storage.read() + self.set_data(json_str) + self._is_closed = False + + def check_password(self, password): + self.storage.check_password(password) + + def supports_file_encryption(self): + return bool(self.storage) + + def get_encryption_versions(self): + return self.storage.get_encryption_versions() + + def is_encrypted(self): + return self.storage and self.storage.is_encrypted() + + def is_encrypted_with_user_pw(self) -> bool: + return self.storage and self.storage.is_encrypted_with_user_pw() + + def is_encrypted_with_hw_device(self) -> bool: + return self.storage and self.storage.is_encrypted_with_hw_device() + + def add_password(self, password: str, password_type=None): + self.storage.add_password(password, password_type=password_type) + + def update_password(self, password: str, new_password: str, new_password_type): + self.storage.update_password(password, new_password, new_password_type) + + def remove_password(self, password: str): + self.storage.remove_password(password) + + def file_exists(self): + return self.storage and self.storage.file_exists() + + def _subdict(self, path): + d = self.json_data + for k in path[1:]: + d = d[k] + return d + + def iter_keys(self, d, path): + return d.__iter__() + + def dict_len(self, d, path): + return len(d) + + def dict_contains(self, d, path, key): + return key in d + + def replace(self, d, path, key, value): + # called by setattr + self.put(d, path, key, value) + + def put(self, d, path, key, value): + is_new = key not in d + if not is_new and d[key] == value: + return + op = 'dict_add' if is_new else 'dict_replace' + self.add_pending_change(d, op, path, key, value) + + def clear(self, d, path): + path, key = path[:-1], path[-1] + self.add_pending_change(d, 'dict_clear', path, key, None) + + def get(self, d, key): + return d[key] + + def get_hint(self, path): + return self._subdict(path) + + def remove(self, d, path, key): + self.add_pending_change(d, 'dict_remove', path, key, None) + + def list_append(self, _list, path, item): + self.add_pending_change(_list, 'list_append', path, None, item) + + def list_index(self, _list, path, item): + return _list.index(item) + + def list_len(self, _list, path): + return len(_list) + + def list_clear(self, _list, path): + self.add_pending_change(_list, 'list_clear', path[:-1], path[-1], None) + + def list_remove(self, _list, path, item): + self.add_pending_change(_list, 'list_remove', path[:-1], path[-1], item) def load_data(self, s: str) -> Dict[str, Any]: if s == '': @@ -127,15 +215,14 @@ def load_data(self, s: str) -> Dict[str, Any]: elif r := self.maybe_load_incomplete_data(s): data, patches = r, [] else: - raise WalletFileException("Cannot read wallet file. (parsing failed)") + raise StorageException("Cannot read wallet file. (parsing failed)") if not isinstance(data, dict): - raise WalletFileException("Malformed wallet file (not dict)") + raise StorageException("Malformed wallet file (not dict)") if patches: # apply patches self.logger.info('found %d patches'%len(patches)) patch = jsonpatch.JsonPatch(patches) data = patch.apply(data) - self.set_modified(True) return data def maybe_load_ast_data(self, s) ->Dict[str, Any]: @@ -173,67 +260,19 @@ def maybe_load_incomplete_data(self, s): self.logger.info('found incomplete data {s[i:]}') return self.load_data(s[0:-2]) - def set_modified(self, b): - with self.lock: - self._modified = b - - def modified(self): - return self._modified - @locked - def add_patch(self, patch): - self.pending_changes.append(json.dumps(patch, cls=self.encoder)) - self.set_modified(True) + def add_pending_change(self, hint, op, path, key, value): + self.pending_changes.append((hint, op, path, key, value)) + if not self._write_batch: + self.write() - def add(self, path, key: _FLEX_KEY, value) -> None: + def db_replace(self, hint, path, key: _FLEX_KEY, value) -> None: assert isinstance(key, _FLEX_KEY), repr(key) - self.add_patch({'op': 'add', 'path': key_path(path, key), 'value': value}) + self.add_pending_change(hint, 'replace', key_path(path, key), value) - def replace(self, path, key: _FLEX_KEY, value) -> None: + def db_remove(self, hint, path, key: _FLEX_KEY) -> None: assert isinstance(key, _FLEX_KEY), repr(key) - self.add_patch({'op': 'replace', 'path': key_path(path, key), 'value': value}) - - def remove(self, path, key: _FLEX_KEY) -> None: - assert isinstance(key, _FLEX_KEY), repr(key) - self.add_patch({'op': 'remove', 'path': key_path(path, key)}) - - @locked - def get(self, key, default=None): - v = self.data.get(key) - if v is None: - v = default - return v - - @modifier - def put(self, key, value): - try: - json.dumps(key, cls=self.encoder) - json.dumps(value, cls=self.encoder) - except Exception: - self.logger.info(f"json error: cannot save {repr(key)} ({repr(value)})") - return False - if value is not None: - if self.data.get(key) != value: - self.data[key] = copy.deepcopy(value) - return True - elif key in self.data: - self.data.pop(key) - return True - return False - - @locked - def get_dict(self, name) -> dict: - # Warning: interacts un-intuitively with 'put': certain parts - # of 'data' will have pointers saved as separate variables. - if name not in self.data: - self.data[name] = {} - return self.data[name] - - @locked - def get_stored_item(self, key, default) -> dict: - if key not in self.data: - self.data[key] = default - return self.data[key] + self.add_pending_change(hint, 'remove', key_path(path, key), None) @locked def dump(self, *, human_readable: bool = True) -> str: @@ -241,62 +280,90 @@ def dump(self, *, human_readable: bool = True) -> str: 'human_readable': makes the json indented and sorted, but this is ~2x slower """ return json.dumps( - self.data, + self.json_data, indent=4 if human_readable else None, sort_keys=bool(human_readable), - cls=self.encoder, ) - def _should_convert_to_stored_dict(self, key) -> bool: - return True - - def _convert_dict_key(self, path: List[str], key: str) -> _FLEX_KEY: - return _convert_dict_key(path, key) - - def _convert_dict_value(self, path: List[str], v) -> Any: - v = _convert_dict_value(path, v) - if isinstance(v, dict): - v = self._convert_dict(path, v) - return v - - def _convert_dict(self, path: List[str], data: dict): - # recursively convert json dict to StoredDict - assert all(isinstance(x, str) for x in path), repr(path) - d = {} - for k, v in list(data.items()): - child_path = path + [k] - k = self._convert_dict_key(path, k) - v = self._convert_dict_value(child_path, v) - d[k] = v - return d + @contextmanager + def write_batch(self): + assert self._write_batch is False + self._write_batch = True + try: + yield + finally: + self._write_batch = False + self.write() + + def close(self): + # do not call write, because we may need to close the DB after an exception was raised during a batch write + self._is_closed = True + + def is_closed(self): + return self._is_closed + + def _commit_pending_changes(self): + patches = [] + for hint, op, _path, key, value in self.pending_changes: + path = key_path(_path, key) + if op == 'dict_add': + hint[key] = value + patch = {'op': 'add', 'path': path, 'value': value} + elif op == 'dict_remove': + hint.pop(key, None) + patch = {'op': 'remove', 'path': path} + elif op == 'dict_replace': + hint[key] = value + patch = {'op': 'replace', 'path': path, 'value': value} + elif op == 'dict_clear': + hint.clear() + patch = {'op': 'replace', 'path': path, 'value': {}} + elif op == 'list_append': + n = len(hint) + hint.append(value) + path = key_path(_path, str(n)) + patch = {'op': 'add', 'path': path, 'value': value} + elif op == 'list_remove': + hint.remove(value) + # we replace the whole list because indexes are deprecated + patch = {'op': 'replace', 'path': path, 'value': hint} + elif op == 'list_clear': + hint.clear() + patch = {'op': 'replace', 'path': path, 'value': []} + else: + raise Exception('unknown operation') + patches.append(patch) + self.pending_changes = [] + return patches @locked - def write(self): - if self.storage.should_do_full_write_next(): - self.write_and_force_consolidation() + def write(self, force_consolidation=False): + if self._is_closed: + raise StorageException('DB is closed') + assert self._write_batch is False + patches = self._commit_pending_changes() + if not self.storage: + return + if force_consolidation or self.storage.should_do_full_write_next(): + self._write_and_force_consolidation() else: - self._append_pending_changes() + self._append_pending_changes(patches) @locked - def _append_pending_changes(self): + def _append_pending_changes(self, patches): if threading.current_thread().daemon: raise Exception('daemon thread cannot write db') - if not self.pending_changes: + if not patches: self.logger.info('no pending changes') return - self.logger.info(f'appending {len(self.pending_changes)} pending changes') - s = ''.join([',\n' + x for x in self.pending_changes]) + self.logger.info(f'appending {len(patches)} pending changes') + s = ''.join([',\n' + json.dumps(x) for x in patches]) self.storage.append(s) - self.pending_changes = [] @locked @profiler - def write_and_force_consolidation(self): + def _write_and_force_consolidation(self): if threading.current_thread().daemon: raise Exception('daemon thread cannot write db') - if not self.modified(): - return json_str = self.dump(human_readable=not self.storage.is_encrypted()) self.storage.write(json_str) - self.pending_changes = [] - self.set_modified(False) diff --git a/electrum/keystore.py b/electrum/keystore.py index 06d0e01086bc..76ff00e09a75 100644 --- a/electrum/keystore.py +++ b/electrum/keystore.py @@ -1123,14 +1123,12 @@ def hardware_keystore(d) -> Hardware_KeyStore: f'hw_keystores: {list(hw_keystores)}') def load_keystore(db: 'WalletDB', name: str) -> KeyStore: - # deepcopy object to avoid keeping a pointer to db.data - # note: this is needed as type(wallet.db.get("keystore")) != StoredDict - d = copy.deepcopy(db.get(name, {})) + d = db.get(name) + if d is None: + raise WalletFileException('Cannot find keystore for name {}'.format(name)) t = d.get('type') if not t: - raise WalletFileException( - 'Wallet format requires update.\n' - 'Cannot find keystore for name {}'.format(name)) + raise WalletFileException('Cannot find keystore for name {}'.format(name)) keystore_constructors = {ks.type: ks for ks in [Old_KeyStore, Imported_KeyStore, BIP32_KeyStore]} keystore_constructors['hardware'] = hardware_keystore try: diff --git a/electrum/lnchannel.py b/electrum/lnchannel.py index 56595eee61c8..f29d4dc5d175 100644 --- a/electrum/lnchannel.py +++ b/electrum/lnchannel.py @@ -62,10 +62,11 @@ from .lnutil import ChannelBackupStorage, ImportedChannelBackupStorage, OnchainChannelBackupStorage from .lnutil import format_short_channel_id from .fee_policy import FEERATE_PER_KW_MIN_RELAY_LIGHTNING +from .stored_dict import stored_at if TYPE_CHECKING: from .lnworker import LNWallet - from .json_db import StoredDict + from .stored_dict import StoredDict # channel flags diff --git a/electrum/lnhtlc.py b/electrum/lnhtlc.py index 1d36442d4108..2bdad634760f 100644 --- a/electrum/lnhtlc.py +++ b/electrum/lnhtlc.py @@ -6,7 +6,7 @@ from .util import bfh, with_lock if TYPE_CHECKING: - from .json_db import StoredDict + from .stored_dict import StoredDict LOG_TEMPLATE = { 'adds': {}, # "side who offered htlc" -> htlc_id -> htlc diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index 4b88f87d6e7d..d149cbda41c5 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -1191,7 +1191,7 @@ async def channel_establishment_flow( lnworker=self.lnworker, initial_feerate=feerate ) - temp_chan.storage['funding_inputs'] = [txin.prevout.to_json() for txin in funding_tx.inputs()] + temp_chan.storage['funding_inputs'] = [txin.prevout for txin in funding_tx.inputs()] temp_chan.storage['has_onchain_backup'] = has_onchain_backup temp_chan.storage['init_height'] = self.lnworker.network.get_local_height() temp_chan.storage['init_timestamp'] = int(time.time()) diff --git a/electrum/lnutil.py b/electrum/lnutil.py index 8baff6182aed..8c3559380021 100644 --- a/electrum/lnutil.py +++ b/electrum/lnutil.py @@ -1942,7 +1942,7 @@ def from_tuple(amount_msat, rhash, cltv_abs, htlc_id, timestamp) -> 'UpdateAddHt htlc_id=htlc_id, timestamp=timestamp) - def to_json(self): + def as_tuple(self): self._validate() return dataclasses.astuple(self) diff --git a/electrum/lnwatcher.py b/electrum/lnwatcher.py index f6e4c7beb83f..8c3de4dd55b9 100644 --- a/electrum/lnwatcher.py +++ b/electrum/lnwatcher.py @@ -140,6 +140,9 @@ async def check_onchain_situation(self, address: str, funding_outpoint: str) -> # early return if address has not been added yet if not self.adb.is_mine(address): return + # early return if storage has been closed + if self.adb.db.storage.is_closed(): + return # inspect_tx_candidate might have added new addresses, in which case we return early # note: maybe we should wait until adb.is_up_to_date... (?) funding_txid = funding_outpoint.split(':')[0] diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 6e65cf626854..b30f3b6afed6 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -1426,7 +1426,6 @@ def save_channel(self, chan: Channel): assert type(chan) is Channel if chan.config[REMOTE].next_per_commitment_point == chan.config[REMOTE].current_per_commitment_point: raise Exception("Tried to save channel with next_point == current_point, this should not happen") - self.wallet.save_db() util.trigger_callback('channel', self.wallet, chan) def channel_by_txo(self, txo: str) -> Optional[AbstractChannel]: @@ -2746,8 +2745,6 @@ def create_payment_info( ) self.save_preimage(payment_hash, payment_preimage, write_to_disk=False) self.save_payment_info(info, write_to_disk=False) - if write_to_disk: - self.wallet.save_db() return payment_hash def bundle_payments(self, hash_list: Sequence[bytes]) -> None: @@ -2839,8 +2836,6 @@ def save_preimage( return self.logger.debug(f"saving preimage for {payment_hash.hex()} (public={mark_as_public})") self._preimages[payment_hash.hex()] = new_tuple - if write_to_disk: - self.wallet.save_db() def get_preimage(self, payment_hash: bytes) -> Optional[bytes]: assert isinstance(payment_hash, bytes), f"expected bytes, but got {type(payment_hash)}" @@ -2939,8 +2934,6 @@ def save_payment_info(self, info: PaymentInfo, *, write_to_disk: bool = True) -> raise Exception(f"payment_hash already in use: {info=} != {old_info=}") v = info.amount_msat, info.status, info.min_final_cltv_delta, info.expiry_delay, info.creation_ts, int(info.invoice_features) self.payment_info[info.db_key] = v - if write_to_disk: - self.wallet.save_db() def update_or_create_mpp_with_received_htlc( self, @@ -3000,7 +2993,6 @@ def set_mpp_resolution(self, payment_key: str, new_resolution: RecvMPPResolution raise ValueError(f'forbidden mpp set transition: {mpp_status.resolution} -> {new_resolution}') self.logger.info(f'set_mpp_resolution {new_resolution.name} {len(mpp_status.htlcs)=}: {payment_key=}') self.received_mpp_htlcs[payment_key] = mpp_status._replace(resolution=new_resolution) - self.wallet.save_db() return self.received_mpp_htlcs[payment_key] def set_htlc_set_error( @@ -3654,7 +3646,7 @@ def remove_channel(self, chan_id): assert chan.can_be_deleted() with self.lock: self._channels.pop(chan_id) - self.db.get('channels').pop(chan_id.hex()) + self.db.get_dict('channels').pop(chan_id.hex()) self.wallet.set_reserved_addresses_for_chan(chan, reserved=False) util.trigger_callback('channels_updated', self.wallet) @@ -3783,7 +3775,6 @@ def import_channel_backup(self, data): cb = ChannelBackup(cb_storage, lnworker=self) self._channel_backups[channel_id] = cb self.wallet.set_reserved_addresses_for_chan(cb, reserved=True) - self.wallet.save_db() util.trigger_callback('channels_updated', self.wallet) self.lnwatcher.add_channel(cb) @@ -3811,7 +3802,6 @@ def remove_channel_backup(self, channel_id): with self.lock: self._channel_backups.pop(channel_id) self.wallet.set_reserved_addresses_for_chan(chan, reserved=False) - self.wallet.save_db() util.trigger_callback('channels_updated', self.wallet) @log_exceptions @@ -3901,7 +3891,6 @@ def maybe_add_backup_from_tx(self, tx): d[channel_id] = cb_storage cb = ChannelBackup(cb_storage, lnworker=self) self.wallet.set_reserved_addresses_for_chan(cb, reserved=True) - self.wallet.save_db() with self.lock: self._channel_backups[bfh(channel_id)] = cb util.trigger_callback('channels_updated', self.wallet) diff --git a/electrum/plugins/psbt_nostr/psbt_nostr.py b/electrum/plugins/psbt_nostr/psbt_nostr.py index 9299844578e9..29fee637d58e 100644 --- a/electrum/plugins/psbt_nostr/psbt_nostr.py +++ b/electrum/plugins/psbt_nostr/psbt_nostr.py @@ -297,6 +297,5 @@ def add_transaction_to_wallet( if on_failure: on_failure(str(e)) else: - self.wallet.save_db() if on_success: on_success() diff --git a/electrum/plugins/watchtower/watchtower.py b/electrum/plugins/watchtower/watchtower.py index 4b161304e017..dd25e9418e4a 100644 --- a/electrum/plugins/watchtower/watchtower.py +++ b/electrum/plugins/watchtower/watchtower.py @@ -36,6 +36,7 @@ from electrum.network import Network from electrum.address_synchronizer import AddressSynchronizer, TX_HEIGHT_LOCAL from electrum.wallet_db import WalletDB +from electrum.stored_dict import DictStorage from electrum.lnutil import WITNESS_TEMPLATE_RECEIVED_HTLC, WITNESS_TEMPLATE_OFFERED_HTLC from electrum.logging import Logger from electrum.util import EventListener, event_listener @@ -67,7 +68,8 @@ class WatchTower(Logger, EventListener): def __init__(self, network: 'Network'): Logger.__init__(self) self.config = network.config - wallet_db = WalletDB('', storage=None, upgrade=True) + storage = DictStorage(None) + wallet_db = WalletDB(storage) self.adb = AddressSynchronizer(wallet_db, self.config, name=self.diagnostic_name()) self.adb.start_network(network) self.callbacks = {} # address -> lambda function diff --git a/electrum/scripts/bruteforce_pw.py b/electrum/scripts/bruteforce_pw.py index 621e0e8f63bb..44cb320a54af 100755 --- a/electrum/scripts/bruteforce_pw.py +++ b/electrum/scripts/bruteforce_pw.py @@ -29,7 +29,7 @@ from functools import partial from electrum.wallet import Wallet, Abstract_Wallet -from electrum.storage import WalletStorage +from electrum.stored_dict import DictStorage from electrum.wallet_db import WalletDB from electrum.simple_config import SimpleConfig from electrum.util import InvalidPassword @@ -39,7 +39,7 @@ MAX_PASSWORD_LEN = 12 -def test_password_for_storage_encryption(storage: WalletStorage, password: str) -> bool: +def test_password_for_storage_encryption(storage: DictStorage, password: str) -> bool: try: storage.decrypt(password) except InvalidPassword: @@ -76,7 +76,7 @@ def bruteforce_loop(test_password: Callable[[str], bool]) -> str: path = sys.argv[1] config = SimpleConfig() - storage = WalletStorage(path) + storage = DictStorage(path) if not storage.file_exists(): print(f"ERROR. wallet file not found at path: {path}") sys.exit(1) @@ -84,7 +84,7 @@ def bruteforce_loop(test_password: Callable[[str], bool]) -> str: test_password = partial(test_password_for_storage_encryption, storage) print(f"wallet found: with storage encryption.") else: - db = WalletDB(storage.read(), storage=storage, upgrade=False) + db = WalletDB(storage, upgrade=False) wallet = Wallet(db, config=config) if not wallet.has_password(): print("wallet found but it is not encrypted.") diff --git a/electrum/scripts/quick_start.py b/electrum/scripts/quick_start.py index 47115efc26de..29a03807002c 100755 --- a/electrum/scripts/quick_start.py +++ b/electrum/scripts/quick_start.py @@ -6,7 +6,6 @@ from electrum.simple_config import SimpleConfig from electrum import constants from electrum.daemon import Daemon -from electrum.storage import WalletStorage from electrum.wallet import Wallet, create_new_wallet from electrum.wallet_db import WalletDB from electrum.commands import Commands diff --git a/electrum/simple_config.py b/electrum/simple_config.py index cf76ffa3f409..4b8210f9d006 100644 --- a/electrum/simple_config.py +++ b/electrum/simple_config.py @@ -705,7 +705,7 @@ def __setattr__(self, name, value): to a previously-paid address of yours that would then be included with unrelated inputs in your future payments."""), ) WALLET_PARTIAL_WRITES = ConfigVar( - 'wallet_partial_writes', default=False, type_=bool, + 'wallet_partial_writes', default=True, type_=bool, long_desc=lambda: _("""Allows partial updates to be written to disk for the wallet DB. If disabled, the full wallet file is written to disk for every change. Experimental."""), ) diff --git a/electrum/storage.py b/electrum/storage.py index 6accb9ca3167..7686eac752b8 100644 --- a/electrum/storage.py +++ b/electrum/storage.py @@ -23,20 +23,36 @@ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. import os -import threading +import io import stat import hashlib import base64 import zlib +import hmac +import struct + from typing import Optional +from secrets import token_bytes import electrum_ecc as ecc from . import crypto -from .util import (profiler, InvalidPassword, WalletFileException, bfh, standardize_path, - test_read_write_permissions, os_chmod) - +from .util import InvalidPassword, standardize_path, test_read_write_permissions, os_chmod from .logging import Logger +from .crypto import aes_encrypt_with_iv, aes_decrypt_with_iv, strip_PKCS7_padding +from .stored_dict import StorageReadWriteError, StorageException, PasswordType + + +STORAGE_VERSION = 0 +STORAGE_MAGIC_BYTES = b'Electrum' # pass this as parameter + +STORAGE_FLAG_ZIP_FIRST_BLOB = 0x01 +STORAGE_FLAGS = STORAGE_FLAG_ZIP_FIRST_BLOB + +KDF_FLAGS = 0 # update when we change the kdf +KDF_POWER = 16 # rounds = pow(2, kdf_power) +MAX_KDF_POWER = 22 +MAX_PASSWORDS = 5 def get_derivation_used_for_hw_device_encryption(): @@ -45,16 +61,44 @@ def get_derivation_used_for_hw_device_encryption(): "/1112098098'") # ascii 'BIE2' as decimal -from .stored_dict import StorageEncryptionVersion, StorageReadWriteError - +def var_int(i: int) -> bytes: + # https://en.bitcoin.it/wiki/Protocol_specification#Variable_length_integer + # https://github.com/bitcoin/bitcoin/blob/efe1ee0d8d7f82150789f1f6840f139289628a2b/src/serialize.h#L247 + # "CompactSize" + assert i >= 0, i + if i < 0xfd: + return int.to_bytes(i, length=1, byteorder="little", signed=False) + elif i <= 0xffff: + return b"\xfd" + int.to_bytes(i, length=2, byteorder="little", signed=False) + elif i <= 0xffffffff: + return b"\xfe" + int.to_bytes(i, length=4, byteorder="little", signed=False) + else: + return b"\xff" + int.to_bytes(i, length=8, byteorder="little", signed=False) + + +def read_var_int(stream): + # leaves cursor unchanged + pos = stream.tell() + x = ord(stream.read(1)) + if x == 253: + format = ' None: try: @@ -105,7 +155,7 @@ def write(self, data: str) -> None: os_chmod(temp_path, mode) # set restrictive perms *before* we write data except PermissionError as e: # tolerate NFS or similar weirdness? self.logger.warning(f"cannot chmod temp wallet file: {e!r}") - f.write(s.encode("utf-8")) + f.write(s) self.pos = f.seek(0, os.SEEK_END) f.flush() os.fsync(f.fileno()) @@ -117,14 +167,19 @@ def write(self, data: str) -> None: self.logger.info(f"saved {self.path}") def append(self, data: str) -> None: - """ append data to file. for the moment, only non-encrypted file""" + """ append data to encrypted file""" assert self._allow_partial_writes - assert not self.is_encrypted() + s, mac = self.maybe_encrypt_for_append(data) with open(self.path, "rb+") as f: pos = f.seek(0, os.SEEK_END) if pos != self.pos: raise StorageOnDiskUnexpectedlyChanged(f"expected size {self.pos}, found {pos}") - f.write(data.encode("utf-8")) + f.write(s) + f.flush() + os.fsync(f.fileno()) # this must be written before the hmac + if mac is not None: + f.seek(self.mac_offset, 0) + f.write(mac) self.pos = f.seek(0, os.SEEK_END) f.flush() os.fsync(f.fileno()) @@ -136,7 +191,6 @@ def should_do_full_write_next(self) -> bool: """If false, next action can be a partial-write ('append').""" return ( not self.file_exists() - or self.is_encrypted() or self._needs_consolidation() or not self._allow_partial_writes ) @@ -151,110 +205,292 @@ def is_past_initial_decryption(self) -> bool: if encryption is disabled completely (self.is_encrypted() == False), or if encryption is enabled but the contents have already been decrypted. """ - return not self.is_encrypted() or bool(self.pubkey) + return not self.is_encrypted() or bool(self.master_key) def is_encrypted(self) -> bool: """Return if storage encryption is currently enabled.""" - return self.get_encryption_version() != StorageEncryptionVersion.PLAINTEXT + return self._is_old_base64 or len(self.encrypted_keys) > 0 def is_encrypted_with_user_pw(self) -> bool: - return self.get_encryption_version() == StorageEncryptionVersion.USER_PASSWORD + return PasswordType.USER in self.get_encryption_versions() def is_encrypted_with_hw_device(self) -> bool: - return self.get_encryption_version() == StorageEncryptionVersion.XPUB_PASSWORD - - def get_encryption_version(self): - """Return the version of encryption used for this storage. + return PasswordType.XPUB in self.get_encryption_versions() - 0: plaintext / no encryption - - ECIES, private key derived from a password, - 1: password is provided by user - 2: password is derived from an xpub; used with hw wallets + def get_encryption_versions(self) -> list[PasswordType]: """ - return self._encryption_version - - def _init_encryption_version(self): - try: - magic = base64.b64decode(self.raw, validate=True)[0:4] - if magic == b'BIE1': - return StorageEncryptionVersion.USER_PASSWORD - elif magic == b'BIE2': - return StorageEncryptionVersion.XPUB_PASSWORD + Returns a list of encryption versions (password types) used for this storage. + Empty list if unencrypted. + """ + if self._is_old_base64: + return [self._encryption_version] + return [x[0] for x in self.encrypted_keys] + + def read_header(self): + f = open(self.path, "rb") + first_bytes = f.read(8) + if first_bytes.startswith(base64.b64encode(b'BIE')): + self._is_old_base64 = True + data = first_bytes + f.read() + self.raw = base64.b64decode(data, validate=True) + self._magic = self.raw[0:4] + if self._magic not in [b'BIE1', b'BIE2']: + raise StorageException('unknown file format') + self._encryption_version = PasswordType.USER if self._magic == b'BIE1' else PasswordType.XPUB + else: + self._is_old_base64 = False + if first_bytes != STORAGE_MAGIC_BYTES: + self.raw = first_bytes + f.read() else: - return StorageEncryptionVersion.PLAINTEXT - except Exception: - return StorageEncryptionVersion.PLAINTEXT + # magic_bytes + version + flags + salt + num_password + n*[pw_type, kdf_flags, kdf_power, encrypted_master_key] + mac + version = ord(f.read(1)) + if version != STORAGE_VERSION: + raise StorageException(f'Unsupported storage version {version}') + self._storage_flags = ord(f.read(1)) + self.salt = f.read(16) + num_passwords = ord(f.read(1)) + if num_passwords > MAX_PASSWORDS: + raise StorageException(f'Too many passwords in header: {num_passwords}') + self.encrypted_keys = [] + for i in range(num_passwords): + password_type = PasswordType(ord(f.read(1))) + kdf_flags = ord(f.read(1)) + kdf_power = ord(f.read(1)) + if kdf_power > MAX_KDF_POWER: + raise StorageException(f'KDF power too high: {kdf_power}') + encrypted_master_key = f.read(32) + self.encrypted_keys.append((password_type, kdf_flags, kdf_power, encrypted_master_key)) + self.master_key_mac = f.read(32) + header_size = f.tell() + f.seek(0) + self.header = f.read(header_size) + f.close() + + def update_header(self, is_zipped=False) -> bytes: + N = len(self.encrypted_keys) + assert N <= MAX_PASSWORDS + self._storage_flags = STORAGE_FLAGS + header = STORAGE_MAGIC_BYTES + bytes([STORAGE_VERSION, self._storage_flags]) + self.salt + bytes([N]) + for item in self.encrypted_keys: + pw_type, kdf_flags, kdf_power, encrypted_master_key = item + header += bytes([pw_type, kdf_flags, kdf_power]) + encrypted_master_key + mac = hmac.new(self.master_key, None, hashlib.sha256).digest() + assert len(mac) == 32 + header += mac + self.header = header + self.master_key_mac = mac @staticmethod - def get_eckey_from_password(password): + def get_old_eckey_from_password(password): if password is None: password = "" secret = hashlib.pbkdf2_hmac('sha512', password.encode('utf-8'), b'', iterations=1024) ec_key = ecc.ECPrivkey.from_arbitrary_size_secret(secret) return ec_key - def _get_encryption_magic(self): - v = self._encryption_version - if v == StorageEncryptionVersion.USER_PASSWORD: - return b'BIE1' - elif v == StorageEncryptionVersion.XPUB_PASSWORD: - return b'BIE2' - else: - raise WalletFileException('no encryption magic for version: %s' % v) + def get_secret_from_password(self, password: str, kdf_flags:int, rounds: int): + if password is None: + password = "" + # kdf flags are not used for the moment + return hashlib.pbkdf2_hmac('sha512', password.encode('utf-8'), self.salt, iterations=rounds) + + def read_all(self): + with open(self.path, "rb") as f: + self.raw = f.read() + + def decrypt_old(self, password) -> None: + self.read_all() + ec_key = self.get_old_eckey_from_password(password) + s = crypto.ecies_decrypt_message(ec_key, self.raw, magic=self._magic) + s = zlib.decompress(s) + # convert to new scheme + self.init_master_key() + self._add_password_to_header(password, self._encryption_version) + self.update_header() + self.decrypted = s.decode('utf8') + self.write(self.decrypted) + self._is_old_base64 = False def decrypt(self, password) -> None: - """Raises an InvalidPassword exception on invalid password""" + """May raise InvalidPassword or StorageException""" if self.is_past_initial_decryption(): return - ec_key = self.get_eckey_from_password(password) - if self.raw: - enc_magic = self._get_encryption_magic() - s = zlib.decompress(crypto.ecies_decrypt_message(ec_key, self.raw, magic=enc_magic)) - s = s.decode('utf8') - else: - s = '' - self.pubkey = ec_key.get_public_key_hex() + if self._is_old_base64: + self.decrypt_old(password) + return + # check_password may raise InvalidPassword + self.check_password(password) + self.read_all() + mac = self.raw[self.mac_offset:self.mac_offset + 32] + iv = self.raw[self.mac_offset + 32:self.mac_offset + 32 + 16] + key_e, key_m = self.master_key[0:16], self.master_key[16:32] + ciphertext = self.raw[self.mac_offset + 32 + 16:] + # truncate ciphertext if it exceeds 16-byte block boundary + remainder = len(ciphertext) % 16 + if remainder > 0: + self.truncate_file(remainder) + ciphertext = ciphertext[0:-remainder] + # decrypt. this too may raise InvalidPassword, although that would rather result from corrupted file + decrypted = aes_decrypt_with_iv(key_e, iv, ciphertext, strip_pkcs7=False) + stream = io.BytesIO(decrypted) + s = b'' + self.mac = hmac.new(key_m, b'', hashlib.sha256) + # we break the loop if the remaining bytes have not been commited + while self.mac.digest() != mac: + try: + n = read_var_int(stream) + n_size = len(var_int(n)) + blob = stream.read(n*16) + blob = strip_PKCS7_padding(blob) + blob = blob[n_size:] + self.mac.update(blob) + if len(s) == 0: + # the first blob may be zipped + if self._storage_flags & STORAGE_FLAG_ZIP_FIRST_BLOB: + blob = zlib.decompress(blob) + except Exception as e: + # the file has been corrupted + raise StorageException(str(e)) + s += blob + # truncate the file if there are remaining bytes not covered by hmac + cursor = stream.tell() + if cursor < len(decrypted): + self.truncate_file(len(decrypted) - cursor) + self.next_iv = ciphertext[cursor-16:cursor] + s = s.decode('utf8') self.decrypted = s - def encrypt_before_writing(self, plaintext: str) -> str: - s = plaintext - if self.pubkey: - self.decrypted = plaintext - s = bytes(s, 'utf8') - c = zlib.compress(s, level=zlib.Z_BEST_SPEED) - enc_magic = self._get_encryption_magic() - public_key = ecc.ECPubkey(bfh(self.pubkey)) - s = crypto.ecies_encrypt_message(public_key, c, magic=enc_magic) - s = s.decode('utf8') + def truncate_file(self, delta: int): + self.logger.info(f"truncating file {delta}") + with open(self.path, "rb+") as f: + self.pos -= delta + f.truncate(self.pos) + self.init_pos = self.pos + + def get_prefixed_blob(self, s: bytes) -> bytes: + """return data prefixed by its size (number of 16 bytes blocks required, including bytes used for size and the padding) """ + for x in [1,3,5,9]: + size = len(s) + x + n = size // 16 + 1 # add one for pkcs7 padding + header = var_int(n) + if len(header) == x: + return header + s + else: + raise Exception('blob too large for var_int') + + def init_master_key(self): + self.salt = token_bytes(16) + self.master_key = token_bytes(32) + + def encrypt_before_writing(self, plaintext: str) -> bytes: + s = bytes(plaintext, 'utf8') + if self.master_key: + if self._storage_flags & STORAGE_FLAG_ZIP_FIRST_BLOB: + s = zlib.compress(s, level=zlib.Z_BEST_SPEED) + blob = self.get_prefixed_blob(s) + key_e, key_m = self.master_key[0:16], self.master_key[16:32] + iv = token_bytes(16) + ciphertext = aes_encrypt_with_iv(key_e, iv, blob) + # save mac, key_e, key_m, and iv, for subsequent writes + self.next_iv = ciphertext[-16:] + self.mac = hmac.new(key_m, s, hashlib.sha256) + mac = self.mac.digest() + s = self.header + mac + iv + ciphertext return s + def maybe_encrypt_for_append(self, plaintext: str) -> str: + s = bytes(plaintext, 'utf8') + if self.is_encrypted(): + assert self.master_key + self.mac.update(s) + mac = self.mac.digest() + blob = self.get_prefixed_blob(s) + key_e = self.master_key[0:16] + ciphertext = aes_encrypt_with_iv(key_e, self.next_iv, blob) + self.next_iv = ciphertext[-16:] + return ciphertext, mac + else: + return s, None + + def _check_update_password(self, password: Optional[str], new_password: Optional[str], new_password_type: Optional[PasswordType]) -> None: + """ + if old_password == new_password, only check password + otherwise, check and update password + """ + assert self.is_encrypted() + # decrypt master_key and compare mac + for i, item in enumerate(self.encrypted_keys): + password_type, kdf_flags, kdf_power, encrypted_master_key = item + decrypted_master_key = self._get_decrypted_master_key(encrypted_master_key, password, kdf_flags, kdf_power) + if hmac.new(decrypted_master_key, None, hashlib.sha256).digest() == self.master_key_mac: + break + else: + raise InvalidPassword() + self.master_key = decrypted_master_key + if new_password: + if new_password != password: + assert new_password_type is not None + kdf_flags, kdf_power, encrypted_master_key = self._get_encrypted_master_key(new_password, new_password_type) + self.encrypted_keys[i] = new_password_type, kdf_flags, kdf_power, encrypted_master_key + else: + assert new_password_type is None + del self.encrypted_keys[i] + + def _get_encrypted_master_key(self, password, password_type): + # password_type not used currently. + # we could use it to make KDF dependent on it + kdf_flags, kdf_power = KDF_FLAGS, KDF_POWER + password_key = self.get_secret_from_password(password, kdf_flags, rounds=pow(2, kdf_power)) + key_e, key_m = password_key[0:16], password_key[16:32] + encrypted_master_key = aes_encrypt_with_iv(key_e, key_m, self.master_key, append_pkcs7=False) + assert len(encrypted_master_key) == 32 + return kdf_flags, kdf_power, encrypted_master_key + + def _get_decrypted_master_key(self, encrypted_master_key, password, kdf_flags, kdf_power): + password_key = self.get_secret_from_password(password, kdf_flags, rounds=pow(2, kdf_power)) + key_e, key_m = password_key[0:16], password_key[16:32] + decrypted_master_key = aes_decrypt_with_iv(key_e, key_m, encrypted_master_key, strip_pkcs7=False) + assert len(encrypted_master_key) == 32 + return decrypted_master_key + + def _add_password_to_header(self, password, password_type): + kdf_flags, kdf_power, encrypted_master_key = self._get_encrypted_master_key(password, password_type) + self.encrypted_keys.append((password_type, kdf_flags, kdf_power, encrypted_master_key)) + def check_password(self, password: Optional[str]) -> None: - """Raises an InvalidPassword exception on invalid password""" + """Raises an InvalidPassword exception on invalid password + """ if not self.is_encrypted(): if password is not None: raise InvalidPassword("password given but wallet has no password") return - if not self.is_past_initial_decryption(): - self.decrypt(password) # this sets self.pubkey - assert self.pubkey is not None - if self.pubkey != self.get_eckey_from_password(password).get_public_key_hex(): - raise InvalidPassword() + if self._is_old_base64: + if not self.is_past_initial_decryption(): + self.decrypt_old(password) # this sets self.master_key + return + self._check_update_password(password, password, None) - def set_password(self, password, enc_version=None): - """Set a password to be used for encrypting this storage.""" + def update_password(self, password, new_password, new_password_type): + self._check_update_password(password, new_password, new_password_type) + self.update_header() + + def remove_password(self, password): + """ remove password from list. disable encryption if list is empty.""" if not self.is_past_initial_decryption(): raise Exception("storage needs to be decrypted before changing password") - if enc_version is None: - enc_version = self._encryption_version - if password and enc_version != StorageEncryptionVersion.PLAINTEXT: - ec_key = self.get_eckey_from_password(password) - self.pubkey = ec_key.get_public_key_hex() - self._encryption_version = enc_version + self._check_update_password(password, None, None) + if len(self.encrypted_keys) == 0: + self.master_key = None else: - self.pubkey = None - self._encryption_version = StorageEncryptionVersion.PLAINTEXT + self.update_header() - def basename(self) -> str: - return os.path.basename(self.path) + def add_password(self, password, password_type): + """Set a password to be used for encrypting this storage.""" + assert password + if not self.is_past_initial_decryption(): + raise Exception("storage needs to be decrypted before changing password") + if len(self.encrypted_keys) == 0: + self.init_master_key() + self._add_password_to_header(password, password_type) + self.update_header() diff --git a/electrum/stored_dict.py b/electrum/stored_dict.py index 387912632007..648e321750e7 100644 --- a/electrum/stored_dict.py +++ b/electrum/stored_dict.py @@ -24,30 +24,40 @@ # SOFTWARE. import threading -import json +import os from enum import IntEnum from collections import defaultdict -from typing import TYPE_CHECKING, Optional, Sequence, List, Union, Any +from typing import Any, Optional, Tuple, Union, Iterator, Iterable, List, Sequence +from .logging import Logger -if TYPE_CHECKING: - from .json_db import JsonDB - from .storage import WalletStorage +_FLEX_KEY = str | int | None +_RaiseKeyError = object() # singleton for no-default behavior class StorageReadWriteError(Exception): pass +class StorageException(Exception): pass + +class PasswordType(IntEnum): + USER = 1 + XPUB = 2 -class StorageEncryptionVersion(IntEnum): - PLAINTEXT = 0 - USER_PASSWORD = 1 - XPUB_PASSWORD = 2 +def normalize_key(x: _FLEX_KEY) -> str: + if isinstance(x, int): + return int(x) + elif isinstance(x, str): + return x + else: + raise Exception(f"key {x=}") -def locked(func): - def wrapper(self, *args, **kwargs): - with self.lock: - return func(self, *args, **kwargs) - return wrapper +def key_to_str(x: _FLEX_KEY) -> str: + if isinstance(x, int): + return str(int(x)) + elif isinstance(x, str): + return x + else: + raise Exception(f"key {x=}") registered_names = {} @@ -78,49 +88,61 @@ def decorator(func): return func return decorator -_FLEX_KEY = str | int | None -def _walk_path(d, path): - for k in path: - if k in d: - d = d[k] - elif '*' in d: - d = d['*'] - else: - return None - return d - -def _convert_dict_key(path: List[str], key: str) -> _FLEX_KEY: - """Maybe convert key from str to python type (typically int or IntEnum)""" - assert all(isinstance(x, str) for x in path), repr(path) - r = _walk_path(registered_keys, path) - if r: - if func := r.get('self'): - key = func(key) - assert isinstance(key, _FLEX_KEY), f"unexpected type for {key=!r} at {path=}" - return key - -def _convert_dict_value(path: List[str], v) -> Any: - assert all(isinstance(x, str) for x in path), repr(path) - r = _walk_path(registered_names, path) - if r and type(r) is tuple: - _type, constructor = r - if _type == dict: - v = constructor(**v) - elif _type == tuple: - v = constructor(*v) - else: - v = constructor(v) - return v +def to_default(obj): + """Convert user-defined classes to python built-in types. + Also convert bytes to hex, so that the result is json serializable. + """ + if obj is None or isinstance(obj, (str, int, float)): + return obj + if isinstance(obj, bytes): + return obj.hex() + if hasattr(obj, 'as_str') and callable(obj.as_str): + return obj.as_str() + if hasattr(obj, 'as_dict') and callable(obj.as_dict): + obj = obj.as_dict() + if hasattr(obj, 'as_tuple') and callable(obj.as_tuple): + obj = obj.as_tuple() + if isinstance(obj, (set, frozenset)): + return [to_default(x) for x in list(obj)] + if isinstance(obj, dict): + return dict([(key_to_str(k), to_default(v)) for k, v in obj.items()]) + if isinstance(obj, list): + return [to_default(x) for x in obj] + if isinstance(obj, tuple): + return tuple([to_default(x) for x in list(obj)]) + raise Exception('unsupported type', type(obj)) + + + +class BaseDB(Logger): + + def __init__(self, path): + Logger.__init__(self) + self._write_batch = None + self.path = path + self._should_convert = True + + def file_exists(self): + raise NotImplementedError() + + def get_path(self): + return self.path + + def set_password(self, password:str): + raise NotImplementedError() + class BaseStoredObject: - _db: 'JsonDB' = None + _db: BaseDB = None _key: _FLEX_KEY = None _parent: Optional['BaseStoredObject'] = None _lock: threading.RLock = None + _path = None + _hint = None def set_db(self, db): self._db = db @@ -131,6 +153,7 @@ def set_parent(self, *, key: _FLEX_KEY, parent: Optional['BaseStoredObject']) -> assert isinstance(key, _FLEX_KEY), repr(key) self._key = key self._parent = parent + self._path = self._parent._path + [key] if parent else [''] @property def lock(self): @@ -138,31 +161,97 @@ def lock(self): @property def path(self) -> Sequence[_FLEX_KEY] | None: - # return None iff we are pruned from root - x = self - s = [x._key] - while x._parent is not None: - x = x._parent - s = [x._key] + s - if x._key != '': - return None - assert self._db is not None - return s - - def db_add(self, key: _FLEX_KEY, value) -> None: - assert isinstance(key, _FLEX_KEY), repr(key) - if self.path: - self._db.add(self.path, key, value) + return self._path + + def _to_stored_dict_or_list(self, key, value): + """convert list to StoredList, dict to StoredDict""" + if isinstance(value, list): + value = StoredList(self._db, key=key, parent=self) + elif isinstance(value, dict): + value = StoredDict(self._db, key=key, parent=self) + elif isinstance(value, tuple): + value = StoredList(self._db, key=key, parent=self) + value = tuple(value[:]) # do not expose StoredTuple to callers + return value - def db_replace(self, key: _FLEX_KEY, value) -> None: - assert isinstance(key, _FLEX_KEY), repr(key) - if self.path: - self._db.replace(self.path, key, value) + @property + def hint(self): + # cached object returned by the db (performance optimization) + if self._hint is None: + self._hint = self._db.get_hint(self._path) + return self._hint + + def db_get(self, key): + value = self._db.get(self.hint, key) + value = self._to_stored_dict_or_list(key, value) + if not self.should_convert(): + return value + value = self._convert_value(key, value) + # set db for StoredObject, because it is not set in the constructor + if isinstance(value, StoredObject): + value.set_db(self._db) + value.set_parent(key=key, parent=self) + return value + + def _convert_key(self, key: str) -> _FLEX_KEY: + """Maybe convert key from str to python type (typically int or IntEnum)""" + if self._key_converters: + if func := self._key_converters.get('self'): + key = func(key) + assert isinstance(key, _FLEX_KEY), f"unexpected type for {key=!r} at {self._path}" + return key + + def _convert_value(self, key, v) -> Any: + reg = self.get_constructor(key) + if reg: + if isinstance(v, (StoredDict, StoredList)): + v = v.dump() + _type, constructor = reg + if _type == dict: + v = constructor(**v) + elif _type == tuple: + v = constructor(*v) + else: + v = constructor(v) + return v + + def get_constructor(self, key): + if self._constructor: + r = self._constructor.get(key, self._constructor.get('*', None)) + if type(r) is tuple: + return r + + def init_constructor(self): + if self._parent is None: + self._constructor = registered_names + else: + d = self._parent._constructor + if d is None: + return + if self._key in d: + d = d[self._key] + elif '*' in d: + d = d['*'] + else: + d = None + if d and type(d) is dict: + self._constructor = d - def db_remove(self, key: _FLEX_KEY) -> None: - assert isinstance(key, _FLEX_KEY), repr(key) - if self.path: - self._db.remove(self.path, key) + def init_key_converters(self): + if self._parent is None: + self._key_converters = registered_keys + else: + d = self._parent._key_converters + if d is None: + return + if self._key in d: + d = d[self._key] + elif '*' in d: + d = d['*'] + else: + d = None + if d and type(d) is dict: + self._key_converters = d class StoredObject(BaseStoredObject): @@ -170,12 +259,12 @@ class StoredObject(BaseStoredObject): def __setattr__(self, key: str, value): assert isinstance(key, str), repr(key) - if self.path and not key.startswith('_'): + if not key.startswith('_') and self._path: if value != getattr(self, key): - self.db_replace(key, value) + self._db.replace(self.hint, self._path, key, to_default(value)) object.__setattr__(self, key, value) - def to_json(self): + def as_dict(self): d = dict(vars(self)) # don't expose/store private stuff d = {k: v for k, v in d.items() @@ -183,66 +272,124 @@ def to_json(self): return d +class StoredDict(BaseStoredObject): + """ + dict-like object that queries the DB + type conversions are performed here -_RaiseKeyError = object() # singleton for no-default behavior - + the DB object returns simple python objects: list or dict + this class converts them + """ -class StoredDict(dict, BaseStoredObject): - - def __init__(self, data: dict, db: 'JsonDB'): - self.set_db(db) - # recursively convert dicts to StoredDict - for k, v in list(data.items()): - self.__setitem__(k, v) + def __init__(self, db: BaseDB, key: _FLEX_KEY, parent): + BaseStoredObject.__init__(self) + self._db = db + self._lock = db.lock + self._parent = parent + self._key = normalize_key(key) + self._path = self._parent._path + [self._key] if parent else [''] + self._constructor = None # func or Dict[str, func] + self._key_converters = None + self.init_constructor() + self.init_key_converters() + + def should_convert(self): + return self._db._should_convert + + def write_batch(self): + return self._db.write_batch() + + def dump(self) -> dict: + data = {} + for k, v in self.items(): + if isinstance(v, (StoredDict, StoredList)): + v = v.dump() + data[k] = v + return data + + def __getitem__(self, key: _FLEX_KEY) -> Any: + key = key_to_str(key) + return self.db_get(key) + + def __setitem__(self, key: _FLEX_KEY, value: Any) -> None: + key = key_to_str(key) + if isinstance(value, StoredObject): + # side effect + value.set_db(self._db) + value.set_parent(key=key, parent=self) + if isinstance(value, (StoredList, StoredDict)): + value = value.dump() + # convert to python + value = to_default(value) + self._db.put(self.hint, self._path, key, value) - @locked - def __setitem__(self, key: _FLEX_KEY, v) -> None: - assert isinstance(key, _FLEX_KEY), repr(key) - is_new = key not in self - # early return to prevent unnecessary disk writes - if not is_new and self._db and json.dumps(v, cls=self._db.encoder) == json.dumps(self[key], cls=self._db.encoder): - return - # convert dict to StoredDict. - if type(v) == dict and (self._db is None or self._db._should_convert_to_stored_dict(key)): - v = StoredDict(v, self._db) - # convert list to StoredList - elif type(v) == list: - v = StoredList(v, self._db) - # reject sets. they do not work well with jsonpatch - elif isinstance(v, set): - raise Exception(f"Do not store sets inside jsondb. path={self.path!r}") - # set db for StoredObject, because it is not set in the constructor - if isinstance(v, StoredObject): - v.set_db(self._db) - # set parent - if isinstance(v, BaseStoredObject): - v.set_parent(key=key, parent=self) - # set item - dict.__setitem__(self, key, v) - self.db_add(key, v) if is_new else self.db_replace(key, v) - - @locked def __delitem__(self, key: _FLEX_KEY) -> None: - assert isinstance(key, _FLEX_KEY), repr(key) - r = self.get(key, None) - dict.__delitem__(self, key) - self.db_remove(key) - if isinstance(r, BaseStoredObject): - r._parent = None - - @locked - def pop(self, key: _FLEX_KEY, v=_RaiseKeyError) -> Any: - assert isinstance(key, _FLEX_KEY), repr(key) - if key not in self: - if v is _RaiseKeyError: - raise KeyError(key) - else: - return v - r = dict.pop(self, key) - self.db_remove(key) - if isinstance(r, BaseStoredObject): - r._parent = None - return r + key = key_to_str(key) + self._db.remove(self.hint, self._path, key) + + def __iter__(self) -> Iterator[str]: + return self._db.iter_keys(self.hint, self._path) + + def __len__(self) -> int: + return self._db.dict_len(self.hint, self._path) + + # ---- Dict-like extras ---- + + def __contains__(self, key: object) -> bool: + key = key_to_str(key) + return self._db.dict_contains(self.hint, self._path, key) + + def keys(self) -> Iterable[str]: + for k in self._db.iter_keys(self.hint, self._path): + yield self._convert_key(k) + + def values(self) -> Iterator[Any]: + for k in self._db.iter_keys(self.hint, self._path): + yield self[k] + + def items(self) -> Iterator[Tuple[str, Any]]: + for k in self._db.iter_keys(self.hint, self._path): + yield (self._convert_key(k), self[k]) + + def get(self, key: _FLEX_KEY, default: Any = None, add_if_missing=False) -> Any: + # If add_if_missing is True, create DB entry if it does not exist. + # This will return StoredDict/StoredList if default is dict/list + try: + return self[key] + except KeyError: + if add_if_missing: + self[key] = default + return self[key] + return default + + def clear(self) -> None: + self._db.clear(self.hint, self._path) + + def pop(self, key: _FLEX_KEY, default: Any = _RaiseKeyError) -> Any: + # This will return dict/list + try: + v = self[key] + except KeyError: + if default is _RaiseKeyError: + raise + return default + if isinstance(v, (StoredList, StoredDict)): + v = v.dump() + del self[key] + return v + + def update(self, other=(), /, **kwargs) -> None: + if isinstance(other, dict): + pairs = list(other.items()) + else: + pairs = list(other) + pairs.extend(kwargs.items()) + for k, v in pairs: + self[k] = v + + def as_dict(self) -> dict: + """used by keystore""" + return self.dump() def setdefault(self, key: _FLEX_KEY, default = None, /): assert isinstance(key, _FLEX_KEY), repr(key) @@ -251,28 +398,130 @@ def setdefault(self, key: _FLEX_KEY, default = None, /): return self[key] -class StoredList(list, BaseStoredObject): +class StoredList(BaseStoredObject): + + def __init__(self, db: BaseDB, key: _FLEX_KEY, parent): + self._db = db + self._lock = db.lock + self._parent = parent + self._key = normalize_key(key) + self._path = self._parent._path + [self._key] + self._constructor = None + self._key_converters = None + self.init_constructor() + self.init_key_converters() + + def should_convert(self): + return self._db._should_convert + + def _get_list_item(self, key: int): + key = int(key) + return self.db_get(key) + + def __getitem__(self, s: slice) -> Any: + n = self._db.list_len(self.hint, self._path) + if type(s) is int: + s = n + s if s < 0 else s + return self._get_list_item(s) + elif type(s) is slice: + start = 0 if s.start is None else s.start if s.start >= 0 else n + s.start + stop = n if s.stop is None else s.stop if s.stop >= 0 else n + s.stop + step = 1 if s.step is None else s.step + return [self._get_list_item(i) for i in range(start, stop, step)] + else: + raise Exception() + + def __len__(self): + return self._db.list_len(self.hint, self._path) + + def __iter__(self) -> Iterator[str]: + for i in range(self._db.list_len(self.hint, self._path)): + yield self._get_list_item(i) + + def append(self, value): + value = to_default(value) + self._db.list_append(self.hint, self._path, value) - def __init__(self, data, db: 'JsonDB'): - list.__init__(self, data) - self.set_db(db) + def clear(self): + self._db.list_clear(self.hint, self._path) + assert len(self) == 0 - @locked - def append(self, item): - n = len(self) - list.append(self, item) - self.db_add('%d'%n, item) + def index(self, item) -> int: + item = to_default(item) + return self._db.list_index(self.hint, self._path, item) - @locked def remove(self, item): - n = self.index(item) - list.remove(self, item) - self.db_remove('%d'%n) + item = to_default(item) + self._db.list_remove(self.hint, self._path, item) - @locked - def clear(self): - list.clear(self) - self.db_replace(None, []) + def dump(self) -> list: + data = [] + for v in self: + if isinstance(v, (dict, list)): + raise Exception() + if isinstance(v, (StoredDict, StoredList)): + v = v.dump() + data.append(v) + return data + + + +class DictStorage(StoredDict): + """ stored dict at the root of the file """ + + def __init__(self, path: str, init_db: bool = True, allow_partial_writes: bool = True): + from .json_db import JsonDB + db = JsonDB(path=path, init_db=init_db, allow_partial_writes=allow_partial_writes) + StoredDict.__init__(self, db, key='', parent=None) + + def file_exists(self): + return self._db.file_exists() + + def is_encrypted(self): + return self._db.is_encrypted() + + def decrypt(self, pw:str): + return self._db.decrypt(pw) + + def get_path(self): + return self._db.get_path() + + def add_password(self, password: str, password_type=None): + return self._db.add_password(password, password_type) + + def remove_password(self, password: str): + return self._db.remove_password(password) + + def update_password(self, password: str, new_password: str, new_password_type: PasswordType): + return self._db.update_password(password, new_password, new_password_type) + + def set_data(self, data:str): + return self._db.set_data(data) + + def get_encryption_versions(self) -> PasswordType: + return self._db.get_encryption_versions() + + def check_password(self, password): + self._db.check_password(password) + + def supports_file_encryption(self): + return self._db.supports_file_encryption() + + def is_encrypted_with_hw_device(self): + return self._db.is_encrypted_with_hw_device() + + def is_encrypted_with_user_pw(self): + return self._db.is_encrypted_with_user_pw() + + def write(self, **kwargs): + return self._db.write(**kwargs) + def close(self): + return self._db.close() + def is_closed(self): + return self._db.is_closed() + def basename(self) -> str: + path = self.get_path() + return os.path.basename(path) if path else 'no name' diff --git a/electrum/submarine_swaps.py b/electrum/submarine_swaps.py index c36abb02412e..a9e962e4c1ad 100644 --- a/electrum/submarine_swaps.py +++ b/electrum/submarine_swaps.py @@ -40,6 +40,7 @@ ) from . import lnutil from .lnutil import hex_to_bytes, REDEEM_AFTER_DOUBLE_SPENT_DELAY, Keypair + from .bolt11 import decode_bolt11_invoice from .stored_dict import StoredObject, stored_at from . import constants @@ -290,6 +291,7 @@ def start_network(self, network: 'Network'): for k, swap in swaps_items: if swap.is_redeemed: continue + swap._payment_hash = bytes.fromhex(k) self.add_lnwatcher_callback(swap) asyncio.run_coroutine_threadsafe(self.main_loop(), self.network.asyncio_loop) diff --git a/electrum/transaction.py b/electrum/transaction.py index 46b9351c765d..6225d1cbc984 100644 --- a/electrum/transaction.py +++ b/electrum/transaction.py @@ -164,6 +164,9 @@ def to_legacy_tuple(self) -> Tuple[int, str, Union[int, str]]: return TYPE_ADDRESS, self.address, self.value return TYPE_SCRIPT, self.scriptpubkey.hex(), self.value + def as_tuple(self): + return self.to_legacy_tuple() + @classmethod def from_legacy_tuple(cls, _type: int, addr: str, val: Union[int, str]) -> Union['TxOutput', 'PartialTxOutput']: if _type == TYPE_ADDRESS: @@ -305,8 +308,8 @@ def __repr__(self): def to_str(self) -> str: return f"{self.txid.hex()}:{self.out_idx}" - def to_json(self): - return [self.txid.hex(), self.out_idx] + def as_tuple(self): + return (self.txid.hex(), self.out_idx) def serialize_to_network(self) -> bytes: return self.txid[::-1] + int.to_bytes(self.out_idx, length=4, byteorder="little", signed=False) @@ -907,6 +910,9 @@ class Transaction: def __str__(self): return self.serialize() + def as_str(self): + return str(self) + def __init__(self, raw): if raw is None: self._cached_network_ser = None diff --git a/electrum/txbatcher.py b/electrum/txbatcher.py index 48279b40e0e0..f3f779b73082 100644 --- a/electrum/txbatcher.py +++ b/electrum/txbatcher.py @@ -76,11 +76,18 @@ from .transaction import PartialTransaction, PartialTxOutput, Transaction, TxOutpoint, PartialTxInput from .address_synchronizer import TX_HEIGHT_LOCAL, TX_HEIGHT_FUTURE from .lnsweep import SweepInfo -from .json_db import locked, StoredDict from .fee_policy import FeePolicy if TYPE_CHECKING: from .wallet import Abstract_Wallet + from .stored_dict import StoredDict + + +def locked(func): + def wrapper(self, *args, **kwargs): + with self.lock: + return func(self, *args, **kwargs) + return wrapper class TxBatcher(Logger): @@ -90,7 +97,7 @@ class TxBatcher(Logger): def __init__(self, wallet: 'Abstract_Wallet'): Logger.__init__(self) self.lock = threading.RLock() - self.storage = wallet.db.get_stored_item("tx_batches", {}) + self.storage = wallet.db.get_dict("tx_batches") self.tx_batches = {} # type: Dict[str, TxBatch] self.wallet = wallet for key, item_storage in self.storage.items(): @@ -228,7 +235,7 @@ def get_password_future(self, txid: str): class TxBatch(Logger): - def __init__(self, wallet: 'Abstract_Wallet', storage: StoredDict): + def __init__(self, wallet: 'Abstract_Wallet', storage: 'StoredDict'): Logger.__init__(self) self.wallet = wallet self.storage = storage diff --git a/electrum/util.py b/electrum/util.py index bc870ea6d971..fe92ee706858 100644 --- a/electrum/util.py +++ b/electrum/util.py @@ -68,6 +68,7 @@ from .i18n import _ from .logging import get_logger, Logger +from .stored_dict import stored_at if TYPE_CHECKING: from .network import Network, ProxySettings @@ -340,6 +341,12 @@ def default(self, obj): return obj.hex() if hasattr(obj, 'to_json') and callable(obj.to_json): return obj.to_json() + if hasattr(obj, 'to_str') and callable(obj.to_str): + return obj.to_str() + if hasattr(obj, 'as_tuple') and callable(obj.as_tuple): + return obj.as_tuple() + if hasattr(obj, 'as_dict') and callable(obj.as_dict): + return obj.as_dict() return super(MyEncoder, self).default(obj) @@ -1257,6 +1264,19 @@ class TxMinedInfo: header_hash: Optional[str] = None # hash of block that mined tx wanted_height: Optional[int] = None # in case of timelock, min abs block height + def as_tuple(self): + return (self._height, self.timestamp, self.txpos, self.header_hash) + + @staticmethod + @stored_at('/verified_tx3/*', tuple) + def from_tuple(height, timestamp, txpos, header_hash): + return TxMinedInfo( + _height=height, + timestamp=timestamp, + txpos=txpos, + header_hash=header_hash, + ) + def height(self) -> int: """Treat unverified heights as unconfirmed.""" h = self._height @@ -2286,16 +2306,16 @@ def test_read_write_permissions(path) -> None: # note: There might already be a file at 'path'. # Make sure we do NOT overwrite/corrupt that! temp_path = "%s.tmptest.%s" % (path, os.getpid()) - echo = "fs r/w test" + echo = b"fs r/w test" try: # test READ permissions for actual path if os.path.exists(path): with open(path, "rb") as f: f.read(1) # read 1 byte # test R/W sanity for "similar" path - with open(temp_path, "w", encoding='utf-8') as f: + with open(temp_path, "wb") as f: f.write(echo) - with open(temp_path, "r", encoding='utf-8') as f: + with open(temp_path, "rb") as f: echo2 = f.read() os.remove(temp_path) except Exception as e: diff --git a/electrum/wallet.py b/electrum/wallet.py index 70573de8c238..77cd03d87d5d 100644 --- a/electrum/wallet.py +++ b/electrum/wallet.py @@ -55,7 +55,7 @@ WalletFileException, BitcoinException, InvalidPassword, format_time, timestamp_to_datetime, Satoshis, Fiat, TxMinedInfo, quantize_feerate, OrderedDictWithIndex, multisig_type, parse_max_spend, OnchainHistoryItem, read_json_file, write_json_file, UserFacingException, FileImportFailed, EventListener, - event_listener + event_listener, standardize_path ) from .bitcoin import COIN, is_address, is_minikey, relayfee, dust_threshold, DummyAddress, DummyAddressUsedInTxException from .keystore import ( @@ -63,8 +63,6 @@ ) from .simple_config import SimpleConfig from .fee_policy import FeePolicy, FixedFeePolicy, FEE_RATIO_HIGH_WARNING, FEERATE_WARNING_HIGH_FEE -from .stored_dict import StorageEncryptionVersion -from .storage import WalletStorage from .wallet_db import WalletDB from .transaction import ( Transaction, TxInput, TxOutput, PartialTransaction, PartialTxInput, PartialTxOutput, TxOutpoint, Sighash @@ -83,6 +81,7 @@ from .descriptor import Descriptor from .txbatcher import TxBatcher from .submarine_swaps import MIN_SWAP_AMOUNT_SAT +from .stored_dict import DictStorage, PasswordType if TYPE_CHECKING: from .network import Network @@ -408,7 +407,7 @@ def __init__(self, db: WalletDB, *, config: SimpleConfig): self.config = config assert self.config is not None, "config must not be None" self.db = db - self.storage = db.storage # type: Optional[WalletStorage] + self.storage = db.storage # type: StoredDict # load addresses needs to be called before constructor for sanity checks db.load_addresses(self.wallet_type) self.keystore = None # type: Optional[KeyStore] # will be set by load_keystore @@ -459,9 +458,11 @@ def __init__(self, db: WalletDB, *, config: SimpleConfig): self.up_to_date_changed_event = asyncio.Event() assert self.db.get('genesis_blockhash') == constants.net.GENESIS, self.db.get('genesis_blockhash') - if self.storage and self.has_storage_encryption(): - if (se := self.storage.get_encryption_version()) not in (ae := self.get_available_storage_encryption_versions()): - raise WalletFileException(f"unexpected storage encryption type. found: {se!r}. allowed: {ae!r}") + if self.has_storage_encryption(): + ae = self.get_available_storage_encryption_versions() + for se in self.storage.get_encryption_versions(): + if se not in ae: + raise WalletFileException(f"unexpected storage encryption type. found: {se!r}. allowed: {ae!r}") self.register_callbacks() @@ -494,26 +495,29 @@ async def do_synchronize_loop(self): # have history that are mined and SPV-verified. await run_in_thread(self.synchronize) - def save_db(self): - if self.db.storage: - self.db.write() - - def save_backup(self, backup_dir): - new_path = os.path.join(backup_dir, self.basename() + '.backup') - new_storage = WalletStorage(new_path) - new_storage._encryption_version = self.storage._encryption_version - new_storage.pubkey = self.storage.pubkey - - new_db = WalletDB(self.db.dump(), storage=new_storage, upgrade=True) + def save_backup(self, new_path, password, password_type: PasswordType): + import json + from .stored_dict import to_default + # create data + data = self.storage.dump() if self.lnworker: - channel_backups = new_db.get_dict('imported_channel_backups') + channel_backups = data.get('imported_channel_backups', {}) for chan_id, chan in self.lnworker.channels.items(): channel_backups[chan_id.hex()] = self.lnworker.create_channel_backup(chan_id) - new_db.put('channels', None) - new_db.put('lightning_privkey2', None) - new_db.set_modified(True) - new_db.write() - return new_path + data['imported_channel_backups'] = channel_backups + data.pop('channels', None) + data.pop('lightning_privkey2', None) + json_str = json.dumps( + to_default(data), + indent=4, + sort_keys=True, + ) + assert not os.path.exists(new_path) + new_storage = DictStorage(path=new_path) + if password: + new_storage.add_password(password, password_type) + new_storage.set_data(json_str) + new_storage.write(force_consolidation=True) def has_lightning(self) -> bool: return bool(self.lnworker) @@ -558,7 +562,6 @@ def init_lightning(self, *, password) -> None: ln_xprv = node.to_xprv() self.db.put('lightning_privkey2', ln_xprv) self.lnworker = LNWallet(self, ln_xprv) - self.save_db() if self.network: self._start_network_lightning() @@ -580,7 +583,8 @@ async def stop(self): if any([ks.is_requesting_to_be_rewritten_to_wallet_file for ks in self.get_keystores()]): self.save_keystore() self.db.prune_uninstalled_plugin_data(self.config.get_installed_plugins()) - self.save_db() + if self.storage: + self.storage.close() def is_up_to_date(self) -> bool: if self.taskgroup and self.taskgroup.joined: # either stop() was called, or the taskgroup died @@ -609,7 +613,6 @@ async def on_event_adb_set_up_to_date(self, adb): self._up_to_date = up_to_date if up_to_date: self.adb.reset_netrequest_counters() # sync progress indicator - self.save_db() # fire triggers if status_changed or up_to_date: # suppress False->False transition, as it is spammy if self.lnworker: @@ -658,7 +661,6 @@ def on_event_adb_removed_verified_tx(self, adb, tx_hash): def clear_history(self): self.adb.clear_history() - self.save_db() def start_network(self, network: 'Network'): assert self.network is None, "already started" @@ -675,7 +677,7 @@ def _start_network_lightning(self): assert self.lnworker.network is None, 'lnworker network already initialized' self.lnworker.start_network(self.network) # only start gossiping when we already have channels - if self.db.get('channels'): + if len(self.db.get_dict('channels')) > 0: self.network.start_gossip() @abstractmethod @@ -695,7 +697,7 @@ def get_master_public_keys(self): return [] def basename(self) -> str: - return self.storage.basename() if self.storage else 'no_name' + return self.storage.basename() def check_returned_address_for_corruption(func): def wrapper(self, *args, **kwargs): @@ -1285,17 +1287,13 @@ def save_invoice(self, invoice: Invoice, *, write_to_disk: bool = True) -> None: for txout in invoice.get_outputs(): self._invoices_from_scriptpubkey_map[txout.scriptpubkey].add(key) self._invoices[key] = invoice - if write_to_disk: - self.save_db() def clear_invoices(self): self._invoices.clear() - self.save_db() def clear_requests(self): self._receive_requests.clear() self._requests_addr_to_key.clear() - self.save_db() def get_invoices(self) -> List[Invoice]: out = list(self._invoices.values()) @@ -1317,7 +1315,6 @@ def import_requests(self, path): except Exception: raise FileImportFailed(_("Invalid invoice format")) self.add_payment_request(req, write_to_disk=False) - self.save_db() def export_requests(self, path): # note: this does not export preimages for LN bolt11 invoices @@ -1331,7 +1328,6 @@ def import_invoices(self, path): except Exception: raise FileImportFailed(_("Invalid invoice format")) self.save_invoice(invoice, write_to_disk=False) - self.save_db() def export_invoices(self, path): write_json_file(path, list(self._invoices.values())) @@ -1738,7 +1734,7 @@ def get_label_for_rhash(self, rhash: str) -> str: def get_all_labels(self) -> Dict[str, str]: with self.lock: - return copy.copy(self._labels) + return self._labels.dump() def get_tx_status(self, tx_hash: str, tx_mined_info: TxMinedInfo): extra = [] @@ -2211,8 +2207,6 @@ def set_frozen_state_of_addresses( self._frozen_addresses -= set(addrs) self.db.put('frozen_addresses', list(self._frozen_addresses)) util.trigger_callback('status') - if write_to_disk: - self.save_db() return True return False @@ -2238,8 +2232,6 @@ def set_frozen_state_of_coins( else: self._frozen_coins[utxo] = bool(freeze) util.trigger_callback('status') - if write_to_disk: - self.save_db() def is_address_reserved(self, addr: str) -> bool: # note: atm 'reserved' status is only taken into consideration for 'change addresses' @@ -2926,7 +2918,7 @@ def get_formatted_request(self, request_id): def export_request(self, x: Request) -> Dict[str, Any]: key = x.get_id() status = self.get_invoice_status(x) - d = x.as_dict(status) + d = x.export(status) d['request_id'] = d.pop('id') if x.is_lightning(): d['rhash'] = x.rhash @@ -2952,7 +2944,7 @@ def export_request(self, x: Request) -> Dict[str, Any]: def export_invoice(self, x: Invoice) -> Dict[str, Any]: key = x.get_id() status = self.get_invoice_status(x) - d = x.as_dict(status) + d = x.export(status) d['invoice_id'] = d.pop('id') if x.is_lightning(): d['lightning_invoice'] = x.lightning_invoice @@ -3059,8 +3051,6 @@ def add_payment_request(self, req: Request, *, write_to_disk: bool = True): self._receive_requests[request_id] = req if addr := req.get_address(): self._requests_addr_to_key[addr].add(request_id) - if write_to_disk: - self.save_db() return request_id def delete_request(self, request_id, *, write_to_disk: bool = True): @@ -3073,8 +3063,6 @@ def delete_request(self, request_id, *, write_to_disk: bool = True): self._requests_addr_to_key[addr].discard(request_id) if req.is_lightning() and self.lnworker: self.lnworker.delete_payment_info(req.rhash, direction=RECEIVED) - if write_to_disk: - self.save_db() def delete_invoice(self, invoice_id, *, write_to_disk: bool = True): """ lightning or on-chain """ @@ -3083,8 +3071,6 @@ def delete_invoice(self, invoice_id, *, write_to_disk: bool = True): return if inv.is_lightning() and self.lnworker: self.lnworker.delete_payment_info(inv.rhash, direction=SENT) - if write_to_disk: - self.save_db() def get_requests(self) -> List[Request]: out = [self.get_request(x) for x in self._receive_requests.keys()] @@ -3110,8 +3096,6 @@ def delete_expired_requests(self): def delete_requests(self, keys): for key in keys: self.delete_request(key, write_to_disk=False) - if keys: - self.save_db() @abstractmethod def get_fingerprint(self) -> str: @@ -3136,17 +3120,20 @@ def has_password(self) -> bool: def can_have_keystore_encryption(self): return self.keystore and self.keystore.may_have_password() - def get_available_storage_encryption_versions(self) -> Sequence[StorageEncryptionVersion]: + def get_available_storage_encryption_versions(self) -> Sequence[PasswordType]: """Returns the type of storage encryption offered to the user. A wallet file (storage) is either encrypted with this version or is stored in plaintext. """ - out = [StorageEncryptionVersion.USER_PASSWORD] + out = [PasswordType.USER] if isinstance(self.keystore, Hardware_KeyStore): - out.append(StorageEncryptionVersion.XPUB_PASSWORD) + out.append(PasswordType.XPUB) return out + def is_hw_encryption_available(self): + return PasswordType.XPUB in self.get_available_storage_encryption_versions() + def has_keystore_encryption(self) -> bool: """Returns whether encryption is enabled for the keystore. @@ -3158,7 +3145,7 @@ def has_keystore_encryption(self) -> bool: def has_storage_encryption(self) -> bool: """Returns whether encryption is enabled for the wallet file on disk.""" - return bool(self.storage) and self.storage.is_encrypted() + return self.storage.is_encrypted() @classmethod def may_have_password(cls): @@ -3179,16 +3166,24 @@ def update_password(self, old_pw, new_pw, *, encrypt_storage: bool = True, xpub_ if old_pw is None and self.has_password(): raise InvalidPassword() self.check_password(old_pw) - if self.storage: - if encrypt_storage: - enc_version = StorageEncryptionVersion.XPUB_PASSWORD if xpub_encrypt else StorageEncryptionVersion.USER_PASSWORD - assert enc_version in self.get_available_storage_encryption_versions() - else: - enc_version = StorageEncryptionVersion.PLAINTEXT - self.storage.set_password(new_pw, enc_version) - # make sure next storage.write() saves changes - self.db.set_modified(True) + if encrypt_storage: + assert self.storage.supports_file_encryption() + if self.storage.supports_file_encryption(): + if encrypt_storage and new_pw: + password_type = PasswordType.XPUB if xpub_encrypt else PasswordType.USER + assert password_type in self.get_available_storage_encryption_versions() + if self.storage.is_encrypted(): + self.storage.update_password(old_pw, new_pw, password_type) + self.storage.write(force_consolidation=True) + else: + # we never add more than one password + self.storage.add_password(new_pw, password_type) + self.storage.write(force_consolidation=True) + else: + if self.storage.is_encrypted(): + self.storage.remove_password(old_pw) + self.storage.write(force_consolidation=True) # note: Encrypting storage with a hw device is currently only # allowed for non-multisig wallets. Further, # Hardware_KeyStore.may_have_password() == False. @@ -3198,8 +3193,8 @@ def update_password(self, old_pw, new_pw, *, encrypt_storage: bool = True, xpub_ encrypt_keystore = self.can_have_keystore_encryption() self.db.set_keystore_encryption(bool(new_pw) and encrypt_keystore) # save changes. force full rewrite to rm remnants of old password - if self.storage and self.storage.file_exists(): - self.db.write_and_force_consolidation() + if self.storage.file_exists(): + self.storage.write(force_consolidation=True) # if wallet was previously unlocked, reset password_in_memory self.lock_wallet() @@ -3813,8 +3808,6 @@ def import_addresses(self, addresses: List[str], *, good_addr.append(address) self.db.add_imported_address(address, {}) self.adb.add_address(address) - if write_to_disk: - self.save_db() return good_addr, bad_addr def import_address(self, address: str) -> str: @@ -3868,7 +3861,6 @@ def delete_address(self, address: str) -> None: else: self.keystore.delete_imported_key(pubkey) self.save_keystore() - self.save_db() def get_change_addresses_for_new_transaction(self, *args, **kwargs) -> List[str]: # for an imported wallet, if all "change addresses" are already used, @@ -3916,8 +3908,6 @@ def import_private_keys(self, keys: Sequence[str], password: Optional[str], *, good_inputs, bad_keys = self.keystore.import_private_keys(keys, password) self.save_keystore() self._add_imported_addresses(good_inputs) - if write_to_disk: - self.save_db() good_addr = [bitcoin.pubkey_to_address(txin_type, pubkey) for txin_type, pubkey in good_inputs] return good_addr, bad_keys @@ -4041,7 +4031,6 @@ def change_gap_limit(self, value): if value >= self.min_acceptable_gap(): self.gap_limit = value self.db.put('gap_limit', self.gap_limit) - self.save_db() return True else: return False @@ -4335,9 +4324,9 @@ def check_password(self, password): if self.has_storage_encryption(): self.storage.check_password(password) - def get_available_storage_encryption_versions(self) -> Sequence[StorageEncryptionVersion]: + def get_available_storage_encryption_versions(self) -> Sequence[PasswordType]: # multisig wallets are not offered hw device encryption - return [StorageEncryptionVersion.USER_PASSWORD] + return [PasswordType.USER] def has_seed(self): return self.keystore.has_seed() @@ -4378,7 +4367,7 @@ def register_constructor(wallet_type, constructor): class Wallet(object): """The main wallet "entry point". This class is actually a factory that will return a wallet of the correct - type when passed a WalletStorage instance.""" + type when passed a WalletDB instance.""" def __new__(cls, db: 'WalletDB', *, config: SimpleConfig) -> Abstract_Wallet: wallet_type = db.get('wallet_type') @@ -4407,12 +4396,12 @@ def create_new_wallet( gap_limit_for_change: Optional[int] = None, ) -> dict: """Create a new wallet""" - storage = WalletStorage(path, allow_partial_writes=config.WALLET_PARTIAL_WRITES) - if storage.file_exists(): + if os.path.exists(standardize_path(path)): raise UserFacingException("Remove the existing wallet first!") - if encrypt_file: - storage.set_password(password, StorageEncryptionVersion.USER_PASSWORD) - db = WalletDB('', storage=storage, upgrade=True) + storage = DictStorage(path, allow_partial_writes=config.WALLET_PARTIAL_WRITES) + if encrypt_file and password: + storage.add_password(password, PasswordType.USER) + db = WalletDB(storage) seed = Mnemonic('en').make_seed(seed_type=seed_type) k = keystore.from_seed(seed, passphrase=passphrase) k.update_password(None, password) @@ -4428,7 +4417,6 @@ def create_new_wallet( wallet = Wallet(db, config=config) wallet.synchronize() msg = "Please keep your seed in a safe place; if you lose it, you will not be able to restore your wallet." - wallet.save_db() return {'seed': seed, 'wallet': wallet, 'msg': msg} @@ -4450,14 +4438,15 @@ def restore_wallet_from_text( if encrypt_file is None: encrypt_file = True if path is None: # create wallet in-memory - storage = None + storage = DictStorage(None) else: - storage = WalletStorage(path, allow_partial_writes=config.WALLET_PARTIAL_WRITES) - if storage.file_exists(): + if os.path.exists(standardize_path(path)): raise UserFacingException("Remove the existing wallet first!") - if encrypt_file: - storage.set_password(password, StorageEncryptionVersion.USER_PASSWORD) - db = WalletDB('', storage=storage, upgrade=True) + storage = DictStorage(path, allow_partial_writes=config.WALLET_PARTIAL_WRITES) + if encrypt_file and password: + storage.add_password(password, PasswordType.USER) + + db = WalletDB(storage) db.set_keystore_encryption(bool(password)) text = text.strip() if keystore.is_address_list(text): @@ -4496,10 +4485,7 @@ def restore_wallet_from_text( if gap_limit_for_change is not None: db.put('gap_limit_for_change', gap_limit_for_change) wallet = wallet_factory(db, config=config) - if db.storage: - assert not db.storage.file_exists(), "file was created too soon! plaintext keys might have been written to disk" wallet.synchronize() msg = ("This wallet was restored offline. It may contain more addresses than displayed. " "Start a daemon and use load_wallet to sync its history.") - wallet.save_db() return {'wallet': wallet, 'msg': msg} diff --git a/electrum/wallet_db.py b/electrum/wallet_db.py index c20180a6d595..ebd7dac70e35 100644 --- a/electrum/wallet_db.py +++ b/electrum/wallet_db.py @@ -27,7 +27,7 @@ import copy from collections import defaultdict from typing import (Dict, Optional, List, Tuple, Set, Iterable, NamedTuple, Sequence, TYPE_CHECKING, - Union, AbstractSet) + Union, AbstractSet, Any) import time from functools import partial @@ -35,22 +35,19 @@ from . import bitcoin from . import constants -from .util import profiler, WalletFileException, multisig_type, TxMinedInfo, MyEncoder +from .util import with_lock as locked +from .util import profiler, WalletFileException, multisig_type, TxMinedInfo from .keystore import bip44_derivation from .transaction import Transaction, TxOutpoint, tx_from_any, PartialTransaction, PartialTxOutput, BadHeaderMagic from .logging import Logger from .lnutil import HTLCOwner, ChannelType, RecvMPPResolution -from .json_db import JsonDB, locked, modifier -from . import stored_dict -from .stored_dict import StoredObject, stored_at, register_key, register_name +from .stored_dict import register_name, register_key +from .stored_dict import StoredObject, StoredDict, StoredList, stored_at from .plugin import run_hook, plugin_loaders from .version import ELECTRUM_VERSION from .i18n import _ -if TYPE_CHECKING: - from .storage import WalletStorage - class WalletRequiresUpgrade(WalletFileException): pass @@ -107,8 +104,14 @@ class WalletFileExceptionVersion51(WalletFileException): pass register_name('/transactions/*', None, lambda x: tx_from_any(x, deserialize=False, sanitize=False)) register_name('/channels/*/data_loss_protect_remote_pcp/*', None, lambda x: bytes.fromhex(x)) # register tuples, otherwise they will default to StoredList +register_name('/channels/*/closing_height', None, tuple) +register_name('/channels/*/funding_height', None, tuple) +register_name('/forwarding_failures/*', None, tuple) +register_name('/lightning_payments/*', None, tuple) register_name('/contacts/*', None, tuple) register_name('/lightning_preimages/*', None, tuple) +register_name('/addr_history/*/*', None, tuple) +register_name('/channels/*/funding_inputs/*', None, tuple) # register dicts that require key conversion for key in [ '/channels/*/log/*/adds', @@ -130,12 +133,9 @@ class WalletFileExceptionVersion51(WalletFileException): pass class WalletDBUpgrader(Logger): - def __init__(self, data: dict): + def __init__(self, data: StoredDict): Logger.__init__(self) self.data = data - # self.data must be in-memory dict (not a StoredDict or similar), - # so a failed, partial upgrade won't get commited to disk - assert type(self.data) == dict, type(self.data) def get(self, key, default=None): return self.data.get(key, default) @@ -159,11 +159,11 @@ def get_split_accounts(self): wallet_type = self.get('wallet_type') if wallet_type == 'old': assert len(d) == 2 - data1 = copy.deepcopy(self.data) - data1['accounts'] = {'0': d['0']} + data1 = copy.deepcopy(self.data.as_dict()) + data1['accounts'] = {'0': d['0'].as_dict()} data1['suffix'] = 'deterministic' - data2 = copy.deepcopy(self.data) - data2['accounts'] = {'/x': d['/x']} + data2 = copy.deepcopy(self.data.as_dict()) + data2['accounts'] = {'/x': d['/x'].as_dict()} data2['seed'] = None data2['seed_version'] = None data2['master_public_key'] = None @@ -176,11 +176,11 @@ def get_split_accounts(self): mpk = self.get('master_public_keys') for k in d.keys(): i = int(k) - x = d[k] + x = d[k].as_dict() if x.get("pending"): continue xpub = mpk["x/%d'"%i] - new_data = copy.deepcopy(self.data) + new_data = copy.deepcopy(self.data.as_dict()) # save account, derivation and xpub at index 0 new_data['accounts'] = {'0': x} new_data['master_public_keys'] = {"x/0'": xpub} @@ -196,70 +196,19 @@ def requires_upgrade(self): @profiler def upgrade(self): + assert self.data.should_convert() is False self.logger.info('upgrading wallet format') - self._convert_imported() - self._convert_wallet_type() - self._convert_account() - self._convert_version_13_b() - self._convert_version_14() - self._convert_version_15() - self._convert_version_16() - self._convert_version_17() - self._convert_version_18() - self._convert_version_19() - self._convert_version_20() - self._convert_version_21() - self._convert_version_22() - self._convert_version_23() - self._convert_version_24() - self._convert_version_25() - self._convert_version_26() - self._convert_version_27() - self._convert_version_28() - self._convert_version_29() - self._convert_version_30() - self._convert_version_31() - self._convert_version_32() - self._convert_version_33() - self._convert_version_34() - self._convert_version_35() - self._convert_version_36() - self._convert_version_37() - self._convert_version_38() - self._convert_version_39() - self._convert_version_40() - self._convert_version_41() - self._convert_version_42() - self._convert_version_43() - self._convert_version_44() - self._convert_version_45() - self._convert_version_46() - self._convert_version_47() - self._convert_version_48() - self._convert_version_49() - self._convert_version_50() - self._convert_version_51() - self._convert_version_52() - self._convert_version_53() - self._convert_version_54() - self._convert_version_55() - self._convert_version_56() - self._convert_version_57() - self._convert_version_58() - self._convert_version_59() - self._convert_version_60() - self._convert_version_61() - self._convert_version_62() - self._convert_version_63() - self._convert_version_64() - self._convert_version_65() - self._convert_version_66() - self._convert_version_67() - self._convert_version_68() - self._convert_version_69() - self._convert_version_70() - self._convert_version_71() - self.put('seed_version', FINAL_SEED_VERSION) # just to be sure + with self.data.write_batch(): + self._convert_imported() + with self.data.write_batch(): + self._convert_wallet_type() + with self.data.write_batch(): + self._convert_account() + for i in range(13, FINAL_SEED_VERSION + 1): + f = getattr(self, '_convert_version_%d'%i) + with self.data.write_batch(): + f() + self._put_seed_version(FINAL_SEED_VERSION) # just to be sure def _convert_wallet_type(self): if not self._is_upgrade_method_needed(0, 13): @@ -276,6 +225,8 @@ def _convert_wallet_type(self): xprvs = self.get('master_private_keys', {}) mpk = self.get('master_public_key') keypairs = self.get('keypairs') + if keypairs: + keypairs = keypairs.as_dict() key_type = self.get('key_type') if seed_version == OLD_SEED_VERSION or wallet_type == 'old': d = { @@ -351,7 +302,7 @@ def _convert_wallet_type(self): self.put('keypairs', None) self.put('key_type', None) - def _convert_version_13_b(self): + def _convert_version_13(self): # version 13 is ambiguous, and has an earlier and a later structure if not self._is_upgrade_method_needed(0, 13): return @@ -368,7 +319,7 @@ def _convert_version_13_b(self): self.put('addresses', d) self.put('pubkeys', None) - self.put('seed_version', 13) + self._put_seed_version(13) def _convert_version_14(self): # convert imported wallets for 3.0 @@ -377,14 +328,14 @@ def _convert_version_14(self): if self.get('wallet_type') =='imported': addresses = self.get('addresses') - if type(addresses) is list: + if type(addresses) is StoredList: addresses = dict([(x, None) for x in addresses]) self.put('addresses', addresses) elif self.get('wallet_type') == 'standard': if self.get('keystore').get('type')=='imported': addresses = set(self.get('addresses').get('receiving')) pubkeys = self.get('keystore').get('keypairs').keys() - assert len(addresses) == len(pubkeys) + assert len(addresses) == len(list(pubkeys)) d = {} for pubkey in pubkeys: addr = bitcoin.pubkey_to_address('p2pkh', pubkey) @@ -397,7 +348,7 @@ def _convert_version_14(self): self.put('addresses', d) self.put('pubkeys', None) self.put('wallet_type', 'imported') - self.put('seed_version', 14) + self._put_seed_version(14) def _convert_version_15(self): if not self._is_upgrade_method_needed(14, 14): @@ -405,7 +356,7 @@ def _convert_version_15(self): if self.get('seed_type') == 'segwit': # should not get here; get_seed_version should have caught this raise Exception('unsupported derivation (development segwit, v14)') - self.put('seed_version', 15) + self._put_seed_version(15) def _convert_version_16(self): # fixes issue #3193 for Imported_Wallets with addresses @@ -436,7 +387,7 @@ def remove_from_list(list_name): if self.get('wallet_type') == 'imported': addresses = self.get('addresses') - assert isinstance(addresses, dict) + assert isinstance(addresses, StoredDict) addresses_new = dict() for address, details in addresses.items(): if not bitcoin.is_address(address): @@ -448,7 +399,7 @@ def remove_from_list(list_name): addresses_new[address] = details self.put('addresses', addresses_new) - self.put('seed_version', 16) + self._put_seed_version(16) def _convert_version_17(self): # delete pruned_txo; construct spent_outpoints @@ -469,21 +420,21 @@ def _convert_version_17(self): spent_outpoints[prevout_hash][str(prevout_n)] = txid self.put('spent_outpoints', spent_outpoints) - self.put('seed_version', 17) + self._put_seed_version(17) def _convert_version_18(self): # delete verified_tx3 as its structure changed if not self._is_upgrade_method_needed(17, 17): return self.put('verified_tx3', None) - self.put('seed_version', 18) + self._put_seed_version(18) def _convert_version_19(self): # delete tx_fees as its structure changed if not self._is_upgrade_method_needed(18, 18): return self.put('tx_fees', None) - self.put('seed_version', 19) + self._put_seed_version(19) def _convert_version_20(self): # store 'derivation' (prefix) and 'root_fingerprint' in all xpub-based keystores. @@ -497,6 +448,7 @@ def _convert_version_20(self): for ks_name in ('keystore', *['x{}/'.format(i) for i in range(1, 16)]): ks = self.get(ks_name, None) if ks is None: continue + assert isinstance(ks, StoredDict) xpub = ks.get('xpub', None) if xpub is None: continue bip32node = BIP32Node.from_xkey(xpub) @@ -521,9 +473,8 @@ def _convert_version_20(self): root_fingerprint = bip32node.fingerprint.hex() ks['root_fingerprint'] = root_fingerprint ks.pop('ckcc_xfp', None) - self.put(ks_name, ks) - self.put('seed_version', 20) + self._put_seed_version(20) def _convert_version_21(self): if not self._is_upgrade_method_needed(20, 20): @@ -533,7 +484,7 @@ def _convert_version_21(self): for channel in channels: channel['state'] = 'OPENING' self.put('channels', channels) - self.put('seed_version', 21) + self._put_seed_version(21) def _convert_version_22(self): # construct prevouts_by_scripthash @@ -551,7 +502,7 @@ def _convert_version_22(self): prevouts_by_scripthash[scripthash].append((outpoint, txout.value)) self.put('prevouts_by_scripthash', prevouts_by_scripthash) - self.put('seed_version', 22) + self._put_seed_version(22) def _convert_version_23(self): if not self._is_upgrade_method_needed(22, 22): @@ -577,7 +528,7 @@ def _convert_version_23(self): log[str(int(sub))]['fee_updates'] = d self.data['channels'] = channels - self.data['seed_version'] = 23 + self._put_seed_version(23) def _convert_version_24(self): if not self._is_upgrade_method_needed(23, 23): @@ -596,7 +547,7 @@ def _convert_version_24(self): # convert channels to dict self.data['channels'] = {x['channel_id']: x for x in channels} # convert txi & txo - txi = self.get('txi', {}) + txi = self.data.get('txi', {}) for tx_hash, d in list(txi.items()): d2 = {} for addr, l in d.items(): @@ -604,8 +555,7 @@ def _convert_version_24(self): for ser, v in l: d2[addr][ser] = v txi[tx_hash] = d2 - self.data['txi'] = txi - txo = self.get('txo', {}) + txo = self.data.get('txo', {}) for tx_hash, d in list(txo.items()): d2 = {} for addr, l in d.items(): @@ -613,9 +563,8 @@ def _convert_version_24(self): for n, v, cb in l: d2[addr][str(n)] = (v, cb) txo[tx_hash] = d2 - self.data['txo'] = txo - self.data['seed_version'] = 24 + self._put_seed_version(24) def _convert_version_25(self): from .crypto import sha256 @@ -644,7 +593,7 @@ def _convert_version_25(self): if pr_id != k: continue del invoices[k] - self.data['seed_version'] = 25 + self._put_seed_version(25) def _convert_version_26(self): if not self._is_upgrade_method_needed(25, 25): @@ -659,7 +608,7 @@ def _convert_version_26(self): c['funding_height'] = funding_txid, funding_height, funding_timestamp if closing_txid: c['closing_height'] = closing_txid, closing_height, closing_timestamp - self.data['seed_version'] = 26 + self._put_seed_version(26) def _convert_version_27(self): if not self._is_upgrade_method_needed(26, 26): @@ -667,7 +616,7 @@ def _convert_version_27(self): channels = self.data.get('channels', {}) for channel_id, c in channels.items(): c['local_config']['htlc_minimum_msat'] = 1 - self.data['seed_version'] = 27 + self._put_seed_version(27) def _convert_version_28(self): if not self._is_upgrade_method_needed(27, 27): @@ -675,7 +624,7 @@ def _convert_version_28(self): channels = self.data.get('channels', {}) for channel_id, c in channels.items(): c['local_config']['channel_seed'] = None - self.data['seed_version'] = 28 + self._put_seed_version(28) def _convert_version_29(self): if not self._is_upgrade_method_needed(28, 28): @@ -685,6 +634,7 @@ def _convert_version_29(self): invoices = self.data.get('invoices', {}) for d in [invoices, requests]: for key, r in list(d.items()): + r = r.dump() # convert StoredDict/List to dict/list _type = r.get('type', 0) item = { 'type': _type, @@ -711,7 +661,7 @@ def _convert_version_29(self): 'invoice': r['invoice'], }) d[key] = item - self.data['seed_version'] = 29 + self._put_seed_version(29) def _convert_version_30(self): if not self._is_upgrade_method_needed(29, 29): @@ -734,7 +684,7 @@ def _convert_version_30(self): item.pop('time') else: raise Exception(f"unknown invoice type: {_type}") - self.data['seed_version'] = 30 + self._put_seed_version(30) def _convert_version_31(self): if not self._is_upgrade_method_needed(30, 30): @@ -748,7 +698,7 @@ def _convert_version_31(self): item['amount_sat'] = item['amount_sat'] or 0 item['exp'] = item['exp'] or 0 item['time'] = item['time'] or 0 - self.data['seed_version'] = 31 + self._put_seed_version(31) def _convert_version_32(self): if not self._is_upgrade_method_needed(31, 31): @@ -758,7 +708,7 @@ def _convert_version_32(self): invoices_new = {k: item for k, item in invoices_old.items() if not (item['type'] == PR_TYPE_ONCHAIN and item['outputs'] is None)} self.data['invoices'] = invoices_new - self.data['seed_version'] = 32 + self._put_seed_version(32) def _convert_version_33(self): if not self._is_upgrade_method_needed(32, 32): @@ -770,7 +720,7 @@ def _convert_version_33(self): for key, item in list(d.items()): if item['type'] == PR_TYPE_ONCHAIN: item['height'] = item.get('height') or 0 - self.data['seed_version'] = 33 + self._put_seed_version(33) def _convert_version_34(self): if not self._is_upgrade_method_needed(33, 33): @@ -781,18 +731,20 @@ def _convert_version_34(self): item['local_config'].get('upfront_shutdown_script') or "" item['remote_config']['upfront_shutdown_script'] = \ item['remote_config'].get('upfront_shutdown_script') or "" - self.data['seed_version'] = 34 + self._put_seed_version(34) def _convert_version_35(self): # same as 32, but for payment_requests if not self._is_upgrade_method_needed(34, 34): return PR_TYPE_ONCHAIN = 0 - requests_old = self.data.get('payment_requests', {}) - requests_new = {k: item for k, item in requests_old.items() - if not (item['type'] == PR_TYPE_ONCHAIN and item['outputs'] is None)} - self.data['payment_requests'] = requests_new - self.data['seed_version'] = 35 + payment_requests = self.data.get('payment_requests', {}) + for k in list(payment_requests.keys()): + item = payment_requests[k] + if (item['type'] == PR_TYPE_ONCHAIN and item['outputs'] is None): + payment_requests.pop(k) + + self._put_seed_version(35) def _convert_version_36(self): if not self._is_upgrade_method_needed(35, 35): @@ -800,7 +752,7 @@ def _convert_version_36(self): old_frozen_coins = self.data.get('frozen_coins', []) new_frozen_coins = {coin: True for coin in old_frozen_coins} self.data['frozen_coins'] = new_frozen_coins - self.data['seed_version'] = 36 + self._put_seed_version(36) def _convert_version_37(self): if not self._is_upgrade_method_needed(36, 36): @@ -811,7 +763,7 @@ def _convert_version_37(self): amount_msat = amount_sat * 1000 if amount_sat is not None else None payments[k] = amount_msat, direction, status self.data['lightning_payments'] = payments - self.data['seed_version'] = 37 + self._put_seed_version(37) def _convert_version_38(self): if not self._is_upgrade_method_needed(37, 37): @@ -836,14 +788,14 @@ def _convert_version_38(self): continue if not (isinstance(amount_msat, int) and 0 <= amount_msat <= max_sats * 1000): del d[key] - self.data['seed_version'] = 38 + self._put_seed_version(38) def _convert_version_39(self): # this upgrade prevents initialization of lightning_privkey2 after lightning_xprv has been set if not self._is_upgrade_method_needed(38, 38): return self.data['imported_channel_backups'] = self.data.pop('channel_backups', {}) - self.data['seed_version'] = 39 + self._put_seed_version(39) def _convert_version_40(self): # put 'seed_type' into keystores @@ -866,7 +818,7 @@ def _convert_version_40(self): seed_type = 'old' if seed_type is not None: ks['seed_type'] = seed_type - self.data['seed_version'] = 40 + self._put_seed_version(40) def _convert_version_41(self): # this is a repeat of upgrade 39, to fix wallet backup files (see #7339) @@ -875,7 +827,7 @@ def _convert_version_41(self): imported_channel_backups = self.data.pop('channel_backups', {}) imported_channel_backups.update(self.data.get('imported_channel_backups', {})) self.data['imported_channel_backups'] = imported_channel_backups - self.data['seed_version'] = 41 + self._put_seed_version(41) def _convert_version_42(self): # in OnchainInvoice['outputs'], convert values from None to 0 @@ -889,19 +841,18 @@ def _convert_version_42(self): if item['type'] == PR_TYPE_ONCHAIN: item['outputs'] = [(_type, addr, (val or 0)) for _type, addr, val in item['outputs']] - self.data['seed_version'] = 42 + self._put_seed_version(42) def _convert_version_43(self): if not self._is_upgrade_method_needed(42, 42): return - channels = self.data.pop('channels', {}) + channels = self.data.get('channels', {}) for k, c in channels.items(): log = c['log'] c['fail_htlc_reasons'] = log.pop('fail_htlc_reasons', {}) c['unfulfilled_htlcs'] = log.pop('unfulfilled_htlcs', {}) log["1"]['unacked_updates'] = log.pop('unacked_local_updates2', {}) - self.data['channels'] = channels - self.data['seed_version'] = 43 + self._put_seed_version(43) def _convert_version_44(self): if not self._is_upgrade_method_needed(43, 43): @@ -914,7 +865,7 @@ def _convert_version_44(self): channel_type = ChannelType(0) item.pop('static_remotekey_enabled', None) item['channel_type'] = channel_type - self.data['seed_version'] = 44 + self._put_seed_version(44) def _convert_version_45(self): from .bolt11 import decode_bolt11_invoice @@ -928,6 +879,7 @@ def _convert_version_45(self): for name in ['invoices', 'payment_requests']: invoices = self.data.get(name, {}) for key, item in invoices.items(): + item = item.dump() # convert StoredDict/List to dict/list is_lightning = item['type'] == 2 lightning_invoice = item['invoice'] if is_lightning else None outputs = item['outputs'] if not is_lightning else None @@ -957,7 +909,7 @@ def _convert_version_45(self): 'bip70':bip70, 'lightning_invoice':lightning_invoice, } - self.data['seed_version'] = 45 + self._put_seed_version(45) def _convert_invoices_keys(self, invoices): # recalc keys of outgoing on-chain invoices @@ -966,7 +918,9 @@ def get_id_from_onchain_outputs(raw_outputs, timestamp): outputs = [PartialTxOutput.from_legacy_tuple(*output) for output in raw_outputs] outputs_str = "\n".join(f"{txout.scriptpubkey.hex()}, {txout.value}" for txout in outputs) return sha256d(outputs_str + "%d" % timestamp).hex()[0:10] + assert isinstance(invoices, StoredDict) for key, item in list(invoices.items()): + item = item.dump() # convert StoredDict/List to dict/list is_lightning = item['lightning_invoice'] is not None if is_lightning: continue @@ -981,9 +935,10 @@ def get_id_from_onchain_outputs(raw_outputs, timestamp): def _convert_version_46(self): if not self._is_upgrade_method_needed(45, 45): return - invoices = self.data.get('invoices', {}) - self._convert_invoices_keys(invoices) - self.data['seed_version'] = 46 + invoices = self.data.get('invoices') + if invoices: + self._convert_invoices_keys(invoices) + self._put_seed_version(46) def _convert_version_47(self): from .bolt11 import decode_bolt11_invoice @@ -999,7 +954,7 @@ def _convert_version_47(self): if key != rhash: requests[rhash] = item del requests[key] - self.data['seed_version'] = 47 + self._put_seed_version(47) def _convert_version_48(self): # fix possible corruption of invoice amounts, see #7774 @@ -1009,7 +964,7 @@ def _convert_version_48(self): for key, item in list(invoices.items()): if item['amount_msat'] == 1000 * "!": item['amount_msat'] = "!" - self.data['seed_version'] = 48 + self._put_seed_version(48) def _convert_version_49(self): if not self._is_upgrade_method_needed(48, 48): @@ -1025,14 +980,15 @@ def _convert_version_49(self): f"Please use Electrum 4.3.0 to open this wallet, close the channels, " f"and delete them from the wallet." ) - self.data['seed_version'] = 49 + self._put_seed_version(49) def _convert_version_50(self): if not self._is_upgrade_method_needed(49, 49): return - requests = self.data.get('payment_requests', {}) - self._convert_invoices_keys(requests) - self.data['seed_version'] = 50 + requests = self.data.get('payment_requests') + if requests: + self._convert_invoices_keys(requests) + self._put_seed_version(50) def _convert_version_51(self): from .bolt11 import decode_bolt11_invoice @@ -1047,7 +1003,7 @@ def _convert_version_51(self): lnaddr = decode_bolt11_invoice(lightning_invoice) payment_hash = lnaddr.paymenthash.hex() item['payment_hash'] = payment_hash - self.data['seed_version'] = 51 + self._put_seed_version(51) def _detect_insane_version_51(self) -> int: """Returns 0 if file okay, @@ -1079,7 +1035,7 @@ def _convert_version_52(self): if (error_code := self._detect_insane_version_51()) != 0: # should not get here; get_seed_version should have caught this raise Exception(f'unsupported wallet file: version_51 with error {error_code}') - self.data['seed_version'] = 52 + self._put_seed_version(52) def _convert_version_53(self): if not self._is_upgrade_method_needed(52, 52): @@ -1088,7 +1044,7 @@ def _convert_version_53(self): for channel_id, cb in list(cbs.items()): if 'local_payment_pubkey' not in cb: cb['local_payment_pubkey'] = None - self.data['seed_version'] = 53 + self._put_seed_version(53) def _convert_version_54(self): # note: similar to convert_version_38 @@ -1105,7 +1061,7 @@ def _convert_version_54(self): continue if not (isinstance(amount_msat, int) and 0 <= amount_msat <= max_sats * 1000): del d[key] - self.data['seed_version'] = 54 + self._put_seed_version(54) def _convert_version_55(self): if not self._is_upgrade_method_needed(54, 54): @@ -1113,8 +1069,10 @@ def _convert_version_55(self): # do not use '/' in dict keys for key in list(self.data.keys()): if key.endswith('/'): - self.data[key[:-1]] = self.data.pop(key) - self.data['seed_version'] = 55 + item = self.data.get(key) + self.data[key[:-1]] = item.as_dict() + self.data.pop(key) + self._put_seed_version(55) def _convert_version_56(self): if not self._is_upgrade_method_needed(55, 55): @@ -1126,7 +1084,7 @@ def _convert_version_56(self): item[c]['announcement_node_sig'] = '' item[c]['announcement_bitcoin_sig'] = '' item['local_config'].pop('was_announced') - self.data['seed_version'] = 56 + self._put_seed_version(56) def _convert_version_57(self): if not self._is_upgrade_method_needed(56, 56): @@ -1134,7 +1092,7 @@ def _convert_version_57(self): # The 'seed_type' field could be present both at the top-level and inside keystores. # We delete the one that is top-level. self.data.pop('seed_type', None) - self.data['seed_version'] = 57 + self._put_seed_version(57) def _convert_version_58(self): # re-construct prevouts_by_scripthash @@ -1156,7 +1114,7 @@ def _convert_version_58(self): prevouts_by_scripthash[scripthash] = {} prevouts_by_scripthash[scripthash][outpoint] = txout.value self.put('prevouts_by_scripthash', prevouts_by_scripthash) - self.data['seed_version'] = 58 + self._put_seed_version(58) def _convert_version_59(self): if not self._is_upgrade_method_needed(58, 58): @@ -1168,8 +1126,7 @@ def _convert_version_59(self): for htlc_id, (local_ctn, remote_ctn, onion_packet_hex, forwarding_key) in chan['unfulfilled_htlcs'].items(): unfulfilled_htlcs[htlc_id] = (onion_packet_hex, forwarding_key or None) chan['unfulfilled_htlcs'] = unfulfilled_htlcs - self.data['channels'] = channels - self.data['seed_version'] = 59 + self._put_seed_version(59) def _convert_version_60(self): if not self._is_upgrade_method_needed(59, 59): @@ -1178,7 +1135,7 @@ def _convert_version_60(self): for channel_id, cb in list(cbs.items()): if 'multisig_funding_privkey' not in cb: cb['multisig_funding_privkey'] = None - self.data['seed_version'] = 60 + self._put_seed_version(60) def _convert_version_61(self): if not self._is_upgrade_method_needed(60, 60): @@ -1190,7 +1147,7 @@ def _convert_version_61(self): for rhash, (amount_msat, direction, is_paid) in list(lightning_payments.items()): new = (amount_msat, direction, is_paid, 147, expiry_never, migration_time) lightning_payments[rhash] = new - self.data['seed_version'] = 61 + self._put_seed_version(61) def _convert_version_62(self): if not self._is_upgrade_method_needed(61, 61): @@ -1201,7 +1158,7 @@ def _convert_version_62(self): for swap in swaps.values(): del swap['receive_address'] swap['claim_to_output'] = None - self.data['seed_version'] = 62 + self._put_seed_version(62) def _convert_version_63(self): if not self._is_upgrade_method_needed(62, 62): @@ -1238,14 +1195,17 @@ def _move_unprocessed_onion(short_channel_id: str, htlc_id: Optional[int]) -> Op htlc_data = unfulfilled_htlcs_.get(str(htlc_id)) if htlc_data is None: return None + htlc_data = htlc_data.dump() # StoredList -> list stored_onion_packet, htlc_forwarding_key = htlc_data if stored_onion_packet is not None: htlc_data[0] = None # overwrite the onion so it is not processed again in htlc_switch + unfulfilled_htlcs_[str(htlc_id)] = htlc_data return stored_onion_packet, htlc_forwarding_key return None mpp_sets = self.data.get('received_mpp_htlcs', {}) for payment_key, recv_mpp_status in list(mpp_sets.items()): + recv_mpp_status = recv_mpp_status.dump() # StoredList -> list assert isinstance(recv_mpp_status, list), f"{recv_mpp_status=}" del recv_mpp_status[1] # remove expected_msat @@ -1280,6 +1240,7 @@ def _move_unprocessed_onion(short_channel_id: str, htlc_id: Optional[int]) -> Op # as the htlcs won't get failed due to the new SETTLING state # unless a forwarding error is set. recv_mpp_status[0] = 4 # RecvMPPResolution.SETTLING + mpp_sets[payment_key] = recv_mpp_status # replace Tuple[onion, forwarding_key] with just the onion in chan['unfulfilled_htlcs'] for chan in channels.values(): @@ -1291,7 +1252,7 @@ def _move_unprocessed_onion(short_channel_id: str, htlc_id: Optional[int]) -> Op else: unfulfilled_htlcs[htlc_id] = unprocessed_onion - self.data['seed_version'] = 63 + self._put_seed_version(63) def _convert_version_64(self): """Key payment_info by "rhash:direction" instead of just rhash to allow storing a PaymentInfo @@ -1309,7 +1270,7 @@ def _convert_version_64(self): new_payment_infos[new_key] = new_values # save new entry self.data['lightning_payments'] = new_payment_infos - self.data['seed_version'] = 64 + self._put_seed_version(64) def _convert_version_65(self): """Store channel_id instead of short_channel_id in ReceivedMPPHtlc""" @@ -1326,6 +1287,7 @@ def scid_to_channel_id(scid): mpp_sets = self.data.get('received_mpp_htlcs', {}) new_mpp_sets = {} for payment_key, mpp_set in mpp_sets.items(): + mpp_set = mpp_set.dump() # StoredList -> list if len(mpp_set) == 2: # if the db has received_mpp_htlcs pre version 65 we cannot assume they have parent_set_key # as _convert_version_63 doesn't set it @@ -1341,7 +1303,7 @@ def scid_to_channel_id(scid): new_mpp_sets[payment_key] = (resolution, new_htlc_list, parent_set_key) self.data['received_mpp_htlcs'] = new_mpp_sets - self.data['seed_version'] = 65 + self._put_seed_version(65) def _convert_version_66(self): """Add invoice features to PaymentInfo""" @@ -1357,7 +1319,7 @@ def _convert_version_66(self): new_payment_infos[key] = new_v self.data['lightning_payments'] = new_payment_infos - self.data['seed_version'] = 66 + self._put_seed_version(66) def _convert_version_67(self): if not self._is_upgrade_method_needed(66, 66): @@ -1368,8 +1330,8 @@ def _convert_version_67(self): key = '-1' if is_initiator else '1' assert len(chan['log'][key]['fee_updates']) == 1, chan['log'][key]['fee_updates'] chan['log'][key]['fee_updates'] = {} - self.data['channels'] = channels - self.data['seed_version'] = 67 + #self.data['channels'] = channels + self._put_seed_version(67) def _convert_version_68(self): if not self._is_upgrade_method_needed(67, 67): @@ -1379,7 +1341,7 @@ def _convert_version_68(self): for _hash, preimage in old_preimages.items(): new_preimages[_hash] = (preimage, False) self.data['lightning_preimages'] = new_preimages - self.data['seed_version'] = 68 + self._put_seed_version(68) def _convert_version_69(self): """Convert PaymentInfo amounts from 0 to None""" @@ -1398,7 +1360,7 @@ def _convert_version_69(self): new_v = (amount_msat, *old_v[1:]) new_payment_infos[key] = new_v self.data['lightning_payments'] = new_payment_infos - self.data['seed_version'] = 69 + self._put_seed_version(69) def _convert_version_70(self): """ @@ -1412,7 +1374,7 @@ def _convert_version_70(self): for amount_sat, timestamp in connection.get('budget_spends', []): new_budget_spends.append([amount_sat * 1000, timestamp]) connection['budget_spends'] = new_budget_spends - self.data['seed_version'] = 70 + self._put_seed_version(70) def _convert_version_71(self): """Save 'genesis_blockhash' in DB.""" @@ -1433,7 +1395,7 @@ def _convert_version_71(self): f"e.g. {neutered_addr} (len={len(first_address)})") # if so, save genesis hash self.data['genesis_blockhash'] = constants.net.GENESIS - self.data['seed_version'] = 71 + self._put_seed_version(71) def _convert_imported(self): if not self._is_upgrade_method_needed(0, 13): @@ -1481,6 +1443,9 @@ def _is_upgrade_method_needed(self, min_version, max_version): else: return True + def _put_seed_version(self, n): + self.data['seed_version'] = n + def get_seed_version(self): seed_version = self.get('seed_version') if not seed_version: @@ -1532,7 +1497,7 @@ def _raise_unsupported_version(self, seed_version): raise WalletFileException(msg) -def upgrade_wallet_db(data: dict, do_upgrade: bool) -> Tuple[dict, bool]: +def upgrade_wallet_db(data: 'StoredDict', do_upgrade: bool) -> Tuple[dict, bool]: was_upgraded = False if len(data) == 0: @@ -1545,7 +1510,7 @@ def upgrade_wallet_db(data: dict, do_upgrade: bool) -> Tuple[dict, bool]: first_electrum_version_used=ELECTRUM_VERSION, ) assert data.get("db_metadata", None) is None - data["db_metadata"] = v.to_json() + data["db_metadata"] = v.as_dict() was_upgraded = True # Test mainnet/testnet mixup. Do this before DB upgrades, as those might assume # network magic bytes (e.g. if they parse an address or an xpub). @@ -1555,6 +1520,7 @@ def upgrade_wallet_db(data: dict, do_upgrade: bool) -> Tuple[dict, bool]: "Current chain: {}").format(constants.net.NET_NAME) ) + data._db._should_convert = False dbu = WalletDBUpgrader(data) if dbu.requires_split(): raise WalletRequiresSplit(dbu.get_split_accounts()) @@ -1563,30 +1529,62 @@ def upgrade_wallet_db(data: dict, do_upgrade: bool) -> Tuple[dict, bool]: was_upgraded = True if dbu.requires_upgrade(): raise WalletRequiresUpgrade() - return dbu.data, was_upgraded - - -class WalletDB(JsonDB): - - def __init__( - self, - s: str, - *, - storage: Optional['WalletStorage'] = None, - upgrade: bool = False, - ): - JsonDB.__init__( - self, - s, - storage=storage, - encoder=MyEncoder, - upgrader=partial(upgrade_wallet_db, do_upgrade=upgrade), - ) + data._db._should_convert = True + return was_upgraded + + + +@stored_at('/txo/*/*/*', tuple) +class TxoValue(NamedTuple): + value: int + is_coinbase: bool + + +class WalletDB(Logger): + + def __init__(self, data: 'StoredDict', upgrade: bool = True): + Logger.__init__(self) + self.storage = data + self.lock = self.storage.lock + # we must perform db upgrades on the storeddict + was_upgraded = upgrade_wallet_db(self.storage, upgrade) + #self._modified |= was_upgraded + # create pointers self.load_transactions() # load plugins that are conditional on wallet type self.load_plugins() + @locked + def put(self, key, value): + # raises if value cannot be serialized by db + if value is not None: + if self.storage.get(key) != value: + self.storage[key] = copy.deepcopy(value) + return True + elif key in self.storage: + self.storage.pop(key) + return True + return False + + @locked + def get(self, key, default=None) -> Any: + # returns dict or list in place of StoredDict/StoredList + v = self.storage.get(key, default) + if isinstance(v, (StoredDict, StoredList)): + v = v.dump() + return v + + @locked + def get_dict(self, name) -> StoredDict: + # side effect: creates DB entry + return self.storage.get(name, {}, add_if_missing=True) + + @locked + def get_list(self, name) -> StoredList: + # side effect: creates DB entry + return self.storage.get(name, [], add_if_missing=True) + @locked def get_seed_version(self): return self.get('seed_version') @@ -1621,9 +1619,9 @@ def get_txo_addr(self, tx_hash: str, address: str) -> Dict[int, Tuple[int, bool] assert isinstance(tx_hash, str) assert isinstance(address, str) d = self.txo.get(tx_hash, {}).get(address, {}) - return {int(n): (v, cb) for (n, (v, cb)) in d.items()} + return {int(n): (item.value, item.is_coinbase) for (n, item) in d.items()} - @modifier + @locked def add_txi_addr(self, tx_hash: str, addr: str, ser: str, v: int) -> None: assert isinstance(tx_hash, str) assert isinstance(addr, str) @@ -1636,7 +1634,7 @@ def add_txi_addr(self, tx_hash: str, addr: str, ser: str, v: int) -> None: d[addr] = {} d[addr][ser] = v - @modifier + @locked def add_txo_addr(self, tx_hash: str, addr: str, n: Union[int, str], v: int, is_coinbase: bool) -> None: n = str(n) assert isinstance(tx_hash, str) @@ -1649,7 +1647,7 @@ def add_txo_addr(self, tx_hash: str, addr: str, n: Union[int, str], v: int, is_c d = self.txo[tx_hash] if addr not in d: d[addr] = {} - d[addr][n] = (v, is_coinbase) + d[addr][n] = TxoValue(v, is_coinbase) @locked def list_txi(self) -> Sequence[str]: @@ -1659,12 +1657,12 @@ def list_txi(self) -> Sequence[str]: def list_txo(self) -> Sequence[str]: return list(self.txo.keys()) - @modifier + @locked def remove_txi(self, tx_hash: str) -> None: assert isinstance(tx_hash, str) self.txi.pop(tx_hash, None) - @modifier + @locked def remove_txo(self, tx_hash: str) -> None: assert isinstance(tx_hash, str) self.txo.pop(tx_hash, None) @@ -1687,7 +1685,7 @@ def get_spent_outpoint(self, prevout_hash: str, prevout_n: Union[int, str]) -> O prevout_n = str(prevout_n) return self.spent_outpoints.get(prevout_hash, {}).get(prevout_n) - @modifier + @locked def remove_spent_outpoint(self, prevout_hash: str, prevout_n: Union[int, str]) -> None: assert isinstance(prevout_hash, str) prevout_n = str(prevout_n) @@ -1695,7 +1693,7 @@ def remove_spent_outpoint(self, prevout_hash: str, prevout_n: Union[int, str]) - if not self.spent_outpoints[prevout_hash]: self.spent_outpoints.pop(prevout_hash) - @modifier + @locked def set_spent_outpoint(self, prevout_hash: str, prevout_n: Union[int, str], tx_hash: str) -> None: assert isinstance(prevout_hash, str) assert isinstance(tx_hash, str) @@ -1704,7 +1702,7 @@ def set_spent_outpoint(self, prevout_hash: str, prevout_n: Union[int, str], tx_h self.spent_outpoints[prevout_hash] = {} self.spent_outpoints[prevout_hash][prevout_n] = tx_hash - @modifier + @locked def add_prevout_by_scripthash(self, scripthash: str, *, prevout: TxOutpoint, value: int) -> None: assert isinstance(scripthash, str) assert isinstance(prevout, TxOutpoint) @@ -1713,7 +1711,7 @@ def add_prevout_by_scripthash(self, scripthash: str, *, prevout: TxOutpoint, val self._prevouts_by_scripthash[scripthash] = dict() self._prevouts_by_scripthash[scripthash][prevout.to_str()] = value - @modifier + @locked def remove_prevout_by_scripthash(self, scripthash: str, *, prevout: TxOutpoint, value: int) -> None: assert isinstance(scripthash, str) assert isinstance(prevout, TxOutpoint) @@ -1728,7 +1726,7 @@ def get_prevouts_by_scripthash(self, scripthash: str) -> Set[Tuple[TxOutpoint, i prevouts_and_values = self._prevouts_by_scripthash.get(scripthash, {}) return {(TxOutpoint.from_str(prevout), value) for prevout, value in prevouts_and_values.items()} - @modifier + @locked def add_transaction(self, tx_hash: str, tx: Transaction) -> None: assert isinstance(tx_hash, str) assert isinstance(tx, Transaction), tx @@ -1744,7 +1742,7 @@ def add_transaction(self, tx_hash: str, tx: Transaction) -> None: if tx_we_already_have is None or isinstance(tx_we_already_have, PartialTransaction): self.transactions[tx_hash] = tx - @modifier + @locked def remove_transaction(self, tx_hash: str) -> Optional[Transaction]: assert isinstance(tx_hash, str) return self.transactions.pop(tx_hash, None) @@ -1774,12 +1772,12 @@ def get_addr_history(self, addr: str) -> Sequence[Tuple[str, int]]: assert isinstance(addr, str) return self.history.get(addr, []) - @modifier + @locked def set_addr_history(self, addr: str, hist) -> None: assert isinstance(addr, str) self.history[addr] = hist - @modifier + @locked def remove_addr_history(self, addr: str) -> None: assert isinstance(addr, str) self.history.pop(addr, None) @@ -1793,22 +1791,17 @@ def get_verified_tx(self, txid: str) -> Optional[TxMinedInfo]: assert isinstance(txid, str) if txid not in self.verified_tx: return None - height, timestamp, txpos, header_hash = self.verified_tx[txid] - return TxMinedInfo(_height=height, - conf=None, - timestamp=timestamp, - txpos=txpos, - header_hash=header_hash) - - @modifier + return self.verified_tx[txid] + + @locked def add_verified_tx(self, txid: str, info: TxMinedInfo): assert isinstance(txid, str) assert isinstance(info, TxMinedInfo) height = info._height # number of conf is dynamic and might not be set here assert height > 0, height - self.verified_tx[txid] = (height, info.timestamp, info.txpos, info.header_hash) + self.verified_tx[txid] = info - @modifier + @locked def remove_verified_tx(self, txid: str): assert isinstance(txid, str) self.verified_tx.pop(txid, None) @@ -1817,7 +1810,7 @@ def is_in_verified_tx(self, txid: str) -> bool: assert isinstance(txid, str) return txid in self.verified_tx - @modifier + @locked def add_tx_fee_from_server(self, txid: str, fee_sat: Optional[int]) -> None: assert isinstance(txid, str) assert fee_sat is None or isinstance(fee_sat, int) @@ -1829,7 +1822,7 @@ def add_tx_fee_from_server(self, txid: str, fee_sat: Optional[int]) -> None: return self.tx_fees[txid] = tx_fees_value._replace(fee=fee_sat, is_calculated_by_us=False) - @modifier + @locked def add_tx_fee_we_calculated(self, txid: str, fee_sat: Optional[int]) -> None: assert isinstance(txid, str) if fee_sat is None: @@ -1850,7 +1843,7 @@ def get_tx_fee(self, txid: str, *, trust_server: bool = False) -> Optional[int]: return None return tx_fees_value.fee - @modifier + @locked def add_num_inputs_to_tx(self, txid: str, num_inputs: int) -> None: assert isinstance(txid, str) assert isinstance(num_inputs, int) @@ -1872,7 +1865,7 @@ def get_num_ismine_inputs_of_tx(self, txid: str) -> int: txins = self.txi.get(txid, {}) return sum([len(tupls) for addr, tupls in txins.items()]) - @modifier + @locked def remove_tx_fee(self, txid: str) -> None: assert isinstance(txid, str) self.tx_fees.pop(txid, None) @@ -1895,13 +1888,13 @@ def get_receiving_addresses(self, *, slice_start=None, slice_stop=None) -> List[ # note: slicing makes a shallow copy return self.receiving_addresses[slice_start:slice_stop] - @modifier + @locked def add_change_address(self, addr: str) -> None: assert isinstance(addr, str) self._addr_to_addr_index[addr] = (1, len(self.change_addresses)) self.change_addresses.append(addr) - @modifier + @locked def add_receiving_address(self, addr: str) -> None: assert isinstance(addr, str) self._addr_to_addr_index[addr] = (0, len(self.receiving_addresses)) @@ -1912,12 +1905,12 @@ def get_address_index(self, address: str) -> Optional[Sequence[int]]: assert isinstance(address, str) return self._addr_to_addr_index.get(address) - @modifier + @locked def add_imported_address(self, addr: str, d: dict) -> None: assert isinstance(addr, str) self.imported_addresses[addr] = d - @modifier + @locked def remove_imported_address(self, addr: str) -> None: assert isinstance(addr, str) self.imported_addresses.pop(addr) @@ -1941,12 +1934,12 @@ def load_addresses(self, wallet_type): if wallet_type == 'imported': self.imported_addresses = self.get_dict('addresses') # type: Dict[str, dict] else: - self.get_dict('addresses') + addresses = self.get_dict('addresses') for name in ['receiving', 'change']: - if name not in self.data['addresses']: - self.data['addresses'][name] = [] - self.change_addresses = self.data['addresses']['change'] - self.receiving_addresses = self.data['addresses']['receiving'] + if name not in addresses: + addresses[name] = [] + self.change_addresses = self.storage['addresses']['change'] + self.receiving_addresses = self.storage['addresses']['receiving'] self._addr_to_addr_index = {} # type: Dict[str, Sequence[int]] # key: address, value: (is_change, index) for i, addr in enumerate(self.receiving_addresses): self._addr_to_addr_index[addr] = (0, i) @@ -1955,7 +1948,7 @@ def load_addresses(self, wallet_type): @profiler def load_transactions(self): - # references in self.data + # references in self.storage # TODO make all these private # txid -> address -> prev_outpoint -> value self.txi = self.get_dict('txi') # type: Dict[str, Dict[str, Dict[str, int]]] @@ -1981,7 +1974,7 @@ def load_transactions(self): self.logger.info("removing unreferenced spent outpoint") d.pop(prevout_n) - @modifier + @locked def clear_history(self): self.txi.clear() self.txo.clear() @@ -1992,23 +1985,17 @@ def clear_history(self): self.tx_fees.clear() self._prevouts_by_scripthash.clear() - def _should_convert_to_stored_dict(self, key) -> bool: - if key == 'keystore': - return False - multisig_keystore_names = [('x%d' % i) for i in range(1, 16)] - if key in multisig_keystore_names: - return False - return True - @classmethod def split_accounts(klass, root_path, split_data): - from .storage import WalletStorage + # not covered by tests + from .stored_dict import DictStorage file_list = [] for data in split_data: path = root_path + '.' + data['suffix'] - item_storage = WalletStorage(path) - db = WalletDB(json.dumps(data), storage=item_storage, upgrade=True) - db.write() + storage = DictStorage(path) + storage.set_data(json.dumps(data)) + db = WalletDB(storage, upgrade=True) + storage.write() file_list.append(path) return file_list diff --git a/electrum/wizard.py b/electrum/wizard.py index 6b2e7b65d2cd..cefe93866040 100644 --- a/electrum/wizard.py +++ b/electrum/wizard.py @@ -12,7 +12,8 @@ from electrum.network import ProxySettings from electrum.plugin import run_hook from electrum.slip39 import EncryptedSeed -from electrum.storage import WalletStorage, StorageEncryptionVersion, StorageReadWriteError +from electrum.stored_dict import StorageReadWriteError +from electrum.stored_dict import DictStorage, PasswordType from electrum.util import UserFacingException from electrum.wallet_db import WalletDB from electrum.bip32 import normalize_bip32_derivation, xpub_type @@ -687,7 +688,7 @@ def create_storage(self, path: str, data: dict): if os.path.exists(path): raise UserFacingException(_('File already exists at path: {}').format(path)) try: - storage = WalletStorage(path) + storage = DictStorage(path) except StorageReadWriteError as e: raise UserFacingException(e) @@ -772,15 +773,15 @@ def create_storage(self, path: str, data: dict): if k and k.may_have_password(): k.update_password(None, data['password']) - if data['encrypt']: + if data['password'] and data['encrypt']: if data.get('xpub_encrypt'): assert data.get('keystore_type') == 'hardware' and data['wallet_type'] == 'standard' - enc_version = StorageEncryptionVersion.XPUB_PASSWORD + password_type = PasswordType.XPUB else: - enc_version = StorageEncryptionVersion.USER_PASSWORD - storage.set_password(data['password'], enc_version=enc_version) + password_type = PasswordType.USER + storage.add_password(data['password'], password_type) - db = WalletDB('', storage=storage, upgrade=True) + db = WalletDB(storage) db.set_keystore_encryption(bool(data['password'])) db.put('wallet_type', data['wallet_type']) @@ -823,7 +824,8 @@ def create_storage(self, path: str, data: dict): db.put('lightning_xprv', k.get_lightning_xprv(data['password'])) db.load_plugins() - db.write() + storage.write() + storage.close() class ServerConnectWizard(AbstractWizard): diff --git a/run_electrum b/run_electrum index 8ee74de9eb4b..5d9820ecfd8a 100755 --- a/run_electrum +++ b/run_electrum @@ -124,7 +124,7 @@ from electrum.payment_identifier import PaymentIdentifier from electrum import SimpleConfig from electrum.wallet_db import WalletDB from electrum.wallet import Wallet -from electrum.storage import WalletStorage +from electrum.stored_dict import DictStorage from electrum.util import print_msg, print_stderr, json_encode, json_decode, UserCancelled from electrum.util import InvalidPassword from electrum.plugin import Plugins @@ -167,8 +167,8 @@ def init_cmdline(config_options, wallet_path, *, rpcserver: bool, config: 'Simpl print_msg("wallet path not provided.") sys_exit(1) - # instantiate wallet for command-line - storage = WalletStorage(wallet_path, allow_partial_writes=config.WALLET_PARTIAL_WRITES) if wallet_path else None + # instantiate storage without opening the DB, so that we can check if it is encrypted + storage = DictStorage(wallet_path, init_db=False) if wallet_path else None if cmd.requires_wallet and not storage.file_exists(): print_msg("Error: Wallet file not found.") @@ -250,13 +250,13 @@ async def run_offline_command(config: 'SimpleConfig', config_options: dict, wall if 'wallet_path' in cmd.options and config_options.get('wallet_path') is None: config_options['wallet_path'] = wallet_path if cmd.requires_wallet: - storage = WalletStorage(wallet_path, allow_partial_writes=config.WALLET_PARTIAL_WRITES) + storage = DictStorage(wallet_path, allow_partial_writes=config.WALLET_PARTIAL_WRITES) if storage.is_encrypted(): if storage.is_encrypted_with_hw_device(): password = get_password_for_hw_device_encrypted_storage(plugins) config_options['password'] = password storage.decrypt(password) - db = WalletDB(storage.read(), storage=storage, upgrade=True) + db = WalletDB(storage) wallet = Wallet(db, config=config) config_options['wallet'] = wallet else: @@ -282,9 +282,6 @@ async def run_offline_command(config: 'SimpleConfig', config_options: dict, wall cmd_runner = Commands(config=config) func = getattr(cmd_runner, cmd.name) result = await func(*args, **kwargs) - # save wallet - if wallet: - wallet.save_db() return result diff --git a/tests/plugins/test_timelock_recovery.py b/tests/plugins/test_timelock_recovery.py index 8c57b19c6d28..b57156888834 100644 --- a/tests/plugins/test_timelock_recovery.py +++ b/tests/plugins/test_timelock_recovery.py @@ -5,8 +5,8 @@ from electrum.bitcoin import address_to_script from electrum.fee_policy import FixedFeePolicy from electrum.simple_config import SimpleConfig -from electrum.storage import WalletStorage from electrum.transaction import PartialTxOutput +from electrum.stored_dict import DictStorage from electrum.wallet import Wallet from electrum.wallet_db import WalletDB @@ -36,8 +36,9 @@ def tearDown(self): def _create_default_wallet(self): with open(os.path.join(os.path.dirname(__file__), "test_timelock_recovery", "default_wallet"), "r") as f: wallet_str = f.read() - storage = WalletStorage(self.wallet_path) - db = WalletDB(wallet_str, storage=storage, upgrade=True) + storage = DictStorage(self.wallet_path) + storage.set_data(wallet_str) + db = WalletDB(storage, upgrade=True) wallet = Wallet(db, config=self.config) return wallet diff --git a/tests/test_bitcoin.py b/tests/test_bitcoin.py index 78a6b523d96a..d649593ef0ee 100644 --- a/tests/test_bitcoin.py +++ b/tests/test_bitcoin.py @@ -28,7 +28,7 @@ from electrum.crypto import sha256d, SUPPORTED_PW_HASH_VERSIONS from electrum import crypto, constants from electrum.util import bfh, InvalidPassword, randrange -from electrum.storage import WalletStorage +from electrum.storage import FileStorage from electrum.keystore import xtype_from_derivation from . import ElectrumTestCase @@ -270,7 +270,7 @@ def test_signmessage_segwit_witness_v0_address_test_we_also_accept_sigs_from_tre @needs_test_with_all_aes_implementations def test_decrypt_message(self): - key = WalletStorage.get_eckey_from_password('pw123') + key = FileStorage.get_old_eckey_from_password('pw123') self.assertEqual(b'me<(s_s)>age', crypto.ecies_decrypt_message( key, b'QklFMQMDFtgT3zWSQsa+Uie8H/WvfUjlu9UN9OJtTt3KlgKeSTi6SQfuhcg1uIz9hp3WIUOFGTLr4RNQBdjPNqzXwhkcPi2Xsbiw6UCNJncVPJ6QBg==')) self.assertEqual(b'me<(s_s)>age', crypto.ecies_decrypt_message( @@ -280,7 +280,7 @@ def test_decrypt_message(self): @needs_test_with_all_aes_implementations def test_encrypt_message(self): - key = WalletStorage.get_eckey_from_password('secret_password77') + key = FileStorage.get_old_eckey_from_password('secret_password77') msgs = [ bytes([0] * 555), b'cannot think of anything funny' diff --git a/tests/test_commands.py b/tests/test_commands.py index 81801464e107..152850841dac 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -10,8 +10,9 @@ import electrum from electrum.commands import Commands, eval_bool -from electrum import storage, wallet +from electrum import storage from electrum.lnutil import RECEIVED, channel_id_from_funding_tx +from electrum.lnutil import ReceivedMPPStatus, UpdateAddHtlc, ReceivedMPPHtlc from electrum.lnworker import RecvMPPResolution from electrum.wallet import Abstract_Wallet from electrum.address_synchronizer import TX_HEIGHT_UNCONFIRMED @@ -209,8 +210,9 @@ async def asyncSetUp(self): assert self.daemon.network is None async def asyncTearDown(self): - with mock.patch.object(wallet.Abstract_Wallet, 'save_db'): - await self.daemon.stop() + with mock.patch.object(storage.FileStorage, 'write'): + with mock.patch.object(storage.FileStorage, 'append'): + await self.daemon.stop() await super().asyncTearDown() async def test_convert_xkey(self): @@ -536,16 +538,28 @@ async def test_hold_invoice_commands(self): wallet=wallet, ) - mock_htlc1 = mock.Mock() - mock_htlc1.htlc.cltv_abs = 800_000 - mock_htlc1.htlc.amount_msat = 4_500_000 - mock_htlc2 = mock.Mock() - mock_htlc2.htlc.cltv_abs = 800_144 - mock_htlc2.htlc.amount_msat = 5_500_000 - mock_htlc_status = mock.Mock() - mock_htlc_status.htlcs = [mock_htlc1, mock_htlc2] - mock_htlc_status.resolution = RecvMPPResolution.COMPLETE - + mock_htlc1 = ReceivedMPPHtlc( + channel_id='', + htlc = UpdateAddHtlc( + cltv_abs = 800_000, + amount_msat = 4_500_000, + payment_hash=bytes(32), + ), + unprocessed_onion='', + ) + mock_htlc2 = ReceivedMPPHtlc( + channel_id = '', + htlc = UpdateAddHtlc( + cltv_abs = 800_144, + amount_msat = 5_500_000, + payment_hash=bytes(32), + ), + unprocessed_onion = '', + ) + mock_htlc_status = ReceivedMPPStatus( + htlcs = [mock_htlc1, mock_htlc2], + resolution = RecvMPPResolution.COMPLETE, + ) payment_key = wallet.lnworker._get_payment_key(bytes.fromhex(payment_hash)).hex() with mock.patch.dict(wallet.lnworker.received_mpp_htlcs, {payment_key: mock_htlc_status}): status: dict = await cmds.check_hold_invoice(payment_hash=payment_hash, wallet=wallet) @@ -574,8 +588,8 @@ async def test_hold_invoice_commands(self): # cancelling a settled invoice should raise await cmds.cancel_hold_invoice(payment_hash=payment_hash, wallet=wallet) - @mock.patch.object(storage.WalletStorage, 'write') - @mock.patch.object(storage.WalletStorage, 'append') + @mock.patch.object(storage.FileStorage, 'write') + @mock.patch.object(storage.FileStorage, 'append') async def test_onchain_history(self, *mock_args): cmds = Commands(config=self.config, daemon=self.daemon) wallet_path = self.get_wallet_file_path("client_3_3_8_xpub_with_realistic_history") diff --git a/tests/test_jsondb.py b/tests/test_jsondb.py index bb07eaf1a286..69feae5bf37c 100644 --- a/tests/test_jsondb.py +++ b/tests/test_jsondb.py @@ -2,6 +2,7 @@ import copy import traceback import json +import os from typing import Any import jsonpatch @@ -10,7 +11,7 @@ from . import ElectrumTestCase -from electrum.json_db import JsonDB +from electrum.stored_dict import DictStorage class TestJsonpatch(ElectrumTestCase): @@ -89,16 +90,6 @@ def fail_if_leaking_secret(ctx) -> None: fail_if_leaking_secret(ctx) -def pop1_from_dict(d: dict, key: str) -> Any: - return d.pop(key) - - -def pop2_from_dict(d: dict, key: str) -> Any: - val = d[key] - del d[key] - return val - - class TestJsonDB(ElectrumTestCase): async def test_jsonpatch_replace_after_remove(self): @@ -119,36 +110,22 @@ async def test_jsonpatch_replace_after_remove(self): with self.assertRaises(JsonPatchException): data = jpatch.apply(data) - async def test_jsondb_replace_after_remove(self): - for pop_from_dict in [pop1_from_dict, pop2_from_dict]: - with self.subTest(pop_from_dict): - data = { 'a': {'b': {'c': 0}}, 'd': 3} - db = JsonDB(repr(data)) - a = db.get_dict('a') - # remove - b = pop_from_dict(a, 'b') - self.assertEqual(len(db.pending_changes), 1) - # replace item. this must not been written to db - b['c'] = 42 - self.assertEqual(len(db.pending_changes), 1) - patches = json.loads('[' + ','.join(db.pending_changes) + ']') - jpatch = jsonpatch.JsonPatch(patches) - data = jpatch.apply(data) - self.assertEqual(data, {'a': {}, 'd': 3}) - - async def test_jsondb_replace_after_remove_nested(self): - for pop_from_dict in [pop1_from_dict, pop2_from_dict]: - with self.subTest(pop_from_dict): - data = { 'a': {'b': {'c': 0}}, 'd': 3} - db = JsonDB(repr(data)) - # remove - a = pop_from_dict(db.data, "a") - self.assertEqual(len(db.pending_changes), 1) - b = a['b'] - # replace item. this must not be written to db - b['c'] = 42 - self.assertEqual(len(db.pending_changes), 1) - patches = json.loads('[' + ','.join(db.pending_changes) + ']') - jpatch = jsonpatch.JsonPatch(patches) - data = jpatch.apply(data) - self.assertEqual(data, {'d': 3}) + async def test_jsondb_partial_write_round_test(self): + wallet_path = os.path.join(self.electrum_path, "somewallet") + storage = DictStorage(wallet_path, allow_partial_writes=True) + storage['a'] = [1, 2, 3] + storage._db.write() + storage['a'].append(4) + storage._db.write() + storage = DictStorage(wallet_path, allow_partial_writes=True) + self.assertEqual(len(storage['a']), 4) + + async def test_jsondb_list_clear(self): + wallet_path = os.path.join(self.electrum_path, "somewallet") + storage = DictStorage(wallet_path, allow_partial_writes=True) + storage['a'] = [1, 2, 3] + storage._db.write() + storage['a'].clear() + storage._db.write() + storage = DictStorage(wallet_path, allow_partial_writes=True) + self.assertEqual(len(storage['a']), 0) diff --git a/tests/test_lnchannel.py b/tests/test_lnchannel.py index 9eedcd54f8da..e8cf8cb88cbc 100644 --- a/tests/test_lnchannel.py +++ b/tests/test_lnchannel.py @@ -40,6 +40,7 @@ ) from electrum.logging import console_stderr_handler from electrum.lnchannel import ChannelState, Channel +from electrum.coinchooser import PRNG from . import ElectrumTestCase from .lnhelpers import create_test_channels @@ -589,7 +590,7 @@ async def test_update_unfunded_zeroconf_channel(self): self.assertTrue(chan.is_zeroconf()) # add channel to lnwallet/db bob._channels[chan.channel_id] = chan - bob.db.get('channels')[chan.channel_id.hex()] = "something" + bob.db.get_dict('channels')[chan.channel_id.hex()] = "something" self.assertIsNotNone(bob.get_channel_by_id(chan.channel_id)) chan.storage['init_height'] = 0 # checked by has_funding_timed_out chan.storage['init_timestamp'] = int(time.time()) @@ -601,7 +602,7 @@ async def test_update_unfunded_zeroconf_channel(self): # assert nothing happened self.assertIsNotNone(bob.get_channel_by_id(chan.channel_id)) - self.assertIsNotNone(bob.db.get('channels').get(chan.channel_id.hex())) + self.assertIsNotNone(bob.db.get_dict('channels').get(chan.channel_id.hex())) self.assertEqual(chan.get_state(), ChannelState.OPEN) self.assertEqual(bob.config.ZEROCONF_TRUSTED_NODE, trusted_node) @@ -613,7 +614,7 @@ async def test_update_unfunded_zeroconf_channel(self): # assert nothing happened again self.assertIsNotNone(bob.get_channel_by_id(chan.channel_id)) - self.assertIsNotNone(bob.db.get('channels').get(chan.channel_id.hex())) + self.assertIsNotNone(bob.db.get_dict('channels').get(chan.channel_id.hex())) self.assertEqual(chan.get_state(), ChannelState.OPEN) self.assertEqual(bob.config.ZEROCONF_TRUSTED_NODE, trusted_node) self.assertFalse(chan.is_frozen_for_receiving()) @@ -636,7 +637,7 @@ async def test_update_unfunded_zeroconf_channel(self): # check that channel got removed, now that funding has timed out self.assertIsNone(self.alice_lnwallet.get_channel_by_id(chan.channel_id)) - self.assertIsNone(self.alice_lnwallet.db.get('channels').get(chan.channel_id.hex())) + self.assertIsNone(self.alice_lnwallet.db.get_dict('channels').get(chan.channel_id.hex())) async def test_should_be_closed_due_to_expiring_htlcs_offered_htlcs(self): alice_lnwallet = self.create_mock_lnwallet(name="alice") diff --git a/tests/test_lnhtlc.py b/tests/test_lnhtlc.py index c8ac53db39b8..5cac3dd6bbbc 100644 --- a/tests/test_lnhtlc.py +++ b/tests/test_lnhtlc.py @@ -4,7 +4,6 @@ from electrum.lnutil import RECEIVED, LOCAL, REMOTE, SENT, HTLCOwner, Direction from electrum.lnhtlc import HTLCManager -from electrum.json_db import StoredDict from . import ElectrumTestCase @@ -14,8 +13,8 @@ class H(NamedTuple): class TestHTLCManager(ElectrumTestCase): def test_adding_htlcs_race(self): - A = HTLCManager(StoredDict({}, None)) - B = HTLCManager(StoredDict({}, None)) + A = HTLCManager({}) + B = HTLCManager({}) A.channel_open_finished() B.channel_open_finished() ah0, bh0 = H('A', 0), H('B', 0) @@ -61,8 +60,8 @@ def test_adding_htlcs_race(self): def test_single_htlc_full_lifecycle(self): def htlc_lifecycle(htlc_success: bool): - A = HTLCManager(StoredDict({}, None)) - B = HTLCManager(StoredDict({}, None)) + A = HTLCManager({}) + B = HTLCManager({}) A.channel_open_finished() B.channel_open_finished() B.recv_htlc(A.send_htlc(H('A', 0))) @@ -134,8 +133,8 @@ def htlc_lifecycle(htlc_success: bool): def test_remove_htlc_while_owing_commitment(self): def htlc_lifecycle(htlc_success: bool): - A = HTLCManager(StoredDict({}, None)) - B = HTLCManager(StoredDict({}, None)) + A = HTLCManager({}) + B = HTLCManager({}) A.channel_open_finished() B.channel_open_finished() ah0 = H('A', 0) @@ -171,8 +170,8 @@ def htlc_lifecycle(htlc_success: bool): htlc_lifecycle(htlc_success=False) def test_adding_htlc_between_send_ctx_and_recv_rev(self): - A = HTLCManager(StoredDict({}, None)) - B = HTLCManager(StoredDict({}, None)) + A = HTLCManager({}) + B = HTLCManager({}) A.channel_open_finished() B.channel_open_finished() A.send_ctx() @@ -217,8 +216,8 @@ def test_adding_htlc_between_send_ctx_and_recv_rev(self): self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_next_ctx(REMOTE)) def test_unacked_local_updates(self): - A = HTLCManager(StoredDict({}, None)) - B = HTLCManager(StoredDict({}, None)) + A = HTLCManager({}) + B = HTLCManager({}) A.channel_open_finished() B.channel_open_finished() self.assertEqual({}, A.get_unacked_local_updates()) diff --git a/tests/test_lnutil.py b/tests/test_lnutil.py index 5df8755ce524..f32df8e5b1c0 100644 --- a/tests/test_lnutil.py +++ b/tests/test_lnutil.py @@ -3,7 +3,6 @@ from typing import Dict, List from electrum import bitcoin -from electrum.json_db import StoredDict from electrum.lnutil import ( RevocationStore, get_per_commitment_secret_from_seed, make_offered_htlc, make_received_htlc, make_commitment, make_htlc_tx_witness, make_htlc_tx_output, make_htlc_tx_inputs, secret_to_pubkey, derive_blinded_pubkey, @@ -12,7 +11,8 @@ IncompatibleLightningFeatures, ChannelType, offered_htlc_trim_threshold_sat, received_htlc_trim_threshold_sat, ImportedChannelBackupStorage, list_enabled_ln_feature_bits, PaymentFeeBudget, LnFeatureContexts ) -from electrum.util import bfh, MyEncoder +from electrum.util import bfh +from electrum.stored_dict import to_default, DictStorage from electrum.transaction import Transaction, PartialTransaction, Sighash from electrum.lnworker import LNWallet from electrum.wallet import Standard_Wallet @@ -474,10 +474,11 @@ def test_shachain_store(self): ] for test in tests: - receiver = RevocationStore(StoredDict({}, None)) + storage = DictStorage(None) + storage.set_data(json.dumps({"channels": {"0": { "revocation_store": {}}}})) + receiver = RevocationStore(storage["channels"]["0"]["revocation_store"]) for insert in test["inserts"]: secret = bytes.fromhex(insert["secret"]) - try: receiver.add_next_entry(secret) except Exception as e: @@ -497,7 +498,9 @@ def test_shachain_store(self): def test_shachain_produce_consume(self): seed = bitcoin.sha256(b"shachaintest") - consumer = RevocationStore(StoredDict({}, None)) + storage = DictStorage(None) + storage.set_data(json.dumps({"channels": {"0": { "revocation_store": {}}}})) + consumer = RevocationStore(storage["channels"]["0"]["revocation_store"]) for i in range(10000): secret = get_per_commitment_secret_from_seed(seed, RevocationStore.START_INDEX - i) try: @@ -506,9 +509,11 @@ def test_shachain_produce_consume(self): raise Exception("iteration " + str(i) + ": " + str(e)) if i % 1000 == 0: c1 = consumer - s1 = json.dumps(c1.storage, cls=MyEncoder) - c2 = RevocationStore(StoredDict(json.loads(s1), None)) - s2 = json.dumps(c2.storage, cls=MyEncoder) + s1 = json.dumps(storage._db.json_data, default=to_default) + storage2 = DictStorage(None) + storage2.set_data(s1) + c2 = RevocationStore(storage2["channels"]["0"]["revocation_store"]) + s2 = json.dumps(storage2._db.json_data, default=to_default) self.assertEqual(s1, s2) def test_commitment_tx_with_all_five_HTLCs_untrimmed_minimum_feerate(self): diff --git a/tests/test_storage_upgrade.py b/tests/test_storage_upgrade.py index 5e36a7a3b5d9..01bd3eae7fad 100644 --- a/tests/test_storage_upgrade.py +++ b/tests/test_storage_upgrade.py @@ -7,6 +7,8 @@ import inspect import electrum +from electrum.stored_dict import DictStorage +from electrum.stored_dict import StoredDict from electrum.wallet_db import WalletDBUpgrader, WalletDB, WalletRequiresUpgrade, WalletRequiresSplit from electrum.wallet import Wallet from electrum import constants @@ -289,6 +291,7 @@ async def test_upgrade_from_client_3_2_3_ledger_standard_keystore_changes(self): # see #6066 wallet_str = self._get_wallet_str() db = await self._upgrade_storage(wallet_str) + assert not db.storage.is_closed() wallet = Wallet(db, config=self.config) ks = wallet.keystore # to simulate ks.opportunistically_fill_in_missing_info_from_device(): @@ -300,6 +303,7 @@ async def test_upgrade_from_client_2_9_3_importedkeys_keystore_changes(self): # see #6401 wallet_str = self._get_wallet_str() db = await self._upgrade_storage(wallet_str) + assert not db.storage.is_closed() wallet = Wallet(db, config=self.config) wallet.import_private_keys( ["p2wpkh:L1cgMEnShp73r9iCukoPE3MogLeueNYRD9JVsfT1zVHyPBR3KqBY"], @@ -357,7 +361,6 @@ async def _upgrade_storage(self, wallet_json, accounts=1) -> Optional[WalletDB]: db = self._load_db_from_json_string( wallet_json=wallet_json, upgrade=True) - await self._sanity_check_upgraded_db(db) return db else: try: @@ -369,14 +372,13 @@ async def _upgrade_storage(self, wallet_json, accounts=1) -> Optional[WalletDB]: self.assertEqual(accounts, len(split_data)) for item in split_data: data = json.dumps(item) - new_db = WalletDB(data, storage=None, upgrade=True) - await self._sanity_check_upgraded_db(new_db) - - async def _sanity_check_upgraded_db(self, db): - wallet = Wallet(db, config=self.config) - await wallet.stop() + storage = DictStorage(None) + storage.set_data(data) + new_db = WalletDB(storage, upgrade=True) @staticmethod def _load_db_from_json_string(*, wallet_json, upgrade): - db = WalletDB(wallet_json, storage=None, upgrade=upgrade) + storage = DictStorage(None) + storage.set_data(wallet_json) + db = WalletDB(storage, upgrade=upgrade) return db diff --git a/tests/test_stored_dict.py b/tests/test_stored_dict.py new file mode 100644 index 000000000000..97c3ba5ab035 --- /dev/null +++ b/tests/test_stored_dict.py @@ -0,0 +1,100 @@ +import tempfile +import sys +import os +import json +import time +from io import StringIO +import asyncio +from pathlib import Path + +from electrum.stored_dict import DictStorage, StoredDict + + + +from . import ElectrumTestCase + + +class TestStorage(ElectrumTestCase): + + def setUp(self): + super(TestStorage, self).setUp() + self.path = os.path.join(self.electrum_path, "somewallet") + + self._saved_stdout = sys.stdout + self._stdout_buffer = StringIO() + sys.stdout = self._stdout_buffer + + def tearDown(self): + super(TestStorage, self).tearDown() + # Restore the "real" stdout + sys.stdout = self._saved_stdout + + def test_db_roundtrip(self): + sd = DictStorage(self.path) + # list containing list and dict + some_list = [[1, 2], {"c": "d"} ] + sd['1'] = some_list + self.assertEqual(sd['1'].dump(), some_list) + # dict containing list and dict + some_dict = {"a": [1, 2], "b": {"c":"d"} } + sd['2'] = some_dict + self.assertEqual(sd['2'].dump(), some_dict) + # simple tuple. + some_tuple = (1, 2, 3) + sd['3'] = some_tuple + self.assertEqual(sd['3'], some_tuple) + # complex tuple: the third element is a StoredDict + complex_tuple = (1, 2, [3, 4]) + sd['4'] = complex_tuple + with self.assertRaises(AssertionError): + self.assertEqual(sd['4'], complex_tuple) + self.assertEqual(sd['4'][2].dump(), complex_tuple[2]) + + def test_db_iterators(self): + sd = DictStorage(self.path) + sd['a'] = [0, 1, 2, 3, 4] + sl = sd.get('a') + self.assertEqual(len(sl), 5) + for i, v in enumerate(sl): + self.assertEqual(i, v) + + def test_write_batch(self): + # test that batches are written atomically + sd = DictStorage(self.path) + with sd.write_batch(): + sd['a'] = 0 + self.assertEqual(len(sd), 1) + with sd.write_batch(): + sd['a'] = 1 + self.assertEqual(len(sd), 1) + try: + with sd.write_batch(): + sd['b'] = 1 + raise Exception('blah') + except Exception as e: + pass + self.assertEqual(sd._db._write_batch, False) + # at this point, the StoredDict length is 1 + self.assertEqual(len(sd), 1) + sd.close() + # check that changes have not been written to disk + sd = DictStorage(self.path) + self.assertEqual(len(sd), 1) + + async def test_dangling_dict(self): + storage = DictStorage(self.path) + storage['a'] = {'b': {'c': 0}} + storage.write() + a = storage.get('a') + b = a['b'] + self.assertEqual(type(b), StoredDict) + b2 = a.pop('b') + self.assertEqual(type(b2), dict) + # replace item. this must not been written to db + with self.assertRaises(KeyError): + b['c'] = 42 + storage.write() + storage.close() + storage = DictStorage(self.path) + self.assertEqual(storage.dump(), {'a':{}}) + diff --git a/tests/test_txbatcher.py b/tests/test_txbatcher.py index 692cc2c210e5..2f1208d454e3 100644 --- a/tests/test_txbatcher.py +++ b/tests/test_txbatcher.py @@ -6,9 +6,9 @@ from aiorpcx import timeout_after import electrum.fee_policy -from electrum import keystore, wallet, lnutil +from electrum import keystore, lnutil from electrum import SimpleConfig -from electrum import util +from electrum import util, storage from electrum.address_synchronizer import TX_HEIGHT_UNCONFIRMED from electrum.transaction import Transaction, PartialTxInput, PartialTxOutput, TxOutpoint from electrum.logging import console_stderr_handler, Logger @@ -144,8 +144,8 @@ def _create_wallet(self): self.network.wallets.append(wallet) return wallet - @mock.patch.object(wallet.Abstract_Wallet, 'save_db') - async def test_batch_payments(self, mock_save_db): + @mock.patch.object(storage.FileStorage, 'append') + async def test_batch_payments(self, mock_append): # output 1: tx1(o1) --------------- # \ # output 2: tx1'(o1,o2) ----> tx2(tx1|o2) @@ -190,8 +190,8 @@ async def test_batch_payments(self, mock_save_db): assert tx2.inputs()[0].prevout.txid.hex() == tx1.txid() - @mock.patch.object(wallet.Abstract_Wallet, 'save_db') - async def test_rbf_batching__cannot_batch_as_would_need_to_use_ismine_outputs_of_basetx(self, mock_save_db): + @mock.patch.object(storage.FileStorage, 'append') + async def test_rbf_batching__cannot_batch_as_would_need_to_use_ismine_outputs_of_basetx(self, mock_append): """Wallet history contains unconf tx1 that spends all its coins to two ismine outputs, one 'recv' address (20k sats) and one 'change' (80k sats). The user tries to create tx2, that pays an invoice for 90k sats. @@ -228,8 +228,8 @@ async def test_rbf_batching__cannot_batch_as_would_need_to_use_ismine_outputs_of assert output2 in tx2.outputs() - @mock.patch.object(wallet.Abstract_Wallet, 'save_db') - async def test_sweep_from_submarine_swap(self, mock_save_db): + @mock.patch.object(storage.FileStorage, 'append') + async def test_sweep_from_submarine_swap(self, mock_append): self.maxDiff = None # create wallet wallet = self._create_wallet() diff --git a/tests/test_wallet.py b/tests/test_wallet.py index de1a278a6ce1..b8d4523cc952 100644 --- a/tests/test_wallet.py +++ b/tests/test_wallet.py @@ -10,14 +10,14 @@ from unittest import mock from pathlib import Path -from electrum.storage import WalletStorage from electrum.wallet_db import FINAL_SEED_VERSION from electrum.wallet import (Abstract_Wallet, Standard_Wallet, create_new_wallet, Imported_Wallet, Wallet) from electrum.exchange_rate import ExchangeBase, FxThread from electrum.util import TxMinedInfo, InvalidPassword from electrum.bitcoin import COIN -from electrum.wallet_db import WalletDB, JsonDB +from electrum.wallet_db import WalletDB +from electrum.stored_dict import DictStorage, PasswordType from electrum.simple_config import SimpleConfig from electrum import util, storage from electrum.daemon import Daemon @@ -66,15 +66,13 @@ def test_read_dictionary_from_file(self): with open(self.wallet_path, "w") as f: contents = f.write(contents) - storage = WalletStorage(self.wallet_path) - db = JsonDB(storage.read(), storage=storage) + db = DictStorage(self.wallet_path) self.assertEqual("b", db.get("a")) self.assertEqual("d", db.get("c")) def test_write_dictionary_to_file(self): - storage = WalletStorage(self.wallet_path) - db = JsonDB('', storage=storage) + db = DictStorage(self.wallet_path) some_dict = { u"a": u"b", @@ -82,7 +80,7 @@ def test_write_dictionary_to_file(self): u"seed_version": FINAL_SEED_VERSION} for key, value in some_dict.items(): - db.put(key, value) + db[key] = value db.write() with open(self.wallet_path, "r") as f: @@ -91,6 +89,26 @@ def test_write_dictionary_to_file(self): for key, value in some_dict.items(): self.assertEqual(d[key], value) + def test_add_update_remove_password(self): + storage = DictStorage(self.wallet_path) + pw1 = "123456" + pw2 = "789012" + pw3 = "tttttt" + storage.add_password(pw1, PasswordType.USER) + storage.add_password(pw2, PasswordType.USER) + self.assertTrue(storage.is_encrypted()) + with self.assertRaises(InvalidPassword): + storage.remove_password(pw3) + storage.remove_password(pw1) + self.assertTrue(storage.is_encrypted()) + with self.assertRaises(InvalidPassword): + storage.remove_password(pw1) + with self.assertRaises(InvalidPassword): + storage.update_password(pw1, pw3, PasswordType.USER) + storage.update_password(pw2, pw3, PasswordType.USER) + storage.remove_password(pw3) + self.assertFalse(storage.is_encrypted()) + async def test_storage_imported_add_privkeys_persistence_test(self): text = ' '.join([ 'p2wpkh:L4jkdiXszG26SUYvwwJhzGwg37H2nLhrbip7u6crmgNeJysv5FHL', @@ -172,7 +190,8 @@ class FakeWallet: def __init__(self, fiat_value): super().__init__() self.fiat_value = fiat_value - self.db = WalletDB('', storage=None, upgrade=False) + storage = DictStorage(None) + self.db = WalletDB(storage) self.adb = FakeADB() self.db.transactions = self.db.verified_tx = {'abc':'Tx'} @@ -234,8 +253,8 @@ def tearDown(self): time.tzset() @mock.patch('electrum.wallet.run_hook') - @mock.patch.object(storage.WalletStorage, 'write') - @mock.patch.object(storage.WalletStorage, 'append') + @mock.patch.object(storage.FileStorage, 'write') + @mock.patch.object(storage.FileStorage, 'append') async def test_export_history_to_file(self, _mock_append, _mock_write, mock_run_hook): # prepare wallet with realistic history c = self.config @@ -272,6 +291,8 @@ async def test_export_history_to_file(self, _mock_append, _mock_write, mock_run_ # compare line by line for more readable traceback on difference for reference, test in zip(reference_text, test_export_text): self.assertEqual(reference, test) + # stop wallet + await daemon._stop_wallet(wallet_path) class TestCreateRestoreWallet(WalletTestCase): @@ -325,7 +346,7 @@ async def test_restore_wallet_from_text_no_storage(self): config=self.config, ) wallet = d['wallet'] # type: Standard_Wallet - self.assertEqual(None, wallet.storage) + self.assertEqual(None, wallet.storage._db.storage) self.assertEqual(text, wallet.keystore.get_seed(None)) self.assertEqual('bc1q3g5tmkmlvxryhh843v4dz026avatc0zzr6h3af', wallet.get_receiving_addresses()[0]) @@ -378,8 +399,9 @@ class TestWalletPassword(WalletTestCase): async def test_update_password_of_imported_wallet(self): wallet_str = '{"addr_history":{"1364Js2VG66BwRdkaoxAaFtdPb1eQgn8Dr":[],"15CyDgLffJsJgQrhcyooFH4gnVDG82pUrA":[],"1Exet2BhHsFxKTwhnfdsBMkPYLGvobxuW6":[]},"addresses":{"change":[],"receiving":["1364Js2VG66BwRdkaoxAaFtdPb1eQgn8Dr","1Exet2BhHsFxKTwhnfdsBMkPYLGvobxuW6","15CyDgLffJsJgQrhcyooFH4gnVDG82pUrA"]},"keystore":{"keypairs":{"0344b1588589958b0bcab03435061539e9bcf54677c104904044e4f8901f4ebdf5":"L2sED74axVXC4H8szBJ4rQJrkfem7UMc6usLCPUoEWxDCFGUaGUM","0389508c13999d08ffae0f434a085f4185922d64765c0bff2f66e36ad7f745cc5f":"L3Gi6EQLvYw8gEEUckmqawkevfj9s8hxoQDFveQJGZHTfyWnbk1U","04575f52b82f159fa649d2a4c353eb7435f30206f0a6cb9674fbd659f45082c37d559ffd19bea9c0d3b7dcc07a7b79f4cffb76026d5d4dff35341efe99056e22d2":"5JyVyXU1LiRXATvRTQvR9Kp8Rx1X84j2x49iGkjSsXipydtByUq"},"type":"imported"},"pruned_txo":{},"seed_version":13,"stored_height":-1,"transactions":{},"tx_fees":{},"txi":{},"txo":{},"use_encryption":false,"verified_tx3":{},"wallet_type":"standard","winpos-qt":[100,100,840,405]}' - storage = WalletStorage(self.wallet_path) - db = WalletDB(wallet_str, storage=storage, upgrade=True) + storage = DictStorage(self.wallet_path) + storage.set_data(wallet_str) + db = WalletDB(storage) wallet = Wallet(db, config=self.config) wallet.check_password(None) @@ -394,8 +416,9 @@ async def test_update_password_of_imported_wallet(self): async def test_update_password_of_standard_wallet(self): wallet_str = '''{"addr_history":{"12ECgkzK6gHouKAZ7QiooYBuk1CgJLJxes":[],"12iR43FPb5M7sw4Mcrr5y1nHKepg9EtZP1":[],"13HT1pfWctsSXVFzF76uYuVdQvcAQ2MAgB":[],"13kG9WH9JqS7hyCcVL1ssLdNv4aXocQY9c":[],"14Tf3qiiHJXStSU4KmienAhHfHq7FHpBpz":[],"14gmBxYV97mzYwWdJSJ3MTLbTHVegaKrcA":[],"15FGuHvRssu1r8fCw98vrbpfc3M4xs5FAV":[],"17oJzweA2gn6SDjsKgA9vUD5ocT1sSnr2Z":[],"18hNcSjZzRcRP6J2bfFRxp9UfpMoC4hGTv":[],"18n9PFxBjmKCGhd4PCDEEqYsi2CsnEfn2B":[],"19a98ZfEezDNbCwidVigV5PAJwrR2kw4Jz":[],"19z3j2ELqbg2pR87byCCt3BCyKR7rc3q8G":[],"1A3XSmvLQvePmvm7yctsGkBMX9ZKKXLrVq":[],"1CmhFe2BN1h9jheFpJf4v39XNPj8F9U6d":[],"1DuphhHUayKzbkdvjVjf5dtjn2ACkz4zEs":[],"1E4ygSNJpWL2uPXZHBptmU2LqwZTqb1Ado":[],"1GTDSjkVc9vaaBBBGNVqTANHJBcoT5VW9z":[],"1GWqgpThAuSq3tDg6uCoLQxPXQNnU8jZ52":[],"1GhmpwqSF5cqNgdr9oJMZx8dKxPRo4pYPP":[],"1J5TTUQKhwehEACw6Jjte1E22FVrbeDmpv":[],"1JWySzjzJhsETUUcqVZHuvQLA7pfFfmesb":[],"1KQHxcy3QUHAWMHKUtJjqD9cMKXcY2RTwZ":[],"1KoxZfc2KsgovjGDxwqanbFEA76uxgYH4G":[],"1KqVEPXdpbYvEbwsZcEKkrA4A2jsgj9hYN":[],"1N16yDSYe76c5A3CoVoWAKxHeAUc8Jhf9J":[],"1Pm8JBhzUJDqeQQKrmnop1Frr4phe1jbTt":[]},"addresses":{"change":["1GhmpwqSF5cqNgdr9oJMZx8dKxPRo4pYPP","1GTDSjkVc9vaaBBBGNVqTANHJBcoT5VW9z","15FGuHvRssu1r8fCw98vrbpfc3M4xs5FAV","1A3XSmvLQvePmvm7yctsGkBMX9ZKKXLrVq","19z3j2ELqbg2pR87byCCt3BCyKR7rc3q8G","1JWySzjzJhsETUUcqVZHuvQLA7pfFfmesb"],"receiving":["14gmBxYV97mzYwWdJSJ3MTLbTHVegaKrcA","13HT1pfWctsSXVFzF76uYuVdQvcAQ2MAgB","19a98ZfEezDNbCwidVigV5PAJwrR2kw4Jz","1J5TTUQKhwehEACw6Jjte1E22FVrbeDmpv","1Pm8JBhzUJDqeQQKrmnop1Frr4phe1jbTt","13kG9WH9JqS7hyCcVL1ssLdNv4aXocQY9c","1KQHxcy3QUHAWMHKUtJjqD9cMKXcY2RTwZ","12ECgkzK6gHouKAZ7QiooYBuk1CgJLJxes","12iR43FPb5M7sw4Mcrr5y1nHKepg9EtZP1","14Tf3qiiHJXStSU4KmienAhHfHq7FHpBpz","1KqVEPXdpbYvEbwsZcEKkrA4A2jsgj9hYN","17oJzweA2gn6SDjsKgA9vUD5ocT1sSnr2Z","1E4ygSNJpWL2uPXZHBptmU2LqwZTqb1Ado","18hNcSjZzRcRP6J2bfFRxp9UfpMoC4hGTv","1KoxZfc2KsgovjGDxwqanbFEA76uxgYH4G","18n9PFxBjmKCGhd4PCDEEqYsi2CsnEfn2B","1CmhFe2BN1h9jheFpJf4v39XNPj8F9U6d","1DuphhHUayKzbkdvjVjf5dtjn2ACkz4zEs","1GWqgpThAuSq3tDg6uCoLQxPXQNnU8jZ52","1N16yDSYe76c5A3CoVoWAKxHeAUc8Jhf9J"]},"keystore":{"seed":"cereal wise two govern top pet frog nut rule sketch bundle logic","type":"bip32","xprv":"xprv9s21ZrQH143K29XjRjUs6MnDB9wXjXbJP2kG1fnRk8zjdDYWqVkQYUqaDtgZp5zPSrH5PZQJs8sU25HrUgT1WdgsPU8GbifKurtMYg37d4v","xpub":"xpub661MyMwAqRbcEdcCXm1sTViwjBn28zK9kFfrp4C3JUXiW1sfP34f6HA45B9yr7EH5XGzWuTfMTdqpt9XPrVQVUdgiYb5NW9m8ij1FSZgGBF"},"pruned_txo":{},"seed_type":"standard","seed_version":13,"stored_height":-1,"transactions":{},"tx_fees":{},"txi":{},"txo":{},"use_encryption":false,"verified_tx3":{},"wallet_type":"standard","winpos-qt":[619,310,840,405]}''' - storage = WalletStorage(self.wallet_path) - db = WalletDB(wallet_str, storage=storage, upgrade=True) + storage = DictStorage(path=self.wallet_path) + storage.set_data(wallet_str) + db = WalletDB(storage) wallet = Wallet(db, config=self.config) wallet.check_password(None) @@ -423,15 +446,16 @@ async def test_update_password_of_standard_wallet_oldseed(self): async def test_update_password_with_app_restarts(self): wallet_str = '{"addr_history":{"1364Js2VG66BwRdkaoxAaFtdPb1eQgn8Dr":[],"15CyDgLffJsJgQrhcyooFH4gnVDG82pUrA":[],"1Exet2BhHsFxKTwhnfdsBMkPYLGvobxuW6":[]},"addresses":{"change":[],"receiving":["1364Js2VG66BwRdkaoxAaFtdPb1eQgn8Dr","1Exet2BhHsFxKTwhnfdsBMkPYLGvobxuW6","15CyDgLffJsJgQrhcyooFH4gnVDG82pUrA"]},"keystore":{"keypairs":{"0344b1588589958b0bcab03435061539e9bcf54677c104904044e4f8901f4ebdf5":"L2sED74axVXC4H8szBJ4rQJrkfem7UMc6usLCPUoEWxDCFGUaGUM","0389508c13999d08ffae0f434a085f4185922d64765c0bff2f66e36ad7f745cc5f":"L3Gi6EQLvYw8gEEUckmqawkevfj9s8hxoQDFveQJGZHTfyWnbk1U","04575f52b82f159fa649d2a4c353eb7435f30206f0a6cb9674fbd659f45082c37d559ffd19bea9c0d3b7dcc07a7b79f4cffb76026d5d4dff35341efe99056e22d2":"5JyVyXU1LiRXATvRTQvR9Kp8Rx1X84j2x49iGkjSsXipydtByUq"},"type":"imported"},"pruned_txo":{},"seed_version":13,"stored_height":-1,"transactions":{},"tx_fees":{},"txi":{},"txo":{},"use_encryption":false,"verified_tx3":{},"wallet_type":"standard","winpos-qt":[100,100,840,405]}' - storage = WalletStorage(self.wallet_path) - db = WalletDB(wallet_str, storage=storage, upgrade=True) + storage = DictStorage(self.wallet_path) + storage.set_data(wallet_str) + db = WalletDB(storage) wallet = Wallet(db, config=self.config) await wallet.stop() - storage = WalletStorage(self.wallet_path) + storage = DictStorage(self.wallet_path) # if storage.is_encrypted(): # storage.decrypt(password) - db = WalletDB(storage.read(), storage=storage, upgrade=True) + db = WalletDB(storage) wallet = Wallet(db, config=self.config) wallet.check_password(None) diff --git a/tests/test_wallet_vertical.py b/tests/test_wallet_vertical.py index 8dbb0d7c42f6..6484a96c3db7 100644 --- a/tests/test_wallet_vertical.py +++ b/tests/test_wallet_vertical.py @@ -8,7 +8,7 @@ from electrum import bitcoin, keystore, bip32, slip39 from electrum.wallet_db import WalletDB -from electrum.storage import WalletStorage +from electrum.stored_dict import DictStorage from electrum import SimpleConfig from electrum import util from electrum.address_synchronizer import TX_HEIGHT_UNCONFIRMED, TX_HEIGHT_UNCONF_PARENT, TX_HEIGHT_LOCAL, TX_HEIGHT_FUTURE @@ -55,7 +55,8 @@ def check_xpub_keystore_sanity(cls, test_obj, ks): @classmethod def create_standard_wallet(cls, ks, *, config: SimpleConfig, gap_limit=None, gap_limit_for_change=None): - db = WalletDB('', storage=None, upgrade=True) + storage = DictStorage(None) + db = WalletDB(storage) db.put('keystore', ks.dump()) db.put('gap_limit', gap_limit or cls.gap_limit) db.put('gap_limit_for_change', gap_limit_for_change or cls.gap_limit_for_change) @@ -65,7 +66,8 @@ def create_standard_wallet(cls, ks, *, config: SimpleConfig, gap_limit=None, gap @classmethod def create_imported_wallet(cls, *, config: SimpleConfig, privkeys: bool): - db = WalletDB('', storage=None, upgrade=True) + storage = DictStorage(None) + db = WalletDB(storage) if privkeys: k = keystore.Imported_KeyStore({}) db.put('keystore', k.dump()) @@ -79,12 +81,14 @@ def create_multisig_wallet( multisig_type: str, *, config: SimpleConfig, - storage: WalletStorage | None = None, + storage: DictStorage | None = None, gap_limit=None, gap_limit_for_change=None, ): """Creates a multisig wallet.""" - db = WalletDB('', storage=storage, upgrade=False) + if storage is None: + storage = DictStorage(None) + db = WalletDB(storage) for i, ks in enumerate(keystores): cosigner_index = i + 1 db.put('x%d' % cosigner_index, ks.dump()) diff --git a/tests/test_wizard.py b/tests/test_wizard.py index 7c5b740a8611..cd82152f2424 100644 --- a/tests/test_wizard.py +++ b/tests/test_wizard.py @@ -12,7 +12,7 @@ from electrum import slip39 from electrum.bip32 import KeyOriginInfo from electrum import keystore -from electrum.storage import WalletStorage +from electrum.stored_dict import DictStorage from . import ElectrumTestCase from .test_wallet_vertical import UNICODE_HORROR, WalletIntegrityHelper @@ -403,7 +403,7 @@ async def test_multisig(self): ], '2of2', config=self.config, - storage=WalletStorage(self.wallet_path), + storage=DictStorage(self.wallet_path), ) w, v = self._wizard_for(wallet_type=wallet.wallet_type)