IGNITE-8907: [ML] Using vectors in featureExtractor
[ignite.git] / modules / ml / src / main / java / org / apache / ignite / ml / clustering / kmeans / KMeansTrainer.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.ignite.ml.clustering.kmeans;
19
20 import java.util.ArrayList;
21 import java.util.List;
22 import java.util.Random;
23 import java.util.concurrent.ConcurrentHashMap;
24 import java.util.stream.Collectors;
25 import java.util.stream.Stream;
26 import org.apache.ignite.lang.IgniteBiTuple;
27 import org.apache.ignite.ml.dataset.Dataset;
28 import org.apache.ignite.ml.dataset.DatasetBuilder;
29 import org.apache.ignite.ml.dataset.PartitionDataBuilder;
30 import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
31 import org.apache.ignite.ml.math.Vector;
32 import org.apache.ignite.ml.math.VectorUtils;
33 import org.apache.ignite.ml.math.distances.DistanceMeasure;
34 import org.apache.ignite.ml.math.distances.EuclideanDistance;
35 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
36 import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
37 import org.apache.ignite.ml.math.util.MapUtil;
38 import org.apache.ignite.ml.structures.LabeledDataset;
39 import org.apache.ignite.ml.structures.LabeledVector;
40 import org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap;
41 import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
42
43 /**
44 * The trainer for KMeans algorithm.
45 */
46 public class KMeansTrainer implements SingleLabelDatasetTrainer<KMeansModel> {
47 /** Amount of clusters. */
48 private int k = 2;
49
50 /** Amount of iterations. */
51 private int maxIterations = 10;
52
53 /** Delta of convergence. */
54 private double epsilon = 1e-4;
55
56 /** Distance measure. */
57 private DistanceMeasure distance = new EuclideanDistance();
58
59 /** KMeans initializer. */
60 private long seed;
61
62 /**
63 * Trains model based on the specified data.
64 *
65 * @param datasetBuilder Dataset builder.
66 * @param featureExtractor Feature extractor.
67 * @param lbExtractor Label extractor.
68 * @return Model.
69 */
70 @Override public <K, V> KMeansModel fit(DatasetBuilder<K, V> datasetBuilder,
71 IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
72 assert datasetBuilder != null;
73
74 PartitionDataBuilder<K, V, EmptyContext, LabeledDataset<Double, LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>(
75 featureExtractor,
76 lbExtractor
77 );
78
79 Vector[] centers;
80
81 try (Dataset<EmptyContext, LabeledDataset<Double, LabeledVector>> dataset = datasetBuilder.build(
82 (upstream, upstreamSize) -> new EmptyContext(),
83 partDataBuilder
84 )) {
85 final int cols = dataset.compute(org.apache.ignite.ml.structures.Dataset::colSize, (a, b) -> a == null ? b : a);
86 centers = initClusterCentersRandomly(dataset, k);
87
88 boolean converged = false;
89 int iteration = 0;
90
91 while (iteration < maxIterations && !converged) {
92 Vector[] newCentroids = new DenseLocalOnHeapVector[k];
93
94 TotalCostAndCounts totalRes = calcDataForNewCentroids(centers, dataset, cols);
95
96 converged = true;
97
98 for (Integer ind : totalRes.sums.keySet()) {
99 Vector massCenter = totalRes.sums.get(ind).times(1.0 / totalRes.counts.get(ind));
100
101 if (converged && distance.compute(massCenter, centers[ind]) > epsilon * epsilon)
102 converged = false;
103
104 newCentroids[ind] = massCenter;
105 }
106
107 iteration++;
108 centers = newCentroids;
109 }
110 }
111 catch (Exception e) {
112 throw new RuntimeException(e);
113 }
114 return new KMeansModel(centers, distance);
115 }
116
117 /**
118 * Prepares the data to define new centroids on current iteration.
119 *
120 * @param centers Current centers on the current iteration.
121 * @param dataset Dataset.
122 * @param cols Amount of columns.
123 * @return Helper data to calculate the new centroids.
124 */
125 private TotalCostAndCounts calcDataForNewCentroids(Vector[] centers,
126 Dataset<EmptyContext, LabeledDataset<Double, LabeledVector>> dataset, int cols) {
127 final Vector[] finalCenters = centers;
128
129 return dataset.compute(data -> {
130
131 TotalCostAndCounts res = new TotalCostAndCounts();
132
133 for (int i = 0; i < data.rowSize(); i++) {
134 final IgniteBiTuple<Integer, Double> closestCentroid = findClosestCentroid(finalCenters, data.getRow(i));
135
136 int centroidIdx = closestCentroid.get1();
137
138 data.setLabel(i, centroidIdx);
139
140 res.totalCost += closestCentroid.get2();
141 res.sums.putIfAbsent(centroidIdx, VectorUtils.zeroes(cols));
142
143 int finalI = i;
144 res.sums.compute(centroidIdx,
145 (IgniteBiFunction<Integer, Vector, Vector>)(ind, v) -> v.plus(data.getRow(finalI).features()));
146
147 res.counts.merge(centroidIdx, 1,
148 (IgniteBiFunction<Integer, Integer, Integer>)(i1, i2) -> i1 + i2);
149 }
150 return res;
151 }, (a, b) -> a == null ? b : a.merge(b));
152 }
153
154 /**
155 * Find the closest cluster center index and distance to it from a given point.
156 *
157 * @param centers Centers to look in.
158 * @param pnt Point.
159 */
160 private IgniteBiTuple<Integer, Double> findClosestCentroid(Vector[] centers, LabeledVector pnt) {
161 double bestDistance = Double.POSITIVE_INFINITY;
162 int bestInd = 0;
163
164 for (int i = 0; i < centers.length; i++) {
165 double dist = distance.compute(centers[i], pnt.features());
166 if (dist < bestDistance) {
167 bestDistance = dist;
168 bestInd = i;
169 }
170 }
171 return new IgniteBiTuple<>(bestInd, bestDistance);
172 }
173
174 /**
175 * K cluster centers are initialized randomly.
176 *
177 * @param dataset The dataset to pick up random centers.
178 * @param k Amount of clusters.
179 * @return K cluster centers.
180 */
181 private Vector[] initClusterCentersRandomly(Dataset<EmptyContext, LabeledDataset<Double, LabeledVector>> dataset,
182 int k) {
183
184 Vector[] initCenters = new DenseLocalOnHeapVector[k];
185
186 List<LabeledVector> rndPnts = dataset.compute(data -> {
187 List<LabeledVector> rndPnt = new ArrayList<>();
188 rndPnt.add(data.getRow(new Random(seed).nextInt(data.rowSize())));
189 return rndPnt;
190 }, (a, b) -> a == null ? b : Stream.concat(a.stream(), b.stream()).collect(Collectors.toList()));
191
192 for (int i = 0; i < k; i++) {
193 final LabeledVector rndPnt = rndPnts.get(new Random(seed).nextInt(rndPnts.size()));
194 rndPnts.remove(rndPnt);
195 initCenters[i] = rndPnt.features();
196 }
197
198 return initCenters;
199 }
200
201 /** Service class used for statistics. */
202 private static class TotalCostAndCounts {
203 /** */
204 double totalCost;
205
206 /** */
207 ConcurrentHashMap<Integer, Vector> sums = new ConcurrentHashMap<>();
208
209 /** Count of points closest to the center with a given index. */
210 ConcurrentHashMap<Integer, Integer> counts = new ConcurrentHashMap<>();
211
212 /** Merge current */
213 TotalCostAndCounts merge(TotalCostAndCounts other) {
214 this.totalCost += totalCost;
215 this.sums = MapUtil.mergeMaps(sums, other.sums, Vector::plus, ConcurrentHashMap::new);
216 this.counts = MapUtil.mergeMaps(counts, other.counts, (i1, i2) -> i1 + i2, ConcurrentHashMap::new);
217 return this;
218 }
219 }
220
221 /**
222 * Gets the amount of clusters.
223 *
224 * @return The parameter value.
225 */
226 public int getK() {
227 return k;
228 }
229
230 /**
231 * Set up the amount of clusters.
232 *
233 * @param k The parameter value.
234 * @return Model with new amount of clusters parameter value.
235 */
236 public KMeansTrainer withK(int k) {
237 this.k = k;
238 return this;
239 }
240
241 /**
242 * Gets the max number of iterations before convergence.
243 *
244 * @return The parameter value.
245 */
246 public int getMaxIterations() {
247 return maxIterations;
248 }
249
250 /**
251 * Set up the max number of iterations before convergence.
252 *
253 * @param maxIterations The parameter value.
254 * @return Model with new max number of iterations before convergence parameter value.
255 */
256 public KMeansTrainer withMaxIterations(int maxIterations) {
257 this.maxIterations = maxIterations;
258 return this;
259 }
260
261 /**
262 * Gets the epsilon.
263 *
264 * @return The parameter value.
265 */
266 public double getEpsilon() {
267 return epsilon;
268 }
269
270 /**
271 * Set up the epsilon.
272 *
273 * @param epsilon The parameter value.
274 * @return Model with new epsilon parameter value.
275 */
276 public KMeansTrainer withEpsilon(double epsilon) {
277 this.epsilon = epsilon;
278 return this;
279 }
280
281 /**
282 * Gets the distance.
283 *
284 * @return The parameter value.
285 */
286 public DistanceMeasure getDistance() {
287 return distance;
288 }
289
290 /**
291 * Set up the distance.
292 *
293 * @param distance The parameter value.
294 * @return Model with new distance parameter value.
295 */
296 public KMeansTrainer withDistance(DistanceMeasure distance) {
297 this.distance = distance;
298 return this;
299 }
300
301 /**
302 * Gets the seed number.
303 *
304 * @return The parameter value.
305 */
306 public long getSeed() {
307 return seed;
308 }
309
310 /**
311 * Set up the seed.
312 *
313 * @param seed The parameter value.
314 * @return Model with new seed parameter value.
315 */
316 public KMeansTrainer withSeed(long seed) {
317 this.seed = seed;
318 return this;
319 }
320 }