IGNITE-8169:
[ignite.git] / modules / ml / src / test / java / org / apache / ignite / ml / clustering / KMeansModelTest.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.ignite.ml.clustering;
19
20 import org.apache.ignite.ml.TestUtils;
21 import org.apache.ignite.ml.clustering.kmeans.KMeansModel;
22 import org.apache.ignite.ml.math.Vector;
23 import org.apache.ignite.ml.math.distances.DistanceMeasure;
24 import org.apache.ignite.ml.math.distances.EuclideanDistance;
25 import org.apache.ignite.ml.math.exceptions.CardinalityException;
26 import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
27 import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
28 import org.apache.ignite.ml.svm.SVMLinearBinaryClassificationModel;
29 import org.apache.ignite.ml.svm.SVMLinearMultiClassClassificationModel;
30 import org.junit.Assert;
31 import org.junit.Test;
32
33 /**
34 * Tests for {@link KMeansModel}.
35 */
36 public class KMeansModelTest {
37 /** Precision in test checks. */
38 private static final double PRECISION = 1e-6;
39
40 /** */
41 @Test
42 public void predictClusters() {
43 DistanceMeasure distanceMeasure = new EuclideanDistance();
44
45 Vector[] centers = new DenseLocalOnHeapVector[4];
46
47 centers[0] = new DenseLocalOnHeapVector(new double[]{1.0, 1.0});
48 centers[1] = new DenseLocalOnHeapVector(new double[]{-1.0, 1.0});
49 centers[2] = new DenseLocalOnHeapVector(new double[]{1.0, -1.0});
50 centers[3] = new DenseLocalOnHeapVector(new double[]{-1.0, -1.0});
51
52 KMeansModel mdl = new KMeansModel(centers, distanceMeasure);
53
54 Assert.assertEquals(mdl.apply(new DenseLocalOnHeapVector(new double[]{1.1, 1.1})), 0.0, PRECISION);
55 Assert.assertEquals(mdl.apply(new DenseLocalOnHeapVector(new double[]{-1.1, 1.1})), 1.0, PRECISION);
56 Assert.assertEquals(mdl.apply(new DenseLocalOnHeapVector(new double[]{1.1, -1.1})), 2.0, PRECISION);
57 Assert.assertEquals(mdl.apply(new DenseLocalOnHeapVector(new double[]{-1.1, -1.1})), 3.0, PRECISION);
58
59 Assert.assertEquals(mdl.distanceMeasure(), distanceMeasure);
60 Assert.assertEquals(mdl.amountOfClusters(), 4);
61 Assert.assertArrayEquals(mdl.centers(), centers);
62 }
63 }