diff --git a/src/jsonlogic/registry.py b/src/jsonlogic/registry.py index c0dbb03..f1a7a64 100644 --- a/src/jsonlogic/registry.py +++ b/src/jsonlogic/registry.py @@ -117,27 +117,22 @@ def remove(self, operator_id: str, /) -> None: self._registry.pop(operator_id, None) - def copy(self) -> Self: - """Create a new instance of the registry.""" - - new = self.__class__() - new._registry = self._registry.copy() - return new - - def with_operator(self, operator_id: str, operator_type: OperatorType, *, force: bool = False) -> Self: - """Create a new instance of the registry with the provided operator. + def copy(self, *, extend: OperatorRegistry | dict[str, OperatorType] | None = None, force: bool = False) -> Self: + """Create a new instance of the registry. Args: - operator_id: The ID to be used to register the operator. - operator_type: The class object of the operator. - force: Whether to override any existing operator under the provided ID. - - Raises: - AlreadyRegistered: If :paramref:`force` wasn't set and the ID already exists. + - extend: A registry or a mapping to use to register new operators + while doing the copy. + - force: Whether to override any existing operator under the provided ID. Returns: A new instance of the registry. """ - new = self.copy() - new.register(operator_id, operator_type, force=force) + + new = self.__class__() + new._registry = self._registry.copy() + overrides = extend._registry if isinstance(extend, OperatorRegistry) else extend + if overrides is not None: + for id, operator in overrides.items(): + new.register(operator_id=id, operator_type=operator, force=force) return new diff --git a/tests/test_registry.py b/tests/test_registry.py index 3eb45cc..b22dc3a 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -89,12 +89,19 @@ class Var(Operator): assert copy._registry == registry._registry -def test_with_operator(): +def test_copy_with_operator(): class Var(Operator): pass registry = OperatorRegistry() - copy = registry.with_operator("var", Var) + copy_dict = registry.copy(extend={"var": Var}) - assert copy._registry == {"var": Var} + assert copy_dict._registry == {"var": Var} + + other = OperatorRegistry() + other.register("var", Var) + + copy_registry = registry.copy(extend=other) + + assert copy_registry._registry == {"var": Var}