8. 回归

中国谚语

千里之行始于足下。 --中国古语

在统计建模中,回归分析主要研究因变量与一个或多个自变量之间的关系。 Wikipedia Regression analysis

在数据挖掘中,回归是一个模型,用来表示标签值(或目标值,它是数值变量)和一个或多个特征(或预测器,它们可以是数值变量和分类变量)之间的关系。

8.1. 线性回归

8.1.1. 介绍

假设有一个数据集

System Message: WARNING/2 ({{\displaystyle \{{\,x_{{i1}},\ldots ,x_{{in}},y_{{i}}\}}_{{i=1}}^{{m}}}})

latex exited with error [stdout] This is pdfTeX, Version 3.14159265-2.6-1.40.17 (TeX Live 2016/Debian) (preloaded format=latex) restricted \write18 enabled. entering extended mode (./math.tex LaTeX2e <2017/01/01> patch level 3 Babel <3.9r> and hyphenation patterns for 12 language(s) loaded. (/usr/share/texlive/texmf-dist/tex/latex/base/article.cls Document Class: article 2014/09/29 v1.4h Standard LaTeX document class (/usr/share/texlive/texmf-dist/tex/latex/base/size12.clo)) (/usr/share/texlive/texmf-dist/tex/latex/base/inputenc.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/utf8x.def)) (/usr/share/texlive/texmf-dist/tex/latex/ucs/ucs.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/data/uni-global.def)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsmath.sty For additional information on amsmath, use the `?' option. (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amstext.sty (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsgen.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsbsy.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsopn.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amscls/amsthm.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amssymb.sty (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amsfonts.sty)) (/usr/share/texlive/texmf-dist/tex/latex/anyfontsize/anyfontsize.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/bm.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mathtools.sty (/usr/share/texlive/texmf-dist/tex/latex/graphics/keyval.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/calc.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mhsetup.sty)) ! LaTeX Error: File `dsfont.sty' not found. Type X to quit or <RETURN> to proceed, or enter new name. (Default extension: sty) Enter file name: ! Emergency stop. <read *> l.16 \def \Z{\mathbb{Z}}^^M No pages of output. Transcript written on math.log.

简单线性回归模型中包含n个特征(变量)和m个样本(数据点)。

System Message: WARNING/2 ({{\displaystyle m}})

latex exited with error [stdout] This is pdfTeX, Version 3.14159265-2.6-1.40.17 (TeX Live 2016/Debian) (preloaded format=latex) restricted \write18 enabled. entering extended mode (./math.tex LaTeX2e <2017/01/01> patch level 3 Babel <3.9r> and hyphenation patterns for 12 language(s) loaded. (/usr/share/texlive/texmf-dist/tex/latex/base/article.cls Document Class: article 2014/09/29 v1.4h Standard LaTeX document class (/usr/share/texlive/texmf-dist/tex/latex/base/size12.clo)) (/usr/share/texlive/texmf-dist/tex/latex/base/inputenc.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/utf8x.def)) (/usr/share/texlive/texmf-dist/tex/latex/ucs/ucs.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/data/uni-global.def)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsmath.sty For additional information on amsmath, use the `?' option. (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amstext.sty (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsgen.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsbsy.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsopn.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amscls/amsthm.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amssymb.sty (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amsfonts.sty)) (/usr/share/texlive/texmf-dist/tex/latex/anyfontsize/anyfontsize.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/bm.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mathtools.sty (/usr/share/texlive/texmf-dist/tex/latex/graphics/keyval.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/calc.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mhsetup.sty)) ! LaTeX Error: File `dsfont.sty' not found. Type X to quit or <RETURN> to proceed, or enter new name. (Default extension: sty) Enter file name: ! Emergency stop. <read *> l.16 \def \Z{\mathbb{Z}}^^M No pages of output. Transcript written on math.log.

数据点

System Message: WARNING/2 (j)

latex exited with error [stdout] This is pdfTeX, Version 3.14159265-2.6-1.40.17 (TeX Live 2016/Debian) (preloaded format=latex) restricted \write18 enabled. entering extended mode (./math.tex LaTeX2e <2017/01/01> patch level 3 Babel <3.9r> and hyphenation patterns for 12 language(s) loaded. (/usr/share/texlive/texmf-dist/tex/latex/base/article.cls Document Class: article 2014/09/29 v1.4h Standard LaTeX document class (/usr/share/texlive/texmf-dist/tex/latex/base/size12.clo)) (/usr/share/texlive/texmf-dist/tex/latex/base/inputenc.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/utf8x.def)) (/usr/share/texlive/texmf-dist/tex/latex/ucs/ucs.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/data/uni-global.def)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsmath.sty For additional information on amsmath, use the `?' option. (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amstext.sty (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsgen.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsbsy.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsopn.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amscls/amsthm.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amssymb.sty (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amsfonts.sty)) (/usr/share/texlive/texmf-dist/tex/latex/anyfontsize/anyfontsize.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/bm.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mathtools.sty (/usr/share/texlive/texmf-dist/tex/latex/graphics/keyval.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/calc.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mhsetup.sty)) ! LaTeX Error: File `dsfont.sty' not found. Type X to quit or <RETURN> to proceed, or enter new name. (Default extension: sty) Enter file name: ! Emergency stop. <read *> l.16 \def \Z{\mathbb{Z}}^^M No pages of output. Transcript written on math.log.

自变量:

System Message: WARNING/2 ({{\displaystyle x_{{ij}}}})

latex exited with error [stdout] This is pdfTeX, Version 3.14159265-2.6-1.40.17 (TeX Live 2016/Debian) (preloaded format=latex) restricted \write18 enabled. entering extended mode (./math.tex LaTeX2e <2017/01/01> patch level 3 Babel <3.9r> and hyphenation patterns for 12 language(s) loaded. (/usr/share/texlive/texmf-dist/tex/latex/base/article.cls Document Class: article 2014/09/29 v1.4h Standard LaTeX document class (/usr/share/texlive/texmf-dist/tex/latex/base/size12.clo)) (/usr/share/texlive/texmf-dist/tex/latex/base/inputenc.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/utf8x.def)) (/usr/share/texlive/texmf-dist/tex/latex/ucs/ucs.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/data/uni-global.def)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsmath.sty For additional information on amsmath, use the `?' option. (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amstext.sty (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsgen.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsbsy.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsopn.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amscls/amsthm.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amssymb.sty (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amsfonts.sty)) (/usr/share/texlive/texmf-dist/tex/latex/anyfontsize/anyfontsize.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/bm.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mathtools.sty (/usr/share/texlive/texmf-dist/tex/latex/graphics/keyval.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/calc.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mhsetup.sty)) ! LaTeX Error: File `dsfont.sty' not found. Type X to quit or <RETURN> to proceed, or enter new name. (Default extension: sty) Enter file name: ! Emergency stop. <read *> l.16 \def \Z{\mathbb{Z}}^^M No pages of output. Transcript written on math.log.

,公式如下:

System Message: WARNING/2 (y_i=beta_0+beta_j x_i j、text其中、i=1、cdots m、j=1、cdots n。)

latex exited with error [stdout] This is pdfTeX, Version 3.14159265-2.6-1.40.17 (TeX Live 2016/Debian) (preloaded format=latex) restricted \write18 enabled. entering extended mode (./math.tex LaTeX2e <2017/01/01> patch level 3 Babel <3.9r> and hyphenation patterns for 12 language(s) loaded. (/usr/share/texlive/texmf-dist/tex/latex/base/article.cls Document Class: article 2014/09/29 v1.4h Standard LaTeX document class (/usr/share/texlive/texmf-dist/tex/latex/base/size12.clo)) (/usr/share/texlive/texmf-dist/tex/latex/base/inputenc.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/utf8x.def)) (/usr/share/texlive/texmf-dist/tex/latex/ucs/ucs.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/data/uni-global.def)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsmath.sty For additional information on amsmath, use the `?' option. (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amstext.sty (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsgen.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsbsy.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsopn.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amscls/amsthm.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amssymb.sty (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amsfonts.sty)) (/usr/share/texlive/texmf-dist/tex/latex/anyfontsize/anyfontsize.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/bm.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mathtools.sty (/usr/share/texlive/texmf-dist/tex/latex/graphics/keyval.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/calc.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mhsetup.sty)) ! LaTeX Error: File `dsfont.sty' not found. Type X to quit or <RETURN> to proceed, or enter new name. (Default extension: sty) Enter file name: ! Emergency stop. <read *> l.16 \def \Z{\mathbb{Z}}^^M No pages of output. Transcript written on math.log.

在矩阵表示法中,数据集写为

System Message: WARNING/2 (\X = [\x_1,\cdots, \x_n])

latex exited with error [stdout] This is pdfTeX, Version 3.14159265-2.6-1.40.17 (TeX Live 2016/Debian) (preloaded format=latex) restricted \write18 enabled. entering extended mode (./math.tex LaTeX2e <2017/01/01> patch level 3 Babel <3.9r> and hyphenation patterns for 12 language(s) loaded. (/usr/share/texlive/texmf-dist/tex/latex/base/article.cls Document Class: article 2014/09/29 v1.4h Standard LaTeX document class (/usr/share/texlive/texmf-dist/tex/latex/base/size12.clo)) (/usr/share/texlive/texmf-dist/tex/latex/base/inputenc.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/utf8x.def)) (/usr/share/texlive/texmf-dist/tex/latex/ucs/ucs.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/data/uni-global.def)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsmath.sty For additional information on amsmath, use the `?' option. (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amstext.sty (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsgen.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsbsy.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsopn.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amscls/amsthm.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amssymb.sty (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amsfonts.sty)) (/usr/share/texlive/texmf-dist/tex/latex/anyfontsize/anyfontsize.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/bm.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mathtools.sty (/usr/share/texlive/texmf-dist/tex/latex/graphics/keyval.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/calc.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mhsetup.sty)) ! LaTeX Error: File `dsfont.sty' not found. Type X to quit or <RETURN> to proceed, or enter new name. (Default extension: sty) Enter file name: ! Emergency stop. <read *> l.16 \def \Z{\mathbb{Z}}^^M No pages of output. Transcript written on math.log.

具有

System Message: WARNING/2 (\x_j = {{\displaystyle \{{x_{{ij}}\}}_{{i=1}}^{{m}}}})

latex exited with error [stdout] This is pdfTeX, Version 3.14159265-2.6-1.40.17 (TeX Live 2016/Debian) (preloaded format=latex) restricted \write18 enabled. entering extended mode (./math.tex LaTeX2e <2017/01/01> patch level 3 Babel <3.9r> and hyphenation patterns for 12 language(s) loaded. (/usr/share/texlive/texmf-dist/tex/latex/base/article.cls Document Class: article 2014/09/29 v1.4h Standard LaTeX document class (/usr/share/texlive/texmf-dist/tex/latex/base/size12.clo)) (/usr/share/texlive/texmf-dist/tex/latex/base/inputenc.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/utf8x.def)) (/usr/share/texlive/texmf-dist/tex/latex/ucs/ucs.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/data/uni-global.def)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsmath.sty For additional information on amsmath, use the `?' option. (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amstext.sty (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsgen.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsbsy.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsopn.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amscls/amsthm.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amssymb.sty (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amsfonts.sty)) (/usr/share/texlive/texmf-dist/tex/latex/anyfontsize/anyfontsize.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/bm.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mathtools.sty (/usr/share/texlive/texmf-dist/tex/latex/graphics/keyval.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/calc.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mhsetup.sty)) ! LaTeX Error: File `dsfont.sty' not found. Type X to quit or <RETURN> to proceed, or enter new name. (Default extension: sty) Enter file name: ! Emergency stop. <read *> l.16 \def \Z{\mathbb{Z}}^^M No pages of output. Transcript written on math.log.

System Message: WARNING/2 (\y = {{\displaystyle \{{y_{{i}}\}}_{{i=1}}^{{m}}}})

