定义新的装配工类

本节介绍如何向该包中添加新的非线性拟合算法或编写用户定义的拟合器。简而言之,我们需要定义一个错误函数和一个 __call__ 方法并定义与此装配工一起使用的约束类型(如果有)。

下面以scipy的SLSQP算法为例进行详细说明。所有装配工的基类是 Fitter ::

class SLSQPFitter(Fitter):
    supported_constraints = ['bounds', 'eqcons', 'ineqcons', 'fixed',
                             'tied']

    def __init__(self):
        # Most currently defined fitters take no arguments in their
        # __init__, but the option certainly exists for custom fitters
        super().__init__()

所有的装配工都要取一个模型(他们的 __call__ 方法修改模型的参数)作为它们的第一个参数。

接下来,error函数获取拟合算法和输入坐标迭代返回的参数列表,用它们评估模型并返回拟合的某种类型的度量。在本例中,残差平方和用作拟合度量:

def objective_function(self, fps, *args):
    model = args[0]
    meas = args[-1]
    model.fitparams(fps)
    res = self.model(*args[1:-1]) - meas
    return np.sum(res**2)

这个 __call__ 方法执行拟合。作为最低要求,它将所有坐标作为单独的参数。根据需要传递其他参数:

def __call__(self, model, x, y , maxiter=MAXITER, epsilon=EPS):
    if model.linear:
            raise ModelLinearityException(
                'Model is linear in parameters; '
                'non-linear fitting methods should not be used.')
    model_copy = model.copy()
    init_values, _ = _model_to_fit_params(model_copy)
    self.fitparams = optimize.fmin_slsqp(self.errorfunc, p0=init_values,
                                         args=(y, x),
                                         bounds=self.bounds,
                                         eqcons=self.eqcons,
                                         ineqcons=self.ineqcons)
    return model_copy

定义插件装配工

astropy.modeling 包括一个插件机制,允许在astropy内核外部定义的装配工插入到 astropy.modeling.fitting namespace through the use of entry points. Entry points are references to importable objects. A tutorial on defining entry points can be found in setuptools' documentation . 插件装配工必须从 Fitter 基类。以便发现并插入装配工 astropy.modeling.fitting 入口点必须插入 astropy.modeling 入口点组

setup(
      # ...
      entry_points = {'astropy.modeling': 'PluginFitterName = fitter_module:PlugFitterClass'}
)

这将允许用户导入 PlugFitterName 通过 astropy.modeling.fitting 通过

from astropy.modeling.fitting import PlugFitterName

使用此功能的一个项目是 Saba 可作为参考。

使用自定义统计函数

本节介绍如何使用用户定义的统计函数编写新的装配工。下面的例子显示了一个专门的类,它与两个变量中的不确定性拟合一条直线。

需要以下导入语句::

import numpy as np
from astropy.modeling.fitting import (_validate_model,
                                      _fitter_to_model_params,
                                      _model_to_fit_params, Fitter,
                                      _convert_input)
from astropy.modeling.optimizers import Simplex

首先需要定义一个统计数据。这可以是函数或可调用类。::

def chi_line(measured_vals, updated_model, x_sigma, y_sigma, x):
    """
    Chi^2 statistic for fitting a straight line with uncertainties in x and
    y.

    Parameters
    ----------
    measured_vals : array
    updated_model : `~astropy.modeling.ParametricModel`
        model with parameters set by the current iteration of the optimizer
    x_sigma : array
        uncertainties in x
    y_sigma : array
        uncertainties in y

    """
    model_vals = updated_model(x)
    if x_sigma is None and y_sigma is None:
        return np.sum((model_vals - measured_vals) ** 2)
    elif x_sigma is not None and y_sigma is not None:
        weights = 1 / (y_sigma ** 2 + updated_model.parameters[1] ** 2 *
                       x_sigma ** 2)
        return np.sum((weights * (model_vals - measured_vals)) ** 2)
    else:
        if x_sigma is not None:
            weights = 1 / x_sigma ** 2
        else:
            weights = 1 / y_sigma ** 2
        return np.sum((weights * (model_vals - measured_vals)) ** 2)

一般来说,要定义一个新的装配工,只需提供一个统计函数和一个优化器。在本例中,我们将让优化器作为fitter的可选参数,并将统计信息设置为 chi_line 以上:

class LineFitter(Fitter):
    """
    Fit a straight line with uncertainties in both variables

    Parameters
    ----------
    optimizer : class or callable
        one of the classes in optimizers.py (default: Simplex)
    """

    def __init__(self, optimizer=Simplex):
        self.statistic = chi_line
        super().__init__(optimizer, statistic=self.statistic)

最后要定义的是 __call__ 方法:

def __call__(self, model, x, y, x_sigma=None, y_sigma=None, **kwargs):
    """
    Fit data to this model.

    Parameters
    ----------
    model : `~astropy.modeling.core.ParametricModel`
        model to fit to x, y
    x : array
        input coordinates
    y : array
        input coordinates
    x_sigma : array
        uncertainties in x
    y_sigma : array
        uncertainties in y
    kwargs : dict
        optional keyword arguments to be passed to the optimizer

    Returns
    ------
    model_copy : `~astropy.modeling.core.ParametricModel`
        a copy of the input model with parameters set by the fitter

    """
    model_copy = _validate_model(model,
                                 self._opt_method.supported_constraints)

    farg = _convert_input(x, y)
    farg = (model_copy, x_sigma, y_sigma) + farg
    p0, _ = _model_to_fit_params(model_copy)

    fitparams, self.fit_info = self._opt_method(
        self.objective_function, p0, farg, **kwargs)
    _fitter_to_model_params(model_copy, fitparams)

    return model_copy