diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 26a6ebe97..9fd21fa36 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -18,7 +18,7 @@ jobs: - python-version: 3.9 numpy: "numpy" uncertainties: "uncertainties" - extras: "sparse xarray netCDF4 dask[complete]==2023.4.0 graphviz babel==2.8" + extras: "sparse xarray netCDF4 dask[complete]==2023.4.0 graphviz babel==2.8 mip>=1.13" runs-on: ubuntu-latest env: diff --git a/CHANGES b/CHANGES index 0b8d625d3..01f5644c0 100644 --- a/CHANGES +++ b/CHANGES @@ -1,15 +1,37 @@ Pint Changelog ============== -0.23 (unreleased) +0.24 (unreleased) ----------------- - Add `dim_sort` parameter to formatter. (PR #1864, fixes Issue #1841) +- Nothing changed yet. + + +0.23 (2023-12-08) +----------------- + +- Add _get_conversion_factor to registry with cache. +- Homogenize input and ouput of internal regitry functions to + facility typing, subclassing and wrapping. + (_yield_unit_triplets, ) +- Generated downstream_status page to track the + state of downstream projects. +- Improve typing annotation. +- Updated to flexparser 0.2. +- Faster wraps + (PR #1862) +- Add codspeed github action. +- Move benchmarks to pytest-benchmarks. +- Support pytest on python 3.12 wrt Fraction formatting change + (#1818) - Fixed Transformation type protocol. (PR #1805, PR #1832) - Documented to_preferred and created added an autoautoconvert_to_preferred registry option. (PR #1803) +- Enable Pint to parse uncertainty numbers. + (See #1611, #1614) - Optimize matplotlib unit conversion for Quantity arrays (PR #1819) - Add numpy.linalg.norm implementation. diff --git a/docs/changes.rst b/docs/changes.rst new file mode 100644 index 000000000..d6c5f48c7 --- /dev/null +++ b/docs/changes.rst @@ -0,0 +1 @@ +.. include:: ../CHANGES diff --git a/docs/getting/index.rst b/docs/getting/index.rst index 41ffaf93f..95de7e5a5 100644 --- a/docs/getting/index.rst +++ b/docs/getting/index.rst @@ -8,7 +8,7 @@ The getting started guide aims to get you using pint productively as quickly as Installation ------------ -Pint has no dependencies except Python itself. In runs on Python 3.9+. +Pint has no dependencies except Python itself. It runs on Python 3.9+. .. grid:: 2 diff --git a/docs/index.rst b/docs/index.rst index 8c60992b9..a2bc6454c 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -70,6 +70,7 @@ Pint: makes units easy Advanced topics ecosystem API Reference + changes .. toctree:: :maxdepth: 1 diff --git a/pint/_vendor/flexparser.py b/pint/_vendor/flexparser.py index 8945b6ed5..cac3c2b49 100644 --- a/pint/_vendor/flexparser.py +++ b/pint/_vendor/flexparser.py @@ -17,6 +17,7 @@ from __future__ import annotations +import sys import collections import dataclasses import enum @@ -27,19 +28,102 @@ import logging import pathlib import re -import sys import typing as ty -from collections.abc import Iterator from dataclasses import dataclass from functools import cached_property from importlib import resources -from typing import Optional, Tuple, Type +from typing import Any, Union, Optional, no_type_check + +if sys.version_info >= (3, 10): + from typing import TypeAlias # noqa +else: + from typing_extensions import TypeAlias # noqa + + +if sys.version_info >= (3, 11): + from typing import Self # noqa +else: + from typing_extensions import Self # noqa + _LOGGER = logging.getLogger("flexparser") _SENTINEL = object() +class HasherProtocol(ty.Protocol): + @property + def name(self) -> str: + ... + + def hexdigest(self) -> str: + ... + + +class GenericInfo: + _specialized: Optional[ + dict[type, Optional[list[tuple[type, dict[ty.TypeVar, type]]]]] + ] = None + + @staticmethod + def _summarize(d: dict[ty.TypeVar, type]) -> dict[ty.TypeVar, type]: + d = d.copy() + while True: + for k, v in d.items(): + if isinstance(v, ty.TypeVar): + d[k] = d[v] + break + else: + return d + + del d[v] + + @classmethod + def _specialization(cls) -> dict[ty.TypeVar, type]: + if cls._specialized is None: + return dict() + + out: dict[ty.TypeVar, type] = {} + specialized = cls._specialized[cls] + + if specialized is None: + return {} + + for parent, content in specialized: + for tvar, typ in content.items(): + out[tvar] = typ + origin = getattr(parent, "__origin__", None) + if origin is not None and origin in cls._specialized: + out = {**origin._specialization(), **out} + + return out + + @classmethod + def specialization(cls) -> dict[ty.TypeVar, type]: + return GenericInfo._summarize(cls._specialization()) + + def __init_subclass__(cls) -> None: + if cls._specialized is None: + cls._specialized = {GenericInfo: None} + + tv: list[ty.TypeVar] = [] + entries: list[tuple[type, dict[ty.TypeVar, type]]] = [] + + for par in getattr(cls, "__parameters__", ()): + if isinstance(par, ty.TypeVar): + tv.append(par) + + for b in getattr(cls, "__orig_bases__", ()): + for k in cls._specialized.keys(): + if getattr(b, "__origin__", None) is k: + entries.append((b, {k: v for k, v in zip(tv, b.__args__)})) + break + + cls._specialized[cls] = entries + + return super().__init_subclass__() + + ################ # Exceptions ################ @@ -49,53 +133,66 @@ class Statement: """Base class for parsed elements within a source file.""" - start_line: int = dataclasses.field(init=False, default=None) - start_col: int = dataclasses.field(init=False, default=None) + is_position_set: bool = dataclasses.field(init=False, default=False, repr=False) + + start_line: int = dataclasses.field(init=False, default=0) + start_col: int = dataclasses.field(init=False, default=0) - end_line: int = dataclasses.field(init=False, default=None) - end_col: int = dataclasses.field(init=False, default=None) + end_line: int = dataclasses.field(init=False, default=0) + end_col: int = dataclasses.field(init=False, default=0) - raw: str = dataclasses.field(init=False, default=None) + raw: Optional[str] = dataclasses.field(init=False, default=None) @classmethod - def from_statement(cls, statement: Statement): + def from_statement(cls, statement: Statement) -> Self: out = cls() - out.set_position(*statement.get_position()) - out.set_raw(statement.raw) + if statement.is_position_set: + out.set_position(*statement.get_position()) + if statement.raw is not None: + out.set_raw(statement.raw) return out @classmethod - def from_statement_iterator_element(cls, values: ty.Tuple[int, int, int, int, str]): + def from_statement_iterator_element( + cls, values: tuple[int, int, int, int, str] + ) -> Self: out = cls() out.set_position(*values[:-1]) out.set_raw(values[-1]) return out @property - def format_position(self): - if self.start_line is None: + def format_position(self) -> str: + if not self.is_position_set: return "N/A" return "%d,%d-%d,%d" % self.get_position() @property - def raw_strip(self): + def raw_strip(self) -> Optional[str]: + if self.raw is None: + return None return self.raw.strip() - def get_position(self): - return self.start_line, self.start_col, self.end_line, self.end_col + def get_position(self) -> tuple[int, int, int, int]: + if self.is_position_set: + return self.start_line, self.start_col, self.end_line, self.end_col + return 0, 0, 0, 0 - def set_position(self, start_line, start_col, end_line, end_col): + def set_position( + self: Self, start_line: int, start_col: int, end_line: int, end_col: int + ) -> Self: + object.__setattr__(self, "is_position_set", True) object.__setattr__(self, "start_line", start_line) object.__setattr__(self, "start_col", start_col) object.__setattr__(self, "end_line", end_line) object.__setattr__(self, "end_col", end_col) return self - def set_raw(self, raw): + def set_raw(self: Self, raw: str) -> Self: object.__setattr__(self, "raw", raw) return self - def set_simple_position(self, line, col, width): + def set_simple_position(self: Self, line: int, col: int, width: int) -> Self: return self.set_position(line, col, line, col + width) @@ -103,7 +200,7 @@ def set_simple_position(self, line, col, width): class ParsingError(Statement, Exception): """Base class for all parsing exceptions in this package.""" - def __str__(self): + def __str__(self) -> str: return Statement.__str__(self) @@ -111,7 +208,7 @@ def __str__(self): class UnknownStatement(ParsingError): """A string statement could not bee parsed.""" - def __str__(self): + def __str__(self) -> str: return f"Could not parse '{self.raw}' ({self.format_position})" @@ -121,12 +218,12 @@ class UnhandledParsingError(ParsingError): ex: Exception - def __str__(self): + def __str__(self) -> str: return f"Unhandled exception while parsing '{self.raw}' ({self.format_position}): {self.ex}" @dataclass(frozen=True) -class UnexpectedEOF(ParsingError): +class UnexpectedEOS(ParsingError): """End of file was found within an open block.""" @@ -140,7 +237,7 @@ class Hash: algorithm_name: str hexdigest: str - def __eq__(self, other: Hash): + def __eq__(self, other: Any) -> bool: return ( isinstance(other, Hash) and self.algorithm_name != "" @@ -149,22 +246,42 @@ def __eq__(self, other: Hash): ) @classmethod - def from_bytes(cls, algorithm, b: bytes): + def from_bytes( + cls, + algorithm: ty.Callable[ + [ + bytes, + ], + HasherProtocol, + ], + b: bytes, + ) -> Self: hasher = algorithm(b) return cls(hasher.name, hasher.hexdigest()) @classmethod - def from_file_pointer(cls, algorithm, fp: ty.BinaryIO): + def from_file_pointer( + cls, + algorithm: ty.Callable[ + [ + bytes, + ], + HasherProtocol, + ], + fp: ty.BinaryIO, + ) -> Self: return cls.from_bytes(algorithm, fp.read()) @classmethod - def nullhash(cls): + def nullhash(cls) -> Self: return cls("", "") def _yield_types( - obj, valid_subclasses=(object,), recurse_origin=(tuple, list, ty.Union) -): + obj: type, + valid_subclasses: tuple[type, ...] = (object,), + recurse_origin: tuple[Any, ...] = (tuple, list, Union), +) -> ty.Generator[type, None, None]: """Recursively transverse type annotation if the origin is any of the types in `recurse_origin` and yield those type which are subclasses of `valid_subclasses`. @@ -190,25 +307,11 @@ def myprop(self): """ - def __init__(self, fget): + def __init__(self, fget): # type: ignore self.fget = fget - def __get__(self, owner_self, owner_cls): - return self.fget(owner_cls) - - -def is_relative_to(self, *other): - """Return True if the path is relative to another path or False. - - In Python 3.9+ can be replaced by - - path.is_relative_to(other) - """ - try: - self.relative_to(*other) - return True - except ValueError: - return False + def __get__(self, owner_self, owner_cls): # type: ignore + return self.fget(owner_cls) # type: ignore class DelimiterInclude(enum.IntEnum): @@ -259,7 +362,7 @@ class DelimiterAction(enum.IntEnum): @functools.lru_cache -def _build_delimiter_pattern(delimiters: ty.Tuple[str, ...]) -> re.Pattern: +def _build_delimiter_pattern(delimiters: tuple[str, ...]) -> re.Pattern[str]: """Compile a tuple of delimiters into a regex expression with a capture group around the delimiter. """ @@ -270,13 +373,13 @@ def _build_delimiter_pattern(delimiters: ty.Tuple[str, ...]) -> re.Pattern: # Iterators ############ -DelimiterDictT = ty.Dict[str, ty.Tuple[DelimiterInclude, DelimiterAction]] +DelimiterDictT = dict[str, tuple[DelimiterInclude, DelimiterAction]] class Spliter: """Content iterator splitting according to given delimiters. - The pattern can be changed dynamically sending a new pattern to the generator, + The pattern can be changed dynamically sending a new pattern to the ty.Generator, see DelimiterInclude and DelimiterAction for more information. The current scanning position can be changed at any time. @@ -284,7 +387,7 @@ class Spliter: Parameters ---------- content : str - delimiters : ty.Dict[str, ty.Tuple[DelimiterInclude, DelimiterAction]] + delimiters : dict[str, tuple[DelimiterInclude, DelimiterAction]] Yields ------ @@ -300,26 +403,26 @@ class Spliter: part of the text between delimiters. """ - _pattern: ty.Optional[re.Pattern] + _pattern: Optional[re.Pattern[str]] _delimiters: DelimiterDictT - __stop_searching_in_line = False + __stop_searching_in_line: bool = False - __pending = "" - __first_line_col = None + __pending: str = "" + __first_line_col: Optional[tuple[int, int]] = None - __lines = () - __lineno = 0 - __colno = 0 + __lines: list[str] + __lineno: int = 0 + __colno: int = 0 def __init__(self, content: str, delimiters: DelimiterDictT): self.set_delimiters(delimiters) self.__lines = content.splitlines(keepends=True) - def set_position(self, lineno: int, colno: int): + def set_position(self, lineno: int, colno: int) -> None: self.__lineno, self.__colno = lineno, colno - def set_delimiters(self, delimiters: DelimiterDictT): + def set_delimiters(self, delimiters: DelimiterDictT) -> None: for k, v in delimiters.items(): if v == (DelimiterInclude.DO_NOT_SPLIT, DelimiterAction.STOP_PARSING): raise ValueError( @@ -334,10 +437,10 @@ def set_delimiters(self, delimiters: DelimiterDictT): # We add the end of line as delimiters if not present. self._delimiters = {**DO_NOT_SPLIT_EOL, **delimiters} - def __iter__(self): + def __iter__(self) -> Spliter: return self - def __next__(self): + def __next__(self) -> tuple[int, int, int, int, str]: if self.__lineno >= len(self.__lines): raise StopIteration @@ -378,23 +481,27 @@ def __next__(self): part = line[self.__colno : end_col] - include, action = self._delimiters.get( - dlm, (DelimiterInclude.SPLIT, DelimiterAction.STOP_PARSING) - ) + if dlm is None: + include, action = DelimiterInclude.SPLIT, DelimiterAction.STOP_PARSING + else: + include, action = self._delimiters[dlm] if include == DelimiterInclude.SPLIT: next_pending = "" - elif include == DelimiterInclude.SPLIT_AFTER: - end_col += len(dlm) - part = part + dlm - next_pending = "" - elif include == DelimiterInclude.SPLIT_BEFORE: - next_pending = dlm - elif include == DelimiterInclude.DO_NOT_SPLIT: - self.__pending += line[self.__colno : end_col] + dlm - next_pending = "" else: - raise ValueError(f"Unknown action {include}.") + # When dlm is None, DelimiterInclude.SPLIT + assert isinstance(dlm, str) + if include == DelimiterInclude.SPLIT_AFTER: + end_col += len(dlm) + part = part + dlm + next_pending = "" + elif include == DelimiterInclude.SPLIT_BEFORE: + next_pending = dlm + elif include == DelimiterInclude.DO_NOT_SPLIT: + self.__pending += line[self.__colno : end_col] + dlm + next_pending = "" + else: + raise ValueError(f"Unknown action {include}.") if action == DelimiterAction.STOP_PARSING: # this will raise a StopIteration in the next call. @@ -439,13 +546,13 @@ def __next__(self): class StatementIterator: """Content peekable iterator splitting according to given delimiters. - The pattern can be changed dynamically sending a new pattern to the generator, + The pattern can be changed dynamically sending a new pattern to the ty.Generator, see DelimiterInclude and DelimiterAction for more information. Parameters ---------- content : str - delimiters : dict[str, ty.Tuple[DelimiterInclude, DelimiterAction]] + delimiters : dict[str, tuple[DelimiterInclude, DelimiterAction]] Yields ------ @@ -464,7 +571,7 @@ def __init__( def __iter__(self): return self - def set_delimiters(self, delimiters: DelimiterDictT): + def set_delimiters(self, delimiters: DelimiterDictT) -> None: self._spliter.set_delimiters(delimiters) if self._cache: value = self.peek() @@ -485,7 +592,7 @@ def _get_next_strip(self) -> Statement: end_col -= lo - len(part) return Statement.from_statement_iterator_element( - (start_line + 1, start_col, end_line + 1, end_col, part) + (start_line + 1, start_col, end_line + 1, end_col, part) # type: ignore ) def _get_next(self) -> Statement: @@ -497,10 +604,10 @@ def _get_next(self) -> Statement: start_line, start_col, end_line, end_col, part = next(self._spliter) return Statement.from_statement_iterator_element( - (start_line + 1, start_col, end_line + 1, end_col, part) + (start_line + 1, start_col, end_line + 1, end_col, part) # type: ignore ) - def peek(self, default=_SENTINEL) -> Statement: + def peek(self, default: Any = _SENTINEL) -> Statement: """Return the item that will be next returned from ``next()``. Return ``default`` if there are no items left. If ``default`` is not @@ -519,8 +626,7 @@ def peek(self, default=_SENTINEL) -> Statement: def __next__(self) -> Statement: if self._cache: return self._cache.popleft() - else: - return self._get_next() + return self._get_next() ########### @@ -528,15 +634,41 @@ def __next__(self) -> Statement: ########### # Configuration type +T = ty.TypeVar("T") CT = ty.TypeVar("CT") -PST = ty.TypeVar("PST", bound="ParsedStatement") -LineColStr = Tuple[int, int, str] -FromString = ty.Union[None, PST, ParsingError] -Consume = ty.Union[PST, ParsingError] -NullableConsume = ty.Union[None, PST, ParsingError] +PST = ty.TypeVar("PST", bound="ParsedStatement[Any]") +LineColStr: TypeAlias = tuple[int, int, str] + +ParsedResult: TypeAlias = Union[T, ParsingError] +NullableParsedResult: TypeAlias = Union[T, ParsingError, None] + + +class ConsumeProtocol(ty.Protocol): + @property + def is_position_set(self) -> bool: + ... + + @property + def start_line(self) -> int: + ... + + @property + def start_col(self) -> int: + ... + + @property + def end_line(self) -> int: + ... -Single = ty.Union[PST, ParsingError] -Multi = ty.Tuple[ty.Union[PST, ParsingError], ...] + @property + def end_col(self) -> int: + ... + + @classmethod + def consume( + cls, statement_iterator: StatementIterator, config: Any + ) -> NullableParsedResult[Self]: + ... @dataclass(frozen=True) @@ -555,7 +687,7 @@ class ParsedStatement(ty.Generic[CT], Statement): """ @classmethod - def from_string(cls: Type[PST], s: str) -> FromString[PST]: + def from_string(cls, s: str) -> NullableParsedResult[Self]: """Parse a string into a ParsedStatement. Return files and their meaning: @@ -570,7 +702,7 @@ def from_string(cls: Type[PST], s: str) -> FromString[PST]: ) @classmethod - def from_string_and_config(cls: Type[PST], s: str, config: CT) -> FromString[PST]: + def from_string_and_config(cls, s: str, config: CT) -> NullableParsedResult[Self]: """Parse a string into a ParsedStatement. Return files and their meaning: @@ -583,10 +715,14 @@ def from_string_and_config(cls: Type[PST], s: str, config: CT) -> FromString[PST @classmethod def from_statement_and_config( - cls: Type[PST], statement: Statement, config: CT - ) -> FromString[PST]: + cls, statement: Statement, config: CT + ) -> NullableParsedResult[Self]: + raw = statement.raw + if raw is None: + return None + try: - out = cls.from_string_and_config(statement.raw, config) + out = cls.from_string_and_config(raw, config) except Exception as ex: out = UnhandledParsingError(ex) @@ -594,13 +730,13 @@ def from_statement_and_config( return None out.set_position(*statement.get_position()) - out.set_raw(statement.raw) + out.set_raw(raw) return out @classmethod def consume( - cls: Type[PST], statement_iterator: StatementIterator, config: CT - ) -> NullableConsume[PST]: + cls, statement_iterator: StatementIterator, config: CT + ) -> NullableParsedResult[Self]: """Peek into the iterator and try to parse. Return files and their meaning: @@ -617,64 +753,61 @@ def consume( return parsed_statement -OPST = ty.TypeVar("OPST", bound="ParsedStatement") -IPST = ty.TypeVar("IPST", bound="ParsedStatement") -CPST = ty.TypeVar("CPST", bound="ParsedStatement") -BT = ty.TypeVar("BT", bound="Block") -RBT = ty.TypeVar("RBT", bound="RootBlock") +OPST = ty.TypeVar("OPST", bound="ParsedStatement[Any]") +BPST = ty.TypeVar( + "BPST", bound="Union[ParsedStatement[Any], Block[Any, Any, Any, Any]]" +) +CPST = ty.TypeVar("CPST", bound="ParsedStatement[Any]") +RBT = ty.TypeVar("RBT", bound="RootBlock[Any, Any]") @dataclass(frozen=True) -class Block(ty.Generic[OPST, IPST, CPST, CT]): +class Block(ty.Generic[OPST, BPST, CPST, CT], GenericInfo): """A sequence of statements with an opening, body and closing.""" - opening: Consume[OPST] - body: Tuple[Consume[IPST], ...] - closing: Consume[CPST] + opening: ParsedResult[OPST] + body: tuple[ParsedResult[BPST], ...] + closing: Union[ParsedResult[CPST], EOS[CT]] - delimiters = {} + delimiters: DelimiterDictT = dataclasses.field(default_factory=dict, init=False) + + def is_closed(self) -> bool: + return not isinstance(self.closing, EOS) @property - def start_line(self): + def is_position_set(self) -> bool: + return self.opening.is_position_set + + @property + def start_line(self) -> int: return self.opening.start_line @property - def start_col(self): + def start_col(self) -> int: return self.opening.start_col @property - def end_line(self): + def end_line(self) -> int: return self.closing.end_line @property - def end_col(self): + def end_col(self) -> int: return self.closing.end_col - def get_position(self): + def get_position(self) -> tuple[int, int, int, int]: return self.start_line, self.start_col, self.end_line, self.end_col @property - def format_position(self): - if self.start_line is None: + def format_position(self) -> str: + if not self.is_position_set: return "N/A" return "%d,%d-%d,%d" % self.get_position() - @classmethod - def subclass_with(cls, *, opening=None, body=None, closing=None): - @dataclass(frozen=True) - class CustomBlock(Block): - pass - - if opening: - CustomBlock.__annotations__["opening"] = Single[ty.Union[opening]] - if body: - CustomBlock.__annotations__["body"] = Multi[ty.Union[body]] - if closing: - CustomBlock.__annotations__["closing"] = Single[ty.Union[closing]] - - return CustomBlock - - def __iter__(self) -> Iterator[Statement]: + def __iter__( + self, + ) -> ty.Generator[ + ParsedResult[Union[OPST, BPST, Union[CPST, EOS[CT]]]], None, None + ]: yield self.opening for el in self.body: if isinstance(el, Block): @@ -683,7 +816,10 @@ def __iter__(self) -> Iterator[Statement]: yield el yield self.closing - def iter_blocks(self) -> Iterator[ty.Union[Block, Statement]]: + def iter_blocks( + self, + ) -> ty.Generator[ParsedResult[Union[OPST, BPST, CPST]], None, None]: + # raise RuntimeError("Is this used?") yield self.opening yield from self.body yield self.closing @@ -694,12 +830,14 @@ def iter_blocks(self) -> Iterator[ty.Union[Block, Statement]]: _ElementT = ty.TypeVar("_ElementT", bound=Statement) - def filter_by(self, *klass: Type[_ElementT]) -> Iterator[_ElementT]: + def filter_by( + self, klass1: type[_ElementT], *klass: type[_ElementT] + ) -> ty.Generator[_ElementT, None, None]: """Yield elements of a given class or classes.""" - yield from (el for el in self if isinstance(el, klass)) # noqa Bug in pycharm. + yield from (el for el in self if isinstance(el, (klass1,) + klass)) # type: ignore[misc] @cached_property - def errors(self) -> ty.Tuple[ParsingError, ...]: + def errors(self) -> tuple[ParsingError, ...]: """Tuple of errors found.""" return tuple(self.filter_by(ParsingError)) @@ -712,37 +850,46 @@ def has_errors(self) -> bool: # Statement classes #################### - @classproperty - def opening_classes(cls) -> Iterator[Type[OPST]]: + @classmethod + def opening_classes(cls) -> ty.Generator[type[OPST], None, None]: """Classes representing any of the parsed statement that can open this block.""" - opening = ty.get_type_hints(cls)["opening"] - yield from _yield_types(opening, ParsedStatement) + try: + opening = cls.specialization()[OPST] # type: ignore[misc] + except KeyError: + opening: type = ty.get_type_hints(cls)["opening"] # type: ignore[no-redef] + yield from _yield_types(opening, ParsedStatement) # type: ignore - @classproperty - def body_classes(cls) -> Iterator[Type[IPST]]: + @classmethod + def body_classes(cls) -> ty.Generator[type[BPST], None, None]: """Classes representing any of the parsed statement that can be in the body.""" - body = ty.get_type_hints(cls)["body"] - yield from _yield_types(body, (ParsedStatement, Block)) + try: + body = cls.specialization()[BPST] # type: ignore[misc] + except KeyError: + body: type = ty.get_type_hints(cls)["body"] # type: ignore[no-redef] + yield from _yield_types(body, (ParsedStatement, Block)) # type: ignore - @classproperty - def closing_classes(cls) -> Iterator[Type[CPST]]: + @classmethod + def closing_classes(cls) -> ty.Generator[type[CPST], None, None]: """Classes representing any of the parsed statement that can close this block.""" - closing = ty.get_type_hints(cls)["closing"] - yield from _yield_types(closing, ParsedStatement) + try: + closing = cls.specialization()[CPST] # type: ignore[misc] + except KeyError: + closing: type = ty.get_type_hints(cls)["closing"] # type: ignore[no-redef] + yield from _yield_types(closing, ParsedStatement) # type: ignore ########## - # Consume + # ParsedResult ########## @classmethod def consume_opening( - cls: Type[BT], statement_iterator: StatementIterator, config: CT - ) -> NullableConsume[OPST]: + cls, statement_iterator: StatementIterator, config: CT + ) -> NullableParsedResult[OPST]: """Peek into the iterator and try to parse with any of the opening classes. See `ParsedStatement.consume` for more details. """ - for c in cls.opening_classes: + for c in cls.opening_classes(): el = c.consume(statement_iterator, config) if el is not None: return el @@ -751,27 +898,27 @@ def consume_opening( @classmethod def consume_body( cls, statement_iterator: StatementIterator, config: CT - ) -> Consume[IPST]: + ) -> ParsedResult[BPST]: """Peek into the iterator and try to parse with any of the body classes. If the statement cannot be parsed, a UnknownStatement is returned. """ - for c in cls.body_classes: + for c in cls.body_classes(): el = c.consume(statement_iterator, config) if el is not None: return el - el = next(statement_iterator) - return UnknownStatement.from_statement(el) + unkel = next(statement_iterator) + return UnknownStatement.from_statement(unkel) @classmethod def consume_closing( - cls: Type[BT], statement_iterator: StatementIterator, config: CT - ) -> NullableConsume[CPST]: + cls, statement_iterator: StatementIterator, config: CT + ) -> NullableParsedResult[CPST]: """Peek into the iterator and try to parse with any of the opening classes. See `ParsedStatement.consume` for more details. """ - for c in cls.closing_classes: + for c in cls.closing_classes(): el = c.consume(statement_iterator, config) if el is not None: return el @@ -779,10 +926,10 @@ def consume_closing( @classmethod def consume_body_closing( - cls: Type[BT], opening: OPST, statement_iterator: StatementIterator, config: CT - ) -> BT: - body = [] - closing = None + cls, opening: OPST, statement_iterator: StatementIterator, config: CT + ) -> Self: + body: list[ParsedResult[BPST]] = [] + closing: ty.Union[CPST, ParsingError, None] = None last_line = opening.end_line while closing is None: try: @@ -793,15 +940,16 @@ def consume_body_closing( body.append(el) last_line = el.end_line except StopIteration: - closing = cls.on_stop_iteration(config) - closing.set_position(last_line + 1, 0, last_line + 1, 0) + unexpected_end = cls.on_stop_iteration(config) + unexpected_end.set_position(last_line + 1, 0, last_line + 1, 0) + return cls(opening, tuple(body), unexpected_end) return cls(opening, tuple(body), closing) @classmethod def consume( - cls: Type[BT], statement_iterator: StatementIterator, config: CT - ) -> Optional[BT]: + cls, statement_iterator: StatementIterator, config: CT + ) -> Union[Self, None]: """Try consume the block. Possible outcomes: @@ -812,22 +960,25 @@ def consume( if opening is None: return None + if isinstance(opening, ParsingError): + return None + return cls.consume_body_closing(opening, statement_iterator, config) @classmethod - def on_stop_iteration(cls, config): - return UnexpectedEOF() + def on_stop_iteration(cls, config: CT) -> ParsedResult[EOS[CT]]: + return UnexpectedEOS() @dataclass(frozen=True) -class BOS(ParsedStatement[CT]): +class BOS(ty.Generic[CT], ParsedStatement[CT]): """Beginning of source.""" # Hasher algorithm name and hexdigest content_hash: Hash @classmethod - def from_string_and_config(cls: Type[PST], s: str, config: CT) -> FromString[PST]: + def from_string_and_config(cls, s: str, config: CT) -> NullableParsedResult[Self]: raise RuntimeError("BOS cannot be constructed from_string_and_config") @property @@ -836,7 +987,7 @@ def location(self) -> SourceLocationT: @dataclass(frozen=True) -class BOF(BOS): +class BOF(ty.Generic[CT], BOS[CT]): """Beginning of file.""" path: pathlib.Path @@ -850,7 +1001,7 @@ def location(self) -> SourceLocationT: @dataclass(frozen=True) -class BOR(BOS): +class BOR(ty.Generic[CT], BOS[CT]): """Beginning of resource.""" package: str @@ -862,43 +1013,29 @@ def location(self) -> SourceLocationT: @dataclass(frozen=True) -class EOS(ParsedStatement[CT]): +class EOS(ty.Generic[CT], ParsedStatement[CT]): """End of sequence.""" @classmethod - def from_string_and_config(cls: Type[PST], s: str, config: CT) -> FromString[PST]: + def from_string_and_config( + cls: type[PST], s: str, config: CT + ) -> NullableParsedResult[PST]: return cls() -class RootBlock(ty.Generic[IPST, CT], Block[BOS, IPST, EOS, CT]): +class RootBlock(ty.Generic[BPST, CT], Block[BOS[CT], BPST, EOS[CT], CT]): """A sequence of statement flanked by the beginning and ending of stream.""" - opening: Single[BOS] - closing: Single[EOS] - - @classmethod - def subclass_with(cls, *, body=None): - @dataclass(frozen=True) - class CustomRootBlock(RootBlock): - pass - - if body: - CustomRootBlock.__annotations__["body"] = Multi[ty.Union[body]] - - return CustomRootBlock - @classmethod def consume_opening( - cls: Type[RBT], statement_iterator: StatementIterator, config: CT - ) -> NullableConsume[BOS]: + cls, statement_iterator: StatementIterator, config: CT + ) -> NullableParsedResult[BOS[CT]]: raise RuntimeError( "Implementation error, 'RootBlock.consume_opening' should never be called" ) @classmethod - def consume( - cls: Type[RBT], statement_iterator: StatementIterator, config: CT - ) -> RBT: + def consume(cls, statement_iterator: StatementIterator, config: CT) -> Self: block = super().consume(statement_iterator, config) if block is None: raise RuntimeError( @@ -908,41 +1045,42 @@ def consume( @classmethod def consume_closing( - cls: Type[RBT], statement_iterator: StatementIterator, config: CT - ) -> NullableConsume[EOS]: + cls, statement_iterator: StatementIterator, config: CT + ) -> NullableParsedResult[EOS[CT]]: return None @classmethod - def on_stop_iteration(cls, config): - return EOS() + def on_stop_iteration(cls, config: CT) -> ParsedResult[EOS[CT]]: + return EOS[CT]() ################# # Source parsing ################# -ResourceT = ty.Tuple[str, str] # package name, resource name -StrictLocationT = ty.Union[pathlib.Path, ResourceT] -SourceLocationT = ty.Union[str, StrictLocationT] +ResourceT: TypeAlias = tuple[str, str] # package name, resource name +StrictLocationT: TypeAlias = Union[pathlib.Path, ResourceT] +SourceLocationT: TypeAlias = Union[str, StrictLocationT] @dataclass(frozen=True) class ParsedSource(ty.Generic[RBT, CT]): - parsed_source: RBT # Parser configuration. config: CT @property - def location(self) -> StrictLocationT: + def location(self) -> SourceLocationT: + if isinstance(self.parsed_source.opening, ParsingError): + raise self.parsed_source.opening return self.parsed_source.opening.location @cached_property def has_errors(self) -> bool: return self.parsed_source.has_errors - def errors(self): + def errors(self) -> ty.Generator[ParsingError, None, None]: yield from self.parsed_source.errors @@ -956,22 +1094,19 @@ class CannotParseResourceAsFile(Exception): resource_name: str -class Parser(ty.Generic[RBT, CT]): +class Parser(ty.Generic[RBT, CT], GenericInfo): """Parser class.""" #: class to iterate through statements in a source unit. - _statement_iterator_class: Type[StatementIterator] = StatementIterator + _statement_iterator_class: type[StatementIterator] = StatementIterator #: Delimiters. _delimiters: DelimiterDictT = SPLIT_EOL _strip_spaces: bool = True - #: root block class containing statements and blocks can be parsed. - _root_block_class: Type[RBT] - #: source file text encoding. - _encoding = "utf-8" + _encoding: str = "utf-8" #: configuration passed to from_string functions. _config: CT @@ -980,12 +1115,25 @@ class Parser(ty.Generic[RBT, CT]): _prefer_resource_as_file: bool #: parser algorithm to us. Must be a callable member of hashlib - _hasher = hashlib.blake2b - - def __init__(self, config: CT, prefer_resource_as_file=True): + _hasher: ty.Callable[ + [ + bytes, + ], + HasherProtocol, + ] = hashlib.blake2b + + def __init__(self, config: CT, prefer_resource_as_file: bool = True): self._config = config self._prefer_resource_as_file = prefer_resource_as_file + @classmethod + def root_boot_class(cls) -> type[RBT]: + """Class representing the root block class.""" + try: + return cls.specialization()[RBT] # type: ignore[misc] + except KeyError: + return ty.get_type_hints(cls)["root_boot_class"] # type: ignore[no-redef] + def parse(self, source_location: SourceLocationT) -> ParsedSource[RBT, CT]: """Parse a file into a ParsedSourceFile or ParsedResource. @@ -1016,15 +1164,17 @@ def parse(self, source_location: SourceLocationT) -> ParsedSource[RBT, CT]: "for a resource." ) - def parse_bytes(self, b: bytes, bos: BOS = None) -> ParsedSource[RBT, CT]: + def parse_bytes( + self, b: bytes, bos: Optional[BOS[CT]] = None + ) -> ParsedSource[RBT, CT]: if bos is None: - bos = BOS(Hash.from_bytes(self._hasher, b)).set_simple_position(0, 0, 0) + bos = BOS[CT](Hash.from_bytes(self._hasher, b)).set_simple_position(0, 0, 0) sic = self._statement_iterator_class( b.decode(self._encoding), self._delimiters, self._strip_spaces ) - parsed = self._root_block_class.consume_body_closing(bos, sic, self._config) + parsed = self.root_boot_class().consume_body_closing(bos, sic, self._config) return ParsedSource( parsed, @@ -1042,7 +1192,7 @@ def parse_file(self, path: pathlib.Path) -> ParsedSource[RBT, CT]: with path.open(mode="rb") as fi: content = fi.read() - bos = BOF( + bos = BOF[CT]( Hash.from_bytes(self._hasher, content), path, path.stat().st_mtime ).set_simple_position(0, 0, 0) return self.parse_bytes(content, bos) @@ -1059,15 +1209,8 @@ def parse_resource_from_file( resource_name name of the resource """ - if sys.version_info < (3, 9): - # Remove when Python 3.8 is dropped - with resources.path(package, resource_name) as p: - path = p.resolve() - else: - with resources.as_file( - resources.files(package).joinpath(resource_name) - ) as p: - path = p.resolve() + with resources.as_file(resources.files(package).joinpath(resource_name)) as p: + path = p.resolve() if path.exists(): return self.parse_file(path) @@ -1084,15 +1227,10 @@ def parse_resource(self, package: str, resource_name: str) -> ParsedSource[RBT, resource_name name of the resource """ - if sys.version_info < (3, 9): - # Remove when Python 3.8 is dropped - with resources.open_binary(package, resource_name) as fi: - content = fi.read() - else: - with resources.files(package).joinpath(resource_name).open("rb") as fi: - content = fi.read() + with resources.files(package).joinpath(resource_name).open("rb") as fi: + content = fi.read() - bos = BOR( + bos = BOR[CT]( Hash.from_bytes(self._hasher, content), package, resource_name ).set_simple_position(0, 0, 0) @@ -1104,7 +1242,7 @@ def parse_resource(self, package: str, resource_name: str) -> ParsedSource[RBT, ########## -class IncludeStatement(ParsedStatement): +class IncludeStatement(ty.Generic[CT], ParsedStatement[CT]): """ "Include statements allow to merge files.""" @property @@ -1115,10 +1253,11 @@ def target(self) -> str: class ParsedProject( - ty.Dict[ - ty.Optional[ty.Tuple[StrictLocationT, str]], - ParsedSource, - ] + ty.Generic[RBT, CT], + dict[ + Optional[tuple[StrictLocationT, str]], + ParsedSource[RBT, CT], + ], ): """Collection of files, independent or connected via IncludeStatement. @@ -1132,11 +1271,16 @@ class ParsedProject( def has_errors(self) -> bool: return any(el.has_errors for el in self.values()) - def errors(self): + def errors(self) -> ty.Generator[ParsingError, None, None]: for el in self.values(): yield from el.errors() - def _iter_statements(self, items, seen, include_only_once): + def _iter_statements( + self, + items: ty.Iterable[tuple[Any, Any]], + seen: set[Any], + include_only_once: bool, + ) -> ty.Generator[ParsedStatement[CT], None, None]: """Iter all definitions in the order they appear, going into the included files. """ @@ -1153,7 +1297,9 @@ def _iter_statements(self, items, seen, include_only_once): else: yield parsed_statement - def iter_statements(self, include_only_once=True): + def iter_statements( + self, include_only_once: bool = True + ) -> ty.Generator[ParsedStatement[CT], None, None]: """Iter all definitions in the order they appear, going into the included files. @@ -1164,7 +1310,12 @@ def iter_statements(self, include_only_once=True): """ yield from self._iter_statements([(None, self[None])], set(), include_only_once) - def _iter_blocks(self, items, seen, include_only_once): + def _iter_blocks( + self, + items: ty.Iterable[tuple[Any, Any]], + seen: set[Any], + include_only_once: bool, + ) -> ty.Generator[ParsedStatement[CT], None, None]: """Iter all definitions in the order they appear, going into the included files. """ @@ -1181,7 +1332,9 @@ def _iter_blocks(self, items, seen, include_only_once): else: yield parsed_statement - def iter_blocks(self, include_only_once=True): + def iter_blocks( + self, include_only_once: bool = True + ) -> ty.Generator[ParsedStatement[CT], None, None]: """Iter all definitions in the order they appear, going into the included files. @@ -1211,7 +1364,7 @@ def default_locator(source_location: StrictLocationT, target: str) -> StrictLoca ) tmp = (current_path / target_path).resolve() - if not is_relative_to(tmp, current_path): + if not tmp.is_relative_to(current_path): raise ValueError( f"Cannot refer to locations above the current location ({source_location}, {target})" ) @@ -1229,27 +1382,90 @@ def default_locator(source_location: StrictLocationT, target: str) -> StrictLoca ) -DefinitionT = ty.Union[ty.Type[Block], ty.Type[ParsedStatement]] +@no_type_check +def _build_root_block_class_parsed_statement( + spec: type[ParsedStatement[CT]], config: type[CT] +) -> type[RootBlock[ParsedStatement[CT], CT]]: + """Build root block class from a single ParsedStatement.""" + + @dataclass(frozen=True) + class CustomRootBlockA(RootBlock[spec, config]): # type: ignore + pass + + return CustomRootBlockA + + +@no_type_check +def _build_root_block_class_block( + spec: type[Block[OPST, BPST, CPST, CT]], + config: type[CT], +) -> type[RootBlock[Block[OPST, BPST, CPST, CT], CT]]: + """Build root block class from a single ParsedStatement.""" -SpecT = ty.Union[ - ty.Type[Parser], - DefinitionT, - ty.Iterable[DefinitionT], - ty.Type[RootBlock], -] + @dataclass(frozen=True) + class CustomRootBlockA(RootBlock[spec, config]): # type: ignore + pass + return CustomRootBlockA -def build_parser_class(spec: SpecT, *, strip_spaces: bool = True, delimiters=None): + +@no_type_check +def _build_root_block_class_parsed_statement_it( + spec: tuple[type[Union[ParsedStatement[CT], Block[OPST, BPST, CPST, CT]]]], + config: type[CT], +) -> type[RootBlock[ParsedStatement[CT], CT]]: + """Build root block class from iterable ParsedStatement.""" + + @dataclass(frozen=True) + class CustomRootBlockA(RootBlock[Union[spec], config]): # type: ignore + pass + + return CustomRootBlockA + + +@no_type_check +def _build_parser_class_root_block( + spec: type[RootBlock[BPST, CT]], + *, + strip_spaces: bool = True, + delimiters: Optional[DelimiterDictT] = None, +) -> type[Parser[RootBlock[BPST, CT], CT]]: + class CustomParser(Parser[spec, spec.specialization()[CT]]): # type: ignore + _delimiters: DelimiterDictT = delimiters or SPLIT_EOL + _strip_spaces: bool = strip_spaces + + return CustomParser + + +@no_type_check +def build_parser_class( + spec: Union[ + type[ + Union[ + Parser[RBT, CT], + RootBlock[BPST, CT], + Block[OPST, BPST, CPST, CT], + ParsedStatement[CT], + ] + ], + ty.Iterable[type[ParsedStatement[CT]]], + ], + config: CT = None, + strip_spaces: bool = True, + delimiters: Optional[DelimiterDictT] = None, +) -> type[ + Union[ + Parser[RBT, CT], + Parser[RootBlock[BPST, CT], CT], + Parser[RootBlock[Block[OPST, BPST, CPST, CT], CT], CT], + ] +]: """Build a custom parser class. Parameters ---------- spec - specification of the content to parse. Can be one of the following things: - - Parser class. - - Block or ParsedStatement derived class. - - Iterable of Block or ParsedStatement derived class. - - RootBlock derived class. + RootBlock derived class. strip_spaces : bool if True, spaces will be stripped for each statement before calling ``from_string_and_config``. @@ -1267,65 +1483,71 @@ def build_parser_class(spec: SpecT, *, strip_spaces: bool = True, delimiters=Non encountering this delimiter. """ - if delimiters is None: - delimiters = SPLIT_EOL - - if isinstance(spec, type) and issubclass(spec, Parser): - CustomParser = spec - else: - if isinstance(spec, (tuple, list)): - - for el in spec: - if not issubclass(el, (Block, ParsedStatement)): - raise TypeError( - "Elements in root_block_class must be of type Block or ParsedStatement, " - f"not {el}" - ) - - @dataclass(frozen=True) - class CustomRootBlock(RootBlock): - pass - - CustomRootBlock.__annotations__["body"] = Multi[ty.Union[spec]] + if isinstance(spec, type): + if issubclass(spec, Parser): + CustomParser = spec - elif isinstance(spec, type) and issubclass(spec, RootBlock): - - CustomRootBlock = spec - - elif isinstance(spec, type) and issubclass(spec, (Block, ParsedStatement)): + elif issubclass(spec, RootBlock): + CustomParser = _build_parser_class_root_block( + spec, strip_spaces=strip_spaces, delimiters=delimiters + ) - @dataclass(frozen=True) - class CustomRootBlock(RootBlock): - pass + elif issubclass(spec, Block): + CustomRootBlock = _build_root_block_class_block(spec, config.__class__) + CustomParser = _build_parser_class_root_block( + CustomRootBlock, strip_spaces=strip_spaces, delimiters=delimiters + ) - CustomRootBlock.__annotations__["body"] = Multi[spec] + elif issubclass(spec, ParsedStatement): + CustomRootBlock = _build_root_block_class_parsed_statement( + spec, config.__class__ + ) + CustomParser = _build_parser_class_root_block( + CustomRootBlock, strip_spaces=strip_spaces, delimiters=delimiters + ) else: raise TypeError( - "`spec` must be of type RootBlock or tuple of type Block or ParsedStatement, " + "`spec` must be of type Parser, Block, RootBlock or tuple of type Block or ParsedStatement, " f"not {type(spec)}" ) - class CustomParser(Parser): + elif isinstance(spec, (tuple, list)): + CustomRootBlock = _build_root_block_class_parsed_statement_it( + spec, config.__class__ + ) + CustomParser = _build_parser_class_root_block( + CustomRootBlock, strip_spaces=strip_spaces, delimiters=delimiters + ) - _delimiters = delimiters - _root_block_class = CustomRootBlock - _strip_spaces = strip_spaces + else: + raise return CustomParser +@no_type_check def parse( entry_point: SourceLocationT, - spec: SpecT, - config=None, + spec: Union[ + type[ + Union[ + Parser[RBT, CT], + RootBlock[BPST, CT], + Block[OPST, BPST, CPST, CT], + ParsedStatement[CT], + ] + ], + ty.Iterable[type[ParsedStatement[CT]]], + ], + config: CT = None, *, strip_spaces: bool = True, - delimiters=None, - locator: ty.Callable[[StrictLocationT, str], StrictLocationT] = default_locator, + delimiters: Optional[DelimiterDictT] = None, + locator: ty.Callable[[SourceLocationT, str], StrictLocationT] = default_locator, prefer_resource_as_file: bool = True, - **extra_parser_kwargs, -) -> ParsedProject: + **extra_parser_kwargs: Any, +) -> Union[ParsedProject[RBT, CT], ParsedProject[RootBlock[BPST, CT], CT]]: """Parse sources into a ParsedProject dictionary. Parameters @@ -1336,7 +1558,7 @@ def parse( specification of the content to parse. Can be one of the following things: - Parser class. - Block or ParsedStatement derived class. - - Iterable of Block or ParsedStatement derived class. + - ty.Iterable of Block or ParsedStatement derived class. - RootBlock derived class. config a configuration object that will be passed to `from_string_and_config` @@ -1366,17 +1588,14 @@ def parse( encountering this delimiter. """ - CustomParser = build_parser_class( - spec, strip_spaces=strip_spaces, delimiters=delimiters - ) + CustomParser = build_parser_class(spec, config, strip_spaces, delimiters) parser = CustomParser( config, prefer_resource_as_file=prefer_resource_as_file, **extra_parser_kwargs ) pp = ParsedProject() - # : ty.List[Optional[ty.Union[LocatorT, str]], ...] - pending: ty.List[ty.Tuple[StrictLocationT, str]] = [] + pending: list[tuple[SourceLocationT, str]] = [] if isinstance(entry_point, (str, pathlib.Path)): entry_point = pathlib.Path(entry_point) if not entry_point.is_absolute(): @@ -1409,15 +1628,28 @@ def parse( return pp +@no_type_check def parse_bytes( content: bytes, - spec: SpecT, - config=None, + spec: Union[ + type[ + Union[ + Parser[RBT, CT], + RootBlock[BPST, CT], + Block[OPST, BPST, CPST, CT], + ParsedStatement[CT], + ] + ], + ty.Iterable[type[ParsedStatement[CT]]], + ], + config: Optional[CT] = None, *, - strip_spaces: bool = True, - delimiters=None, - **extra_parser_kwargs, -) -> ParsedProject: + strip_spaces: bool, + delimiters: Optional[DelimiterDictT], + **extra_parser_kwargs: Any, +) -> ParsedProject[ + Union[RBT, RootBlock[BPST, CT], RootBlock[ParsedStatement[CT], CT]], CT +]: """Parse sources into a ParsedProject dictionary. Parameters @@ -1428,7 +1660,7 @@ def parse_bytes( specification of the content to parse. Can be one of the following things: - Parser class. - Block or ParsedStatement derived class. - - Iterable of Block or ParsedStatement derived class. + - ty.Iterable of Block or ParsedStatement derived class. - RootBlock derived class. config a configuration object that will be passed to `from_string_and_config` @@ -1440,9 +1672,8 @@ def parse_bytes( Specify how the source file is split into statements (See below). """ - CustomParser = build_parser_class( - spec, strip_spaces=strip_spaces, delimiters=delimiters - ) + CustomParser = build_parser_class(spec, config, strip_spaces, delimiters) + parser = CustomParser(config, prefer_resource_as_file=False, **extra_parser_kwargs) pp = ParsedProject() diff --git a/pint/converters.py b/pint/converters.py index daf25bc88..249cbbf89 100644 --- a/pint/converters.py +++ b/pint/converters.py @@ -13,7 +13,7 @@ from dataclasses import dataclass from dataclasses import fields as dc_fields -from typing import Any, Optional +from typing import Any, Optional, ClassVar from ._typing import Magnitude @@ -24,10 +24,8 @@ class Converter: """Base class for value converters.""" - # list[type[Converter]] - _subclasses = [] - # dict[frozenset[str], type[Converter]] - _param_names_to_subclass = {} + _subclasses: ClassVar[list[type[Converter]]] = [] + _param_names_to_subclass: ClassVar[dict[frozenset[str], type[Converter]]] = {} @property def is_multiplicative(self) -> bool: diff --git a/pint/delegates/__init__.py b/pint/delegates/__init__.py index b2eb9a3ef..e663a10c5 100644 --- a/pint/delegates/__init__.py +++ b/pint/delegates/__init__.py @@ -10,5 +10,6 @@ from . import txt_defparser from .base_defparser import ParserConfig, build_disk_cache_class +from .formatter import Formatter -__all__ = ["txt_defparser", "ParserConfig", "build_disk_cache_class"] +__all__ = ["txt_defparser", "ParserConfig", "build_disk_cache_class", "Formatter"] diff --git a/pint/delegates/formatter/__init__.py b/pint/delegates/formatter/__init__.py new file mode 100644 index 000000000..c30f3657b --- /dev/null +++ b/pint/delegates/formatter/__init__.py @@ -0,0 +1,21 @@ +""" + pint.delegates.formatter + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + Formats quantities and units. + :copyright: 2022 by Pint Authors, see AUTHORS for more details. + :license: BSD, see LICENSE for more details. +""" + + +from .base_formatter import BaseFormatter + + +class Formatter(BaseFormatter): + # TODO: this should derive from all relevant formaters to + # reproduce the current behavior of Pint. + pass + + +__all__ = [ + "Formatter", +] diff --git a/pint/delegates/formatter/base_formatter.py b/pint/delegates/formatter/base_formatter.py new file mode 100644 index 000000000..6f9df55bb --- /dev/null +++ b/pint/delegates/formatter/base_formatter.py @@ -0,0 +1,27 @@ +""" + pint.delegates.formatter.base_formatter + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + Common class and function for all formatters. + :copyright: 2022 by Pint Authors, see AUTHORS for more details. + :license: BSD, see LICENSE for more details. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ...facets.plain import PlainQuantity, PlainUnit, MagnitudeT + + +class BaseFormatter: + def format_quantity( + self, quantity: PlainQuantity[MagnitudeT], spec: str = "" + ) -> str: + # TODO Fill the proper functions + return str(quantity.magnitude) + " " + self.format_unit(quantity.units, spec) + + def format_unit(self, unit: PlainUnit, spec: str = "") -> str: + # TODO Fill the proper functions and discuss + # how to make it that _units is not accessible directly + return " ".join(k if v == 1 else f"{k} ** {v}" for k, v in unit._units.items()) diff --git a/pint/delegates/txt_defparser/defparser.py b/pint/delegates/txt_defparser/defparser.py index a5ccb08ee..e89863d00 100644 --- a/pint/delegates/txt_defparser/defparser.py +++ b/pint/delegates/txt_defparser/defparser.py @@ -139,6 +139,8 @@ def parse_file( _PintParser, cfg or self._default_config, diskcache=self._diskcache, + strip_spaces=True, + delimiters=_PintParser._delimiters, ) def parse_string(self, content: str, cfg: Optional[ParserConfig] = None): @@ -147,4 +149,6 @@ def parse_string(self, content: str, cfg: Optional[ParserConfig] = None): _PintParser, cfg or self._default_config, diskcache=self._diskcache, + strip_spaces=True, + delimiters=_PintParser._delimiters, ) diff --git a/pint/facets/context/registry.py b/pint/facets/context/registry.py index 85682d198..3bfb3fd25 100644 --- a/pint/facets/context/registry.py +++ b/pint/facets/context/registry.py @@ -35,6 +35,7 @@ def __init__(self, registry_cache) -> None: self.root_units = {} self.dimensionality = registry_cache.dimensionality self.parse_unit = registry_cache.parse_unit + self.conversion_factor = {} class GenericContextRegistry( diff --git a/pint/facets/nonmultiplicative/registry.py b/pint/facets/nonmultiplicative/registry.py index 7d783de11..67250ea48 100644 --- a/pint/facets/nonmultiplicative/registry.py +++ b/pint/facets/nonmultiplicative/registry.py @@ -57,7 +57,7 @@ def __init__( # plain units on multiplication and division. self.autoconvert_offset_to_baseunit = autoconvert_offset_to_baseunit - def _parse_units( + def parse_units_as_container( self, input_string: str, as_delta: Optional[bool] = None, @@ -67,7 +67,7 @@ def _parse_units( if as_delta is None: as_delta = self.default_as_delta - return super()._parse_units(input_string, as_delta, case_sensitive) + return super().parse_units_as_container(input_string, as_delta, case_sensitive) def _add_unit(self, definition: UnitDefinition) -> None: super()._add_unit(definition) diff --git a/pint/facets/numpy/numpy_func.py b/pint/facets/numpy/numpy_func.py index 7c31de0c3..57dc5123d 100644 --- a/pint/facets/numpy/numpy_func.py +++ b/pint/facets/numpy/numpy_func.py @@ -741,23 +741,23 @@ def _base_unit_if_needed(a): @implements("trapz", "function") -def _trapz(a, x=None, dx=1.0, **kwargs): - a = _base_unit_if_needed(a) - units = a.units +def _trapz(y, x=None, dx=1.0, **kwargs): + y = _base_unit_if_needed(y) + units = y.units if x is not None: if hasattr(x, "units"): x = _base_unit_if_needed(x) units *= x.units x = x._magnitude - ret = np.trapz(a._magnitude, x, **kwargs) + ret = np.trapz(y._magnitude, x, **kwargs) else: if hasattr(dx, "units"): dx = _base_unit_if_needed(dx) units *= dx.units dx = dx._magnitude - ret = np.trapz(a._magnitude, dx=dx, **kwargs) + ret = np.trapz(y._magnitude, dx=dx, **kwargs) - return a.units._REGISTRY.Quantity(ret, units) + return y.units._REGISTRY.Quantity(ret, units) def implement_mul_func(func): diff --git a/pint/facets/numpy/quantity.py b/pint/facets/numpy/quantity.py index 5257766bc..08d7adf9f 100644 --- a/pint/facets/numpy/quantity.py +++ b/pint/facets/numpy/quantity.py @@ -16,7 +16,7 @@ from ..plain import PlainQuantity, MagnitudeT from ..._typing import Shape -from ...compat import _to_magnitude, np +from ...compat import _to_magnitude, np, HAS_NUMPY from ...errors import DimensionalityError, PintTypeError, UnitStrippedWarning from .numpy_func import ( HANDLED_UFUNCS, @@ -115,11 +115,12 @@ def _numpy_method_wrap(self, func, *args, **kwargs): return value def __array__(self, t=None) -> np.ndarray: - warnings.warn( - "The unit of the quantity is stripped when downcasting to ndarray.", - UnitStrippedWarning, - stacklevel=2, - ) + if HAS_NUMPY and isinstance(self._magnitude, np.ndarray): + warnings.warn( + "The unit of the quantity is stripped when downcasting to ndarray.", + UnitStrippedWarning, + stacklevel=2, + ) return _to_magnitude(self._magnitude, force_ndarray=True) def clip(self, min=None, max=None, out=None, **kwargs): @@ -266,7 +267,7 @@ def __setitem__(self, key, value): isinstance(self._magnitude, np.ma.MaskedArray) and np.ma.is_masked(value) and getattr(value, "size", 0) == 1 - ) or math.isnan(value): + ) or (getattr(value, "ndim", 0) == 0 and math.isnan(value)): self._magnitude[key] = value return except TypeError: diff --git a/pint/facets/plain/qto.py b/pint/facets/plain/qto.py index 9cd8a780a..726523763 100644 --- a/pint/facets/plain/qto.py +++ b/pint/facets/plain/qto.py @@ -100,7 +100,9 @@ def to_compact( """ - if not isinstance(quantity.magnitude, numbers.Number): + if not isinstance(quantity.magnitude, numbers.Number) and not hasattr( + quantity.magnitude, "nominal_value" + ): msg = "to_compact applied to non numerical types " "has an undefined behavior." w = RuntimeWarning(msg) warnings.warn(w, stacklevel=2) @@ -137,6 +139,9 @@ def to_compact( q_base = quantity.to(unit) magnitude = q_base.magnitude + # Support uncertainties + if hasattr(magnitude, "nominal_value"): + magnitude = magnitude.nominal_value units = list(q_base._units.items()) units_numerator = [a for a in units if a[1] > 0] diff --git a/pint/facets/plain/quantity.py b/pint/facets/plain/quantity.py index 4115175cf..2bcd40d9b 100644 --- a/pint/facets/plain/quantity.py +++ b/pint/facets/plain/quantity.py @@ -263,7 +263,7 @@ def __deepcopy__(self, memo) -> PlainQuantity[MagnitudeT]: return ret def __str__(self) -> str: - return str(self.magnitude) + " " + str(self.units) + return self._REGISTRY.formatter.format_quantity(self) def __bytes__(self) -> bytes: return str(self).encode(locale.getpreferredencoding()) diff --git a/pint/facets/plain/registry.py b/pint/facets/plain/registry.py index c9c7d94d2..9e796fed9 100644 --- a/pint/facets/plain/registry.py +++ b/pint/facets/plain/registry.py @@ -33,7 +33,6 @@ from collections import defaultdict from decimal import Decimal from fractions import Fraction -from numbers import Number from token import NAME, NUMBER from tokenize import TokenInfo @@ -132,6 +131,10 @@ def __init__(self) -> None: #: Cache the unit name associated to user input. ('mV' -> 'millivolt') self.parse_unit: dict[str, UnitsContainer] = {} + self.conversion_factor: dict[ + tuple[UnitsContainer, UnitsContainer], Scalar | DimensionalityError + ] = {} + def __eq__(self, other: Any): if not isinstance(other, self.__class__): return False @@ -140,6 +143,7 @@ def __eq__(self, other: Any): "root_units", "dimensionality", "parse_unit", + "conversion_factor", ) return all(getattr(self, attr) == getattr(other, attr) for attr in attrs) @@ -149,14 +153,14 @@ class RegistryMeta(type): instead of asking the developer to do it when subclassing. """ - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any): obj = super().__call__(*args, **kwargs) obj._after_init() return obj # Generic types used to mark types associated to Registries. -QuantityT = TypeVar("QuantityT", bound=PlainQuantity) +QuantityT = TypeVar("QuantityT", bound=PlainQuantity[Any]) UnitT = TypeVar("UnitT", bound=PlainUnit) @@ -251,6 +255,7 @@ def __init__( delegates.ParserConfig(non_int_type), diskcache=self._diskcache ) + self.formatter = delegates.Formatter() self._filename = filename self.force_ndarray = force_ndarray self.force_ndarray_like = force_ndarray_like @@ -739,7 +744,9 @@ def _get_dimensionality_recurse( if reg.reference is not None: self._get_dimensionality_recurse(reg.reference, exp2, accumulator) - def _get_dimensionality_ratio(self, unit1: UnitLike, unit2: UnitLike): + def _get_dimensionality_ratio( + self, unit1: UnitLike, unit2: UnitLike + ) -> Scalar | None: """Get the exponential ratio between two units, i.e. solve unit2 = unit1**x for x. Parameters @@ -773,7 +780,7 @@ def _get_dimensionality_ratio(self, unit1: UnitLike, unit2: UnitLike): def get_root_units( self, input_units: UnitLike, check_nonmult: bool = True - ) -> tuple[Number, UnitT]: + ) -> tuple[Scalar, UnitT]: """Convert unit or dict of units to the root units. If any unit is non multiplicative and check_converter is True, @@ -800,6 +807,43 @@ def get_root_units( return f, self.Unit(units) + def _get_conversion_factor( + self, src: UnitsContainer, dst: UnitsContainer + ) -> Scalar | DimensionalityError: + """Get conversion factor in non-multiplicative units. + + Parameters + ---------- + src + Source units + dst + Target units + + Returns + ------- + Conversion factor or DimensionalityError + """ + cache = self._cache.conversion_factor + try: + return cache[(src, dst)] + except KeyError: + pass + + src_dim = self._get_dimensionality(src) + dst_dim = self._get_dimensionality(dst) + + # If the source and destination dimensionality are different, + # then the conversion cannot be performed. + if src_dim != dst_dim: + return DimensionalityError(src, dst, src_dim, dst_dim) + + # Here src and dst have only multiplicative units left. Thus we can + # convert with a factor. + factor, _ = self._get_root_units(src / dst) + + cache[(src, dst)] = factor + return factor + def _get_root_units( self, input_units: UnitsContainer, check_nonmult: bool = True ) -> tuple[Scalar, UnitsContainer]: @@ -854,7 +898,7 @@ def get_base_units( input_units: Union[UnitsContainer, str], check_nonmult: bool = True, system=None, - ) -> tuple[Number, UnitT]: + ) -> tuple[Scalar, UnitT]: """Convert unit or dict of units to the plain units. If any unit is non multiplicative and check_converter is True, @@ -1014,18 +1058,10 @@ def _convert( """ - if check_dimensionality: - src_dim = self._get_dimensionality(src) - dst_dim = self._get_dimensionality(dst) + factor = self._get_conversion_factor(src, dst) - # If the source and destination dimensionality are different, - # then the conversion cannot be performed. - if src_dim != dst_dim: - raise DimensionalityError(src, dst, src_dim, dst_dim) - - # Here src and dst have only multiplicative units left. Thus we can - # convert with a factor. - factor, _ = self._get_root_units(src / dst) + if isinstance(factor, DimensionalityError): + raise factor # factor is type float and if our magnitude is type Decimal then # must first convert to Decimal before we can '*' the values @@ -1062,17 +1098,19 @@ def parse_unit_name( tuple of tuples (str, str, str) all non-equivalent combinations of (prefix, unit name, suffix) """ + + case_sensitive = ( + self.case_sensitive if case_sensitive is None else case_sensitive + ) return self._dedup_candidates( - self._parse_unit_name(unit_name, case_sensitive=case_sensitive) + self._yield_unit_triplets(unit_name, case_sensitive) ) - def _parse_unit_name( - self, unit_name: str, case_sensitive: Optional[bool] = None + def _yield_unit_triplets( + self, unit_name: str, case_sensitive: bool ) -> Generator[tuple[str, str, str], None, None]: """Helper of parse_unit_name.""" - case_sensitive = ( - self.case_sensitive if case_sensitive is None else case_sensitive - ) + stw = unit_name.startswith edw = unit_name.endswith for suffix, prefix in itertools.product(self._suffixes, self._prefixes): @@ -1097,6 +1135,9 @@ def _parse_unit_name( self._suffixes[suffix], ) + # TODO: keep this for backward compatibility + _parse_unit_name = _yield_unit_triplets + @staticmethod def _dedup_candidates( candidates: Iterable[tuple[str, str, str]] @@ -1145,14 +1186,29 @@ def parse_units( """ - units = self._parse_units(input_string, as_delta, case_sensitive) - return self.Unit(units) + return self.Unit( + self.parse_units_as_container(input_string, as_delta, case_sensitive) + ) - def _parse_units( + def parse_units_as_container( self, input_string: str, - as_delta: bool = True, + as_delta: Optional[bool] = None, case_sensitive: Optional[bool] = None, + ) -> UnitsContainer: + as_delta = ( + as_delta if as_delta is not None else True + ) # TODO This only exists in nonmultiplicative + case_sensitive = ( + case_sensitive if case_sensitive is not None else self.case_sensitive + ) + return self._parse_units_as_container(input_string, as_delta, case_sensitive) + + def _parse_units_as_container( + self, + input_string: str, + as_delta: bool = True, + case_sensitive: bool = True, ) -> UnitsContainer: """Parse a units expression and returns a UnitContainer with the canonical names. diff --git a/pint/facets/plain/unit.py b/pint/facets/plain/unit.py index 4c5c04ac3..227c97b1b 100644 --- a/pint/facets/plain/unit.py +++ b/pint/facets/plain/unit.py @@ -59,7 +59,7 @@ def __deepcopy__(self, memo) -> PlainUnit: return ret def __str__(self) -> str: - return " ".join(k if v == 1 else f"{k} ** {v}" for k, v in self._units.items()) + return self._REGISTRY.formatter.format_unit(self) def __bytes__(self) -> bytes: return str(self).encode(locale.getpreferredencoding()) diff --git a/pint/registry_helpers.py b/pint/registry_helpers.py index 6b2f0e0b6..37c539e35 100644 --- a/pint/registry_helpers.py +++ b/pint/registry_helpers.py @@ -11,7 +11,7 @@ from __future__ import annotations import functools -from inspect import signature +from inspect import signature, Parameter from itertools import zip_longest from typing import TYPE_CHECKING, Callable, TypeVar, Any, Union, Optional from collections.abc import Iterable @@ -119,8 +119,13 @@ def _parse_wrap_args(args, registry=None): "Not all variable referenced in %s are defined using !" % args[ndx] ) - def _converter(ureg, values, strict): - new_values = list(value for value in values) + def _converter(ureg, sig, values, kw, strict): + len_initial_values = len(values) + + # pack kwargs + for i, param_name in enumerate(sig.parameters): + if i >= len_initial_values: + values.append(kw[param_name]) values_by_name = {} @@ -128,13 +133,13 @@ def _converter(ureg, values, strict): for ndx in defs_args_ndx: value = values[ndx] values_by_name[args_as_uc[ndx][0]] = value - new_values[ndx] = getattr(value, "_magnitude", value) + values[ndx] = getattr(value, "_magnitude", value) # second pass: calculate derived values based on named values for ndx in dependent_args_ndx: value = values[ndx] assert _replace_units(args_as_uc[ndx][0], values_by_name) is not None - new_values[ndx] = ureg._convert( + values[ndx] = ureg._convert( getattr(value, "_magnitude", value), getattr(value, "_units", UnitsContainer({})), _replace_units(args_as_uc[ndx][0], values_by_name), @@ -143,7 +148,7 @@ def _converter(ureg, values, strict): # third pass: convert other arguments for ndx in unit_args_ndx: if isinstance(values[ndx], ureg.Quantity): - new_values[ndx] = ureg._convert( + values[ndx] = ureg._convert( values[ndx]._magnitude, values[ndx]._units, args_as_uc[ndx][0] ) else: @@ -151,7 +156,7 @@ def _converter(ureg, values, strict): if isinstance(values[ndx], str): # if the value is a string, we try to parse it tmp_value = ureg.parse_expression(values[ndx]) - new_values[ndx] = ureg._convert( + values[ndx] = ureg._convert( tmp_value._magnitude, tmp_value._units, args_as_uc[ndx][0] ) else: @@ -159,29 +164,35 @@ def _converter(ureg, values, strict): "A wrapped function using strict=True requires " "quantity or a string for all arguments with not None units. " "(error found for {}, {})".format( - args_as_uc[ndx][0], new_values[ndx] + args_as_uc[ndx][0], values[ndx] ) ) - return new_values, values_by_name + # unpack kwargs + for i, param_name in enumerate(sig.parameters): + if i >= len_initial_values: + kw[param_name] = values[i] + + return values[:len_initial_values], kw, values_by_name return _converter -def _apply_defaults(func, args, kwargs): +def _apply_defaults(sig, args, kwargs): """Apply default keyword arguments. Named keywords may have been left blank. This function applies the default values so that every argument is defined. """ - sig = signature(func) - bound_arguments = sig.bind(*args, **kwargs) - for param in sig.parameters.values(): - if param.name not in bound_arguments.arguments: - bound_arguments.arguments[param.name] = param.default - args = [bound_arguments.arguments[key] for key in sig.parameters.keys()] - return args, {} + for i, param in enumerate(sig.parameters.values()): + if ( + i >= len(args) + and param.default != Parameter.empty + and param.name not in kwargs + ): + kwargs[param.name] = param.default + return list(args), kwargs def wraps( @@ -254,7 +265,8 @@ def wraps( ret = _to_units_container(ret, ureg) def decorator(func: Callable[..., Any]) -> Callable[..., Quantity]: - count_params = len(signature(func).parameters) + sig = signature(func) + count_params = len(sig.parameters) if len(args) != count_params: raise TypeError( "%s takes %i parameters, but %i units were passed" @@ -270,13 +282,15 @@ def decorator(func: Callable[..., Any]) -> Callable[..., Quantity]: @functools.wraps(func, assigned=assigned, updated=updated) def wrapper(*values, **kw) -> Quantity: - values, kw = _apply_defaults(func, values, kw) + values, kw = _apply_defaults(sig, values, kw) # In principle, the values are used as is # When then extract the magnitudes when needed. - new_values, values_by_name = converter(ureg, values, strict) + new_values, new_kw, values_by_name = converter( + ureg, sig, values, kw, strict + ) - result = func(*new_values, **kw) + result = func(*new_values, **new_kw) if is_ret_container: out_units = ( @@ -335,7 +349,8 @@ def check( ] def decorator(func): - count_params = len(signature(func).parameters) + sig = signature(func) + count_params = len(sig.parameters) if len(dimensions) != count_params: raise TypeError( "%s takes %i parameters, but %i dimensions were passed" @@ -351,7 +366,11 @@ def decorator(func): @functools.wraps(func, assigned=assigned, updated=updated) def wrapper(*args, **kwargs): - list_args, empty = _apply_defaults(func, args, kwargs) + list_args, kw = _apply_defaults(sig, args, kwargs) + + for i, param_name in enumerate(sig.parameters): + if i >= len(args): + list_args.append(kw[param_name]) for dim, value in zip(dimensions, list_args): if dim is None: diff --git a/pint/testsuite/benchmarks/test_20_quantity.py b/pint/testsuite/benchmarks/test_20_quantity.py index 36c0f92ba..1ec7cbb60 100644 --- a/pint/testsuite/benchmarks/test_20_quantity.py +++ b/pint/testsuite/benchmarks/test_20_quantity.py @@ -53,3 +53,39 @@ def test_op2(benchmark, setup, keys, op): _, data = setup key1, key2 = keys benchmark(op, data[key1], data[key2]) + + +@pytest.mark.parametrize("key", ALL_VALUES_Q) +def test_wrapper(benchmark, setup, key): + ureg, data = setup + value, unit = key.split("_") + + @ureg.wraps(None, (unit,)) + def f(a): + pass + + benchmark(f, data[key]) + + +@pytest.mark.parametrize("key", ALL_VALUES_Q) +def test_wrapper_nonstrict(benchmark, setup, key): + ureg, data = setup + value, unit = key.split("_") + + @ureg.wraps(None, (unit,), strict=False) + def f(a): + pass + + benchmark(f, data[value]) + + +@pytest.mark.parametrize("key", ALL_VALUES_Q) +def test_wrapper_ret(benchmark, setup, key): + ureg, data = setup + value, unit = key.split("_") + + @ureg.wraps(unit, (unit,)) + def f(a): + return a + + benchmark(f, data[key]) diff --git a/pint/testsuite/helpers.py b/pint/testsuite/helpers.py index 191f4c3f5..4121e09eb 100644 --- a/pint/testsuite/helpers.py +++ b/pint/testsuite/helpers.py @@ -36,6 +36,10 @@ _unit_re = re.compile(r"") +def internal(ureg): + return ureg + + class PintOutputChecker(doctest.OutputChecker): def check_output(self, want, got, optionflags): check = super().check_output(want, got, optionflags) diff --git a/pint/testsuite/test_contexts.py b/pint/testsuite/test_contexts.py index ea6525d16..1a5bab237 100644 --- a/pint/testsuite/test_contexts.py +++ b/pint/testsuite/test_contexts.py @@ -17,7 +17,10 @@ from pint.util import UnitsContainer -def add_ctxs(ureg): +from .helpers import internal + + +def add_ctxs(ureg: UnitRegistry): a, b = UnitsContainer({"[length]": 1}), UnitsContainer({"[time]": -1}) d = Context("lc") d.add_transformation(a, b, lambda ureg, x: ureg.speed_of_light / x) @@ -33,7 +36,7 @@ def add_ctxs(ureg): ureg.add_context(d) -def add_arg_ctxs(ureg): +def add_arg_ctxs(ureg: UnitRegistry): a, b = UnitsContainer({"[length]": 1}), UnitsContainer({"[time]": -1}) d = Context("lc") d.add_transformation(a, b, lambda ureg, x, n: ureg.speed_of_light / x / n) @@ -49,7 +52,7 @@ def add_arg_ctxs(ureg): ureg.add_context(d) -def add_argdef_ctxs(ureg): +def add_argdef_ctxs(ureg: UnitRegistry): a, b = UnitsContainer({"[length]": 1}), UnitsContainer({"[time]": -1}) d = Context("lc", defaults=dict(n=1)) assert d.defaults == dict(n=1) @@ -67,7 +70,7 @@ def add_argdef_ctxs(ureg): ureg.add_context(d) -def add_sharedargdef_ctxs(ureg): +def add_sharedargdef_ctxs(ureg: UnitRegistry): a, b = UnitsContainer({"[length]": 1}), UnitsContainer({"[time]": -1}) d = Context("lc", defaults=dict(n=1)) assert d.defaults == dict(n=1) @@ -90,37 +93,37 @@ def test_known_context(self, func_registry): ureg = func_registry add_ctxs(ureg) with ureg.context("lc"): - assert ureg._active_ctx - assert ureg._active_ctx.graph + assert internal(ureg)._active_ctx + assert internal(ureg)._active_ctx.graph - assert not ureg._active_ctx - assert not ureg._active_ctx.graph + assert not internal(ureg)._active_ctx + assert not internal(ureg)._active_ctx.graph with ureg.context("lc", n=1): - assert ureg._active_ctx - assert ureg._active_ctx.graph + assert internal(ureg)._active_ctx + assert internal(ureg)._active_ctx.graph - assert not ureg._active_ctx - assert not ureg._active_ctx.graph + assert not internal(ureg)._active_ctx + assert not internal(ureg)._active_ctx.graph def test_known_context_enable(self, func_registry): ureg = func_registry add_ctxs(ureg) ureg.enable_contexts("lc") - assert ureg._active_ctx - assert ureg._active_ctx.graph + assert internal(ureg)._active_ctx + assert internal(ureg)._active_ctx.graph ureg.disable_contexts(1) - assert not ureg._active_ctx - assert not ureg._active_ctx.graph + assert not internal(ureg)._active_ctx + assert not internal(ureg)._active_ctx.graph ureg.enable_contexts("lc", n=1) - assert ureg._active_ctx - assert ureg._active_ctx.graph + assert internal(ureg)._active_ctx + assert internal(ureg)._active_ctx.graph ureg.disable_contexts(1) - assert not ureg._active_ctx - assert not ureg._active_ctx.graph + assert not internal(ureg)._active_ctx + assert not internal(ureg)._active_ctx.graph def test_graph(self, func_registry): ureg = func_registry @@ -139,27 +142,27 @@ def test_graph(self, func_registry): g.update({l: {t, c}, t: {l}, c: {l}}) with ureg.context("lc"): - assert ureg._active_ctx.graph == g_sp + assert internal(ureg)._active_ctx.graph == g_sp with ureg.context("lc", n=1): - assert ureg._active_ctx.graph == g_sp + assert internal(ureg)._active_ctx.graph == g_sp with ureg.context("ab"): - assert ureg._active_ctx.graph == g_ab + assert internal(ureg)._active_ctx.graph == g_ab with ureg.context("lc"): with ureg.context("ab"): - assert ureg._active_ctx.graph == g + assert internal(ureg)._active_ctx.graph == g with ureg.context("ab"): with ureg.context("lc"): - assert ureg._active_ctx.graph == g + assert internal(ureg)._active_ctx.graph == g with ureg.context("lc", "ab"): - assert ureg._active_ctx.graph == g + assert internal(ureg)._active_ctx.graph == g with ureg.context("ab", "lc"): - assert ureg._active_ctx.graph == g + assert internal(ureg)._active_ctx.graph == g def test_graph_enable(self, func_registry): ureg = func_registry @@ -178,33 +181,33 @@ def test_graph_enable(self, func_registry): g.update({l: {t, c}, t: {l}, c: {l}}) ureg.enable_contexts("lc") - assert ureg._active_ctx.graph == g_sp + assert internal(ureg)._active_ctx.graph == g_sp ureg.disable_contexts(1) ureg.enable_contexts("lc", n=1) - assert ureg._active_ctx.graph == g_sp + assert internal(ureg)._active_ctx.graph == g_sp ureg.disable_contexts(1) ureg.enable_contexts("ab") - assert ureg._active_ctx.graph == g_ab + assert internal(ureg)._active_ctx.graph == g_ab ureg.disable_contexts(1) ureg.enable_contexts("lc") ureg.enable_contexts("ab") - assert ureg._active_ctx.graph == g + assert internal(ureg)._active_ctx.graph == g ureg.disable_contexts(2) ureg.enable_contexts("ab") ureg.enable_contexts("lc") - assert ureg._active_ctx.graph == g + assert internal(ureg)._active_ctx.graph == g ureg.disable_contexts(2) ureg.enable_contexts("lc", "ab") - assert ureg._active_ctx.graph == g + assert internal(ureg)._active_ctx.graph == g ureg.disable_contexts(2) ureg.enable_contexts("ab", "lc") - assert ureg._active_ctx.graph == g + assert internal(ureg)._active_ctx.graph == g ureg.disable_contexts(2) def test_known_nested_context(self, func_registry): @@ -212,22 +215,22 @@ def test_known_nested_context(self, func_registry): add_ctxs(ureg) with ureg.context("lc"): - x = dict(ureg._active_ctx) - y = dict(ureg._active_ctx.graph) - assert ureg._active_ctx - assert ureg._active_ctx.graph + x = dict(internal(ureg)._active_ctx) + y = dict(internal(ureg)._active_ctx.graph) + assert internal(ureg)._active_ctx + assert internal(ureg)._active_ctx.graph with ureg.context("ab"): - assert ureg._active_ctx - assert ureg._active_ctx.graph - assert x != ureg._active_ctx - assert y != ureg._active_ctx.graph + assert internal(ureg)._active_ctx + assert internal(ureg)._active_ctx.graph + assert x != internal(ureg)._active_ctx + assert y != internal(ureg)._active_ctx.graph - assert x == ureg._active_ctx - assert y == ureg._active_ctx.graph + assert x == internal(ureg)._active_ctx + assert y == internal(ureg)._active_ctx.graph - assert not ureg._active_ctx - assert not ureg._active_ctx.graph + assert not internal(ureg)._active_ctx + assert not internal(ureg)._active_ctx.graph def test_unknown_context(self, func_registry): ureg = func_registry @@ -235,25 +238,25 @@ def test_unknown_context(self, func_registry): with pytest.raises(KeyError): with ureg.context("la"): pass - assert not ureg._active_ctx - assert not ureg._active_ctx.graph + assert not internal(ureg)._active_ctx + assert not internal(ureg)._active_ctx.graph def test_unknown_nested_context(self, func_registry): ureg = func_registry add_ctxs(ureg) with ureg.context("lc"): - x = dict(ureg._active_ctx) - y = dict(ureg._active_ctx.graph) + x = dict(internal(ureg)._active_ctx) + y = dict(internal(ureg)._active_ctx.graph) with pytest.raises(KeyError): with ureg.context("la"): pass - assert x == ureg._active_ctx - assert y == ureg._active_ctx.graph + assert x == internal(ureg)._active_ctx + assert y == internal(ureg)._active_ctx.graph - assert not ureg._active_ctx - assert not ureg._active_ctx.graph + assert not internal(ureg)._active_ctx + assert not internal(ureg)._active_ctx.graph def test_one_context(self, func_registry): ureg = func_registry @@ -498,21 +501,21 @@ def _test_ctx(self, ctx, ureg): q = 500 * ureg.meter s = (ureg.speed_of_light / q).to("Hz") - nctx = len(ureg._contexts) + nctx = len(internal(ureg)._contexts) - assert ctx.name not in ureg._contexts + assert ctx.name not in internal(ureg)._contexts ureg.add_context(ctx) - assert ctx.name in ureg._contexts - assert len(ureg._contexts) == nctx + 1 + len(ctx.aliases) + assert ctx.name in internal(ureg)._contexts + assert len(internal(ureg)._contexts) == nctx + 1 + len(ctx.aliases) with ureg.context(ctx.name): assert q.to("Hz") == s assert s.to("meter") == q ureg.remove_context(ctx.name) - assert ctx.name not in ureg._contexts - assert len(ureg._contexts) == nctx + assert ctx.name not in internal(ureg)._contexts + assert len(internal(ureg)._contexts) == nctx @pytest.mark.parametrize( "badrow", @@ -661,11 +664,11 @@ def test_defined(self, class_registry): b = Context.__keytransform__( UnitsContainer({"[length]": 1.0}), UnitsContainer({"[time]": -1.0}) ) - assert a in ureg._contexts["sp"].funcs - assert b in ureg._contexts["sp"].funcs + assert a in internal(ureg)._contexts["sp"].funcs + assert b in internal(ureg)._contexts["sp"].funcs with ureg.context("sp"): - assert a in ureg._active_ctx - assert b in ureg._active_ctx + assert a in internal(ureg)._active_ctx + assert b in internal(ureg)._active_ctx def test_spectroscopy(self, class_registry): ureg = class_registry @@ -681,7 +684,7 @@ def test_spectroscopy(self, class_registry): da, db = Context.__keytransform__( a.dimensionality, b.dimensionality ) - p = find_shortest_path(ureg._active_ctx.graph, da, db) + p = find_shortest_path(internal(ureg)._active_ctx.graph, da, db) assert p msg = f"{a} <-> {b}" # assertAlmostEqualRelError converts second to first @@ -703,7 +706,7 @@ def test_textile(self, class_registry): a = qty_direct.to_base_units() b = qty_indirect.to_base_units() da, db = Context.__keytransform__(a.dimensionality, b.dimensionality) - p = find_shortest_path(ureg._active_ctx.graph, da, db) + p = find_shortest_path(internal(ureg)._active_ctx.graph, da, db) assert p msg = f"{a} <-> {b}" helpers.assert_quantity_almost_equal(b, a, rtol=0.01, msg=msg) diff --git a/pint/testsuite/test_diskcache.py b/pint/testsuite/test_diskcache.py index 399f9f765..060d3f56c 100644 --- a/pint/testsuite/test_diskcache.py +++ b/pint/testsuite/test_diskcache.py @@ -11,13 +11,16 @@ FS_SLEEP = 0.010 +from .helpers import internal + + @pytest.fixture def float_cache_filename(tmp_path): ureg = pint.UnitRegistry(cache_folder=tmp_path / "cache_with_float") - assert ureg._diskcache - assert ureg._diskcache.cache_folder + assert internal(ureg)._diskcache + assert internal(ureg)._diskcache.cache_folder - return tuple(ureg._diskcache.cache_folder.glob("*.pickle")) + return tuple(internal(ureg)._diskcache.cache_folder.glob("*.pickle")) def test_must_be_three_files(float_cache_filename): @@ -30,7 +33,7 @@ def test_must_be_three_files(float_cache_filename): def test_no_cache(): ureg = pint.UnitRegistry(cache_folder=None) - assert ureg._diskcache is None + assert internal(ureg)._diskcache is None assert ureg.cache_folder is None @@ -38,11 +41,11 @@ def test_decimal(tmp_path, float_cache_filename): ureg = pint.UnitRegistry( cache_folder=tmp_path / "cache_with_decimal", non_int_type=decimal.Decimal ) - assert ureg._diskcache - assert ureg._diskcache.cache_folder == tmp_path / "cache_with_decimal" + assert internal(ureg)._diskcache + assert internal(ureg)._diskcache.cache_folder == tmp_path / "cache_with_decimal" assert ureg.cache_folder == tmp_path / "cache_with_decimal" - files = tuple(ureg._diskcache.cache_folder.glob("*.pickle")) + files = tuple(internal(ureg)._diskcache.cache_folder.glob("*.pickle")) assert len(files) == 3 # check that the filenames with decimal are different to the ones with float @@ -66,9 +69,11 @@ def test_auto(float_cache_filename): float_filenames = tuple(p.name for p in float_cache_filename) ureg = pint.UnitRegistry(cache_folder=":auto:") - assert ureg._diskcache - assert ureg._diskcache.cache_folder - auto_files = tuple(p.name for p in ureg._diskcache.cache_folder.glob("*.pickle")) + assert internal(ureg)._diskcache + assert internal(ureg)._diskcache.cache_folder + auto_files = tuple( + p.name for p in internal(ureg)._diskcache.cache_folder.glob("*.pickle") + ) for file in float_filenames: assert file in auto_files @@ -82,7 +87,7 @@ def test_change_file(tmp_path): # (this will create two cache files, one for the file another for RegistryCache) ureg = pint.UnitRegistry(dfile, cache_folder=tmp_path) assert ureg.x == 1234 - files = tuple(ureg._diskcache.cache_folder.glob("*.pickle")) + files = tuple(internal(ureg)._diskcache.cache_folder.glob("*.pickle")) assert len(files) == 2 # Modify the definition file @@ -93,5 +98,5 @@ def test_change_file(tmp_path): # Verify that the definiton file was loaded (the cache was invalidated). ureg = pint.UnitRegistry(dfile, cache_folder=tmp_path) assert ureg.x == 1235 - files = tuple(ureg._diskcache.cache_folder.glob("*.pickle")) + files = tuple(internal(ureg)._diskcache.cache_folder.glob("*.pickle")) assert len(files) == 4 diff --git a/pint/testsuite/test_infer_base_unit.py b/pint/testsuite/test_infer_base_unit.py index 9a273622c..b40e5d6e2 100644 --- a/pint/testsuite/test_infer_base_unit.py +++ b/pint/testsuite/test_infer_base_unit.py @@ -3,107 +3,131 @@ import pytest -from pint import Quantity as Q from pint import UnitRegistry from pint.testsuite import helpers from pint.util import infer_base_unit -class TestInferBaseUnit: - def test_infer_base_unit(self): - from pint.util import infer_base_unit +def test_infer_base_unit(sess_registry): + test_units = sess_registry.Quantity(1, "meter**2").units + registry = sess_registry - test_units = Q(1, "meter**2").units - registry = Q(1, "meter**2")._REGISTRY + assert ( + infer_base_unit(sess_registry.Quantity(1, "millimeter * nanometer")) + == test_units + ) - assert infer_base_unit(Q(1, "millimeter * nanometer")) == test_units + assert infer_base_unit("millimeter * nanometer", registry) == test_units - assert infer_base_unit("millimeter * nanometer", registry) == test_units - - assert ( - infer_base_unit(Q(1, "millimeter * nanometer").units, registry) - == test_units - ) - - with pytest.raises(ValueError, match=r"No registry provided."): - infer_base_unit("millimeter") - - def test_infer_base_unit_decimal(self): - from pint.util import infer_base_unit - - ureg = UnitRegistry(non_int_type=Decimal) - QD = ureg.Quantity - - ibu_d = infer_base_unit(QD(Decimal(1), "millimeter * nanometer")) - - assert ibu_d == QD(Decimal(1), "meter**2").units - - assert all(isinstance(v, Decimal) for v in ibu_d.values()) - - def test_infer_base_unit_fraction(self): - from pint.util import infer_base_unit - - ureg = UnitRegistry(non_int_type=Fraction) - QD = ureg.Quantity - - ibu_d = infer_base_unit(QD(Fraction("1"), "millimeter * nanometer")) - - assert ibu_d == QD(Fraction("1"), "meter**2").units - - assert all(isinstance(v, Fraction) for v in ibu_d.values()) - - def test_units_adding_to_zero(self): - assert infer_base_unit(Q(1, "m * mm / m / um * s")) == Q(1, "s").units - - def test_to_compact(self): - r = Q(1000000000, "m") * Q(1, "mm") / Q(1, "s") / Q(1, "ms") - compact_r = r.to_compact() - expected = Q(1000.0, "kilometer**2 / second**2") - helpers.assert_quantity_almost_equal(compact_r, expected) - - r = (Q(1, "m") * Q(1, "mm") / Q(1, "m") / Q(2, "um") * Q(2, "s")).to_compact() - helpers.assert_quantity_almost_equal(r, Q(1000, "s")) - - def test_to_compact_decimal(self): - ureg = UnitRegistry(non_int_type=Decimal) - Q = ureg.Quantity - r = ( - Q(Decimal("1000000000.0"), "m") - * Q(Decimal(1), "mm") - / Q(Decimal(1), "s") - / Q(Decimal(1), "ms") - ) - compact_r = r.to_compact() - expected = Q(Decimal("1000.0"), "kilometer**2 / second**2") - assert compact_r == expected - - r = ( - Q(Decimal(1), "m") * Q(1, "mm") / Q(1, "m**2") / Q(2, "um") * Q(2, "s") - ).to_compact() - assert r == Q(1000, "s/m") - - def test_to_compact_fraction(self): - ureg = UnitRegistry(non_int_type=Fraction) - Q = ureg.Quantity - r = ( - Q(Fraction("10000000000/10"), "m") - * Q(Fraction("1"), "mm") - / Q(Fraction("1"), "s") - / Q(Fraction("1"), "ms") + assert ( + infer_base_unit( + sess_registry.Quantity(1, "millimeter * nanometer").units, registry ) - compact_r = r.to_compact() - expected = Q(Fraction("1000.0"), "kilometer**2 / second**2") - assert compact_r == expected - - r = ( - Q(Fraction(1), "m") * Q(1, "mm") / Q(1, "m**2") / Q(2, "um") * Q(2, "s") - ).to_compact() - assert r == Q(1000, "s/m") - - def test_volts(self): - from pint.util import infer_base_unit - - r = Q(1, "V") * Q(1, "mV") / Q(1, "kV") - b = infer_base_unit(r) - assert b == Q(1, "V").units - helpers.assert_quantity_almost_equal(r, Q(1, "uV")) + == test_units + ) + + with pytest.raises(ValueError, match=r"No registry provided."): + infer_base_unit("millimeter") + + +def test_infer_base_unit_decimal(sess_registry): + ureg = UnitRegistry(non_int_type=Decimal) + QD = ureg.Quantity + + ibu_d = infer_base_unit(QD(Decimal(1), "millimeter * nanometer")) + + assert ibu_d == QD(Decimal(1), "meter**2").units + + assert all(isinstance(v, Decimal) for v in ibu_d.values()) + + +def test_infer_base_unit_fraction(sess_registry): + ureg = UnitRegistry(non_int_type=Fraction) + QD = ureg.Quantity + + ibu_d = infer_base_unit(QD(Fraction("1"), "millimeter * nanometer")) + + assert ibu_d == QD(Fraction("1"), "meter**2").units + + assert all(isinstance(v, Fraction) for v in ibu_d.values()) + + +def test_units_adding_to_zero(sess_registry): + assert ( + infer_base_unit(sess_registry.Quantity(1, "m * mm / m / um * s")) + == sess_registry.Quantity(1, "s").units + ) + + +def test_to_compact(sess_registry): + r = ( + sess_registry.Quantity(1000000000, "m") + * sess_registry.Quantity(1, "mm") + / sess_registry.Quantity(1, "s") + / sess_registry.Quantity(1, "ms") + ) + compact_r = r.to_compact() + expected = sess_registry.Quantity(1000.0, "kilometer**2 / second**2") + helpers.assert_quantity_almost_equal(compact_r, expected) + + r = ( + sess_registry.Quantity(1, "m") + * sess_registry.Quantity(1, "mm") + / sess_registry.Quantity(1, "m") + / sess_registry.Quantity(2, "um") + * sess_registry.Quantity(2, "s") + ).to_compact() + helpers.assert_quantity_almost_equal(r, sess_registry.Quantity(1000, "s")) + + +def test_to_compact_decimal(sess_registry): + ureg = UnitRegistry(non_int_type=Decimal) + Q = ureg.Quantity + r = ( + Q(Decimal("1000000000.0"), "m") + * Q(Decimal(1), "mm") + / Q(Decimal(1), "s") + / Q(Decimal(1), "ms") + ) + compact_r = r.to_compact() + expected = Q(Decimal("1000.0"), "kilometer**2 / second**2") + assert compact_r == expected + + r = ( + Q(Decimal(1), "m") * Q(1, "mm") / Q(1, "m**2") / Q(2, "um") * Q(2, "s") + ).to_compact() + assert r == Q(1000, "s/m") + + +def test_to_compact_fraction(sess_registry): + ureg = UnitRegistry(non_int_type=Fraction) + Q = ureg.Quantity + r = ( + Q(Fraction("10000000000/10"), "m") + * Q(Fraction("1"), "mm") + / Q(Fraction("1"), "s") + / Q(Fraction("1"), "ms") + ) + compact_r = r.to_compact() + expected = Q(Fraction("1000.0"), "kilometer**2 / second**2") + assert compact_r == expected + + r = ( + sess_registry.Quantity(Fraction(1), "m") + * sess_registry.Quantity(1, "mm") + / sess_registry.Quantity(1, "m**2") + / sess_registry.Quantity(2, "um") + * sess_registry.Quantity(2, "s") + ).to_compact() + assert r == Q(1000, "s/m") + + +def test_volts(sess_registry): + r = ( + sess_registry.Quantity(1, "V") + * sess_registry.Quantity(1, "mV") + / sess_registry.Quantity(1, "kV") + ) + b = infer_base_unit(r) + assert b == sess_registry.Quantity(1, "V").units + helpers.assert_quantity_almost_equal(r, sess_registry.Quantity(1, "uV")) diff --git a/pint/testsuite/test_issues.py b/pint/testsuite/test_issues.py index 7702ea2eb..10700ffce 100644 --- a/pint/testsuite/test_issues.py +++ b/pint/testsuite/test_issues.py @@ -13,6 +13,9 @@ from pint.util import ParserHelper +from .helpers import internal + + # TODO: do not subclass from QuantityTestCase class TestIssues(QuantityTestCase): kwargs = dict(autoconvert_offset_to_baseunit=False) @@ -727,7 +730,7 @@ def test_issue1058(self, module_registry): def test_issue1062_issue1097(self): # Must not be used by any other tests ureg = UnitRegistry() - assert "nanometer" not in ureg._units + assert "nanometer" not in internal(ureg)._units for i in range(5): ctx = Context.from_lines(["@context _", "cal = 4 J"]) with ureg.context("sp", ctx): diff --git a/pint/testsuite/test_quantity.py b/pint/testsuite/test_quantity.py index 7efe74f80..f13aaf868 100644 --- a/pint/testsuite/test_quantity.py +++ b/pint/testsuite/test_quantity.py @@ -12,7 +12,6 @@ from pint import ( DimensionalityError, OffsetUnitCalculusError, - Quantity, UnitRegistry, get_application_registry, ) @@ -79,8 +78,11 @@ def test_quantity_comparison(self): j = self.Q_(5, "meter*meter") # Include a comparison to the application registry - k = 5 * get_application_registry().meter - m = Quantity(5, "meter") # Include a comparison to a directly created Quantity + 5 * get_application_registry().meter + # Include a comparison to a directly created Quantity + from pint import Quantity + + Quantity(5, "meter") # identity for single object assert x == x @@ -99,11 +101,12 @@ def test_quantity_comparison(self): assert x != z assert x < z + # TODO: Reinstate this in the near future. # Compare with items to the separate application registry - assert k >= m # These should both be from application registry - if z._REGISTRY != m._REGISTRY: - with pytest.raises(ValueError): - z > m # One from local registry, one from application registry + # assert k >= m # These should both be from application registry + # if z._REGISTRY._subregistry != m._REGISTRY._subregistry: + # with pytest.raises(ValueError): + # z > m # One from local registry, one from application registry assert z != j @@ -371,8 +374,8 @@ def test_convert(self): @helpers.requires_mip def test_to_preferred(self): - ureg = UnitRegistry() - Q_ = ureg.Quantity + ureg = self.ureg + Q_ = self.Q_ ureg.define("pound_force_per_square_foot = 47.8803 pascals = psf") ureg.define("pound_mass = 0.45359237 kg = lbm") @@ -409,9 +412,9 @@ def test_to_preferred(self): @helpers.requires_mip def test_to_preferred_registry(self): - ureg = UnitRegistry() - Q_ = ureg.Quantity - ureg.preferred_units = [ + ureg = self.ureg + Q_ = self.Q_ + ureg.default_preferred_units = [ ureg.m, # distance L ureg.kg, # mass M ureg.s, # duration T @@ -424,9 +427,10 @@ def test_to_preferred_registry(self): @helpers.requires_mip def test_autoconvert_to_preferred(self): - ureg = UnitRegistry() - Q_ = ureg.Quantity - ureg.preferred_units = [ + ureg = self.ureg + Q_ = self.Q_ + ureg.autoconvert_to_preferred = True + ureg.default_preferred_units = [ ureg.m, # distance L ureg.kg, # mass M ureg.s, # duration T diff --git a/pint/testsuite/test_systems.py b/pint/testsuite/test_systems.py index 5b3f1ce2e..49da32c52 100644 --- a/pint/testsuite/test_systems.py +++ b/pint/testsuite/test_systems.py @@ -4,6 +4,9 @@ from pint.testsuite import QuantityTestCase +from .helpers import internal + + class TestGroup: def _build_empty_reg_root(self): ureg = UnitRegistry(None) @@ -13,7 +16,7 @@ def _build_empty_reg_root(self): def test_units_programmatically(self): ureg, root = self._build_empty_reg_root() - d = ureg._groups + d = internal(ureg)._groups assert root._used_groups == set() assert root._used_by == set() @@ -38,7 +41,7 @@ def test_cyclic(self): def test_groups_programmatically(self): ureg, root = self._build_empty_reg_root() - d = ureg._groups + d = internal(ureg)._groups g2 = ureg.Group("g2") assert d.keys() == {"root", "g2"} @@ -53,7 +56,7 @@ def test_simple(self): lines = ["@group mygroup", "meter = 3", "second = 2"] ureg, root = self._build_empty_reg_root() - d = ureg._groups + d = internal(ureg)._groups grp = ureg.Group.from_lines(lines, lambda x: None) @@ -221,7 +224,7 @@ def test_get_base_units(self): lines = ["@system %s using test-imperial" % sysname, "inch"] s = ureg.System.from_lines(lines, ureg.get_base_units) - ureg._systems[s.name] = s + internal(ureg)._systems[s.name] = s # base_factor, destination_units c = ureg.get_base_units("inch", system=sysname) @@ -243,7 +246,7 @@ def test_get_base_units_different_exponent(self): lines = ["@system %s using test-imperial" % sysname, "pint:meter"] s = ureg.System.from_lines(lines, ureg.get_base_units) - ureg._systems[s.name] = s + internal(ureg)._systems[s.name] = s # base_factor, destination_units c = ureg.get_base_units("inch", system=sysname) @@ -272,7 +275,7 @@ def test_get_base_units_relation(self): lines = ["@system %s using test-imperial" % sysname, "mph:meter"] s = ureg.System.from_lines(lines, ureg.get_base_units) - ureg._systems[s.name] = s + internal(ureg)._systems[s.name] = s # base_factor, destination_units c = ureg.get_base_units("inch", system=sysname) assert round(abs(c[0] - 0.056), 2) == 0 diff --git a/pint/testsuite/test_testing.py b/pint/testsuite/test_testing.py index 3116dd8aa..eab04fcb9 100644 --- a/pint/testsuite/test_testing.py +++ b/pint/testsuite/test_testing.py @@ -1,12 +1,17 @@ import pytest -from pint import Quantity +from typing import Any from .. import testing np = pytest.importorskip("numpy") +class QuantityToBe(tuple[Any]): + def from_many(*args): + return QuantityToBe(args) + + @pytest.mark.parametrize( ["first", "second", "error", "message"], ( @@ -14,7 +19,7 @@ np.array([0, 1]), np.array([0, 1]), False, "", id="ndarray-None-None-equal" ), pytest.param( - Quantity(1, "m"), + QuantityToBe.from_many(1, "m"), 1, True, "The first is not dimensionless", @@ -22,73 +27,81 @@ ), pytest.param( 1, - Quantity(1, "m"), + QuantityToBe.from_many(1, "m"), True, "The second is not dimensionless", id="mixed2-int-not equal-equal", ), pytest.param( - Quantity(1, "m"), Quantity(1, "m"), False, "", id="Quantity-int-equal-equal" + QuantityToBe.from_many(1, "m"), + QuantityToBe.from_many(1, "m"), + False, + "", + id="QuantityToBe.from_many-int-equal-equal", ), pytest.param( - Quantity(1, "m"), - Quantity(1, "s"), + QuantityToBe.from_many(1, "m"), + QuantityToBe.from_many(1, "s"), True, "Units are not equal", - id="Quantity-int-equal-not equal", + id="QuantityToBe.from_many-int-equal-not equal", ), pytest.param( - Quantity(1, "m"), - Quantity(2, "m"), + QuantityToBe.from_many(1, "m"), + QuantityToBe.from_many(2, "m"), True, "Magnitudes are not equal", - id="Quantity-int-not equal-equal", + id="QuantityToBe.from_many-int-not equal-equal", ), pytest.param( - Quantity(1, "m"), - Quantity(2, "s"), + QuantityToBe.from_many(1, "m"), + QuantityToBe.from_many(2, "s"), True, "Units are not equal", - id="Quantity-int-not equal-not equal", + id="QuantityToBe.from_many-int-not equal-not equal", ), pytest.param( - Quantity(1, "m"), - Quantity(float("nan"), "m"), + QuantityToBe.from_many(1, "m"), + QuantityToBe.from_many(float("nan"), "m"), True, "Magnitudes are not equal", - id="Quantity-float-not equal-equal", + id="QuantityToBe.from_many-float-not equal-equal", ), pytest.param( - Quantity([1, 2], "m"), - Quantity([1, 2], "m"), + QuantityToBe.from_many([1, 2], "m"), + QuantityToBe.from_many([1, 2], "m"), False, "", - id="Quantity-ndarray-equal-equal", + id="QuantityToBe.from_many-ndarray-equal-equal", ), pytest.param( - Quantity([1, 2], "m"), - Quantity([1, 2], "s"), + QuantityToBe.from_many([1, 2], "m"), + QuantityToBe.from_many([1, 2], "s"), True, "Units are not equal", - id="Quantity-ndarray-equal-not equal", + id="QuantityToBe.from_many-ndarray-equal-not equal", ), pytest.param( - Quantity([1, 2], "m"), - Quantity([2, 2], "m"), + QuantityToBe.from_many([1, 2], "m"), + QuantityToBe.from_many([2, 2], "m"), True, "Magnitudes are not equal", - id="Quantity-ndarray-not equal-equal", + id="QuantityToBe.from_many-ndarray-not equal-equal", ), pytest.param( - Quantity([1, 2], "m"), - Quantity([2, 2], "s"), + QuantityToBe.from_many([1, 2], "m"), + QuantityToBe.from_many([2, 2], "s"), True, "Units are not equal", - id="Quantity-ndarray-not equal-not equal", + id="QuantityToBe.from_many-ndarray-not equal-not equal", ), ), ) -def test_assert_equal(first, second, error, message): +def test_assert_equal(sess_registry, first, second, error, message): + if isinstance(first, QuantityToBe): + first = sess_registry.Quantity(*first) + if isinstance(second, QuantityToBe): + second = sess_registry.Quantity(*second) if error: with pytest.raises(AssertionError, match=message): testing.assert_equal(first, second) diff --git a/pint/testsuite/test_unit.py b/pint/testsuite/test_unit.py index c1a2704b5..d0f335357 100644 --- a/pint/testsuite/test_unit.py +++ b/pint/testsuite/test_unit.py @@ -14,6 +14,8 @@ from pint.testsuite import QuantityTestCase, assert_no_warnings, helpers from pint.util import ParserHelper, UnitsContainer +from .helpers import internal + # TODO: do not subclass from QuantityTestCase class TestUnit(QuantityTestCase): @@ -293,11 +295,11 @@ def test_define(self): assert len(dir(ureg)) > 0 def test_load(self): - import pkg_resources + from importlib.resources import files from .. import compat - data = pkg_resources.resource_filename(compat.__name__, "default_en.txt") + data = files(compat.__package__).joinpath("default_en.txt") ureg1 = UnitRegistry() ureg2 = UnitRegistry(data) assert dir(ureg1) == dir(ureg2) @@ -593,6 +595,23 @@ def hfunc(x, y): h3 = ureg.wraps((None,), (None, None))(hfunc) assert h3(3, 1) == (3, 1) + def kfunc(a, /, b, c=5, *, d=6): + return a, b, c, d + + k1 = ureg.wraps((None,), (None, None, None, None))(kfunc) + assert k1(1, 2, 3, d=4) == (1, 2, 3, 4) + assert k1(1, 2, c=3, d=4) == (1, 2, 3, 4) + assert k1(1, b=2, c=3, d=4) == (1, 2, 3, 4) + assert k1(1, d=4, b=2, c=3) == (1, 2, 3, 4) + assert k1(1, 2, c=3) == (1, 2, 3, 6) + assert k1(1, 2, d=4) == (1, 2, 5, 4) + assert k1(1, 2) == (1, 2, 5, 6) + + k2 = ureg.wraps((None,), ("meter", "centimeter", "meter", "centimeter"))(kfunc) + assert k2( + 1 * ureg.meter, 2 * ureg.centimeter, 3 * ureg.meter, d=4 * ureg.centimeter + ) == (1, 2, 3, 4) + def test_wrap_referencing(self): ureg = self.ureg @@ -641,6 +660,7 @@ def func(x): assert f0(3.0 * ureg.centimeter) == 0.03 * ureg.meter with pytest.raises(DimensionalityError): f0(3.0 * ureg.kilogram) + assert f0(x=3.0 * ureg.centimeter) == 0.03 * ureg.meter f0b = ureg.check(ureg.meter)(func) with pytest.raises(DimensionalityError): @@ -677,13 +697,13 @@ def test_to_ref_vs_to(self): q = 8.0 * self.ureg.inch t = 8.0 * self.ureg.degF dt = 8.0 * self.ureg.delta_degF - assert q.to("yard").magnitude == self.ureg._units[ + assert q.to("yard").magnitude == internal(self.ureg)._units[ "inch" ].converter.to_reference(8.0) - assert t.to("kelvin").magnitude == self.ureg._units[ + assert t.to("kelvin").magnitude == internal(self.ureg)._units[ "degF" ].converter.to_reference(8.0) - assert dt.to("kelvin").magnitude == self.ureg._units[ + assert dt.to("kelvin").magnitude == internal(self.ureg)._units[ "delta_degF" ].converter.to_reference(8.0) @@ -881,13 +901,6 @@ def test_get_compatible_units(self): class TestRegistryWithDefaultRegistry(TestRegistry): - @classmethod - def setup_class(cls): - from pint import _DEFAULT_REGISTRY - - cls.ureg = _DEFAULT_REGISTRY - cls.Q_ = cls.ureg.Quantity - def test_lazy(self): x = LazyRegistry() x.test = "test" @@ -896,8 +909,10 @@ def test_lazy(self): y("meter") assert isinstance(y, UnitRegistry) - def test_redefinition(self): - d = self.ureg.define + def test_redefinition(self, func_registry): + ureg = UnitRegistry(on_redefinition="raise") + d = ureg.define + assert "meter" in internal(self.ureg)._units with pytest.raises(RedefinitionError): d("meter = [time]") with pytest.raises(RedefinitionError): @@ -908,7 +923,7 @@ def test_redefinition(self): d("[velocity] = [length]") # aliases - assert "inch" in self.ureg._units + assert "inch" in internal(self.ureg)._units with pytest.raises(RedefinitionError): d("bla = 3.2 meter = inch") with pytest.raises(RedefinitionError): @@ -1007,7 +1022,7 @@ def test_alias(self): assert ureg.Unit(a) == ureg.Unit("canonical") # Test that aliases defined multiple times are not duplicated - assert ureg._units["canonical"].aliases == ( + assert internal(ureg)._units["canonical"].aliases == ( "alias1", "alias2", ) diff --git a/pint/util.py b/pint/util.py index d14722a04..1f7defc50 100644 --- a/pint/util.py +++ b/pint/util.py @@ -1043,8 +1043,7 @@ def to_units_container( # TODO: document how to whether to lift preprocessing loop out to caller for p in registry.preprocessors: unit_like = p(unit_like) - # TODO: Why not parse.units here? - return registry._parse_units(unit_like) + return registry.parse_units_as_container(unit_like) else: return ParserHelper.from_string(unit_like) elif dict in mro: