IGNITE-8907: [ML] Using vectors in featureExtractor
[ignite.git] / modules / ml / src / main / java / org / apache / ignite / ml / composition / BaggingModelTrainer.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.ignite.ml.composition;
19
20 import java.util.ArrayList;
21 import java.util.HashMap;
22 import java.util.List;
23 import java.util.Map;
24 import java.util.Random;
25 import java.util.stream.IntStream;
26 import org.apache.ignite.ml.Model;
27 import org.apache.ignite.ml.composition.predictionsaggregator.PredictionsAggregator;
28 import org.apache.ignite.ml.dataset.DatasetBuilder;
29 import org.apache.ignite.ml.math.Vector;
30 import org.apache.ignite.ml.math.VectorUtils;
31 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
32 import org.apache.ignite.ml.math.functions.IgniteFunction;
33 import org.apache.ignite.ml.selection.split.mapper.SHA256UniformMapper;
34 import org.apache.ignite.ml.trainers.DatasetTrainer;
35 import org.apache.ignite.ml.util.Utils;
36 import org.jetbrains.annotations.NotNull;
37
38 /**
39 * Abstract trainer implementing bagging logic.
40 * In each learning iteration the algorithm trains one model on subset of learning sample and
41 * subspace of features space. Each model is produced from same model-class [e.g. Decision Trees].
42 */
43 public abstract class BaggingModelTrainer implements DatasetTrainer<ModelsComposition, Double> {
44 /**
45 * Predictions aggregator.
46 */
47 private final PredictionsAggregator predictionsAggregator;
48 /**
49 * Number of features to draw from original features vector to train each model.
50 */
51 private final int maximumFeaturesCntPerMdl;
52 /**
53 * Ensemble size.
54 */
55 private final int ensembleSize;
56 /**
57 * Size of sample part in percent to train one model.
58 */
59 private final double samplePartSizePerMdl;
60 /**
61 * Feature vector size.
62 */
63 private final int featureVectorSize;
64
65 /**
66 * Constructs new instance of BaggingModelTrainer.
67 *
68 * @param predictionsAggregator Predictions aggregator.
69 * @param featureVectorSize Feature vector size.
70 * @param maximumFeaturesCntPerMdl Number of features to draw from original features vector to train each model.
71 * @param ensembleSize Ensemble size.
72 * @param samplePartSizePerMdl Size of sample part in percent to train one model.
73 */
74 public BaggingModelTrainer(PredictionsAggregator predictionsAggregator,
75 int featureVectorSize,
76 int maximumFeaturesCntPerMdl,
77 int ensembleSize,
78 double samplePartSizePerMdl) {
79
80 this.predictionsAggregator = predictionsAggregator;
81 this.maximumFeaturesCntPerMdl = maximumFeaturesCntPerMdl;
82 this.ensembleSize = ensembleSize;
83 this.samplePartSizePerMdl = samplePartSizePerMdl;
84 this.featureVectorSize = featureVectorSize;
85 }
86
87 /** {@inheritDoc} */
88 @Override public <K, V> ModelsComposition fit(DatasetBuilder<K, V> datasetBuilder,
89 IgniteBiFunction<K, V, Vector> featureExtractor,
90 IgniteBiFunction<K, V, Double> lbExtractor) {
91
92 List<ModelOnFeaturesSubspace> learnedModels = new ArrayList<>();
93 for (int i = 0; i < ensembleSize; i++)
94 learnedModels.add(learnModel(datasetBuilder, featureExtractor, lbExtractor));
95
96 return new ModelsComposition(learnedModels, predictionsAggregator);
97 }
98
99 /**
100 * Trains one model on part of sample and features subspace.
101 *
102 * @param datasetBuilder Dataset builder.
103 * @param featureExtractor Feature extractor.
104 * @param lbExtractor Label extractor.
105 */
106 @NotNull private <K, V> ModelOnFeaturesSubspace learnModel(
107 DatasetBuilder<K, V> datasetBuilder,
108 IgniteBiFunction<K, V, Vector> featureExtractor,
109 IgniteBiFunction<K, V, Double> lbExtractor) {
110
111 Random rnd = new Random();
112 SHA256UniformMapper<K, V> sampleFilter = new SHA256UniformMapper<>(rnd);
113 long featureExtractorSeed = rnd.nextLong();
114 Map<Integer, Integer> featuresMapping = createFeaturesMapping(featureExtractorSeed, featureVectorSize);
115
116 //TODO: IGNITE-8867 Need to implement bootstrapping algorithm
117 Model<Vector, Double> mdl = buildDatasetTrainerForModel().fit(
118 datasetBuilder.withFilter((features, answer) -> sampleFilter.map(features, answer) < samplePartSizePerMdl),
119 wrapFeatureExtractor(featureExtractor, featuresMapping),
120 lbExtractor);
121
122 return new ModelOnFeaturesSubspace(featuresMapping, mdl);
123 }
124
125 /**
126 * Constructs mapping from original feature vector to subspace.
127 *
128 * @param seed Seed.
129 * @param featuresVectorSize Features vector size.
130 */
131 private Map<Integer, Integer> createFeaturesMapping(long seed, int featuresVectorSize) {
132 int[] featureIdxs = Utils.selectKDistinct(featuresVectorSize, maximumFeaturesCntPerMdl, new Random(seed));
133 Map<Integer, Integer> locFeaturesMapping = new HashMap<>();
134
135 IntStream.range(0, maximumFeaturesCntPerMdl)
136 .forEach(localId -> locFeaturesMapping.put(localId, featureIdxs[localId]));
137
138 return locFeaturesMapping;
139 }
140
141 /**
142 * Creates trainer specific to ensemble.
143 */
144 protected abstract DatasetTrainer<? extends Model<Vector, Double>, Double> buildDatasetTrainerForModel();
145
146 /**
147 * Wraps the original feature extractor with features subspace mapping applying.
148 *
149 * @param featureExtractor Feature extractor.
150 * @param featureMapping Feature mapping.
151 */
152 private <K, V> IgniteBiFunction<K, V, Vector> wrapFeatureExtractor(
153 IgniteBiFunction<K, V, Vector> featureExtractor,
154 Map<Integer, Integer> featureMapping) {
155
156 return featureExtractor.andThen((IgniteFunction<Vector, Vector>)featureValues -> {
157 double[] newFeaturesValues = new double[featureMapping.size()];
158 featureMapping.forEach((localId, featureValueId) -> newFeaturesValues[localId] = featureValues.get(featureValueId));
159 return VectorUtils.of(newFeaturesValues);
160 });
161 }
162 }