latex exited with error [stdout] This is pdfTeX, Version 3.14159265-2.6-1.40.17 (TeX Live 2016/Debian) (preloaded format=latex) restricted \write18 enabled. entering extended mode (./math.tex LaTeX2e <2017/01/01> patch level 3 Babel <3.9r> and hyphenation patterns for 12 language(s) loaded. (/usr/share/texlive/texmf-dist/tex/latex/base/article.cls Document Class: article 2014/09/29 v1.4h Standard LaTeX document class (/usr/share/texlive/texmf-dist/tex/latex/base/size12.clo)) (/usr/share/texlive/texmf-dist/tex/latex/base/inputenc.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/utf8x.def)) (/usr/share/texlive/texmf-dist/tex/latex/ucs/ucs.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/data/uni-global.def)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsmath.sty For additional information on amsmath, use the `?' option. (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amstext.sty (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsgen.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsbsy.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsopn.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amscls/amsthm.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amssymb.sty (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amsfonts.sty)) (/usr/share/texlive/texmf-dist/tex/latex/anyfontsize/anyfontsize.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/bm.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mathtools.sty (/usr/share/texlive/texmf-dist/tex/latex/graphics/keyval.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/calc.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mhsetup.sty)) ! LaTeX Error: File `dsfont.sty' not found. Type X to quit or <RETURN> to proceed, or enter new name. (Default extension: sty) Enter file name: ! Emergency stop. <read *> l.16 \def \Z{\mathbb{Z}}^^M No pages of output. Transcript written on math.log.

(见图) 特征矩阵和标签

System Message: WARNING/2 (\Bbeta^\top = {{\displaystyle \{{\beta_{{j}}\}}_{{j=1}}^{{n}}}})

latex exited with error [stdout] This is pdfTeX, Version 3.14159265-2.6-1.40.17 (TeX Live 2016/Debian) (preloaded format=latex) restricted \write18 enabled. entering extended mode (./math.tex LaTeX2e <2017/01/01> patch level 3 Babel <3.9r> and hyphenation patterns for 12 language(s) loaded. (/usr/share/texlive/texmf-dist/tex/latex/base/article.cls Document Class: article 2014/09/29 v1.4h Standard LaTeX document class (/usr/share/texlive/texmf-dist/tex/latex/base/size12.clo)) (/usr/share/texlive/texmf-dist/tex/latex/base/inputenc.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/utf8x.def)) (/usr/share/texlive/texmf-dist/tex/latex/ucs/ucs.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/data/uni-global.def)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsmath.sty For additional information on amsmath, use the `?' option. (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amstext.sty (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsgen.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsbsy.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsopn.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amscls/amsthm.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amssymb.sty (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amsfonts.sty)) (/usr/share/texlive/texmf-dist/tex/latex/anyfontsize/anyfontsize.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/bm.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mathtools.sty (/usr/share/texlive/texmf-dist/tex/latex/graphics/keyval.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/calc.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mhsetup.sty)) ! LaTeX Error: File `dsfont.sty' not found. Type X to quit or <RETURN> to proceed, or enter new name. (Default extension: sty) Enter file name: ! Emergency stop. <read *> l.16 \def \Z{\mathbb{Z}}^^M No pages of output. Transcript written on math.log.

. 然后矩阵格式方程写为

System Message: WARNING/2 (y= xbeta。)

latex exited with error [stdout] This is pdfTeX, Version 3.14159265-2.6-1.40.17 (TeX Live 2016/Debian) (preloaded format=latex) restricted \write18 enabled. entering extended mode (./math.tex LaTeX2e <2017/01/01> patch level 3 Babel <3.9r> and hyphenation patterns for 12 language(s) loaded. (/usr/share/texlive/texmf-dist/tex/latex/base/article.cls Document Class: article 2014/09/29 v1.4h Standard LaTeX document class (/usr/share/texlive/texmf-dist/tex/latex/base/size12.clo)) (/usr/share/texlive/texmf-dist/tex/latex/base/inputenc.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/utf8x.def)) (/usr/share/texlive/texmf-dist/tex/latex/ucs/ucs.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/data/uni-global.def)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsmath.sty For additional information on amsmath, use the `?' option. (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amstext.sty (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsgen.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsbsy.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsopn.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amscls/amsthm.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amssymb.sty (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amsfonts.sty)) (/usr/share/texlive/texmf-dist/tex/latex/anyfontsize/anyfontsize.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/bm.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mathtools.sty (/usr/share/texlive/texmf-dist/tex/latex/graphics/keyval.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/calc.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mhsetup.sty)) ! LaTeX Error: File `dsfont.sty' not found. Type X to quit or <RETURN> to proceed, or enter new name. (Default extension: sty) Enter file name: ! Emergency stop. <read *> l.16 \def \Z{\mathbb{Z}}^^M No pages of output. Transcript written on math.log.

images/fm.png

特征矩阵和标签

