For Your ISHIO Blog

データ分析や機械学習やスクラムや組織とか、色々つぶやくブログです。

SparkでLasso回帰のハイパーパラメータλをグリッドサーチして特徴量選択する

はじめに

Lasso回帰は、正則化された線形回帰手法の1つで、線形回帰にL1正則化項を追加したモデルです。正則化により過学習を防ぐとともに、不要と判断される説明変数の偏回帰係数がゼロになる性質があります。この性質を利用して、目的変数により影響が高い説明変数のみを選択する特徴量選択を自動で行う手法でもあります。

特徴量選択の方法は、より精度が高いものから、より複雑なものまで様々な方法が用意されています。ビジネス現場では、数百万の中から迅速に変数選択を行う必要がある場合に、Lasso回帰のようなシンプルな方法論が利用されるケースがあります。

この記事では、Sparkを利用してLasso回帰を行い特徴量選択を行います。また、最適なハイパーパラメータλを探索するためにグリッドサーチを行います。

正則化とは

正則化とは、過学習を防いで汎化性能(未知のデータに対する予測力)を得るためのに追加情報を導入するプロセスです。通常、データの複雑さに対するペナルティまたは制約の形で行われます。モデルに正則化項を追加することで、データの外れ値が原因で係数が任意に引き伸ばされる可能性を減らし、モデルの過学習を防ぐことができます。

L1正則化

L2正則化項は、線形回帰の損失関数(最小二乗損失関数)に重みの絶対値和を足し合わせた項です。係数はゼロに向かって収束し、変数の選択に役立つように一部の係数は正確にゼロになります。 f:id:ishitonton:20201010160531p:plain

L2正則化

L2正則化項は、線形回帰の損失関数(最小二乗損失関数)に重みの平方和を足し合わせた項です。ハイパーパラメータλの値が増えるとL2正則化項が大きくなります。正則化を強くし、モデルの重みを小さくするように調整されます。つまり説明変数の影響が大きくなりにくいように抑えてくれます。

Elastic Net

リッジ回帰とLasso回帰の両方のペナルティ制約の効果を組み合わせたモデルです。下記の損失関数が与えられます。 f:id:ishitonton:20201010160624p:plain

Elasticパラメータのα = 1の場合、損失関数はL1正則化(Lasso)になります。α= 0の場合、損失関数はL2正則化(Ridge)となります。αが0〜1の間にある場合、損失関数は、係数にL1(ラッソ)制約とL2(リッジ)制約の両方を組み合わせて実装します。 ハイパーパラメータλが大きくなると、損失関数は切片以外の係数にペナルティを課します。

一般に、過学習を回避するために正則化を使用するため、λが異なる複数のモデルをトレーニングし、テストエラーが最小となるモデルを選択する必要があります。そこで今回はグリッドサーチでλを変動させて、最適なλを探索し、そのモデルの係数を利用して特徴量選択を行います。

Sparkでの実装

  • ハイパーパラメータα = 1とし、lasso回帰を実行します。
  • ハイパーパラメータλを0~1の間でグリッドサーチし、探索します。
  • クロスバリデーション(Folds=5)で実行します。

Lasso回帰に実行にはSparkMLを利用してデータフレーム形式で実行します。label(目的変数)のカラムとfeatures(説明変数)のカラムを用意し、featuresはVectorに変換します。

from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml.evaluation import RegressionEvaluator 
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import LinearRegression

df_raw = spark.read.parquet("/user/test")

featureCols = df_raw.columns
featureCols.remove("label")

assembled_feature = VectorAssembler(inputCols=featureCols, outputCol='features')
df = assembled_feature.transform(df_raw).select(["label", "features"])
df.show()

データフレームの出力結果は以下の通りです。

+--------------------+--------------------+
|                 label|            features|
+--------------------+--------------------+
| 0.15088482464765735|[-0.10498780858,-...|
| 0.15088482464765735|[-0.10498780858,-...|
| 0.44954344663064794|(15,[0,1,2,3,4,5,...|
| 0.21725340731054388|[-0.10498780858,0...|
|-0.23073452566394168|[1.01488214956,1....|
| -0.3302873996582715|[-0.10498780858,-...|
| -0.9607889349556958|[1.01488214956,1....|
|  0.7150177772821946|[-1.22485776671,-...|
| 0.44954344663064794|[-1.22485776671,-...|
|-0.14777379733533322|[-0.10498780858,-...|
| -0.3800638366554367|[-0.10498780858,-...|
| -0.6123538759755407|[1.01488214956,0....|
|-0.46302456498404515|[1.01488214956,0....|
| -0.8114596239642009|[1.01488214956,0....|
|  -1.607882615918842|[1.01488214956,1....|
|  -1.607882615918842|[1.01488214956,1....|
| -0.8944203522928094|[1.01488214956,1....|
|    2.04238943053993|(15,[0,1,2,3,4,5,...|
|   1.710546517225496|(15,[0,1,2,3,4,5,...|
|  2.2912716155257553|(15,[0,1,2,3,4,5,...|
+--------------------+--------------------+
only showing top 20 rows

次に、lasso回帰を実行します。ハイパーパラメータλはグリッドサーチで探索を行い、Folds=5でクロスバリデーションを行います。また、メトリックスにはRMSEを選択します。

# grodsearch parameter scope
regParam_list = [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]

lr = LinearRegression(featuresCol="features",
                      labelCol="label",
                      predictionCol="predict",
                      maxIter=20,
                      elasticNetParam=1,
                      tol=1e-06)
paramGrid = ParamGridBuilder().addGrid(lr.regParam, regParam_list).build()
crossval = CrossValidator(estimator=lr,
                          estimatorParamMaps=paramGrid,
                          evaluator=RegressionEvaluator(predictionCol='predict',
                                                        labelCol='label',
                                                        metricName='rmse'),
                          numFolds=5)

cvModel = crossval.fit(df)

最後にベストモデルを選択し、ハイパーパラメータλを確認します。今回は0.1が採用されました。

model = cvModel.bestModel
# Results of regParam
print(model._java_obj.getRegParam()) # 0.1

偏回帰係数と切片を取得します。いくつかの係数はゼロとなっており、変数選択されていることが確認できます。

print("Coefficients: " + str(model.coefficients))
print("Intercept: " + str(model.intercept))
Coefficients: [-0.1997601362638407,0.0,-0.20024756083461823,0.03626482379996047,-0.43624005401234084,0.0,0.0,-0.05147738465330625,0.0,0.0,0.0,0.0,0.0,-0.05600529685189293,0.0]
Intercept: 0.048066352404