IGNITE-8169:
[ignite.git] / modules / ml / src / test / java / org / apache / ignite / ml / clustering / KMeansTrainerTest.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.ignite.ml.clustering;
19
20 import java.util.Arrays;
21 import java.util.HashMap;
22 import java.util.Map;
23 import org.apache.ignite.ml.clustering.kmeans.KMeansModel;
24 import org.apache.ignite.ml.clustering.kmeans.KMeansTrainer;
25 import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
26 import org.apache.ignite.ml.math.Vector;
27 import org.apache.ignite.ml.math.distances.EuclideanDistance;
28 import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
29 import org.junit.Test;
30
31 import static org.junit.Assert.assertEquals;
32
33 /**
34 * Tests for {@link KMeansTrainer}.
35 */
36 public class KMeansTrainerTest {
37 /** Precision in test checks. */
38 private static final double PRECISION = 1e-2;
39
40 /**
41 * A few points, one cluster, one iteration
42 */
43 @Test
44 public void findOneClusters() {
45
46 Map<Integer, double[]> data = new HashMap<>();
47 data.put(0, new double[] {1.0, 1.0, 1.0});
48 data.put(1, new double[] {1.0, 2.0, 1.0});
49 data.put(2, new double[] {2.0, 1.0, 1.0});
50 data.put(3, new double[] {-1.0, -1.0, 2.0});
51 data.put(4, new double[] {-1.0, -2.0, 2.0});
52 data.put(5, new double[] {-2.0, -1.0, 2.0});
53
54 KMeansTrainer trainer = new KMeansTrainer()
55 .withDistance(new EuclideanDistance())
56 .withK(1)
57 .withMaxIterations(1)
58 .withEpsilon(PRECISION);
59
60 KMeansModel knnMdl = trainer.fit(
61 new LocalDatasetBuilder<>(data, 2),
62 (k, v) -> Arrays.copyOfRange(v, 0, v.length - 1),
63 (k, v) -> v[2]
64 );
65
66 Vector firstVector = new DenseLocalOnHeapVector(new double[] {2.0, 2.0});
67 assertEquals(knnMdl.apply(firstVector), 0.0, PRECISION);
68 Vector secondVector = new DenseLocalOnHeapVector(new double[] {-2.0, -2.0});
69 assertEquals(knnMdl.apply(secondVector), 0.0, PRECISION);
70 assertEquals(trainer.getMaxIterations(), 1);
71 assertEquals(trainer.getEpsilon(), PRECISION, PRECISION);
72 }
73 }