8.1.2. 如何解决?

  1. 直接方法(更多信息请参考我的 Prelim Notes for Numerical Analysis

    • 对于方形或矩形矩阵

      • 奇异值分解

      • 克-施密特正交化

      • QR分解

    • 对于平方矩阵

      • LU分解

      • 胆石分解

      • 常规拆分

  2. 迭代法

    • 平稳情况迭代法

      • 雅可比法

      • 高斯-塞德尔法

      • 理查森法

      • 连续过松弛法

    • 动态案例迭代法

      • 切比雪夫迭代法

      • 最小残差法

      • 最小修正迭代法

      • 最陡下降法

      • 共轭梯度法

8.1.3. 普通最小二乘法

在数学中, (1) 是一个过度确定的系统。一般最小二乘法可以用来求超定系统的近似解。对于系统超定系统 (1) ,由该问题得到最小二乘公式。

System Message: WARNING/2 (显示样式min bbeta xbbeta-y,)

latex exited with error [stdout] This is pdfTeX, Version 3.14159265-2.6-1.40.17 (TeX Live 2016/Debian) (preloaded format=latex) restricted \write18 enabled. entering extended mode (./math.tex LaTeX2e <2017/01/01> patch level 3 Babel <3.9r> and hyphenation patterns for 12 language(s) loaded. (/usr/share/texlive/texmf-dist/tex/latex/base/article.cls Document Class: article 2014/09/29 v1.4h Standard LaTeX document class (/usr/share/texlive/texmf-dist/tex/latex/base/size12.clo)) (/usr/share/texlive/texmf-dist/tex/latex/base/inputenc.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/utf8x.def)) (/usr/share/texlive/texmf-dist/tex/latex/ucs/ucs.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/data/uni-global.def)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsmath.sty For additional information on amsmath, use the `?' option. (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amstext.sty (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsgen.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsbsy.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsopn.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amscls/amsthm.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amssymb.sty (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amsfonts.sty)) (/usr/share/texlive/texmf-dist/tex/latex/anyfontsize/anyfontsize.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/bm.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mathtools.sty (/usr/share/texlive/texmf-dist/tex/latex/graphics/keyval.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/calc.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mhsetup.sty)) ! LaTeX Error: File `dsfont.sty' not found. Type X to quit or <RETURN> to proceed, or enter new name. (Default extension: sty) Enter file name: ! Emergency stop. <read *> l.16 \def \Z{\mathbb{Z}}^^M No pages of output. Transcript written on math.log.

其解可以用正态方程表示:

System Message: WARNING/2 (bbeta=(x^tx)^-1 x^ty)

latex exited with error [stdout] This is pdfTeX, Version 3.14159265-2.6-1.40.17 (TeX Live 2016/Debian) (preloaded format=latex) restricted \write18 enabled. entering extended mode (./math.tex LaTeX2e <2017/01/01> patch level 3 Babel <3.9r> and hyphenation patterns for 12 language(s) loaded. (/usr/share/texlive/texmf-dist/tex/latex/base/article.cls Document Class: article 2014/09/29 v1.4h Standard LaTeX document class (/usr/share/texlive/texmf-dist/tex/latex/base/size12.clo)) (/usr/share/texlive/texmf-dist/tex/latex/base/inputenc.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/utf8x.def)) (/usr/share/texlive/texmf-dist/tex/latex/ucs/ucs.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/data/uni-global.def)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsmath.sty For additional information on amsmath, use the `?' option. (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amstext.sty (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsgen.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsbsy.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsopn.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amscls/amsthm.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amssymb.sty (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amsfonts.sty)) (/usr/share/texlive/texmf-dist/tex/latex/anyfontsize/anyfontsize.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/bm.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mathtools.sty (/usr/share/texlive/texmf-dist/tex/latex/graphics/keyval.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/calc.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mhsetup.sty)) ! LaTeX Error: File `dsfont.sty' not found. Type X to quit or <RETURN> to proceed, or enter new name. (Default extension: sty) Enter file name: ! Emergency stop. <read *> l.16 \def \Z{\mathbb{Z}}^^M No pages of output. Transcript written on math.log.

在哪里?

System Message: WARNING/2 ({{\displaystyle {{\mathrm {{T}} }}}})

latex exited with error [stdout] This is pdfTeX, Version 3.14159265-2.6-1.40.17 (TeX Live 2016/Debian) (preloaded format=latex) restricted \write18 enabled. entering extended mode (./math.tex LaTeX2e <2017/01/01> patch level 3 Babel <3.9r> and hyphenation patterns for 12 language(s) loaded. (/usr/share/texlive/texmf-dist/tex/latex/base/article.cls Document Class: article 2014/09/29 v1.4h Standard LaTeX document class (/usr/share/texlive/texmf-dist/tex/latex/base/size12.clo)) (/usr/share/texlive/texmf-dist/tex/latex/base/inputenc.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/utf8x.def)) (/usr/share/texlive/texmf-dist/tex/latex/ucs/ucs.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/data/uni-global.def)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsmath.sty For additional information on amsmath, use the `?' option. (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amstext.sty (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsgen.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsbsy.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsopn.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amscls/amsthm.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amssymb.sty (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amsfonts.sty)) (/usr/share/texlive/texmf-dist/tex/latex/anyfontsize/anyfontsize.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/bm.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mathtools.sty (/usr/share/texlive/texmf-dist/tex/latex/graphics/keyval.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/calc.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mhsetup.sty)) ! LaTeX Error: File `dsfont.sty' not found. Type X to quit or <RETURN> to proceed, or enter new name. (Default extension: sty) Enter file name: ! Emergency stop. <read *> l.16 \def \Z{\mathbb{Z}}^^M No pages of output. Transcript written on math.log.

表示矩阵转置,提供

System Message: WARNING/2 ({{\displaystyle (\X^{{\mathrm {{T}} }}\X)^{{-1}}}})

latex exited with error [stdout] This is pdfTeX, Version 3.14159265-2.6-1.40.17 (TeX Live 2016/Debian) (preloaded format=latex) restricted \write18 enabled. entering extended mode (./math.tex LaTeX2e <2017/01/01> patch level 3 Babel <3.9r> and hyphenation patterns for 12 language(s) loaded. (/usr/share/texlive/texmf-dist/tex/latex/base/article.cls Document Class: article 2014/09/29 v1.4h Standard LaTeX document class (/usr/share/texlive/texmf-dist/tex/latex/base/size12.clo)) (/usr/share/texlive/texmf-dist/tex/latex/base/inputenc.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/utf8x.def)) (/usr/share/texlive/texmf-dist/tex/latex/ucs/ucs.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/data/uni-global.def)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsmath.sty For additional information on amsmath, use the `?' option. (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amstext.sty (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsgen.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsbsy.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsopn.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amscls/amsthm.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amssymb.sty (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amsfonts.sty)) (/usr/share/texlive/texmf-dist/tex/latex/anyfontsize/anyfontsize.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/bm.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mathtools.sty (/usr/share/texlive/texmf-dist/tex/latex/graphics/keyval.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/calc.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mhsetup.sty)) ! LaTeX Error: File `dsfont.sty' not found. Type X to quit or <RETURN> to proceed, or enter new name. (Default extension: sty) Enter file name: ! Emergency stop. <read *> l.16 \def \Z{\mathbb{Z}}^^M No pages of output. Transcript written on math.log.

存在(即,前提是

System Message: WARNING/2 (\X)

latex exited with error [stdout] This is pdfTeX, Version 3.14159265-2.6-1.40.17 (TeX Live 2016/Debian) (preloaded format=latex) restricted \write18 enabled. entering extended mode (./math.tex LaTeX2e <2017/01/01> patch level 3 Babel <3.9r> and hyphenation patterns for 12 language(s) loaded. (/usr/share/texlive/texmf-dist/tex/latex/base/article.cls Document Class: article 2014/09/29 v1.4h Standard LaTeX document class (/usr/share/texlive/texmf-dist/tex/latex/base/size12.clo)) (/usr/share/texlive/texmf-dist/tex/latex/base/inputenc.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/utf8x.def)) (/usr/share/texlive/texmf-dist/tex/latex/ucs/ucs.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/data/uni-global.def)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsmath.sty For additional information on amsmath, use the `?' option. (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amstext.sty (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsgen.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsbsy.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsopn.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amscls/amsthm.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amssymb.sty (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amsfonts.sty)) (/usr/share/texlive/texmf-dist/tex/latex/anyfontsize/anyfontsize.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/bm.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mathtools.sty (/usr/share/texlive/texmf-dist/tex/latex/graphics/keyval.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/calc.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mhsetup.sty)) ! LaTeX Error: File `dsfont.sty' not found. Type X to quit or <RETURN> to proceed, or enter new name. (Default extension: sty) Enter file name: ! Emergency stop. <read *> l.16 \def \Z{\mathbb{Z}}^^M No pages of output. Transcript written on math.log.

具有完整的列排名)。

注解

事实上, (3) 可通过以下方式推导:乘法

System Message: WARNING/2 (\X^T)

latex exited with error [stdout] This is pdfTeX, Version 3.14159265-2.6-1.40.17 (TeX Live 2016/Debian) (preloaded format=latex) restricted \write18 enabled. entering extended mode (./math.tex LaTeX2e <2017/01/01> patch level 3 Babel <3.9r> and hyphenation patterns for 12 language(s) loaded. (/usr/share/texlive/texmf-dist/tex/latex/base/article.cls Document Class: article 2014/09/29 v1.4h Standard LaTeX document class (/usr/share/texlive/texmf-dist/tex/latex/base/size12.clo)) (/usr/share/texlive/texmf-dist/tex/latex/base/inputenc.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/utf8x.def)) (/usr/share/texlive/texmf-dist/tex/latex/ucs/ucs.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/data/uni-global.def)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsmath.sty For additional information on amsmath, use the `?' option. (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amstext.sty (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsgen.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsbsy.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsopn.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amscls/amsthm.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amssymb.sty (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amsfonts.sty)) (/usr/share/texlive/texmf-dist/tex/latex/anyfontsize/anyfontsize.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/bm.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mathtools.sty (/usr/share/texlive/texmf-dist/tex/latex/graphics/keyval.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/calc.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mhsetup.sty)) ! LaTeX Error: File `dsfont.sty' not found. Type X to quit or <RETURN> to proceed, or enter new name. (Default extension: sty) Enter file name: ! Emergency stop. <read *> l.16 \def \Z{\mathbb{Z}}^^M No pages of output. Transcript written on math.log.

在旁边 (1) 然后乘以

System Message: WARNING/2 ((\X^T\X)^{{-1}})

latex exited with error [stdout] This is pdfTeX, Version 3.14159265-2.6-1.40.17 (TeX Live 2016/Debian) (preloaded format=latex) restricted \write18 enabled. entering extended mode (./math.tex LaTeX2e <2017/01/01> patch level 3 Babel <3.9r> and hyphenation patterns for 12 language(s) loaded. (/usr/share/texlive/texmf-dist/tex/latex/base/article.cls Document Class: article 2014/09/29 v1.4h Standard LaTeX document class (/usr/share/texlive/texmf-dist/tex/latex/base/size12.clo)) (/usr/share/texlive/texmf-dist/tex/latex/base/inputenc.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/utf8x.def)) (/usr/share/texlive/texmf-dist/tex/latex/ucs/ucs.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/data/uni-global.def)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsmath.sty For additional information on amsmath, use the `?' option. (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amstext.sty (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsgen.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsbsy.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsopn.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amscls/amsthm.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amssymb.sty (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amsfonts.sty)) (/usr/share/texlive/texmf-dist/tex/latex/anyfontsize/anyfontsize.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/bm.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mathtools.sty (/usr/share/texlive/texmf-dist/tex/latex/graphics/keyval.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/calc.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mhsetup.sty)) ! LaTeX Error: File `dsfont.sty' not found. Type X to quit or <RETURN> to proceed, or enter new name. (Default extension: sty) Enter file name: ! Emergency stop. <read *> l.16 \def \Z{\mathbb{Z}}^^M No pages of output. Transcript written on math.log.

在前一个结果的两边。你也可以申请 Extreme Value Theorem(2) 找到解决方案 (3) .

8.1.4. 梯度下降

让我们使用以下假设:

System Message: WARNING/2 (h_ bbeta=beta_0+beta_jx_123;j,text其中,j=1,cdots n.)

latex exited with error [stdout] This is pdfTeX, Version 3.14159265-2.6-1.40.17 (TeX Live 2016/Debian) (preloaded format=latex) restricted \write18 enabled. entering extended mode (./math.tex LaTeX2e <2017/01/01> patch level 3 Babel <3.9r> and hyphenation patterns for 12 language(s) loaded. (/usr/share/texlive/texmf-dist/tex/latex/base/article.cls Document Class: article 2014/09/29 v1.4h Standard LaTeX document class (/usr/share/texlive/texmf-dist/tex/latex/base/size12.clo)) (/usr/share/texlive/texmf-dist/tex/latex/base/inputenc.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/utf8x.def)) (/usr/share/texlive/texmf-dist/tex/latex/ucs/ucs.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/data/uni-global.def)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsmath.sty For additional information on amsmath, use the `?' option. (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amstext.sty (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsgen.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsbsy.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsopn.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amscls/amsthm.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amssymb.sty (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amsfonts.sty)) (/usr/share/texlive/texmf-dist/tex/latex/anyfontsize/anyfontsize.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/bm.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mathtools.sty (/usr/share/texlive/texmf-dist/tex/latex/graphics/keyval.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/calc.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mhsetup.sty)) ! LaTeX Error: File `dsfont.sty' not found. Type X to quit or <RETURN> to proceed, or enter new name. (Default extension: sty) Enter file name: ! Emergency stop. <read *> l.16 \def \Z{\mathbb{Z}}^^M No pages of output. Transcript written on math.log.

然后,解决 (2) 等于最小化以下内容 cost fucntion

8.1.5. 成本函数

System Message: WARNING/2 (j(bbeta)=frac 1 2 m sum i=1(i)2)

latex exited with error [stdout] This is pdfTeX, Version 3.14159265-2.6-1.40.17 (TeX Live 2016/Debian) (preloaded format=latex) restricted \write18 enabled. entering extended mode (./math.tex LaTeX2e <2017/01/01> patch level 3 Babel <3.9r> and hyphenation patterns for 12 language(s) loaded. (/usr/share/texlive/texmf-dist/tex/latex/base/article.cls Document Class: article 2014/09/29 v1.4h Standard LaTeX document class (/usr/share/texlive/texmf-dist/tex/latex/base/size12.clo)) (/usr/share/texlive/texmf-dist/tex/latex/base/inputenc.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/utf8x.def)) (/usr/share/texlive/texmf-dist/tex/latex/ucs/ucs.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/data/uni-global.def)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsmath.sty For additional information on amsmath, use the `?' option. (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amstext.sty (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsgen.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsbsy.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsopn.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amscls/amsthm.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amssymb.sty (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amsfonts.sty)) (/usr/share/texlive/texmf-dist/tex/latex/anyfontsize/anyfontsize.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/bm.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mathtools.sty (/usr/share/texlive/texmf-dist/tex/latex/graphics/keyval.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/calc.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mhsetup.sty)) ! LaTeX Error: File `dsfont.sty' not found. Type X to quit or <RETURN> to proceed, or enter new name. (Default extension: sty) Enter file name: ! Emergency stop. <read *> l.16 \def \Z{\mathbb{Z}}^^M No pages of output. Transcript written on math.log.

注解

我们喜欢解决的原因 (4) 而不是 (2) 是因为 (4) 它是凸的,具有独特的可解性和能量稳定性,学习速度小等优点。对非凸成本函数(能量)案例有极大兴趣的读者。是指 [Feng2016PSD] 了解更多详细信息。

images/gradient1d.png

一维梯度下降

images/gradient2d.png

二维梯度下降

8.1.6. 批梯度下降

梯度下降是求函数最小值的一阶迭代优化算法。它以最陡脱皮的方向搜索,该方向由 negative of the gradient (见图) 一维梯度下降二维梯度下降 分别针对1d和2d)和学习率(搜索步骤)

System Message: WARNING/2 (\alpha)

latex exited with error [stdout] This is pdfTeX, Version 3.14159265-2.6-1.40.17 (TeX Live 2016/Debian) (preloaded format=latex) restricted \write18 enabled. entering extended mode (./math.tex LaTeX2e <2017/01/01> patch level 3 Babel <3.9r> and hyphenation patterns for 12 language(s) loaded. (/usr/share/texlive/texmf-dist/tex/latex/base/article.cls Document Class: article 2014/09/29 v1.4h Standard LaTeX document class (/usr/share/texlive/texmf-dist/tex/latex/base/size12.clo)) (/usr/share/texlive/texmf-dist/tex/latex/base/inputenc.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/utf8x.def)) (/usr/share/texlive/texmf-dist/tex/latex/ucs/ucs.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/data/uni-global.def)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsmath.sty For additional information on amsmath, use the `?' option. (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amstext.sty (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsgen.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsbsy.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsopn.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amscls/amsthm.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amssymb.sty (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amsfonts.sty)) (/usr/share/texlive/texmf-dist/tex/latex/anyfontsize/anyfontsize.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/bm.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mathtools.sty (/usr/share/texlive/texmf-dist/tex/latex/graphics/keyval.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/calc.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mhsetup.sty)) ! LaTeX Error: File `dsfont.sty' not found. Type X to quit or <RETURN> to proceed, or enter new name. (Default extension: sty) Enter file name: ! Emergency stop. <read *> l.16 \def \Z{\mathbb{Z}}^^M No pages of output. Transcript written on math.log.

.

8.1.7. 随机下降

8.1.8. 小批量梯度下降

8.1.9. 演示

  1. 设置Spark上下文和SparkSession

from pyspark.sql import SparkSession

spark = SparkSession \
    .builder \
    .appName("Python Spark regression example") \
    .config("spark.some.config.option", "some-value") \
    .getOrCreate()
  1. 负载数据集

df = spark.read.format('com.databricks.spark.csv').\
                       options(header='true', \
                       inferschema='true').\
            load("../data/Advertising.csv",header=True);

检查数据集

df.show(5,True)
df.printSchema()

然后你会得到

+-----+-----+---------+-----+
|   TV|Radio|Newspaper|Sales|
+-----+-----+---------+-----+
|230.1| 37.8|     69.2| 22.1|
| 44.5| 39.3|     45.1| 10.4|
| 17.2| 45.9|     69.3|  9.3|
|151.5| 41.3|     58.5| 18.5|
|180.8| 10.8|     58.4| 12.9|
+-----+-----+---------+-----+
only showing top 5 rows

root
 |-- TV: double (nullable = true)
 |-- Radio: double (nullable = true)
 |-- Newspaper: double (nullable = true)
 |-- Sales: double (nullable = true)

您还可以从数据帧中获得统计结果(不幸的是,它只适用于数字)。

df.describe().show()

然后你会得到

+-------+-----------------+------------------+------------------+------------------+
|summary|               TV|             Radio|         Newspaper|             Sales|
+-------+-----------------+------------------+------------------+------------------+
|  count|              200|               200|               200|               200|
|   mean|         147.0425|23.264000000000024|30.553999999999995|14.022500000000003|
| stddev|85.85423631490805|14.846809176168728| 21.77862083852283| 5.217456565710477|
|    min|              0.7|               0.0|               0.3|               1.6|
|    max|            296.4|              49.6|             114.0|              27.0|
+-------+-----------------+------------------+------------------+------------------+
images/ad.png

