astropy.modeling.parameters 源代码

# Licensed under a 3-clause BSD style license - see LICENSE.rst
# pylint: disable=invalid-name

This module defines classes that deal with parameters.

It is unlikely users will need to work with these classes directly,
unless they define their own models.

import functools
import numbers
import operator

import numpy as np

from astropy.units import Quantity
from astropy.utils import isiterable
from .utils import array_repr_oneline
from .utils import get_inputs_and_params

__all__ = ['Parameter', 'InputParameterError', 'ParameterError']

[文档]class ParameterError(Exception): """Generic exception class for all exceptions pertaining to Parameters."""
[文档]class InputParameterError(ValueError, ParameterError): """Used for incorrect input parameter values and definitions."""
class ParameterDefinitionError(ParameterError): """Exception in declaration of class-level Parameters.""" def _tofloat(value): """Convert a parameter to float or float array""" if isiterable(value): try: value = np.asanyarray(value, dtype=float) except (TypeError, ValueError): # catch arrays with strings or user errors like different # types of parameters in a parameter set raise InputParameterError( f"Parameter of {type(value)} could not be converted to float") elif isinstance(value, Quantity): # Quantities are fine as is pass elif isinstance(value, np.ndarray): # A scalar/dimensionless array value = float(value.item()) elif isinstance(value, (numbers.Number, np.number)): value = float(value) elif isinstance(value, bool): raise InputParameterError( "Expected parameter to be of numerical type, not boolean") else: raise InputParameterError( f"Don't know how to convert parameter of {type(value)} to float") return value # Helpers for implementing operator overloading on Parameter def _binary_arithmetic_operation(op, reflected=False): @functools.wraps(op) def wrapper(self, val): if self.unit is not None: self_value = Quantity(self.value, self.unit) else: self_value = self.value if reflected: return op(val, self_value) else: return op(self_value, val) return wrapper def _binary_comparison_operation(op): @functools.wraps(op) def wrapper(self, val): if self.unit is not None: self_value = Quantity(self.value, self.unit) else: self_value = self.value return op(self_value, val) return wrapper def _unary_arithmetic_operation(op): @functools.wraps(op) def wrapper(self): if self.unit is not None: self_value = Quantity(self.value, self.unit) else: self_value = self.value return op(self_value) return wrapper
[文档]class Parameter: """ Wraps individual parameters. Since 4.0 Parameters are no longer descriptors and are based on a new implementation of the Parameter class. Parameters now (as of 4.0) store values locally (as instead previously in the associated model) This class represents a model's parameter (in a somewhat broad sense). It serves a number of purposes: 1) A type to be recognized by models and treated specially at class initialization (i.e., if it is found that there is a class definition of a Parameter, the model initializer makes a copy at the instance level). 2) Managing the handling of allowable parameter values and once defined, ensuring updates are consistent with the Parameter definition. This includes the optional use of units and quantities as well as tranforming values to an internally consistent representation (e.g., from degrees to radians through the use of getters and setters). 3) Holding attributes of parameters relevant to fitting, such as whether the parameter may be varied in fitting, or whether there are constraints that must be satisfied. See :ref:`modeling-parameters` for more details. Parameters ---------- name : str parameter name .. warning:: The fact that `Parameter` accepts ``name`` as an argument is an implementation detail, and should not be used directly. When defining a new `Model` class, parameter names are always automatically defined by the class attribute they're assigned to. description : str parameter description default : float or array default value to use for this parameter unit : `~astropy.units.Unit` if specified, the parameter will be in these units, and when the parameter is updated in future, it should be set to a :class:`~astropy.units.Quantity` that has equivalent units. getter : callable a function that wraps the raw (internal) value of the parameter when returning the value through the parameter proxy (eg. a parameter may be stored internally as radians but returned to the user as degrees) setter : callable a function that wraps any values assigned to this parameter; should be the inverse of getter fixed : bool if True the parameter is not varied during fitting tied : callable or False if callable is supplied it provides a way to link the value of this parameter to another parameter (or some other arbitrary function) min : float the lower bound of a parameter max : float the upper bound of a parameter bounds : tuple specify min and max as a single tuple--bounds may not be specified simultaneously with min or max """ constraints = ('fixed', 'tied', 'bounds') """ Types of constraints a parameter can have. Excludes 'min' and 'max' which are just aliases for the first and second elements of the 'bounds' constraint (which is represented as a 2-tuple). 'prior' and 'posterior' are available for use by user fitters but are not used by any built-in fitters as of this writing. """ def __init__(self, name='', description='', default=None, unit=None, getter=None, setter=None, fixed=False, tied=False, min=None, max=None, bounds=None): super().__init__() self._model = None self._model_required = False self._setter = self._create_value_wrapper(setter, None) self._getter = self._create_value_wrapper(getter, None) self._name = name self.__doc__ = self._description = description.strip() # We only need to perform this check on unbound parameters if isinstance(default, Quantity): if unit is not None and not unit.is_equivalent(default.unit): raise ParameterDefinitionError( "parameter default {0} does not have units equivalent to " "the required unit {1}".format(default, unit)) unit = default.unit default = default.value self._default = default self._unit = unit # Internal units correspond to raw_units held by the model in the # previous implementation. The private _getter and _setter methods # use this to convert to and from the public unit defined for the # parameter. self._internal_unit = None if not self._model_required: if self._default is not None: self.value = self._default else: self._value = None # NOTE: These are *default* constraints--on model instances constraints # are taken from the model if set, otherwise the defaults set here are # used if bounds is not None: if min is not None or max is not None: raise ValueError( 'bounds may not be specified simultaneously with min or ' 'or max when instantiating Parameter {}'.format(name)) else: bounds = (min, max) self._fixed = fixed self._tied = tied self._bounds = bounds self._order = None self._validator = None self._prior = None self._posterior = None self._std = None def __set_name__(self, owner, name): self._name = name def __len__(self): val = self.value if val.shape == (): return 1 else: return val.shape[0] def __getitem__(self, key): value = self.value if len(value.shape) == 0: # Wrap the value in a list so that getitem can work for sensible # indices like [0] and [-1] value = [value] return value[key] def __setitem__(self, key, value): # Get the existing value and check whether it even makes sense to # apply this index oldvalue = self.value if isinstance(key, slice): if len(oldvalue[key]) == 0: raise InputParameterError( "Slice assignment outside the parameter dimensions for " "'{}'".format( for idx, val in zip(range(*key.indices(len(self))), value): self.__setitem__(idx, val) else: try: oldvalue[key] = value except IndexError: raise InputParameterError( "Input dimension {} invalid for {!r} parameter with " "dimension {}".format(key,, value.shape[0])) # likely wrong def __repr__(self): args = f"'{self._name}'" args += f', value={self.value}' if self.unit is not None: args += f', unit={self.unit}' for cons in self.constraints: val = getattr(self, cons) if val not in (None, False, (None, None)): # Maybe non-obvious, but False is the default for the fixed and # tied constraints args += f', {cons}={val}' return f"{self.__class__.__name__}({args})" @property def name(self): """Parameter name""" return self._name @property def default(self): """Parameter default value""" return self._default @property def value(self): """The unadorned value proxied by this parameter.""" if self._getter is None and self._setter is None: return np.float64(self._value) else: # This new implementation uses the names of internal_unit # in place of raw_unit used previously. The contrast between # internal values and units is that between the public # units that the parameter advertises to what it actually # uses internally. if self.internal_unit: return np.float64(self._getter(self._internal_value, self.internal_unit, self.unit).value) elif self._getter: return np.float64(self._getter(self._internal_value)) elif self._setter: return np.float64(self._internal_value) @value.setter def value(self, value): if isinstance(value, Quantity): raise TypeError("The .value property on parameters should be set" " to unitless values, not Quantity objects. To set" "a parameter to a quantity simply set the " "parameter directly without using .value") if self._setter is None: self._value = np.array(value, dtype=np.float64) else: self._internal_value = np.array(self._setter(value), dtype=np.float64) @property def unit(self): """ The unit attached to this parameter, if any. On unbound parameters (i.e. parameters accessed through the model class, rather than a model instance) this is the required/ default unit for the parameter. """ return self._unit @unit.setter def unit(self, unit): if self.unit is None: raise ValueError('Cannot attach units to parameters that were ' 'not initially specified with units') else: raise ValueError('Cannot change the unit attribute directly, ' 'instead change the parameter to a new quantity') def _set_unit(self, unit, force=False): if force: self._unit = unit else: if unit is None: raise ValueError('Cannot attach units to parameters that were ' 'not initially specified with units') else: raise ValueError('Cannot change the unit attribute directly, ' 'instead change the parameter to a new quantity') @property def internal_unit(self): """ Return the internal unit the parameter uses for the internal value stored """ return self._internal_unit @internal_unit.setter def internal_unit(self, internal_unit): """ Set the unit the parameter will convert the supplied value to the representation used internally. """ self._internal_unit = internal_unit @property def quantity(self): """ This parameter, as a :class:`~astropy.units.Quantity` instance. """ if self.unit is None: return None return self.value * self.unit @quantity.setter def quantity(self, quantity): if not isinstance(quantity, Quantity): raise TypeError("The .quantity attribute should be set " "to a Quantity object") self.value = quantity.value self._unit = quantity.unit @property def shape(self): """The shape of this parameter's value array.""" if self._setter is None: return self._value.shape return self._internal_value.shape @shape.setter def shape(self, value): if isinstance(self.value, np.generic): if value not in ((), (1,)): raise ValueError("Cannot assign this shape to a scalar quantity") else: self.value.shape = value @property def size(self): """The size of this parameter's value array.""" return np.size(self.value) @property def std(self): """Standard deviation, if available from fit.""" return self._std @std.setter def std(self, value): self._std = value @property def prior(self): return self._prior @prior.setter def prior(self, val): self._prior = val @property def posterior(self): return self._posterior @posterior.setter def posterior(self, val): self._posterior = val @property def fixed(self): """ Boolean indicating if the parameter is kept fixed during fitting. """ return self._fixed @fixed.setter def fixed(self, value): """ Fix a parameter. """ if not isinstance(value, bool): raise ValueError("Value must be boolean") self._fixed = value @property def tied(self): """ Indicates that this parameter is linked to another one. A callable which provides the relationship of the two parameters. """ return self._tied @tied.setter def tied(self, value): """Tie a parameter""" if not callable(value) and value not in (False, None): raise TypeError("Tied must be a callable or set to False or None") self._tied = value @property def bounds(self): """The minimum and maximum values of a parameter as a tuple""" return self._bounds @bounds.setter def bounds(self, value): """Set the minimum and maximum values of a parameter from a tuple""" _min, _max = value if _min is not None: if not isinstance(_min, (numbers.Number, Quantity)): raise TypeError("Min value must be a number or a Quantity") if isinstance(_min, Quantity): _min = float(_min.value) else: _min = float(_min) if _max is not None: if not isinstance(_max, (numbers.Number, Quantity)): raise TypeError("Max value must be a number") if isinstance(_max, Quantity): _max = float(_max.value) else: _max = float(_max) self._bounds = (_min, _max) @property def min(self): """A value used as a lower bound when fitting a parameter""" return self.bounds[0] @min.setter def min(self, value): """Set a minimum value of a parameter""" self.bounds = (value, self.max) @property def max(self): """A value used as an upper bound when fitting a parameter""" return self.bounds[1] @max.setter def max(self, value): """Set a maximum value of a parameter.""" self.bounds = (self.min, value) @property def validator(self): """ Used as a decorator to set the validator method for a `Parameter`. The validator method validates any value set for that parameter. It takes two arguments--``self``, which refers to the `Model` instance (remember, this is a method defined on a `Model`), and the value being set for this parameter. The validator method's return value is ignored, but it may raise an exception if the value set on the parameter is invalid (typically an `InputParameterError` should be raised, though this is not currently a requirement). """ def validator(func, self=self): if callable(func): self._validator = func return self else: raise ValueError("This decorator method expects a callable.\n" "The use of this method as a direct validator is\n" "deprecated; use the new validate method instead\n") return validator
[文档] def validate(self, value): """ Run the validator on this parameter""" if self._validator is not None and self._model is not None: self._validator(self._model, value)
[文档] def copy(self, name=None, description=None, default=None, unit=None, getter=None, setter=None, fixed=False, tied=False, min=None, max=None, bounds=None, prior=None, posterior=None): """ Make a copy of this `Parameter`, overriding any of its core attributes in the process (or an exact copy). The arguments to this method are the same as those for the `Parameter` initializer. This simply returns a new `Parameter` instance with any or all of the attributes overridden, and so returns the equivalent of: .. code:: python Parameter(, self.description, ...) """ kwargs = locals().copy() del kwargs['self'] for key, value in kwargs.items(): if value is None: # Annoying special cases for min/max where are just aliases for # the components of bounds if key in ('min', 'max'): continue else: if hasattr(self, key): value = getattr(self, key) elif hasattr(self, '_' + key): value = getattr(self, '_' + key) kwargs[key] = value return self.__class__(**kwargs)
@property def model(self): """ Return the model this parameter is associated with.""" return self._model @model.setter def model(self, value): self._model = value self._setter = self._create_value_wrapper(self._setter, value) self._getter = self._create_value_wrapper(self._getter, value) if self._model_required: if self._default is not None: self.value = self._default else: self._value = None @property def _raw_value(self): """ Currently for internal use only. Like Parameter.value but does not pass the result through Parameter.getter. By design this should only be used from bound parameters. This will probably be removed are retweaked at some point in the process of rethinking how parameter values are stored/updated. """ if self._setter: return self._internal_value return self.value def _create_value_wrapper(self, wrapper, model): """Wraps a getter/setter function to support optionally passing in a reference to the model object as the second argument. If a model is tied to this parameter and its getter/setter supports a second argument then this creates a partial function using the model instance as the second argument. """ if isinstance(wrapper, np.ufunc): if wrapper.nin != 1: raise TypeError("A numpy.ufunc used for Parameter " "getter/setter may only take one input " "argument") elif wrapper is None: # Just allow non-wrappers to fall through silently, for convenience return None else: inputs, _ = get_inputs_and_params(wrapper) nargs = len(inputs) if nargs == 1: pass elif nargs == 2: self._model_required = True if model is not None: # Don't make a partial function unless we're tied to a # specific model instance model_arg = inputs[1].name wrapper = functools.partial(wrapper, **{model_arg: model}) else: raise TypeError("Parameter getter/setter must be a function " "of either one or two arguments") return wrapper def __array__(self, dtype=None): # Make np.asarray(self) work a little more straightforwardly arr = np.asarray(self.value, dtype=dtype) if self.unit is not None: arr = Quantity(arr, self.unit, copy=False) return arr def __bool__(self): return bool(self.value) __add__ = _binary_arithmetic_operation(operator.add) __radd__ = _binary_arithmetic_operation(operator.add, reflected=True) __sub__ = _binary_arithmetic_operation(operator.sub) __rsub__ = _binary_arithmetic_operation(operator.sub, reflected=True) __mul__ = _binary_arithmetic_operation(operator.mul) __rmul__ = _binary_arithmetic_operation(operator.mul, reflected=True) __pow__ = _binary_arithmetic_operation(operator.pow) __rpow__ = _binary_arithmetic_operation(operator.pow, reflected=True) __div__ = _binary_arithmetic_operation(operator.truediv) __rdiv__ = _binary_arithmetic_operation(operator.truediv, reflected=True) __truediv__ = _binary_arithmetic_operation(operator.truediv) __rtruediv__ = _binary_arithmetic_operation(operator.truediv, reflected=True) __eq__ = _binary_comparison_operation(operator.eq) __ne__ = _binary_comparison_operation( __lt__ = _binary_comparison_operation( __gt__ = _binary_comparison_operation( __le__ = _binary_comparison_operation(operator.le) __ge__ = _binary_comparison_operation( __neg__ = _unary_arithmetic_operation(operator.neg) __abs__ = _unary_arithmetic_operation(operator.abs)
def param_repr_oneline(param): """ Like array_repr_oneline but works on `Parameter` objects and supports rendering parameters with units like quantities. """ out = array_repr_oneline(param.value) if param.unit is not None: out = f'{out} {param.unit!s}' return out