IGNITE-8907: [ML] Using vectors in featureExtractor
[ignite.git] / examples / src / main / java / org / apache / ignite / examples / ml / tutorial / Step_9_Go_to_LogReg.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.ignite.examples.ml.tutorial;
19
20 import java.io.FileNotFoundException;
21 import java.util.Arrays;
22 import org.apache.ignite.Ignite;
23 import org.apache.ignite.IgniteCache;
24 import org.apache.ignite.Ignition;
25 import org.apache.ignite.ml.math.Vector;
26 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
27 import org.apache.ignite.ml.nn.UpdatesStrategy;
28 import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
29 import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
30 import org.apache.ignite.ml.preprocessing.encoding.stringencoder.StringEncoderTrainer;
31 import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer;
32 import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer;
33 import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer;
34 import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel;
35 import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer;
36 import org.apache.ignite.ml.selection.cv.CrossValidation;
37 import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
38 import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
39 import org.apache.ignite.ml.selection.split.TrainTestDatasetSplitter;
40 import org.apache.ignite.ml.selection.split.TrainTestSplit;
41 import org.apache.ignite.thread.IgniteThread;
42
43 /**
44 * Maybe the another algorithm can give us the higher accuracy?
45 *
46 * Let's win with the LogisticRegressionSGDTrainer!
47 */
48 public class Step_9_Go_to_LogReg {
49 /** Run example. */
50 public static void main(String[] args) throws InterruptedException {
51 try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
52 IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
53 Step_9_Go_to_LogReg.class.getSimpleName(), () -> {
54 try {
55 IgniteCache<Integer, Object[]> dataCache = TitanicUtils.readPassengers(ignite);
56
57 // Defines first preprocessor that extracts features from an upstream data.
58 // Extracts "pclass", "sibsp", "parch", "sex", "embarked", "age", "fare"
59 IgniteBiFunction<Integer, Object[], Object[]> featureExtractor
60 = (k, v) -> new Object[]{v[0], v[3], v[4], v[5], v[6], v[8], v[10]};
61
62 IgniteBiFunction<Integer, Object[], Double> lbExtractor = (k, v) -> (double) v[1];
63
64 TrainTestSplit<Integer, Object[]> split = new TrainTestDatasetSplitter<Integer, Object[]>()
65 .split(0.75);
66
67 IgniteBiFunction<Integer, Object[], Vector> strEncoderPreprocessor = new StringEncoderTrainer<Integer, Object[]>()
68 .encodeFeature(1)
69 .encodeFeature(6) // <--- Changed index here
70 .fit(ignite,
71 dataCache,
72 featureExtractor
73 );
74
75 IgniteBiFunction<Integer, Object[], Vector> imputingPreprocessor = new ImputerTrainer<Integer, Object[]>()
76 .fit(ignite,
77 dataCache,
78 strEncoderPreprocessor
79 );
80
81 IgniteBiFunction<Integer, Object[], Vector> minMaxScalerPreprocessor = new MinMaxScalerTrainer<Integer, Object[]>()
82 .fit(
83 ignite,
84 dataCache,
85 imputingPreprocessor
86 );
87
88 // Tune hyperparams with K-fold Cross-Validation on the splitted training set.
89 int[] pSet = new int[]{1, 2};
90 int[] maxIterationsSet = new int[]{ 100, 1000};
91 int[] batchSizeSet = new int[]{100, 10};
92 int[] locIterationsSet = new int[]{10, 100};
93 double[] learningRateSet = new double[]{0.1, 0.2, 0.5};
94
95
96 int bestP = 1;
97 int bestMaxIterations = 100;
98 int bestBatchSize = 10;
99 int bestLocIterations = 10;
100 double bestLearningRate = 0.0;
101 double avg = Double.MIN_VALUE;
102
103 for(int p: pSet){
104 for(int maxIterations: maxIterationsSet) {
105 for (int batchSize : batchSizeSet) {
106 for (int locIterations : locIterationsSet) {
107 for (double learningRate : learningRateSet) {
108
109 IgniteBiFunction<Integer, Object[], Vector> normalizationPreprocessor = new NormalizationTrainer<Integer, Object[]>()
110 .withP(p)
111 .fit(
112 ignite,
113 dataCache,
114 minMaxScalerPreprocessor
115 );
116
117 LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>(new UpdatesStrategy<>(
118 new SimpleGDUpdateCalculator(learningRate),
119 SimpleGDParameterUpdate::sumLocal,
120 SimpleGDParameterUpdate::avg
121 ), maxIterations, batchSize, locIterations, 123L);
122
123 CrossValidation<LogisticRegressionModel, Double, Integer, Object[]> scoreCalculator
124 = new CrossValidation<>();
125
126 double[] scores = scoreCalculator.score(
127 trainer,
128 new Accuracy<>(),
129 ignite,
130 dataCache,
131 split.getTrainFilter(),
132 normalizationPreprocessor,
133 lbExtractor,
134 3
135 );
136
137 System.out.println("Scores are: " + Arrays.toString(scores));
138
139 final double currAvg = Arrays.stream(scores).average().orElse(Double.MIN_VALUE);
140
141 if (currAvg > avg) {
142 avg = currAvg;
143 bestP = p;
144 bestMaxIterations = maxIterations;
145 bestBatchSize = batchSize;
146 bestLearningRate = learningRate;
147 bestLocIterations = locIterations;
148 }
149
150 System.out.println("Avg is: " + currAvg
151 + " with p: " + p
152 + " with maxIterations: " + maxIterations
153 + " with batchSize: " + batchSize
154 + " with learningRate: " + learningRate
155 + " with locIterations: " + locIterations
156 );
157 }
158 }
159 }
160 }
161 }
162
163 System.out.println("Train "
164 + " with p: " + bestP
165 + " with maxIterations: " + bestMaxIterations
166 + " with batchSize: " + bestBatchSize
167 + " with learningRate: " + bestLearningRate
168 + " with locIterations: " + bestLocIterations
169 );
170
171 IgniteBiFunction<Integer, Object[], Vector> normalizationPreprocessor = new NormalizationTrainer<Integer, Object[]>()
172 .withP(bestP)
173 .fit(
174 ignite,
175 dataCache,
176 minMaxScalerPreprocessor
177 );
178
179 LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>(new UpdatesStrategy<>(
180 new SimpleGDUpdateCalculator(bestLearningRate),
181 SimpleGDParameterUpdate::sumLocal,
182 SimpleGDParameterUpdate::avg
183 ), bestMaxIterations, bestBatchSize, bestLocIterations, 123L);
184
185 System.out.println(">>> Perform the training to get the model.");
186 LogisticRegressionModel bestMdl = trainer.fit(
187 ignite,
188 dataCache,
189 split.getTrainFilter(),
190 normalizationPreprocessor,
191 lbExtractor
192 );
193
194 double accuracy = Evaluator.evaluate(
195 dataCache,
196 split.getTestFilter(),
197 bestMdl,
198 normalizationPreprocessor,
199 lbExtractor,
200 new Accuracy<>()
201 );
202
203 System.out.println("\n>>> Accuracy " + accuracy);
204 System.out.println("\n>>> Test Error " + (1 - accuracy));
205 }
206 catch (FileNotFoundException e) {
207 e.printStackTrace();
208 }
209 });
210
211 igniteThread.start();
212
213 igniteThread.join();
214 }
215 }
216 }