销售分配

  1. 将数据转换为密集向量( 特征标签

from pyspark.sql import Row
from pyspark.ml.linalg import Vectors

# I provide two ways to build the features and labels

# method 1 (good for small feature):
#def transData(row):
#    return Row(label=row["Sales"],
#               features=Vectors.dense([row["TV"],
#                                       row["Radio"],
#                                       row["Newspaper"]]))

# Method 2 (good for large features):
def transData(data):
return data.rdd.map(lambda r: [Vectors.dense(r[:-1]),r[-1]]).toDF(['features','label'])

注解

强烈建议您尝试我的 get_dummy 用于处理comple数据集中的分类数据的函数。

监督学习版本:

def get_dummy(df,indexCol,categoricalCols,continuousCols,labelCol):

    from pyspark.ml import Pipeline
    from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler
    from pyspark.sql.functions import col

    indexers = [ StringIndexer(inputCol=c, outputCol="{0}_indexed".format(c))
                 for c in categoricalCols ]

    # default setting: dropLast=True
    encoders = [ OneHotEncoder(inputCol=indexer.getOutputCol(),
                 outputCol="{0}_encoded".format(indexer.getOutputCol()))
                 for indexer in indexers ]

    assembler = VectorAssembler(inputCols=[encoder.getOutputCol() for encoder in encoders]
                                + continuousCols, outputCol="features")

    pipeline = Pipeline(stages=indexers + encoders + [assembler])

    model=pipeline.fit(df)
    data = model.transform(df)

    data = data.withColumn('label',col(labelCol))

    return data.select(indexCol,'features','label')

无监督学习版本:

def get_dummy(df,indexCol,categoricalCols,continuousCols):
    '''
    Get dummy variables and concat with continuous variables for unsupervised learning.
    :param df: the dataframe
    :param categoricalCols: the name list of the categorical data
    :param continuousCols:  the name list of the numerical data
    :return k: feature matrix

    :author: Wenqiang Feng
    :email:  von198@gmail.com
    '''

    indexers = [ StringIndexer(inputCol=c, outputCol="{0}_indexed".format(c))
                 for c in categoricalCols ]

    # default setting: dropLast=True
    encoders = [ OneHotEncoder(inputCol=indexer.getOutputCol(),
                 outputCol="{0}_encoded".format(indexer.getOutputCol()))
                 for indexer in indexers ]

    assembler = VectorAssembler(inputCols=[encoder.getOutputCol() for encoder in encoders]
                                + continuousCols, outputCol="features")

    pipeline = Pipeline(stages=indexers + encoders + [assembler])

    model=pipeline.fit(df)
    data = model.transform(df)

    return data.select(indexCol,'features')
  1. 将数据集转换为数据帧

transformed= transData(df)
transformed.show(5)
+-----------------+-----+
|         features|label|
+-----------------+-----+
|[230.1,37.8,69.2]| 22.1|
| [44.5,39.3,45.1]| 10.4|
| [17.2,45.9,69.3]|  9.3|
|[151.5,41.3,58.5]| 18.5|
|[180.8,10.8,58.4]| 12.9|
+-----------------+-----+
only showing top 5 rows

注解

您将发现Spark中所有受监控机器学习算法都基于 特征标签 (Spark中的无监督机器学习算法基于 特征 )也就是说,当你准备好 特征标签 在管道体系结构中。

  1. 处理分类变量

from pyspark.ml import Pipeline
from pyspark.ml.regression import LinearRegression
from pyspark.ml.feature import VectorIndexer
from pyspark.ml.evaluation import RegressionEvaluator

# Automatically identify categorical features, and index them.
# We specify maxCategories so features with > 4 distinct values are treated as continuous.

featureIndexer = VectorIndexer(inputCol="features", \
                               outputCol="indexedFeatures",\
                               maxCategories=4).fit(transformed)

data = featureIndexer.transform(transformed)

现在,您可以使用

data.show(5,True)

你会得到

+-----------------+-----+-----------------+
|         features|label|  indexedFeatures|
+-----------------+-----+-----------------+
|[230.1,37.8,69.2]| 22.1|[230.1,37.8,69.2]|
| [44.5,39.3,45.1]| 10.4| [44.5,39.3,45.1]|
| [17.2,45.9,69.3]|  9.3| [17.2,45.9,69.3]|
|[151.5,41.3,58.5]| 18.5|[151.5,41.3,58.5]|
|[180.8,10.8,58.4]| 12.9|[180.8,10.8,58.4]|
+-----------------+-----+-----------------+
only showing top 5 rows
  1. 将数据分成培训和测试集(40%用于测试)

# Split the data into training and test sets (40% held out for testing)
(trainingData, testData) = transformed.randomSplit([0.6, 0.4])

您可以检查您的列车和测试数据,如下所示(在我看来,在原型PAHSE期间始终跟踪您的数据是很好的):

trainingData.show(5)
testData.show(5)

然后你会得到

+---------------+-----+---------------+
|       features|label|indexedFeatures|
+---------------+-----+---------------+
| [4.1,11.6,5.7]|  3.2| [4.1,11.6,5.7]|
| [5.4,29.9,9.4]|  5.3| [5.4,29.9,9.4]|
|[7.3,28.1,41.4]|  5.5|[7.3,28.1,41.4]|
|[7.8,38.9,50.6]|  6.6|[7.8,38.9,50.6]|
|  [8.6,2.1,1.0]|  4.8|  [8.6,2.1,1.0]|
+---------------+-----+---------------+
only showing top 5 rows

+----------------+-----+----------------+
|        features|label| indexedFeatures|
+----------------+-----+----------------+
|  [0.7,39.6,8.7]|  1.6|  [0.7,39.6,8.7]|
|  [8.4,27.2,2.1]|  5.7|  [8.4,27.2,2.1]|
|[11.7,36.9,45.2]|  7.3|[11.7,36.9,45.2]|
|[13.2,15.9,49.6]|  5.6|[13.2,15.9,49.6]|
|[16.9,43.7,89.4]|  8.7|[16.9,43.7,89.4]|
+----------------+-----+----------------+
only showing top 5 rows
  1. 拟合一般最小二乘回归模型

有关参数的详细信息,请访问 Linear Regression API .

# Import LinearRegression class
from pyspark.ml.regression import LinearRegression

# Define LinearRegression algorithm
lr = LinearRegression()
  1. 管道体系结构

# Chain indexer and tree in a Pipeline
pipeline = Pipeline(stages=[featureIndexer, lr])

model = pipeline.fit(trainingData)
  1. 模型概要

Spark对数据和模型的汇总功能较差。我写了一个类似于 R pyspark中线性回归的输出。

def modelsummary(model):
    import numpy as np
    print ("Note: the last rows are the information for Intercept")
    print ("##","-------------------------------------------------")
    print ("##","  Estimate   |   Std.Error | t Values  |  P-value")
    coef = np.append(list(model.coefficients),model.intercept)
    Summary=model.summary

    for i in range(len(Summary.pValues)):
        print ("##",'{:10.6f}'.format(coef[i]),\
        '{:10.6f}'.format(Summary.coefficientStandardErrors[i]),\
        '{:8.3f}'.format(Summary.tValues[i]),\
        '{:10.6f}'.format(Summary.pValues[i]))

    print ("##",'---')
    print ("##","Mean squared error: % .6f" \
           % Summary.meanSquaredError, ", RMSE: % .6f" \
           % Summary.rootMeanSquaredError )
    print ("##","Multiple R-squared: %f" % Summary.r2, ", \
            Total iterations: %i"% Summary.totalIterations)
modelsummary(model.stages[-1])

您将得到以下摘要结果:

Note: the last rows are the information for Intercept
('##', '-------------------------------------------------')
('##', '  Estimate   |   Std.Error | t Values  |  P-value')
('##', '  0.044186', '  0.001663', '  26.573', '  0.000000')
('##', '  0.206311', '  0.010846', '  19.022', '  0.000000')
('##', '  0.001963', '  0.007467', '   0.263', '  0.793113')
('##', '  2.596154', '  0.379550', '   6.840', '  0.000000')
('##', '---')
('##', 'Mean squared error:  2.588230', ', RMSE:  1.608798')
('##', 'Multiple R-squared: 0.911869', ',             Total iterations: 1')
  1. 做出预测

# Make predictions.
predictions = model.transform(testData)
# Select example rows to display.
predictions.select("features","label","predictedLabel").show(5)
+----------------+-----+------------------+
|        features|label|        prediction|
+----------------+-----+------------------+
|  [0.7,39.6,8.7]|  1.6| 10.81405928637388|
|  [8.4,27.2,2.1]|  5.7| 8.583086404079918|
|[11.7,36.9,45.2]|  7.3|10.814712818232422|
|[13.2,15.9,49.6]|  5.6| 6.557106943899219|
|[16.9,43.7,89.4]|  8.7|12.534151375058645|
+----------------+-----+------------------+
only showing top 5 rows
  1. 评价

from pyspark.ml.evaluation import RegressionEvaluator
# Select (prediction, true label) and compute test error
evaluator = RegressionEvaluator(labelCol="label",
                                predictionCol="prediction",
                                metricName="rmse")

rmse = evaluator.evaluate(predictions)
print("Root Mean Squared Error (RMSE) on test data = %g" % rmse)

最终的均方根误差(rmse)如下:

Root Mean Squared Error (RMSE) on test data = 1.63114

您也可以检查

System Message: WARNING/2 (R^2)

latex exited with error [stdout] This is pdfTeX, Version 3.14159265-2.6-1.40.17 (TeX Live 2016/Debian) (preloaded format=latex) restricted \write18 enabled. entering extended mode (./math.tex LaTeX2e <2017/01/01> patch level 3 Babel <3.9r> and hyphenation patterns for 12 language(s) loaded. (/usr/share/texlive/texmf-dist/tex/latex/base/article.cls Document Class: article 2014/09/29 v1.4h Standard LaTeX document class (/usr/share/texlive/texmf-dist/tex/latex/base/size12.clo)) (/usr/share/texlive/texmf-dist/tex/latex/base/inputenc.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/utf8x.def)) (/usr/share/texlive/texmf-dist/tex/latex/ucs/ucs.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/data/uni-global.def)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsmath.sty For additional information on amsmath, use the `?' option. (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amstext.sty (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsgen.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsbsy.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsopn.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amscls/amsthm.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amssymb.sty (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amsfonts.sty)) (/usr/share/texlive/texmf-dist/tex/latex/anyfontsize/anyfontsize.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/bm.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mathtools.sty (/usr/share/texlive/texmf-dist/tex/latex/graphics/keyval.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/calc.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mhsetup.sty)) ! LaTeX Error: File `dsfont.sty' not found. Type X to quit or <RETURN> to proceed, or enter new name. (Default extension: sty) Enter file name: ! Emergency stop. <read *> l.16 \def \Z{\mathbb{Z}}^^M No pages of output. Transcript written on math.log.

测试数据值:

y_true = predictions.select("label").toPandas()
y_pred = predictions.select("prediction").toPandas()

import sklearn.metrics
r2_score = sklearn.metrics.r2_score(y_true, y_pred)
print('r2_score: {0}'.format(r2_score))

然后你会得到

r2_score: 0.854486655585

警告

你应该知道大多数软件使用不同的公式来计算

System Message: WARNING/2 (R^2)

latex exited with error [stdout] This is pdfTeX, Version 3.14159265-2.6-1.40.17 (TeX Live 2016/Debian) (preloaded format=latex) restricted \write18 enabled. entering extended mode (./math.tex LaTeX2e <2017/01/01> patch level 3 Babel <3.9r> and hyphenation patterns for 12 language(s) loaded. (/usr/share/texlive/texmf-dist/tex/latex/base/article.cls Document Class: article 2014/09/29 v1.4h Standard LaTeX document class (/usr/share/texlive/texmf-dist/tex/latex/base/size12.clo)) (/usr/share/texlive/texmf-dist/tex/latex/base/inputenc.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/utf8x.def)) (/usr/share/texlive/texmf-dist/tex/latex/ucs/ucs.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/data/uni-global.def)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsmath.sty For additional information on amsmath, use the `?' option. (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amstext.sty (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsgen.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsbsy.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsopn.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amscls/amsthm.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amssymb.sty (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amsfonts.sty)) (/usr/share/texlive/texmf-dist/tex/latex/anyfontsize/anyfontsize.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/bm.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mathtools.sty (/usr/share/texlive/texmf-dist/tex/latex/graphics/keyval.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/calc.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mhsetup.sty)) ! LaTeX Error: File `dsfont.sty' not found. Type X to quit or <RETURN> to proceed, or enter new name. (Default extension: sty) Enter file name: ! Emergency stop. <read *> l.16 \def \Z{\mathbb{Z}}^^M No pages of output. Transcript written on math.log.

模型中不包含截距时的值。您可以从 disscussion at StackExchange .

8.2. 广义线性回归

8.2.1. 介绍

8.2.2. 如何解决?

8.2.3. 演示

  1. 设置Spark上下文和SparkSession

from pyspark.sql import SparkSession

spark = SparkSession \
    .builder \
    .appName("Python Spark regression example") \
    .config("spark.some.config.option", "some-value") \
    .getOrCreate()
  1. 负载数据集

df = spark.read.format('com.databricks.spark.csv').\
                       options(header='true', \
                       inferschema='true').\
            load("../data/Advertising.csv",header=True);

检查数据集

df.show(5,True)
df.printSchema()

然后你会得到

+-----+-----+---------+-----+
|   TV|Radio|Newspaper|Sales|
+-----+-----+---------+-----+
|230.1| 37.8|     69.2| 22.1|
| 44.5| 39.3|     45.1| 10.4|
| 17.2| 45.9|     69.3|  9.3|
|151.5| 41.3|     58.5| 18.5|
|180.8| 10.8|     58.4| 12.9|
+-----+-----+---------+-----+
only showing top 5 rows

root
 |-- TV: double (nullable = true)
 |-- Radio: double (nullable = true)
 |-- Newspaper: double (nullable = true)
 |-- Sales: double (nullable = true)

您还可以从数据帧中获得统计结果(不幸的是,它只适用于数字)。

df.describe().show()

然后你会得到

+-------+-----------------+------------------+------------------+------------------+
|summary|               TV|             Radio|         Newspaper|             Sales|
+-------+-----------------+------------------+------------------+------------------+
|  count|              200|               200|               200|               200|
|   mean|         147.0425|23.264000000000024|30.553999999999995|14.022500000000003|
| stddev|85.85423631490805|14.846809176168728| 21.77862083852283| 5.217456565710477|
|    min|              0.7|               0.0|               0.3|               1.6|
|    max|            296.4|              49.6|             114.0|              27.0|
+-------+-----------------+------------------+------------------+------------------+
  1. 将数据转换为密集向量( 特征标签

注解

强烈建议您尝试我的 get_dummy 用于处理comple数据集中的分类数据的函数。

监督学习版本:

def get_dummy(df,indexCol,categoricalCols,continuousCols,labelCol):

    from pyspark.ml import Pipeline
    from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler
    from pyspark.sql.functions import col

    indexers = [ StringIndexer(inputCol=c, outputCol="{0}_indexed".format(c))
                 for c in categoricalCols ]

    # default setting: dropLast=True
    encoders = [ OneHotEncoder(inputCol=indexer.getOutputCol(),
                 outputCol="{0}_encoded".format(indexer.getOutputCol()))
                 for indexer in indexers ]

    assembler = VectorAssembler(inputCols=[encoder.getOutputCol() for encoder in encoders]
                                + continuousCols, outputCol="features")

    pipeline = Pipeline(stages=indexers + encoders + [assembler])

    model=pipeline.fit(df)
    data = model.transform(df)

    data = data.withColumn('label',col(labelCol))

    return data.select(indexCol,'features','label')

无监督学习版本:

def get_dummy(df,indexCol,categoricalCols,continuousCols):
    '''
    Get dummy variables and concat with continuous variables for unsupervised learning.
    :param df: the dataframe
    :param categoricalCols: the name list of the categorical data
    :param continuousCols:  the name list of the numerical data
    :return k: feature matrix

    :author: Wenqiang Feng
    :email:  von198@gmail.com
    '''

    indexers = [ StringIndexer(inputCol=c, outputCol="{0}_indexed".format(c))
                 for c in categoricalCols ]

    # default setting: dropLast=True
    encoders = [ OneHotEncoder(inputCol=indexer.getOutputCol(),
                 outputCol="{0}_encoded".format(indexer.getOutputCol()))
                 for indexer in indexers ]

    assembler = VectorAssembler(inputCols=[encoder.getOutputCol() for encoder in encoders]
                                + continuousCols, outputCol="features")

    pipeline = Pipeline(stages=indexers + encoders + [assembler])

    model=pipeline.fit(df)
    data = model.transform(df)

    return data.select(indexCol,'features')
from pyspark.sql import Row
from pyspark.ml.linalg import Vectors

# I provide two ways to build the features and labels

# method 1 (good for small feature):
#def transData(row):
#    return Row(label=row["Sales"],
#               features=Vectors.dense([row["TV"],
#                                       row["Radio"],
#                                       row["Newspaper"]]))

# Method 2 (good for large features):
def transData(data):
return data.rdd.map(lambda r: [Vectors.dense(r[:-1]),r[-1]]).toDF(['features','label'])
transformed= transData(df)
transformed.show(5)
+-----------------+-----+
|         features|label|
+-----------------+-----+
|[230.1,37.8,69.2]| 22.1|
| [44.5,39.3,45.1]| 10.4|
| [17.2,45.9,69.3]|  9.3|
|[151.5,41.3,58.5]| 18.5|
|[180.8,10.8,58.4]| 12.9|
+-----------------+-----+
only showing top 5 rows

注解

您将发现Spark中的所有机器学习算法都基于 特征标签 . 也就是说,当你准备好 特征标签 .

  1. 将数据转换为密集向量

# convert the data to dense vector
def transData(data):
    return data.rdd.map(lambda r: [r[-1], Vectors.dense(r[:-1])]).\
           toDF(['label','features'])

from pyspark.sql import Row
from pyspark.ml.linalg import Vectors

data= transData(df)
data.show()
  1. 处理分类变量

from pyspark.ml import Pipeline
from pyspark.ml.regression import LinearRegression
from pyspark.ml.feature import VectorIndexer
from pyspark.ml.evaluation import RegressionEvaluator

# Automatically identify categorical features, and index them.
# We specify maxCategories so features with > 4
# distinct values are treated as continuous.

featureIndexer = VectorIndexer(inputCol="features", \
                               outputCol="indexedFeatures",\
                               maxCategories=4).fit(transformed)

data = featureIndexer.transform(transformed)

当你在这一点上检查你的数据时,你会

+-----------------+-----+-----------------+
|         features|label|  indexedFeatures|
+-----------------+-----+-----------------+
|[230.1,37.8,69.2]| 22.1|[230.1,37.8,69.2]|
| [44.5,39.3,45.1]| 10.4| [44.5,39.3,45.1]|
| [17.2,45.9,69.3]|  9.3| [17.2,45.9,69.3]|
|[151.5,41.3,58.5]| 18.5|[151.5,41.3,58.5]|
|[180.8,10.8,58.4]| 12.9|[180.8,10.8,58.4]|
+-----------------+-----+-----------------+
only showing top 5 rows
  1. 将数据分成培训和测试集(40%用于测试)

# Split the data into training and test sets (40% held out for testing)
(trainingData, testData) = transformed.randomSplit([0.6, 0.4])

您可以检查您的列车和测试数据,如下所示(在我看来,在原型PAHSE期间始终跟踪您的数据是很好的):

trainingData.show(5)
testData.show(5)

然后你会得到

+----------------+-----+----------------+
|        features|label| indexedFeatures|
+----------------+-----+----------------+
|  [5.4,29.9,9.4]|  5.3|  [5.4,29.9,9.4]|
| [7.8,38.9,50.6]|  6.6| [7.8,38.9,50.6]|
|  [8.4,27.2,2.1]|  5.7|  [8.4,27.2,2.1]|
| [8.7,48.9,75.0]|  7.2| [8.7,48.9,75.0]|
|[11.7,36.9,45.2]|  7.3|[11.7,36.9,45.2]|
+----------------+-----+----------------+
only showing top 5 rows

+---------------+-----+---------------+
|       features|label|indexedFeatures|
+---------------+-----+---------------+
| [0.7,39.6,8.7]|  1.6| [0.7,39.6,8.7]|
| [4.1,11.6,5.7]|  3.2| [4.1,11.6,5.7]|
|[7.3,28.1,41.4]|  5.5|[7.3,28.1,41.4]|
|  [8.6,2.1,1.0]|  4.8|  [8.6,2.1,1.0]|
|[17.2,4.1,31.6]|  5.9|[17.2,4.1,31.6]|
+---------------+-----+---------------+
only showing top 5 rows
  1. 拟合广义线性回归模型

# Import LinearRegression class
from pyspark.ml.regression import GeneralizedLinearRegression

# Define LinearRegression algorithm
glr = GeneralizedLinearRegression(family="gaussian", link="identity",\
                                  maxIter=10, regParam=0.3)
  1. 管道体系结构

# Chain indexer and tree in a Pipeline
pipeline = Pipeline(stages=[featureIndexer, glr])

model = pipeline.fit(trainingData)
  1. 模型概要

Spark对数据和模型的汇总功能较差。我写了一个类似于 R pyspark中线性回归的输出。

def modelsummary(model):
    import numpy as np
    print ("Note: the last rows are the information for Intercept")
    print ("##","-------------------------------------------------")
    print ("##","  Estimate   |   Std.Error | t Values  |  P-value")
    coef = np.append(list(model.coefficients),model.intercept)
    Summary=model.summary

    for i in range(len(Summary.pValues)):
        print ("##",'{:10.6f}'.format(coef[i]),\
        '{:10.6f}'.format(Summary.coefficientStandardErrors[i]),\
        '{:8.3f}'.format(Summary.tValues[i]),\
        '{:10.6f}'.format(Summary.pValues[i]))

    print ("##",'---')
#     print ("##","Mean squared error: % .6f" \
#            % Summary.meanSquaredError, ", RMSE: % .6f" \
#            % Summary.rootMeanSquaredError )
#     print ("##","Multiple R-squared: %f" % Summary.r2, ", \
#             Total iterations: %i"% Summary.totalIterations)
modelsummary(model.stages[-1])

您将得到以下摘要结果:

Note: the last rows are the information for Intercept
('##', '-------------------------------------------------')
('##', '  Estimate   |   Std.Error | t Values  |  P-value')
('##', '  0.042857', '  0.001668', '  25.692', '  0.000000')
('##', '  0.199922', '  0.009881', '  20.232', '  0.000000')
('##', ' -0.001957', '  0.006917', '  -0.283', '  0.777757')
('##', '  3.007515', '  0.406389', '   7.401', '  0.000000')
('##', '---')
  1. 做出预测

# Make predictions.
predictions = model.transform(testData)
# Select example rows to display.
predictions.select("features","label","predictedLabel").show(5)
+---------------+-----+------------------+
|       features|label|        prediction|
+---------------+-----+------------------+
| [0.7,39.6,8.7]|  1.6|10.937383732327625|
| [4.1,11.6,5.7]|  3.2| 5.491166258750164|
|[7.3,28.1,41.4]|  5.5|   8.8571603947873|
|  [8.6,2.1,1.0]|  4.8| 3.793966281660073|
|[17.2,4.1,31.6]|  5.9| 4.502507124763654|
+---------------+-----+------------------+
only showing top 5 rows
  1. 评价

from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.evaluation import RegressionEvaluator
# Select (prediction, true label) and compute test error
evaluator = RegressionEvaluator(labelCol="label",
                                predictionCol="prediction",
                                metricName="rmse")

rmse = evaluator.evaluate(predictions)
print("Root Mean Squared Error (RMSE) on test data = %g" % rmse)

最终的均方根误差(rmse)如下:

Root Mean Squared Error (RMSE) on test data = 1.89857
y_true = predictions.select("label").toPandas()
y_pred = predictions.select("prediction").toPandas()

import sklearn.metrics
r2_score = sklearn.metrics.r2_score(y_true, y_pred)
print('r2_score: {0}'.format(r2_score))

然后你会得到

System Message: WARNING/2 (R^2)

latex exited with error [stdout] This is pdfTeX, Version 3.14159265-2.6-1.40.17 (TeX Live 2016/Debian) (preloaded format=latex) restricted \write18 enabled. entering extended mode (./math.tex LaTeX2e <2017/01/01> patch level 3 Babel <3.9r> and hyphenation patterns for 12 language(s) loaded. (/usr/share/texlive/texmf-dist/tex/latex/base/article.cls Document Class: article 2014/09/29 v1.4h Standard LaTeX document class (/usr/share/texlive/texmf-dist/tex/latex/base/size12.clo)) (/usr/share/texlive/texmf-dist/tex/latex/base/inputenc.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/utf8x.def)) (/usr/share/texlive/texmf-dist/tex/latex/ucs/ucs.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/data/uni-global.def)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsmath.sty For additional information on amsmath, use the `?' option. (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amstext.sty (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsgen.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsbsy.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsopn.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amscls/amsthm.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amssymb.sty (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amsfonts.sty)) (/usr/share/texlive/texmf-dist/tex/latex/anyfontsize/anyfontsize.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/bm.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mathtools.sty (/usr/share/texlive/texmf-dist/tex/latex/graphics/keyval.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/calc.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mhsetup.sty)) ! LaTeX Error: File `dsfont.sty' not found. Type X to quit or <RETURN> to proceed, or enter new name. (Default extension: sty) Enter file name: ! Emergency stop. <read *> l.16 \def \Z{\mathbb{Z}}^^M No pages of output. Transcript written on math.log.

价值观:

r2_score: 0.87707391843

8.3. 决策树回归

8.3.1. 介绍

8.3.2. 如何解决?

8.3.3. 演示

  1. 设置Spark上下文和SparkSession

from pyspark.sql import SparkSession

spark = SparkSession \
    .builder \
    .appName("Python Spark regression example") \
    .config("spark.some.config.option", "some-value") \
    .getOrCreate()
  1. 负载数据集

df = spark.read.format('com.databricks.spark.csv').\
                       options(header='true', \
                       inferschema='true').\
            load("../data/Advertising.csv",header=True);

检查数据集

df.show(5,True)
df.printSchema()

然后你会得到

+-----+-----+---------+-----+
|   TV|Radio|Newspaper|Sales|
+-----+-----+---------+-----+
|230.1| 37.8|     69.2| 22.1|
| 44.5| 39.3|     45.1| 10.4|
| 17.2| 45.9|     69.3|  9.3|
|151.5| 41.3|     58.5| 18.5|
|180.8| 10.8|     58.4| 12.9|
+-----+-----+---------+-----+
only showing top 5 rows

root
 |-- TV: double (nullable = true)
 |-- Radio: double (nullable = true)
 |-- Newspaper: double (nullable = true)
 |-- Sales: double (nullable = true)

您还可以从数据帧中获得统计结果(不幸的是,它只适用于数字)。

df.describe().show()

然后你会得到

+-------+-----------------+------------------+------------------+------------------+
|summary|               TV|             Radio|         Newspaper|             Sales|
+-------+-----------------+------------------+------------------+------------------+
|  count|              200|               200|               200|               200|
|   mean|         147.0425|23.264000000000024|30.553999999999995|14.022500000000003|
| stddev|85.85423631490805|14.846809176168728| 21.77862083852283| 5.217456565710477|
|    min|              0.7|               0.0|               0.3|               1.6|
|    max|            296.4|              49.6|             114.0|              27.0|
+-------+-----------------+------------------+------------------+------------------+
  1. 将数据转换为密集向量( 特征标签

注解

强烈建议您尝试我的 get_dummy 用于处理comple数据集中的分类数据的函数。

监督学习版本:

def get_dummy(df,indexCol,categoricalCols,continuousCols,labelCol):

    from pyspark.ml import Pipeline
    from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler
    from pyspark.sql.functions import col

    indexers = [ StringIndexer(inputCol=c, outputCol="{0}_indexed".format(c))
                 for c in categoricalCols ]

    # default setting: dropLast=True
    encoders = [ OneHotEncoder(inputCol=indexer.getOutputCol(),
                 outputCol="{0}_encoded".format(indexer.getOutputCol()))
                 for indexer in indexers ]

    assembler = VectorAssembler(inputCols=[encoder.getOutputCol() for encoder in encoders]
                                + continuousCols, outputCol="features")

    pipeline = Pipeline(stages=indexers + encoders + [assembler])

    model=pipeline.fit(df)
    data = model.transform(df)

    data = data.withColumn('label',col(labelCol))

    return data.select(indexCol,'features','label')

无监督学习版本:

def get_dummy(df,indexCol,categoricalCols,continuousCols):
    '''
    Get dummy variables and concat with continuous variables for unsupervised learning.
    :param df: the dataframe
    :param categoricalCols: the name list of the categorical data
    :param continuousCols:  the name list of the numerical data
    :return k: feature matrix

    :author: Wenqiang Feng
    :email:  von198@gmail.com
    '''

    indexers = [ StringIndexer(inputCol=c, outputCol="{0}_indexed".format(c))
                 for c in categoricalCols ]

    # default setting: dropLast=True
    encoders = [ OneHotEncoder(inputCol=indexer.getOutputCol(),
                 outputCol="{0}_encoded".format(indexer.getOutputCol()))
                 for indexer in indexers ]

    assembler = VectorAssembler(inputCols=[encoder.getOutputCol() for encoder in encoders]
                                + continuousCols, outputCol="features")

    pipeline = Pipeline(stages=indexers + encoders + [assembler])

    model=pipeline.fit(df)
    data = model.transform(df)

    return data.select(indexCol,'features')
from pyspark.sql import Row
from pyspark.ml.linalg import Vectors

# I provide two ways to build the features and labels

# method 1 (good for small feature):
#def transData(row):
#    return Row(label=row["Sales"],
#               features=Vectors.dense([row["TV"],
#                                       row["Radio"],
#                                       row["Newspaper"]]))

# Method 2 (good for large features):
def transData(data):
return data.rdd.map(lambda r: [Vectors.dense(r[:-1]),r[-1]]).toDF(['features','label'])
transformed= transData(df)
transformed.show(5)
+-----------------+-----+
|         features|label|
+-----------------+-----+
|[230.1,37.8,69.2]| 22.1|
| [44.5,39.3,45.1]| 10.4|
| [17.2,45.9,69.3]|  9.3|
|[151.5,41.3,58.5]| 18.5|
|[180.8,10.8,58.4]| 12.9|
+-----------------+-----+
only showing top 5 rows

注解

您将发现Spark中的所有机器学习算法都基于 特征标签 . 也就是说,当你准备好 特征标签 .

  1. 将数据转换为密集向量

# convert the data to dense vector
def transData(data):
    return data.rdd.map(lambda r: [r[-1], Vectors.dense(r[:-1])]).\
           toDF(['label','features'])

transformed = transData(df)
transformed.show(5)
  1. 处理分类变量

from pyspark.ml import Pipeline
from pyspark.ml.regression import LinearRegression
from pyspark.ml.feature import VectorIndexer
from pyspark.ml.evaluation import RegressionEvaluator

# Automatically identify categorical features, and index them.
# We specify maxCategories so features with > 4
# distinct values are treated as continuous.

featureIndexer = VectorIndexer(inputCol="features", \
                               outputCol="indexedFeatures",\
                               maxCategories=4).fit(transformed)

data = featureIndexer.transform(transformed)

当你在这一点上检查你的数据时,你会

+-----------------+-----+-----------------+
|         features|label|  indexedFeatures|
+-----------------+-----+-----------------+
|[230.1,37.8,69.2]| 22.1|[230.1,37.8,69.2]|
| [44.5,39.3,45.1]| 10.4| [44.5,39.3,45.1]|
| [17.2,45.9,69.3]|  9.3| [17.2,45.9,69.3]|
|[151.5,41.3,58.5]| 18.5|[151.5,41.3,58.5]|
|[180.8,10.8,58.4]| 12.9|[180.8,10.8,58.4]|
+-----------------+-----+-----------------+
only showing top 5 rows
  1. 将数据分成培训和测试集(40%用于测试)

# Split the data into training and test sets (40% held out for testing)
(trainingData, testData) = transformed.randomSplit([0.6, 0.4])

您可以检查您的列车和测试数据,如下所示(在我看来,在原型PAHSE期间始终跟踪您的数据是很好的):

trainingData.show(5)
testData.show(5)

然后你会得到

+---------------+-----+---------------+
|       features|label|indexedFeatures|
+---------------+-----+---------------+
| [4.1,11.6,5.7]|  3.2| [4.1,11.6,5.7]|
|[7.3,28.1,41.4]|  5.5|[7.3,28.1,41.4]|
| [8.4,27.2,2.1]|  5.7| [8.4,27.2,2.1]|
|  [8.6,2.1,1.0]|  4.8|  [8.6,2.1,1.0]|
|[8.7,48.9,75.0]|  7.2|[8.7,48.9,75.0]|
+---------------+-----+---------------+
only showing top 5 rows

+----------------+-----+----------------+
|        features|label| indexedFeatures|
+----------------+-----+----------------+
|  [0.7,39.6,8.7]|  1.6|  [0.7,39.6,8.7]|
|  [5.4,29.9,9.4]|  5.3|  [5.4,29.9,9.4]|
| [7.8,38.9,50.6]|  6.6| [7.8,38.9,50.6]|
|[17.2,45.9,69.3]|  9.3|[17.2,45.9,69.3]|
|[18.7,12.1,23.4]|  6.7|[18.7,12.1,23.4]|
+----------------+-----+----------------+
only showing top 5 rows
  1. 拟合决策树回归模型

from pyspark.ml.regression import DecisionTreeRegressor

# Train a DecisionTree model.
dt = DecisionTreeRegressor(featuresCol="indexedFeatures")
  1. 管道体系结构

# Chain indexer and tree in a Pipeline
pipeline = Pipeline(stages=[featureIndexer, dt])

model = pipeline.fit(trainingData)
  1. 做出预测

# Make predictions.
predictions = model.transform(testData)
# Select example rows to display.
predictions.select("features","label","predictedLabel").show(5)
+----------+-----+----------------+
|prediction|label|        features|
+----------+-----+----------------+
|       7.2|  1.6|  [0.7,39.6,8.7]|
|       7.3|  5.3|  [5.4,29.9,9.4]|
|       7.2|  6.6| [7.8,38.9,50.6]|
|      8.64|  9.3|[17.2,45.9,69.3]|
|      6.45|  6.7|[18.7,12.1,23.4]|
+----------+-----+----------------+
only showing top 5 rows
  1. 评价

from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.evaluation import RegressionEvaluator
# Select (prediction, true label) and compute test error
evaluator = RegressionEvaluator(labelCol="label",
                                predictionCol="prediction",
                                metricName="rmse")

rmse = evaluator.evaluate(predictions)
print("Root Mean Squared Error (RMSE) on test data = %g" % rmse)

最终的均方根误差(rmse)如下:

Root Mean Squared Error (RMSE) on test data = 1.50999
y_true = predictions.select("label").toPandas()
y_pred = predictions.select("prediction").toPandas()

import sklearn.metrics
r2_score = sklearn.metrics.r2_score(y_true, y_pred)
print('r2_score: {0}'.format(r2_score))

然后你会得到

System Message: WARNING/2 (R^2)

latex exited with error [stdout] This is pdfTeX, Version 3.14159265-2.6-1.40.17 (TeX Live 2016/Debian) (preloaded format=latex) restricted \write18 enabled. entering extended mode (./math.tex LaTeX2e <2017/01/01> patch level 3 Babel <3.9r> and hyphenation patterns for 12 language(s) loaded. (/usr/share/texlive/texmf-dist/tex/latex/base/article.cls Document Class: article 2014/09/29 v1.4h Standard LaTeX document class (/usr/share/texlive/texmf-dist/tex/latex/base/size12.clo)) (/usr/share/texlive/texmf-dist/tex/latex/base/inputenc.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/utf8x.def)) (/usr/share/texlive/texmf-dist/tex/latex/ucs/ucs.sty (/usr/share/texlive/texmf-dist/tex/latex/ucs/data/uni-global.def)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsmath.sty For additional information on amsmath, use the `?' option. (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amstext.sty (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsgen.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsbsy.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsopn.sty)) (/usr/share/texlive/texmf-dist/tex/latex/amscls/amsthm.sty) (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amssymb.sty (/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amsfonts.sty)) (/usr/share/texlive/texmf-dist/tex/latex/anyfontsize/anyfontsize.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/bm.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mathtools.sty (/usr/share/texlive/texmf-dist/tex/latex/graphics/keyval.sty) (/usr/share/texlive/texmf-dist/tex/latex/tools/calc.sty) (/usr/share/texlive/texmf-dist/tex/latex/mathtools/mhsetup.sty)) ! LaTeX Error: File `dsfont.sty' not found. Type X to quit or <RETURN> to proceed, or enter new name. (Default extension: sty) Enter file name: ! Emergency stop. <read *> l.16 \def \Z{\mathbb{Z}}^^M No pages of output. Transcript written on math.log.

价值观:

r2_score: 0.911024318967

您还可以检查功能的重要性:

model.stages[1].featureImportances

您将获得每个功能的权重

SparseVector(3, {0: 0.6811, 1: 0.3187, 2: 0.0002})

8.4. 随机森林回归

8.4.1. 介绍

8.4.2. 如何解决?

8.4.3. 演示

  1. 设置Spark上下文和SparkSession

from pyspark.sql import SparkSession

spark = SparkSession \
    .builder \
    .appName("Python Spark RandomForest Regression example") \
    .config("spark.some.config.option", "some-value") \
    .getOrCreate()
  1. 负载数据集

df = spark.read.format('com.databricks.spark.csv').\
                               options(header='true', \
                               inferschema='true').\
                               load("../data/Advertising.csv",header=True);

df.show(5,True)
df.printSchema()
+-----+-----+---------+-----+
|   TV|Radio|Newspaper|Sales|
+-----+-----+---------+-----+
|230.1| 37.8|     69.2| 22.1|
| 44.5| 39.3|     45.1| 10.4|
| 17.2| 45.9|     69.3|  9.3|
|151.5| 41.3|     58.5| 18.5|
|180.8| 10.8|     58.4| 12.9|
+-----+-----+---------+-----+
only showing top 5 rows

root
 |-- TV: double (nullable = true)
 |-- Radio: double (nullable = true)
 |-- Newspaper: double (nullable = true)
 |-- Sales: double (nullable = true)
df.describe().show()

+-------+-----------------+------------------+------------------+------------------+
|summary|               TV|             Radio|         Newspaper|             Sales|
+-------+-----------------+------------------+------------------+------------------+
|  count|              200|               200|               200|               200|
|   mean|         147.0425|23.264000000000024|30.553999999999995|14.022500000000003|
| stddev|85.85423631490805|14.846809176168728| 21.77862083852283| 5.217456565710477|
|    min|              0.7|               0.0|               0.3|               1.6|
|    max|            296.4|              49.6|             114.0|              27.0|
+-------+-----------------+------------------+------------------+------------------+
  1. 将数据转换为密集向量( 特征标签

注解

强烈建议您尝试我的 get_dummy 用于处理comple数据集中的分类数据的函数。

监督学习版本:

def get_dummy(df,indexCol,categoricalCols,continuousCols,labelCol):

    from pyspark.ml import Pipeline
    from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler
    from pyspark.sql.functions import col

    indexers = [ StringIndexer(inputCol=c, outputCol="{0}_indexed".format(c))
                 for c in categoricalCols ]

    # default setting: dropLast=True
    encoders = [ OneHotEncoder(inputCol=indexer.getOutputCol(),
                 outputCol="{0}_encoded".format(indexer.getOutputCol()))
                 for indexer in indexers ]

    assembler = VectorAssembler(inputCols=[encoder.getOutputCol() for encoder in encoders]
                                + continuousCols, outputCol="features")

    pipeline = Pipeline(stages=indexers + encoders + [assembler])

    model=pipeline.fit(df)
    data = model.transform(df)

    data = data.withColumn('label',col(labelCol))

    return data.select(indexCol,'features','label')

无监督学习版本:

def get_dummy(df,indexCol,categoricalCols,continuousCols):
    '''
    Get dummy variables and concat with continuous variables for unsupervised learning.
    :param df: the dataframe
    :param categoricalCols: the name list of the categorical data
    :param continuousCols:  the name list of the numerical data
    :return k: feature matrix

    :author: Wenqiang Feng
    :email:  von198@gmail.com
    '''

    indexers = [ StringIndexer(inputCol=c, outputCol="{0}_indexed".format(c))
                 for c in categoricalCols ]

    # default setting: dropLast=True
    encoders = [ OneHotEncoder(inputCol=indexer.getOutputCol(),
                 outputCol="{0}_encoded".format(indexer.getOutputCol()))
                 for indexer in indexers ]

    assembler = VectorAssembler(inputCols=[encoder.getOutputCol() for encoder in encoders]
                                + continuousCols, outputCol="features")

    pipeline = Pipeline(stages=indexers + encoders + [assembler])

    model=pipeline.fit(df)
    data = model.transform(df)

    return data.select(indexCol,'features')
from pyspark.sql import Row
from pyspark.ml.linalg import Vectors

# convert the data to dense vector
#def transData(row):
#    return Row(label=row["Sales"],
#               features=Vectors.dense([row["TV"],
#                                       row["Radio"],
#                                       row["Newspaper"]]))
def transData(data):
    return data.rdd.map(lambda r: [Vectors.dense(r[:-1]),r[-1]]).toDF(['features','label'])
  1. 将数据转换为密集向量

transformed= transData(df)
transformed.show(5)
+-----------------+-----+
|         features|label|
+-----------------+-----+
|[230.1,37.8,69.2]| 22.1|
| [44.5,39.3,45.1]| 10.4|
| [17.2,45.9,69.3]|  9.3|
|[151.5,41.3,58.5]| 18.5|
|[180.8,10.8,58.4]| 12.9|
+-----------------+-----+
only showing top 5 rows
  1. 处理分类变量

from pyspark.ml import Pipeline
from pyspark.ml.regression import LinearRegression
from pyspark.ml.feature import VectorIndexer
from pyspark.ml.evaluation import RegressionEvaluator

featureIndexer = VectorIndexer(inputCol="features", \
                               outputCol="indexedFeatures",\
                               maxCategories=4).fit(transformed)

data = featureIndexer.transform(transformed)
data.show(5,True)
+-----------------+-----+-----------------+
|         features|label|  indexedFeatures|
+-----------------+-----+-----------------+
|[230.1,37.8,69.2]| 22.1|[230.1,37.8,69.2]|
| [44.5,39.3,45.1]| 10.4| [44.5,39.3,45.1]|
| [17.2,45.9,69.3]|  9.3| [17.2,45.9,69.3]|
|[151.5,41.3,58.5]| 18.5|[151.5,41.3,58.5]|
|[180.8,10.8,58.4]| 12.9|[180.8,10.8,58.4]|
+-----------------+-----+-----------------+
only showing top 5 rows
  1. 将数据分成培训和测试集(40%用于测试)

# Split the data into training and test sets (40% held out for testing)
(trainingData, testData) = data.randomSplit([0.6, 0.4])

trainingData.show(5)
testData.show(5)
+----------------+-----+----------------+
|        features|label| indexedFeatures|
+----------------+-----+----------------+
|  [0.7,39.6,8.7]|  1.6|  [0.7,39.6,8.7]|
|   [8.6,2.1,1.0]|  4.8|   [8.6,2.1,1.0]|
| [8.7,48.9,75.0]|  7.2| [8.7,48.9,75.0]|
|[11.7,36.9,45.2]|  7.3|[11.7,36.9,45.2]|
|[13.2,15.9,49.6]|  5.6|[13.2,15.9,49.6]|
+----------------+-----+----------------+
only showing top 5 rows

+---------------+-----+---------------+
|       features|label|indexedFeatures|
+---------------+-----+---------------+
| [4.1,11.6,5.7]|  3.2| [4.1,11.6,5.7]|
| [5.4,29.9,9.4]|  5.3| [5.4,29.9,9.4]|
|[7.3,28.1,41.4]|  5.5|[7.3,28.1,41.4]|
|[7.8,38.9,50.6]|  6.6|[7.8,38.9,50.6]|
| [8.4,27.2,2.1]|  5.7| [8.4,27.2,2.1]|
+---------------+-----+---------------+
only showing top 5 rows
  1. 拟合随机森林回归模型

# Import LinearRegression class
from pyspark.ml.regression import RandomForestRegressor

# Define LinearRegression algorithm
rf = RandomForestRegressor() # featuresCol="indexedFeatures",numTrees=2, maxDepth=2, seed=42

注解

如果你决定使用 indexedFeatures 功能,您需要添加参数 featuresCol="indexedFeatures" .

  1. 管道体系结构

# Chain indexer and tree in a Pipeline
pipeline = Pipeline(stages=[featureIndexer, rf])
model = pipeline.fit(trainingData)
  1. 做出预测

predictions = model.transform(testData)

# Select example rows to display.
predictions.select("features","label", "prediction").show(5)
+---------------+-----+------------------+
|       features|label|        prediction|
+---------------+-----+------------------+
| [4.1,11.6,5.7]|  3.2| 8.155439814814816|
| [5.4,29.9,9.4]|  5.3|10.412769901394899|
|[7.3,28.1,41.4]|  5.5| 12.13735648148148|
|[7.8,38.9,50.6]|  6.6|11.321796703296704|
| [8.4,27.2,2.1]|  5.7|12.071421957671957|
+---------------+-----+------------------+
only showing top 5 rows
  1. 评价

# Select (prediction, true label) and compute test error
evaluator = RegressionEvaluator(
    labelCol="label", predictionCol="prediction", metricName="rmse")
rmse = evaluator.evaluate(predictions)
print("Root Mean Squared Error (RMSE) on test data = %g" % rmse)
Root Mean Squared Error (RMSE) on test data = 2.35912
import sklearn.metrics
r2_score = sklearn.metrics.r2_score(y_true, y_pred)
print('r2_score: {:4.3f}'.format(r2_score))
r2_score: 0.831
  1. 特征导入

model.stages[-1].featureImportances
SparseVector(3, {0: 0.4994, 1: 0.3196, 2: 0.181})
model.stages[-1].trees
[DecisionTreeRegressionModel (uid=dtr_c75f1c75442c) of depth 5 with 43 nodes,
 DecisionTreeRegressionModel (uid=dtr_70fc2d441581) of depth 5 with 45 nodes,
 DecisionTreeRegressionModel (uid=dtr_bc8464f545a7) of depth 5 with 31 nodes,
 DecisionTreeRegressionModel (uid=dtr_a8a7e5367154) of depth 5 with 59 nodes,
 DecisionTreeRegressionModel (uid=dtr_3ea01314fcbc) of depth 5 with 47 nodes,
 DecisionTreeRegressionModel (uid=dtr_be9a04ac22a6) of depth 5 with 45 nodes,
 DecisionTreeRegressionModel (uid=dtr_38610d47328a) of depth 5 with 51 nodes,
 DecisionTreeRegressionModel (uid=dtr_bf14aea0ad3b) of depth 5 with 49 nodes,
 DecisionTreeRegressionModel (uid=dtr_cde24ebd6bb6) of depth 5 with 39 nodes,
 DecisionTreeRegressionModel (uid=dtr_a1fc9bd4fbeb) of depth 5 with 57 nodes,
 DecisionTreeRegressionModel (uid=dtr_37798d6db1ba) of depth 5 with 41 nodes,
 DecisionTreeRegressionModel (uid=dtr_c078b73ada63) of depth 5 with 41 nodes,
 DecisionTreeRegressionModel (uid=dtr_fd00e3a070ad) of depth 5 with 55 nodes,
 DecisionTreeRegressionModel (uid=dtr_9d01d5fb8604) of depth 5 with 45 nodes,
 DecisionTreeRegressionModel (uid=dtr_8bd8bdddf642) of depth 5 with 41 nodes,
 DecisionTreeRegressionModel (uid=dtr_e53b7bae30f8) of depth 5 with 49 nodes,
 DecisionTreeRegressionModel (uid=dtr_808a869db21c) of depth 5 with 47 nodes,
 DecisionTreeRegressionModel (uid=dtr_64d0916bceb0) of depth 5 with 33 nodes,
 DecisionTreeRegressionModel (uid=dtr_0891055fff94) of depth 5 with 55 nodes,
 DecisionTreeRegressionModel (uid=dtr_19c8bbad26c2) of depth 5 with 51 nodes]

8.5. 梯度增强树回归

8.5.1. 介绍

8.5.2. 如何解决?

8.5.3. 演示

  1. 设置Spark上下文和SparkSession

from pyspark.sql import SparkSession

spark = SparkSession \
    .builder \
    .appName("Python Spark GBTRegressor example") \
    .config("spark.some.config.option", "some-value") \
    .getOrCreate()
  1. 负载数据集

df = spark.read.format('com.databricks.spark.csv').\
                               options(header='true', \
                               inferschema='true').\
                               load("../data/Advertising.csv",header=True);

df.show(5,True)
df.printSchema()
+-----+-----+---------+-----+
|   TV|Radio|Newspaper|Sales|
+-----+-----+---------+-----+
|230.1| 37.8|     69.2| 22.1|
| 44.5| 39.3|     45.1| 10.4|
| 17.2| 45.9|     69.3|  9.3|
|151.5| 41.3|     58.5| 18.5|
|180.8| 10.8|     58.4| 12.9|
+-----+-----+---------+-----+
only showing top 5 rows

root
 |-- TV: double (nullable = true)
 |-- Radio: double (nullable = true)
 |-- Newspaper: double (nullable = true)
 |-- Sales: double (nullable = true)
df.describe().show()

+-------+-----------------+------------------+------------------+------------------+
|summary|               TV|             Radio|         Newspaper|             Sales|
+-------+-----------------+------------------+------------------+------------------+
|  count|              200|               200|               200|               200|
|   mean|         147.0425|23.264000000000024|30.553999999999995|14.022500000000003|
| stddev|85.85423631490805|14.846809176168728| 21.77862083852283| 5.217456565710477|
|    min|              0.7|               0.0|               0.3|               1.6|
|    max|            296.4|              49.6|             114.0|              27.0|
+-------+-----------------+------------------+------------------+------------------+
  1. 将数据转换为密集向量( 特征标签

注解

强烈建议您尝试我的 get_dummy 用于处理comple数据集中的分类数据的函数。

监督学习版本:

def get_dummy(df,indexCol,categoricalCols,continuousCols,labelCol):

    from pyspark.ml import Pipeline
    from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler
    from pyspark.sql.functions import col

    indexers = [ StringIndexer(inputCol=c, outputCol="{0}_indexed".format(c))
                 for c in categoricalCols ]

    # default setting: dropLast=True
    encoders = [ OneHotEncoder(inputCol=indexer.getOutputCol(),
                 outputCol="{0}_encoded".format(indexer.getOutputCol()))
                 for indexer in indexers ]

    assembler = VectorAssembler(inputCols=[encoder.getOutputCol() for encoder in encoders]
                                + continuousCols, outputCol="features")

    pipeline = Pipeline(stages=indexers + encoders + [assembler])

    model=pipeline.fit(df)
    data = model.transform(df)

    data = data.withColumn('label',col(labelCol))

    return data.select(indexCol,'features','label')

无监督学习版本:

def get_dummy(df,indexCol,categoricalCols,continuousCols):
    '''
    Get dummy variables and concat with continuous variables for unsupervised learning.
    :param df: the dataframe
    :param categoricalCols: the name list of the categorical data
    :param continuousCols:  the name list of the numerical data
    :return k: feature matrix

    :author: Wenqiang Feng
    :email:  von198@gmail.com
    '''

    indexers = [ StringIndexer(inputCol=c, outputCol="{0}_indexed".format(c))
                 for c in categoricalCols ]

    # default setting: dropLast=True
    encoders = [ OneHotEncoder(inputCol=indexer.getOutputCol(),
                 outputCol="{0}_encoded".format(indexer.getOutputCol()))
                 for indexer in indexers ]

    assembler = VectorAssembler(inputCols=[encoder.getOutputCol() for encoder in encoders]
                                + continuousCols, outputCol="features")

    pipeline = Pipeline(stages=indexers + encoders + [assembler])

    model=pipeline.fit(df)
    data = model.transform(df)

    return data.select(indexCol,'features')
from pyspark.sql import Row
from pyspark.ml.linalg import Vectors

# convert the data to dense vector
#def transData(row):
#    return Row(label=row["Sales"],
#               features=Vectors.dense([row["TV"],
#                                       row["Radio"],
#                                       row["Newspaper"]]))
def transData(data):
    return data.rdd.map(lambda r: [Vectors.dense(r[:-1]),r[-1]]).toDF(['features','label'])
  1. 将数据转换为密集向量

transformed= transData(df)
transformed.show(5)
+-----------------+-----+
|         features|label|
+-----------------+-----+
|[230.1,37.8,69.2]| 22.1|
| [44.5,39.3,45.1]| 10.4|
| [17.2,45.9,69.3]|  9.3|
|[151.5,41.3,58.5]| 18.5|
|[180.8,10.8,58.4]| 12.9|
+-----------------+-----+
only showing top 5 rows
  1. 处理分类变量

from pyspark.ml import Pipeline
from pyspark.ml.regression import GBTRegressor
from pyspark.ml.feature import VectorIndexer
from pyspark.ml.evaluation import RegressionEvaluator

featureIndexer = VectorIndexer(inputCol="features", \
                               outputCol="indexedFeatures",\
                               maxCategories=4).fit(transformed)

data = featureIndexer.transform(transformed)
data.show(5,True)
+-----------------+-----+-----------------+
|         features|label|  indexedFeatures|
+-----------------+-----+-----------------+
|[230.1,37.8,69.2]| 22.1|[230.1,37.8,69.2]|
| [44.5,39.3,45.1]| 10.4| [44.5,39.3,45.1]|
| [17.2,45.9,69.3]|  9.3| [17.2,45.9,69.3]|
|[151.5,41.3,58.5]| 18.5|[151.5,41.3,58.5]|
|[180.8,10.8,58.4]| 12.9|[180.8,10.8,58.4]|
+-----------------+-----+-----------------+
only showing top 5 rows
  1. 将数据分成培训和测试集(40%用于测试)

# Split the data into training and test sets (40% held out for testing)
(trainingData, testData) = data.randomSplit([0.6, 0.4])

trainingData.show(5)
testData.show(5)
+----------------+-----+----------------+
|        features|label| indexedFeatures|
+----------------+-----+----------------+
|  [0.7,39.6,8.7]|  1.6|  [0.7,39.6,8.7]|
|   [8.6,2.1,1.0]|  4.8|   [8.6,2.1,1.0]|
| [8.7,48.9,75.0]|  7.2| [8.7,48.9,75.0]|
|[11.7,36.9,45.2]|  7.3|[11.7,36.9,45.2]|
|[13.2,15.9,49.6]|  5.6|[13.2,15.9,49.6]|
+----------------+-----+----------------+
only showing top 5 rows

+---------------+-----+---------------+
|       features|label|indexedFeatures|
+---------------+-----+---------------+
| [4.1,11.6,5.7]|  3.2| [4.1,11.6,5.7]|
| [5.4,29.9,9.4]|  5.3| [5.4,29.9,9.4]|
|[7.3,28.1,41.4]|  5.5|[7.3,28.1,41.4]|
|[7.8,38.9,50.6]|  6.6|[7.8,38.9,50.6]|
| [8.4,27.2,2.1]|  5.7| [8.4,27.2,2.1]|
+---------------+-----+---------------+
only showing top 5 rows
  1. 拟合随机森林回归模型

# Import LinearRegression class
from pyspark.ml.regression import GBTRegressor

# Define LinearRegression algorithm
rf = GBTRegressor() #numTrees=2, maxDepth=2, seed=42

注解

如果你决定使用 indexedFeatures 功能,您需要添加参数 featuresCol="indexedFeatures" .

  1. 管道体系结构

# Chain indexer and tree in a Pipeline
pipeline = Pipeline(stages=[featureIndexer, rf])
model = pipeline.fit(trainingData)
  1. 做出预测

predictions = model.transform(testData)

# Select example rows to display.
predictions.select("features","label", "prediction").show(5)
+----------------+-----+------------------+
|        features|label|        prediction|
+----------------+-----+------------------+
| [7.8,38.9,50.6]|  6.6| 6.836040343319862|
|   [8.6,2.1,1.0]|  4.8| 5.652202764688849|
| [8.7,48.9,75.0]|  7.2| 6.908750296855572|
| [13.1,0.4,25.6]|  5.3| 5.784020210692574|
|[19.6,20.1,17.0]|  7.6|6.8678921062629295|
+----------------+-----+------------------+
only showing top 5 rows
  1. 评价

# Select (prediction, true label) and compute test error
evaluator = RegressionEvaluator(
    labelCol="label", predictionCol="prediction", metricName="rmse")
rmse = evaluator.evaluate(predictions)
print("Root Mean Squared Error (RMSE) on test data = %g" % rmse)
Root Mean Squared Error (RMSE) on test data = 1.36939
import sklearn.metrics
r2_score = sklearn.metrics.r2_score(y_true, y_pred)
print('r2_score: {:4.3f}'.format(r2_score))
r2_score: 0.932
  1. 特征导入

model.stages[-1].featureImportances
SparseVector(3, {0: 0.3716, 1: 0.3525, 2: 0.2759})
model.stages[-1].trees
[DecisionTreeRegressionModel (uid=dtr_7f5cd2ef7cb6) of depth 5 with 61 nodes,
 DecisionTreeRegressionModel (uid=dtr_ef3ab6baeac9) of depth 5 with 39 nodes,
 DecisionTreeRegressionModel (uid=dtr_07c6e3cf3819) of depth 5 with 45 nodes,
 DecisionTreeRegressionModel (uid=dtr_ce724af79a2b) of depth 5 with 47 nodes,
 DecisionTreeRegressionModel (uid=dtr_d149ecc71658) of depth 5 with 55 nodes,
 DecisionTreeRegressionModel (uid=dtr_d3a79bdea516) of depth 5 with 43 nodes,
 DecisionTreeRegressionModel (uid=dtr_7abc1a337844) of depth 5 with 51 nodes,
 DecisionTreeRegressionModel (uid=dtr_480834b46d8f) of depth 5 with 33 nodes,
 DecisionTreeRegressionModel (uid=dtr_0cbd1eaa3874) of depth 5 with 39 nodes,
 DecisionTreeRegressionModel (uid=dtr_8088ac71a204) of depth 5 with 57 nodes,
 DecisionTreeRegressionModel (uid=dtr_2ceb9e8deb45) of depth 5 with 47 nodes,
 DecisionTreeRegressionModel (uid=dtr_cc334e84e9a2) of depth 5 with 57 nodes,
 DecisionTreeRegressionModel (uid=dtr_a665c562929e) of depth 5 with 41 nodes,
 DecisionTreeRegressionModel (uid=dtr_2999b1ffd2dc) of depth 5 with 45 nodes,
 DecisionTreeRegressionModel (uid=dtr_29965cbe8cfc) of depth 5 with 55 nodes,
 DecisionTreeRegressionModel (uid=dtr_731df51bf0ad) of depth 5 with 41 nodes,
 DecisionTreeRegressionModel (uid=dtr_354cf33424da) of depth 5 with 51 nodes,
 DecisionTreeRegressionModel (uid=dtr_4230f200b1c0) of depth 5 with 41 nodes,
 DecisionTreeRegressionModel (uid=dtr_3279cdc1ce1d) of depth 5 with 45 nodes,
 DecisionTreeRegressionModel (uid=dtr_f474a99ff06e) of depth 5 with 55 nodes]