具有异类数据源的列Transformer#

数据集通常可能包含需要不同特征提取和处理管道的组件。这种情况可能会发生在以下情况:

  1. 您的数据集由异类数据类型(例如,格栅图像和文本标题)组成,

  2. 您的数据集存储在 pandas.DataFrame 并且不同的柱需要不同的处理管道。

此示例演示如何使用 ColumnTransformer 在包含不同类型特征的数据集上。功能的选择并不是特别有帮助,但可以说明该技术。

# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

import numpy as np

from sklearn.compose import ColumnTransformer
from sklearn.datasets import fetch_20newsgroups
from sklearn.decomposition import PCA
from sklearn.feature_extraction import DictVectorizer
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics import classification_report
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import FunctionTransformer
from sklearn.svm import LinearSVC

20个新闻组数据集#

我们将使用 20 newsgroups dataset ,其中包括来自新闻组关于20个主题的帖子。此数据集根据特定日期之前和之后发布的消息分为训练子集和测试子集。我们将仅使用2个类别的帖子来加快运行时间。

categories = ["sci.med", "sci.space"]
X_train, y_train = fetch_20newsgroups(
    random_state=1,
    subset="train",
    categories=categories,
    remove=("footers", "quotes"),
    return_X_y=True,
)
X_test, y_test = fetch_20newsgroups(
    random_state=1,
    subset="test",
    categories=categories,
    remove=("footers", "quotes"),
    return_X_y=True,
)

每个功能都包括有关该帖子的Meta信息,例如主题和新闻帖子的正文。

print(X_train[0])
From: mccall@mksol.dseg.ti.com (fred j mccall 575-3539)
Subject: Re: Metric vs English
Article-I.D.: mksol.1993Apr6.131900.8407
Organization: Texas Instruments Inc
Lines: 31




American, perhaps, but nothing military about it.  I learned (mostly)
slugs when we talked English units in high school physics and while
the teacher was an ex-Navy fighter jock the book certainly wasn't
produced by the military.

[Poundals were just too flinking small and made the math come out
funny; sort of the same reason proponents of SI give for using that.]

--
"Insisting on perfect safety is for people who don't have the balls to live
 in the real world."   -- Mary Shafer, NASA Ames Dryden

创建transformers#

首先,我们想要一个Transformer,可以提取每个帖子的主题和正文。由于这是一个无状态转换(不需要来自训练数据的状态信息),我们可以定义一个执行数据转换的函数,然后使用 FunctionTransformer 创建一个scikit-learn Transformer。

def subject_body_extractor(posts):
    # construct object dtype array with two columns
    # first column = 'subject' and second column = 'body'
    features = np.empty(shape=(len(posts), 2), dtype=object)
    for i, text in enumerate(posts):
        # temporary variable `_` stores '\n\n'
        headers, _, body = text.partition("\n\n")
        # store body text in second column
        features[i, 1] = body

        prefix = "Subject:"
        sub = ""
        # save text after 'Subject:' in first column
        for line in headers.split("\n"):
            if line.startswith(prefix):
                sub = line[len(prefix) :]
                break
        features[i, 0] = sub

    return features


subject_body_transformer = FunctionTransformer(subject_body_extractor)

我们还将创建一个Transformer,用于提取文本长度和句子数量。

def text_stats(posts):
    return [{"length": len(text), "num_sentences": text.count(".")} for text in posts]


text_stats_transformer = FunctionTransformer(text_stats)

分类流水线#

下面的管道使用从每个帖子中提取主题和正文 SubjectBodyExtractor ,产生(n_samples,2)阵列。然后使用此数组计算主题和正文的标准词袋特征以及正文上的文本长度和句子数量,使用 ColumnTransformer .我们将它们与权重结合起来,然后在组合的特征集上训练分类器。

pipeline = Pipeline(
    [
        # Extract subject & body
        ("subjectbody", subject_body_transformer),
        # Use ColumnTransformer to combine the subject and body features
        (
            "union",
            ColumnTransformer(
                [
                    # bag-of-words for subject (col 0)
                    ("subject", TfidfVectorizer(min_df=50), 0),
                    # bag-of-words with decomposition for body (col 1)
                    (
                        "body_bow",
                        Pipeline(
                            [
                                ("tfidf", TfidfVectorizer()),
                                ("best", PCA(n_components=50, svd_solver="arpack")),
                            ]
                        ),
                        1,
                    ),
                    # Pipeline for pulling text stats from post's body
                    (
                        "body_stats",
                        Pipeline(
                            [
                                (
                                    "stats",
                                    text_stats_transformer,
                                ),  # returns a list of dicts
                                (
                                    "vect",
                                    DictVectorizer(),
                                ),  # list of dicts -> feature matrix
                            ]
                        ),
                        1,
                    ),
                ],
                # weight above ColumnTransformer features
                transformer_weights={
                    "subject": 0.8,
                    "body_bow": 0.5,
                    "body_stats": 1.0,
                },
            ),
        ),
        # Use a SVC classifier on the combined features
        ("svc", LinearSVC(dual=False)),
    ],
    verbose=True,
)

最后,我们将我们的管道与训练数据匹配,并使用它来预测主题 X_test .然后打印我们管道的性能指标。

pipeline.fit(X_train, y_train)
y_pred = pipeline.predict(X_test)
print("Classification report:\n\n{}".format(classification_report(y_test, y_pred)))
[Pipeline] ....... (step 1 of 3) Processing subjectbody, total=   0.0s
[Pipeline] ............. (step 2 of 3) Processing union, total=   0.4s
[Pipeline] ............... (step 3 of 3) Processing svc, total=   0.0s
Classification report:

              precision    recall  f1-score   support

           0       0.84      0.87      0.86       396
           1       0.87      0.84      0.85       394

    accuracy                           0.86       790
   macro avg       0.86      0.86      0.86       790
weighted avg       0.86      0.86      0.86       790

Total running time of the script: (0分1.954秒)

相关实例

使用稀疏特征对文本文档进行分类

Classification of text documents using sparse features

使用谱协同集群算法对文档进行二集群

Biclustering documents with the Spectral Co-clustering algorithm

混合类型的列Transformer

Column Transformer with Mixed Types

文本数据集的半监督分类

Semi-supervised Classification on a Text Dataset

Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io> _