IGNITE-8907: [ML] Using vectors in featureExtractor
[ignite.git] / modules / ml / src / main / java / org / apache / ignite / ml / preprocessing / minmaxscaling / MinMaxScalerTrainer.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.ignite.ml.preprocessing.minmaxscaling;
19
20 import org.apache.ignite.ml.dataset.Dataset;
21 import org.apache.ignite.ml.dataset.DatasetBuilder;
22 import org.apache.ignite.ml.dataset.UpstreamEntry;
23 import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
24 import org.apache.ignite.ml.math.Vector;
25 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
26 import org.apache.ignite.ml.preprocessing.PreprocessingTrainer;
27
28 /**
29 * Trainer of the minmaxscaling preprocessor.
30 *
31 * @param <K> Type of a key in {@code upstream} data.
32 * @param <V> Type of a value in {@code upstream} data.
33 */
34 public class MinMaxScalerTrainer<K, V> implements PreprocessingTrainer<K, V, Vector, Vector> {
35 /** {@inheritDoc} */
36 @Override public MinMaxScalerPreprocessor<K, V> fit(DatasetBuilder<K, V> datasetBuilder,
37 IgniteBiFunction<K, V, Vector> basePreprocessor) {
38 try (Dataset<EmptyContext, MinMaxScalerPartitionData> dataset = datasetBuilder.build(
39 (upstream, upstreamSize) -> new EmptyContext(),
40 (upstream, upstreamSize, ctx) -> {
41 double[] min = null;
42 double[] max = null;
43
44 while (upstream.hasNext()) {
45 UpstreamEntry<K, V> entity = upstream.next();
46 Vector row = basePreprocessor.apply(entity.getKey(), entity.getValue());
47
48 if (min == null) {
49 min = new double[row.size()];
50 for (int i = 0; i < min.length; i++)
51 min[i] = Double.MAX_VALUE;
52 }
53 else
54 assert min.length == row.size() : "Base preprocessor must return exactly " + min.length
55 + " features";
56
57 if (max == null) {
58 max = new double[row.size()];
59 for (int i = 0; i < max.length; i++)
60 max[i] = -Double.MAX_VALUE;
61 }
62 else
63 assert max.length == row.size() : "Base preprocessor must return exactly " + min.length
64 + " features";
65
66 for (int i = 0; i < row.size(); i++) {
67 if (row.get(i) < min[i])
68 min[i] = row.get(i);
69 if (row.get(i) > max[i])
70 max[i] = row.get(i);
71 }
72 }
73
74 return new MinMaxScalerPartitionData(min, max);
75 }
76 )) {
77 double[][] minMax = dataset.compute(
78 data -> data.getMin() != null ? new double[][]{ data.getMin(), data.getMax() } : null,
79 (a, b) -> {
80 if (a == null)
81 return b;
82
83 if (b == null)
84 return a;
85
86 double[][] res = new double[2][];
87
88 res[0] = new double[a[0].length];
89 for (int i = 0; i < res[0].length; i++)
90 res[0][i] = Math.min(a[0][i], b[0][i]);
91
92 res[1] = new double[a[1].length];
93 for (int i = 0; i < res[1].length; i++)
94 res[1][i] = Math.max(a[1][i], b[1][i]);
95
96 return res;
97 }
98 );
99
100 return new MinMaxScalerPreprocessor<>(minMax[0], minMax[1], basePreprocessor);
101 }
102 catch (Exception e) {
103 throw new RuntimeException(e);
104 }
105 }
106 }