Use bytes instead of Any in RunnerApi.FunctionSpec
[beam.git] / sdks / python / apache_beam / transforms / core.py
1 #
2 # Licensed to the Apache Software Foundation (ASF) under one or more
3 # contributor license agreements. See the NOTICE file distributed with
4 # this work for additional information regarding copyright ownership.
5 # The ASF licenses this file to You under the Apache License, Version 2.0
6 # (the "License"); you may not use this file except in compliance with
7 # the License. You may obtain a copy of the License at
8 #
9 # http://www.apache.org/licenses/LICENSE-2.0
10 #
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
16 #
17
18 """Core PTransform subclasses, such as FlatMap, GroupByKey, and Map."""
19
20 from __future__ import absolute_import
21
22 import copy
23 import inspect
24 import types
25
26 from google.protobuf import wrappers_pb2
27
28 from apache_beam import pvalue
29 from apache_beam import typehints
30 from apache_beam import coders
31 from apache_beam.coders import typecoders
32 from apache_beam.internal import pickler
33 from apache_beam.internal import util
34 from apache_beam.portability.api import beam_runner_api_pb2
35 from apache_beam.transforms import ptransform
36 from apache_beam.transforms.display import DisplayDataItem
37 from apache_beam.transforms.display import HasDisplayData
38 from apache_beam.transforms.ptransform import PTransform
39 from apache_beam.transforms.ptransform import PTransformWithSideInputs
40 from apache_beam.transforms.window import MIN_TIMESTAMP
41 from apache_beam.transforms.window import TimestampCombiner
42 from apache_beam.transforms.window import WindowedValue
43 from apache_beam.transforms.window import TimestampedValue
44 from apache_beam.transforms.window import GlobalWindows
45 from apache_beam.transforms.window import WindowFn
46 from apache_beam.typehints import Any
47 from apache_beam.typehints import Iterable
48 from apache_beam.typehints import KV
49 from apache_beam.typehints import trivial_inference
50 from apache_beam.typehints import Union
51 from apache_beam.typehints.decorators import get_type_hints
52 from apache_beam.typehints.decorators import TypeCheckError
53 from apache_beam.typehints.decorators import WithTypeHints
54 from apache_beam.typehints.trivial_inference import element_type
55 from apache_beam.typehints.typehints import is_consistent_with
56 from apache_beam.utils import proto_utils
57 from apache_beam.utils import urns
58 from apache_beam.options.pipeline_options import TypeOptions
59
60
61 __all__ = [
62 'DoFn',
63 'CombineFn',
64 'PartitionFn',
65 'ParDo',
66 'FlatMap',
67 'Map',
68 'Filter',
69 'CombineGlobally',
70 'CombinePerKey',
71 'CombineValues',
72 'GroupByKey',
73 'Partition',
74 'Windowing',
75 'WindowInto',
76 'Flatten',
77 'Create',
78 ]
79
80
81 # Type variables
82 T = typehints.TypeVariable('T')
83 K = typehints.TypeVariable('K')
84 V = typehints.TypeVariable('V')
85
86
87 class DoFnContext(object):
88 """A context available to all methods of DoFn instance."""
89 pass
90
91
92 class DoFnProcessContext(DoFnContext):
93 """A processing context passed to DoFn process() during execution.
94
95 Most importantly, a DoFn.process method will access context.element
96 to get the element it is supposed to process.
97
98 Attributes:
99 label: label of the ParDo whose element is being processed.
100 element: element being processed
101 (in process method only; always None in start_bundle and finish_bundle)
102 timestamp: timestamp of the element
103 (in process method only; always None in start_bundle and finish_bundle)
104 windows: windows of the element
105 (in process method only; always None in start_bundle and finish_bundle)
106 state: a DoFnState object, which holds the runner's internal state
107 for this element.
108 Not used by the pipeline code.
109 """
110
111 def __init__(self, label, element=None, state=None):
112 """Initialize a processing context object with an element and state.
113
114 The element represents one value from a PCollection that will be accessed
115 by a DoFn object during pipeline execution, and state is an arbitrary object
116 where counters and other pipeline state information can be passed in.
117
118 DoFnProcessContext objects are also used as inputs to PartitionFn instances.
119
120 Args:
121 label: label of the PCollection whose element is being processed.
122 element: element of a PCollection being processed using this context.
123 state: a DoFnState object with state to be passed in to the DoFn object.
124 """
125 self.label = label
126 self.state = state
127 if element is not None:
128 self.set_element(element)
129
130 def set_element(self, windowed_value):
131 if windowed_value is None:
132 # Not currently processing an element.
133 if hasattr(self, 'element'):
134 del self.element
135 del self.timestamp
136 del self.windows
137 else:
138 self.element = windowed_value.value
139 self.timestamp = windowed_value.timestamp
140 self.windows = windowed_value.windows
141
142
143 class DoFn(WithTypeHints, HasDisplayData, urns.RunnerApiFn):
144 """A function object used by a transform with custom processing.
145
146 The ParDo transform is such a transform. The ParDo.apply
147 method will take an object of type DoFn and apply it to all elements of a
148 PCollection object.
149
150 In order to have concrete DoFn objects one has to subclass from DoFn and
151 define the desired behavior (start_bundle/finish_bundle and process) or wrap a
152 callable object using the CallableWrapperDoFn class.
153 """
154
155 ElementParam = 'ElementParam'
156 SideInputParam = 'SideInputParam'
157 TimestampParam = 'TimestampParam'
158 WindowParam = 'WindowParam'
159
160 DoFnParams = [ElementParam, SideInputParam, TimestampParam, WindowParam]
161
162 @staticmethod
163 def from_callable(fn):
164 return CallableWrapperDoFn(fn)
165
166 def default_label(self):
167 return self.__class__.__name__
168
169 def process(self, element, *args, **kwargs):
170 """Called for each element of a pipeline. The default arguments are needed
171 for the DoFnRunner to be able to pass the parameters correctly.
172
173 Args:
174 element: The element to be processed
175 *args: side inputs
176 **kwargs: keyword side inputs
177 """
178 raise NotImplementedError
179
180 def start_bundle(self):
181 """Called before a bundle of elements is processed on a worker.
182
183 Elements to be processed are split into bundles and distributed
184 to workers. Before a worker calls process() on the first element
185 of its bundle, it calls this method.
186 """
187 pass
188
189 def finish_bundle(self):
190 """Called after a bundle of elements is processed on a worker.
191 """
192 pass
193
194 def get_function_arguments(self, func):
195 """Return the function arguments based on the name provided. If they have
196 a _inspect_function attached to the class then use that otherwise default
197 to the python inspect library.
198 """
199 func_name = '_inspect_%s' % func
200 if hasattr(self, func_name):
201 f = getattr(self, func_name)
202 return f()
203 f = getattr(self, func)
204 return inspect.getargspec(f)
205
206 # TODO(sourabhbajaj): Do we want to remove the responsiblity of these from
207 # the DoFn or maybe the runner
208 def infer_output_type(self, input_type):
209 # TODO(robertwb): Side inputs types.
210 # TODO(robertwb): Assert compatibility with input type hint?
211 return self._strip_output_annotations(
212 trivial_inference.infer_return_type(self.process, [input_type]))
213
214 def _strip_output_annotations(self, type_hint):
215 annotations = (TimestampedValue, WindowedValue, pvalue.TaggedOutput)
216 # TODO(robertwb): These should be parameterized types that the
217 # type inferencer understands.
218 if (type_hint in annotations
219 or trivial_inference.element_type(type_hint) in annotations):
220 return Any
221 return type_hint
222
223 def _process_argspec_fn(self):
224 """Returns the Python callable that will eventually be invoked.
225
226 This should ideally be the user-level function that is called with
227 the main and (if any) side inputs, and is used to relate the type
228 hint parameters with the input parameters (e.g., by argument name).
229 """
230 return self.process
231
232 def is_process_bounded(self):
233 """Checks if an object is a bound method on an instance."""
234 if not isinstance(self.process, types.MethodType):
235 return False # Not a method
236 if self.process.im_self is None:
237 return False # Method is not bound
238 if issubclass(self.process.im_class, type) or \
239 self.process.im_class is types.ClassType:
240 return False # Method is a classmethod
241 return True
242
243 urns.RunnerApiFn.register_pickle_urn(urns.PICKLED_DO_FN)
244
245
246 def _fn_takes_side_inputs(fn):
247 try:
248 argspec = inspect.getargspec(fn)
249 except TypeError:
250 # We can't tell; maybe it does.
251 return True
252 is_bound = isinstance(fn, types.MethodType) and fn.im_self is not None
253 return len(argspec.args) > 1 + is_bound or argspec.varargs or argspec.keywords
254
255
256 class CallableWrapperDoFn(DoFn):
257 """For internal use only; no backwards-compatibility guarantees.
258
259 A DoFn (function) object wrapping a callable object.
260
261 The purpose of this class is to conveniently wrap simple functions and use
262 them in transforms.
263 """
264
265 def __init__(self, fn):
266 """Initializes a CallableWrapperDoFn object wrapping a callable.
267
268 Args:
269 fn: A callable object.
270
271 Raises:
272 TypeError: if fn parameter is not a callable type.
273 """
274 if not callable(fn):
275 raise TypeError('Expected a callable object instead of: %r' % fn)
276
277 self._fn = fn
278 if isinstance(fn, (
279 types.BuiltinFunctionType, types.MethodType, types.FunctionType)):
280 self.process = fn
281 else:
282 # For cases such as set / list where fn is callable but not a function
283 self.process = lambda element: fn(element)
284
285 super(CallableWrapperDoFn, self).__init__()
286
287 def display_data(self):
288 # If the callable has a name, then it's likely a function, and
289 # we show its name.
290 # Otherwise, it might be an instance of a callable class. We
291 # show its class.
292 display_data_value = (self._fn.__name__ if hasattr(self._fn, '__name__')
293 else self._fn.__class__)
294 return {'fn': DisplayDataItem(display_data_value,
295 label='Transform Function')}
296
297 def __repr__(self):
298 return 'CallableWrapperDoFn(%s)' % self._fn
299
300 def default_type_hints(self):
301 type_hints = get_type_hints(self._fn)
302 # If the fn was a DoFn annotated with a type-hint that hinted a return
303 # type compatible with Iterable[Any], then we strip off the outer
304 # container type due to the 'flatten' portion of FlatMap.
305 # TODO(robertwb): Should we require an iterable specification for FlatMap?
306 if type_hints.output_types:
307 args, kwargs = type_hints.output_types
308 if len(args) == 1 and is_consistent_with(args[0], Iterable[Any]):
309 type_hints = type_hints.copy()
310 type_hints.set_output_types(element_type(args[0]), **kwargs)
311 return type_hints
312
313 def infer_output_type(self, input_type):
314 return self._strip_output_annotations(
315 trivial_inference.infer_return_type(self._fn, [input_type]))
316
317 def _process_argspec_fn(self):
318 return getattr(self._fn, '_argspec_fn', self._fn)
319
320
321 class CombineFn(WithTypeHints, HasDisplayData, urns.RunnerApiFn):
322 """A function object used by a Combine transform with custom processing.
323
324 A CombineFn specifies how multiple values in all or part of a PCollection can
325 be merged into a single value---essentially providing the same kind of
326 information as the arguments to the Python "reduce" builtin (except for the
327 input argument, which is an instance of CombineFnProcessContext). The
328 combining process proceeds as follows:
329
330 1. Input values are partitioned into one or more batches.
331 2. For each batch, the create_accumulator method is invoked to create a fresh
332 initial "accumulator" value representing the combination of zero values.
333 3. For each input value in the batch, the add_input method is invoked to
334 combine more values with the accumulator for that batch.
335 4. The merge_accumulators method is invoked to combine accumulators from
336 separate batches into a single combined output accumulator value, once all
337 of the accumulators have had all the input value in their batches added to
338 them. This operation is invoked repeatedly, until there is only one
339 accumulator value left.
340 5. The extract_output operation is invoked on the final accumulator to get
341 the output value.
342 """
343
344 def default_label(self):
345 return self.__class__.__name__
346
347 def create_accumulator(self, *args, **kwargs):
348 """Return a fresh, empty accumulator for the combine operation.
349
350 Args:
351 *args: Additional arguments and side inputs.
352 **kwargs: Additional arguments and side inputs.
353 """
354 raise NotImplementedError(str(self))
355
356 def add_input(self, accumulator, element, *args, **kwargs):
357 """Return result of folding element into accumulator.
358
359 CombineFn implementors must override add_input.
360
361 Args:
362 accumulator: the current accumulator
363 element: the element to add
364 *args: Additional arguments and side inputs.
365 **kwargs: Additional arguments and side inputs.
366 """
367 raise NotImplementedError(str(self))
368
369 def add_inputs(self, accumulator, elements, *args, **kwargs):
370 """Returns the result of folding each element in elements into accumulator.
371
372 This is provided in case the implementation affords more efficient
373 bulk addition of elements. The default implementation simply loops
374 over the inputs invoking add_input for each one.
375
376 Args:
377 accumulator: the current accumulator
378 elements: the elements to add
379 *args: Additional arguments and side inputs.
380 **kwargs: Additional arguments and side inputs.
381 """
382 for element in elements:
383 accumulator = self.add_input(accumulator, element, *args, **kwargs)
384 return accumulator
385
386 def merge_accumulators(self, accumulators, *args, **kwargs):
387 """Returns the result of merging several accumulators
388 to a single accumulator value.
389
390 Args:
391 accumulators: the accumulators to merge
392 *args: Additional arguments and side inputs.
393 **kwargs: Additional arguments and side inputs.
394 """
395 raise NotImplementedError(str(self))
396
397 def extract_output(self, accumulator, *args, **kwargs):
398 """Return result of converting accumulator into the output value.
399
400 Args:
401 accumulator: the final accumulator value computed by this CombineFn
402 for the entire input key or PCollection.
403 *args: Additional arguments and side inputs.
404 **kwargs: Additional arguments and side inputs.
405 """
406 raise NotImplementedError(str(self))
407
408 def apply(self, elements, *args, **kwargs):
409 """Returns result of applying this CombineFn to the input values.
410
411 Args:
412 elements: the set of values to combine.
413 *args: Additional arguments and side inputs.
414 **kwargs: Additional arguments and side inputs.
415 """
416 return self.extract_output(
417 self.add_inputs(
418 self.create_accumulator(*args, **kwargs), elements,
419 *args, **kwargs),
420 *args, **kwargs)
421
422 def for_input_type(self, input_type):
423 """Returns a specialized implementation of self, if it exists.
424
425 Otherwise, returns self.
426
427 Args:
428 input_type: the type of input elements.
429 """
430 return self
431
432 @staticmethod
433 def from_callable(fn):
434 return CallableWrapperCombineFn(fn)
435
436 @staticmethod
437 def maybe_from_callable(fn):
438 return fn if isinstance(fn, CombineFn) else CallableWrapperCombineFn(fn)
439
440 def get_accumulator_coder(self):
441 return coders.registry.get_coder(object)
442
443 urns.RunnerApiFn.register_pickle_urn(urns.PICKLED_COMBINE_FN)
444
445
446 class CallableWrapperCombineFn(CombineFn):
447 """For internal use only; no backwards-compatibility guarantees.
448
449 A CombineFn (function) object wrapping a callable object.
450
451 The purpose of this class is to conveniently wrap simple functions and use
452 them in Combine transforms.
453 """
454 _EMPTY = object()
455
456 def __init__(self, fn):
457 """Initializes a CallableFn object wrapping a callable.
458
459 Args:
460 fn: A callable object that reduces elements of an iterable to a single
461 value (like the builtins sum and max). This callable must be capable of
462 receiving the kind of values it generates as output in its input, and
463 for best results, its operation must be commutative and associative.
464
465 Raises:
466 TypeError: if fn parameter is not a callable type.
467 """
468 if not callable(fn):
469 raise TypeError('Expected a callable object instead of: %r' % fn)
470
471 super(CallableWrapperCombineFn, self).__init__()
472 self._fn = fn
473
474 def display_data(self):
475 return {'fn_dd': self._fn}
476
477 def __repr__(self):
478 return "CallableWrapperCombineFn(%s)" % self._fn
479
480 def create_accumulator(self, *args, **kwargs):
481 return self._EMPTY
482
483 def add_input(self, accumulator, element, *args, **kwargs):
484 if accumulator is self._EMPTY:
485 return element
486 return self._fn([accumulator, element], *args, **kwargs)
487
488 def add_inputs(self, accumulator, elements, *args, **kwargs):
489 if accumulator is self._EMPTY:
490 return self._fn(elements, *args, **kwargs)
491 elif isinstance(elements, (list, tuple)):
492 return self._fn([accumulator] + list(elements), *args, **kwargs)
493
494 def union():
495 yield accumulator
496 for e in elements:
497 yield e
498 return self._fn(union(), *args, **kwargs)
499
500 def merge_accumulators(self, accumulators, *args, **kwargs):
501 # It's (weakly) assumed that self._fn is associative.
502 return self._fn(accumulators, *args, **kwargs)
503
504 def extract_output(self, accumulator, *args, **kwargs):
505 return self._fn(()) if accumulator is self._EMPTY else accumulator
506
507 def default_type_hints(self):
508 fn_hints = get_type_hints(self._fn)
509 if fn_hints.input_types is None:
510 return fn_hints
511 else:
512 # fn(Iterable[V]) -> V becomes CombineFn(V) -> V
513 input_args, input_kwargs = fn_hints.input_types
514 if not input_args:
515 if len(input_kwargs) == 1:
516 input_args, input_kwargs = tuple(input_kwargs.values()), {}
517 else:
518 raise TypeError('Combiner input type must be specified positionally.')
519 if not is_consistent_with(input_args[0], Iterable[Any]):
520 raise TypeCheckError(
521 'All functions for a Combine PTransform must accept a '
522 'single argument compatible with: Iterable[Any]. '
523 'Instead a function with input type: %s was received.'
524 % input_args[0])
525 input_args = (element_type(input_args[0]),) + input_args[1:]
526 # TODO(robertwb): Assert output type is consistent with input type?
527 hints = fn_hints.copy()
528 hints.set_input_types(*input_args, **input_kwargs)
529 return hints
530
531 def for_input_type(self, input_type):
532 # Avoid circular imports.
533 from apache_beam.transforms import cy_combiners
534 if self._fn is any:
535 return cy_combiners.AnyCombineFn()
536 elif self._fn is all:
537 return cy_combiners.AllCombineFn()
538 else:
539 known_types = {
540 (sum, int): cy_combiners.SumInt64Fn(),
541 (min, int): cy_combiners.MinInt64Fn(),
542 (max, int): cy_combiners.MaxInt64Fn(),
543 (sum, float): cy_combiners.SumFloatFn(),
544 (min, float): cy_combiners.MinFloatFn(),
545 (max, float): cy_combiners.MaxFloatFn(),
546 }
547 return known_types.get((self._fn, input_type), self)
548
549
550 class PartitionFn(WithTypeHints):
551 """A function object used by a Partition transform.
552
553 A PartitionFn specifies how individual values in a PCollection will be placed
554 into separate partitions, indexed by an integer.
555 """
556
557 def default_label(self):
558 return self.__class__.__name__
559
560 def partition_for(self, element, num_partitions, *args, **kwargs):
561 """Specify which partition will receive this element.
562
563 Args:
564 element: An element of the input PCollection.
565 num_partitions: Number of partitions, i.e., output PCollections.
566 *args: optional parameters and side inputs.
567 **kwargs: optional parameters and side inputs.
568
569 Returns:
570 An integer in [0, num_partitions).
571 """
572 pass
573
574
575 class CallableWrapperPartitionFn(PartitionFn):
576 """For internal use only; no backwards-compatibility guarantees.
577
578 A PartitionFn object wrapping a callable object.
579
580 Instances of this class wrap simple functions for use in Partition operations.
581 """
582
583 def __init__(self, fn):
584 """Initializes a PartitionFn object wrapping a callable.
585
586 Args:
587 fn: A callable object, which should accept the following arguments:
588 element - element to assign to a partition.
589 num_partitions - number of output partitions.
590 and may accept additional arguments and side inputs.
591
592 Raises:
593 TypeError: if fn is not a callable type.
594 """
595 if not callable(fn):
596 raise TypeError('Expected a callable object instead of: %r' % fn)
597 self._fn = fn
598
599 def partition_for(self, element, num_partitions, *args, **kwargs):
600 return self._fn(element, num_partitions, *args, **kwargs)
601
602
603 class ParDo(PTransformWithSideInputs):
604 """A ParDo transform.
605
606 Processes an input PCollection by applying a DoFn to each element and
607 returning the accumulated results into an output PCollection. The type of the
608 elements is not fixed as long as the DoFn can deal with it. In reality
609 the type is restrained to some extent because the elements sometimes must be
610 persisted to external storage. See the expand() method comments for a detailed
611 description of all possible arguments.
612
613 Note that the DoFn must return an iterable for each element of the input
614 PCollection. An easy way to do this is to use the yield keyword in the
615 process method.
616
617 Args:
618 pcoll: a PCollection to be processed.
619 fn: a DoFn object to be applied to each element of pcoll argument.
620 *args: positional arguments passed to the dofn object.
621 **kwargs: keyword arguments passed to the dofn object.
622
623 Note that the positional and keyword arguments will be processed in order
624 to detect PCollections that will be computed as side inputs to the
625 transform. During pipeline execution whenever the DoFn object gets executed
626 (its apply() method gets called) the PCollection arguments will be replaced
627 by values from the PCollection in the exact positions where they appear in
628 the argument lists.
629 """
630
631 def __init__(self, fn, *args, **kwargs):
632 super(ParDo, self).__init__(fn, *args, **kwargs)
633 # TODO(robertwb): Change all uses of the dofn attribute to use fn instead.
634 self.dofn = self.fn
635 self.output_tags = set()
636
637 if not isinstance(self.fn, DoFn):
638 raise TypeError('ParDo must be called with a DoFn instance.')
639
640 # Validate the DoFn by creating a DoFnSignature
641 from apache_beam.runners.common import DoFnSignature
642 DoFnSignature(self.fn)
643
644 def default_type_hints(self):
645 return self.fn.get_type_hints()
646
647 def infer_output_type(self, input_type):
648 return trivial_inference.element_type(
649 self.fn.infer_output_type(input_type))
650
651 def make_fn(self, fn):
652 if isinstance(fn, DoFn):
653 return fn
654 return CallableWrapperDoFn(fn)
655
656 def _process_argspec_fn(self):
657 return self.fn._process_argspec_fn()
658
659 def display_data(self):
660 return {'fn': DisplayDataItem(self.fn.__class__,
661 label='Transform Function'),
662 'fn_dd': self.fn}
663
664 def expand(self, pcoll):
665 return pvalue.PCollection(pcoll.pipeline)
666
667 def with_outputs(self, *tags, **main_kw):
668 """Returns a tagged tuple allowing access to the outputs of a ParDo.
669
670 The resulting object supports access to the
671 PCollection associated with a tag (e.g., o.tag, o[tag]) and iterating over
672 the available tags (e.g., for tag in o: ...).
673
674 Args:
675 *tags: if non-empty, list of valid tags. If a list of valid tags is given,
676 it will be an error to use an undeclared tag later in the pipeline.
677 **main_kw: dictionary empty or with one key 'main' defining the tag to be
678 used for the main output (which will not have a tag associated with it).
679
680 Returns:
681 An object of type DoOutputsTuple that bundles together all the outputs
682 of a ParDo transform and allows accessing the individual
683 PCollections for each output using an object.tag syntax.
684
685 Raises:
686 TypeError: if the self object is not a PCollection that is the result of
687 a ParDo transform.
688 ValueError: if main_kw contains any key other than 'main'.
689 """
690 main_tag = main_kw.pop('main', None)
691 if main_kw:
692 raise ValueError('Unexpected keyword arguments: %s' % main_kw.keys())
693 return _MultiParDo(self, tags, main_tag)
694
695 def _pardo_fn_data(self):
696 si_tags_and_types = []
697 windowing = None
698 return self.fn, self.args, self.kwargs, si_tags_and_types, windowing
699
700 def to_runner_api_parameter(self, context):
701 assert self.__class__ is ParDo
702 picked_pardo_fn_data = pickler.dumps(self._pardo_fn_data())
703 return (
704 urns.PARDO_TRANSFORM,
705 beam_runner_api_pb2.ParDoPayload(
706 do_fn=beam_runner_api_pb2.SdkFunctionSpec(
707 spec=beam_runner_api_pb2.FunctionSpec(
708 urn=urns.PICKLED_DO_FN_INFO,
709 any_param=proto_utils.pack_Any(
710 wrappers_pb2.BytesValue(
711 value=picked_pardo_fn_data)),
712 payload=picked_pardo_fn_data))))
713
714 @PTransform.register_urn(
715 urns.PARDO_TRANSFORM, beam_runner_api_pb2.ParDoPayload)
716 def from_runner_api_parameter(pardo_payload, context):
717 assert pardo_payload.do_fn.spec.urn == urns.PICKLED_DO_FN_INFO
718 fn, args, kwargs, si_tags_and_types, windowing = pickler.loads(
719 pardo_payload.do_fn.spec.payload)
720 if si_tags_and_types:
721 raise NotImplementedError('deferred side inputs')
722 elif windowing:
723 raise NotImplementedError('explicit windowing')
724 return ParDo(fn, *args, **kwargs)
725
726
727 class _MultiParDo(PTransform):
728
729 def __init__(self, do_transform, tags, main_tag):
730 super(_MultiParDo, self).__init__(do_transform.label)
731 self._do_transform = do_transform
732 self._tags = tags
733 self._main_tag = main_tag
734
735 def expand(self, pcoll):
736 _ = pcoll | self._do_transform
737 return pvalue.DoOutputsTuple(
738 pcoll.pipeline, self._do_transform, self._tags, self._main_tag)
739
740
741 def FlatMap(fn, *args, **kwargs): # pylint: disable=invalid-name
742 """FlatMap is like ParDo except it takes a callable to specify the
743 transformation.
744
745 The callable must return an iterable for each element of the input
746 PCollection. The elements of these iterables will be flattened into
747 the output PCollection.
748
749 Args:
750 fn: a callable object.
751 *args: positional arguments passed to the transform callable.
752 **kwargs: keyword arguments passed to the transform callable.
753
754 Returns:
755 A PCollection containing the Map outputs.
756
757 Raises:
758 TypeError: If the fn passed as argument is not a callable. Typical error
759 is to pass a DoFn instance which is supported only for ParDo.
760 """
761 label = 'FlatMap(%s)' % ptransform.label_from_callable(fn)
762 if not callable(fn):
763 raise TypeError(
764 'FlatMap can be used only with callable objects. '
765 'Received %r instead.' % (fn))
766
767 pardo = ParDo(CallableWrapperDoFn(fn), *args, **kwargs)
768 pardo.label = label
769 return pardo
770
771
772 def Map(fn, *args, **kwargs): # pylint: disable=invalid-name
773 """Map is like FlatMap except its callable returns only a single element.
774
775 Args:
776 fn: a callable object.
777 *args: positional arguments passed to the transform callable.
778 **kwargs: keyword arguments passed to the transform callable.
779
780 Returns:
781 A PCollection containing the Map outputs.
782
783 Raises:
784 TypeError: If the fn passed as argument is not a callable. Typical error
785 is to pass a DoFn instance which is supported only for ParDo.
786 """
787 if not callable(fn):
788 raise TypeError(
789 'Map can be used only with callable objects. '
790 'Received %r instead.' % (fn))
791 if _fn_takes_side_inputs(fn):
792 wrapper = lambda x, *args, **kwargs: [fn(x, *args, **kwargs)]
793 else:
794 wrapper = lambda x: [fn(x)]
795
796 label = 'Map(%s)' % ptransform.label_from_callable(fn)
797
798 # TODO. What about callable classes?
799 if hasattr(fn, '__name__'):
800 wrapper.__name__ = fn.__name__
801
802 # Proxy the type-hint information from the original function to this new
803 # wrapped function.
804 get_type_hints(wrapper).input_types = get_type_hints(fn).input_types
805 output_hint = get_type_hints(fn).simple_output_type(label)
806 if output_hint:
807 get_type_hints(wrapper).set_output_types(typehints.Iterable[output_hint])
808 # pylint: disable=protected-access
809 wrapper._argspec_fn = fn
810 # pylint: enable=protected-access
811
812 pardo = FlatMap(wrapper, *args, **kwargs)
813 pardo.label = label
814 return pardo
815
816
817 def Filter(fn, *args, **kwargs): # pylint: disable=invalid-name
818 """Filter is a FlatMap with its callable filtering out elements.
819
820 Args:
821 fn: a callable object.
822 *args: positional arguments passed to the transform callable.
823 **kwargs: keyword arguments passed to the transform callable.
824
825 Returns:
826 A PCollection containing the Filter outputs.
827
828 Raises:
829 TypeError: If the fn passed as argument is not a callable. Typical error
830 is to pass a DoFn instance which is supported only for FlatMap.
831 """
832 if not callable(fn):
833 raise TypeError(
834 'Filter can be used only with callable objects. '
835 'Received %r instead.' % (fn))
836 wrapper = lambda x, *args, **kwargs: [x] if fn(x, *args, **kwargs) else []
837
838 label = 'Filter(%s)' % ptransform.label_from_callable(fn)
839
840 # TODO: What about callable classes?
841 if hasattr(fn, '__name__'):
842 wrapper.__name__ = fn.__name__
843 # Proxy the type-hint information from the function being wrapped, setting the
844 # output type to be the same as the input type.
845 get_type_hints(wrapper).input_types = get_type_hints(fn).input_types
846 output_hint = get_type_hints(fn).simple_output_type(label)
847 if (output_hint is None
848 and get_type_hints(wrapper).input_types
849 and get_type_hints(wrapper).input_types[0]):
850 output_hint = get_type_hints(wrapper).input_types[0]
851 if output_hint:
852 get_type_hints(wrapper).set_output_types(typehints.Iterable[output_hint])
853 # pylint: disable=protected-access
854 wrapper._argspec_fn = fn
855 # pylint: enable=protected-access
856
857 pardo = FlatMap(wrapper, *args, **kwargs)
858 pardo.label = label
859 return pardo
860
861
862 def _combine_payload(combine_fn, context):
863 return beam_runner_api_pb2.CombinePayload(
864 combine_fn=combine_fn.to_runner_api(context),
865 accumulator_coder_id=context.coders.get_id(
866 combine_fn.get_accumulator_coder()))
867
868
869 class CombineGlobally(PTransform):
870 """A CombineGlobally transform.
871
872 Reduces a PCollection to a single value by progressively applying a CombineFn
873 to portions of the PCollection (and to intermediate values created thereby).
874 See documentation in CombineFn for details on the specifics on how CombineFns
875 are applied.
876
877 Args:
878 pcoll: a PCollection to be reduced into a single value.
879 fn: a CombineFn object that will be called to progressively reduce the
880 PCollection into single values, or a callable suitable for wrapping
881 by CallableWrapperCombineFn.
882 *args: positional arguments passed to the CombineFn object.
883 **kwargs: keyword arguments passed to the CombineFn object.
884
885 Raises:
886 TypeError: If the output type of the input PCollection is not compatible
887 with Iterable[A].
888
889 Returns:
890 A single-element PCollection containing the main output of the Combine
891 transform.
892
893 Note that the positional and keyword arguments will be processed in order
894 to detect PObjects that will be computed as side inputs to the transform.
895 During pipeline execution whenever the CombineFn object gets executed (i.e.,
896 any of the CombineFn methods get called), the PObject arguments will be
897 replaced by their actual value in the exact position where they appear in
898 the argument lists.
899 """
900 has_defaults = True
901 as_view = False
902
903 def __init__(self, fn, *args, **kwargs):
904 if not (isinstance(fn, CombineFn) or callable(fn)):
905 raise TypeError(
906 'CombineGlobally can be used only with combineFn objects. '
907 'Received %r instead.' % (fn))
908
909 super(CombineGlobally, self).__init__()
910 self.fn = fn
911 self.args = args
912 self.kwargs = kwargs
913
914 def display_data(self):
915 return {'combine_fn':
916 DisplayDataItem(self.fn.__class__, label='Combine Function'),
917 'combine_fn_dd':
918 self.fn}
919
920 def default_label(self):
921 return 'CombineGlobally(%s)' % ptransform.label_from_callable(self.fn)
922
923 def _clone(self, **extra_attributes):
924 clone = copy.copy(self)
925 clone.__dict__.update(extra_attributes)
926 return clone
927
928 def with_defaults(self, has_defaults=True):
929 return self._clone(has_defaults=has_defaults)
930
931 def without_defaults(self):
932 return self.with_defaults(False)
933
934 def as_singleton_view(self):
935 return self._clone(as_view=True)
936
937 def expand(self, pcoll):
938 def add_input_types(transform):
939 type_hints = self.get_type_hints()
940 if type_hints.input_types:
941 return transform.with_input_types(type_hints.input_types[0][0])
942 return transform
943
944 combined = (pcoll
945 | 'KeyWithVoid' >> add_input_types(
946 Map(lambda v: (None, v)).with_output_types(
947 KV[None, pcoll.element_type]))
948 | 'CombinePerKey' >> CombinePerKey(
949 self.fn, *self.args, **self.kwargs)
950 | 'UnKey' >> Map(lambda (k, v): v))
951
952 if not self.has_defaults and not self.as_view:
953 return combined
954
955 if self.has_defaults:
956 combine_fn = (
957 self.fn if isinstance(self.fn, CombineFn)
958 else CombineFn.from_callable(self.fn))
959 default_value = combine_fn.apply([], *self.args, **self.kwargs)
960 else:
961 default_value = pvalue.AsSingleton._NO_DEFAULT # pylint: disable=protected-access
962 view = pvalue.AsSingleton(combined, default_value=default_value)
963 if self.as_view:
964 return view
965 else:
966 if pcoll.windowing.windowfn != GlobalWindows():
967 raise ValueError(
968 "Default values are not yet supported in CombineGlobally() if the "
969 "output PCollection is not windowed by GlobalWindows. "
970 "Instead, use CombineGlobally().without_defaults() to output "
971 "an empty PCollection if the input PCollection is empty, "
972 "or CombineGlobally().as_singleton_view() to get the default "
973 "output of the CombineFn if the input PCollection is empty.")
974
975 def typed(transform):
976 # TODO(robertwb): We should infer this.
977 if combined.element_type:
978 return transform.with_output_types(combined.element_type)
979 return transform
980 return (pcoll.pipeline
981 | 'DoOnce' >> Create([None])
982 | 'InjectDefault' >> typed(Map(lambda _, s: s, view)))
983
984
985 class CombinePerKey(PTransformWithSideInputs):
986 """A per-key Combine transform.
987
988 Identifies sets of values associated with the same key in the input
989 PCollection, then applies a CombineFn to condense those sets to single
990 values. See documentation in CombineFn for details on the specifics on how
991 CombineFns are applied.
992
993 Args:
994 pcoll: input pcollection.
995 fn: instance of CombineFn to apply to all values under the same key in
996 pcoll, or a callable whose signature is ``f(iterable, *args, **kwargs)``
997 (e.g., sum, max).
998 *args: arguments and side inputs, passed directly to the CombineFn.
999 **kwargs: arguments and side inputs, passed directly to the CombineFn.
1000
1001 Returns:
1002 A PObject holding the result of the combine operation.
1003 """
1004 def display_data(self):
1005 return {'combine_fn':
1006 DisplayDataItem(self.fn.__class__, label='Combine Function'),
1007 'combine_fn_dd':
1008 self.fn}
1009
1010 def make_fn(self, fn):
1011 self._fn_label = ptransform.label_from_callable(fn)
1012 return fn if isinstance(fn, CombineFn) else CombineFn.from_callable(fn)
1013
1014 def default_label(self):
1015 return '%s(%s)' % (self.__class__.__name__, self._fn_label)
1016
1017 def _process_argspec_fn(self):
1018 return self.fn._fn # pylint: disable=protected-access
1019
1020 def expand(self, pcoll):
1021 args, kwargs = util.insert_values_in_args(
1022 self.args, self.kwargs, self.side_inputs)
1023 return pcoll | GroupByKey() | 'Combine' >> CombineValues(
1024 self.fn, *args, **kwargs)
1025
1026 def to_runner_api_parameter(self, context):
1027 return (
1028 urns.COMBINE_PER_KEY_TRANSFORM,
1029 _combine_payload(self.fn, context))
1030
1031 @PTransform.register_urn(
1032 urns.COMBINE_PER_KEY_TRANSFORM, beam_runner_api_pb2.CombinePayload)
1033 def from_runner_api_parameter(combine_payload, context):
1034 return CombinePerKey(
1035 CombineFn.from_runner_api(combine_payload.combine_fn, context))
1036
1037
1038 # TODO(robertwb): Rename to CombineGroupedValues?
1039 class CombineValues(PTransformWithSideInputs):
1040
1041 def make_fn(self, fn):
1042 return fn if isinstance(fn, CombineFn) else CombineFn.from_callable(fn)
1043
1044 def expand(self, pcoll):
1045 args, kwargs = util.insert_values_in_args(
1046 self.args, self.kwargs, self.side_inputs)
1047
1048 input_type = pcoll.element_type
1049 key_type = None
1050 if input_type is not None:
1051 key_type, _ = input_type.tuple_types
1052
1053 runtime_type_check = (
1054 pcoll.pipeline._options.view_as(TypeOptions).runtime_type_check)
1055 return pcoll | ParDo(
1056 CombineValuesDoFn(key_type, self.fn, runtime_type_check),
1057 *args, **kwargs)
1058
1059 def to_runner_api_parameter(self, context):
1060 return (
1061 urns.COMBINE_GROUPED_VALUES_TRANSFORM,
1062 _combine_payload(self.fn, context))
1063
1064 @PTransform.register_urn(
1065 urns.COMBINE_GROUPED_VALUES_TRANSFORM, beam_runner_api_pb2.CombinePayload)
1066 def from_runner_api_parameter(combine_payload, context):
1067 return CombineValues(
1068 CombineFn.from_runner_api(combine_payload.combine_fn, context))
1069
1070
1071 class CombineValuesDoFn(DoFn):
1072 """DoFn for performing per-key Combine transforms."""
1073
1074 def __init__(self, input_pcoll_type, combinefn, runtime_type_check):
1075 super(CombineValuesDoFn, self).__init__()
1076 self.combinefn = combinefn
1077 self.runtime_type_check = runtime_type_check
1078
1079 def process(self, element, *args, **kwargs):
1080 # Expected elements input to this DoFn are 2-tuples of the form
1081 # (key, iter), with iter an iterable of all the values associated with key
1082 # in the input PCollection.
1083 if self.runtime_type_check:
1084 # Apply the combiner in a single operation rather than artificially
1085 # breaking it up so that output type violations manifest as TypeCheck
1086 # errors rather than type errors.
1087 return [
1088 (element[0],
1089 self.combinefn.apply(element[1], *args, **kwargs))]
1090
1091 # Add the elements into three accumulators (for testing of merge).
1092 elements = list(element[1])
1093 accumulators = []
1094 for k in range(3):
1095 if len(elements) <= k:
1096 break
1097 accumulators.append(
1098 self.combinefn.add_inputs(
1099 self.combinefn.create_accumulator(*args, **kwargs),
1100 elements[k::3],
1101 *args, **kwargs))
1102 # Merge the accumulators.
1103 accumulator = self.combinefn.merge_accumulators(
1104 accumulators, *args, **kwargs)
1105 # Convert accumulator to the final result.
1106 return [(element[0],
1107 self.combinefn.extract_output(accumulator, *args, **kwargs))]
1108
1109 def default_type_hints(self):
1110 hints = self.combinefn.get_type_hints().copy()
1111 if hints.input_types:
1112 K = typehints.TypeVariable('K')
1113 args, kwargs = hints.input_types
1114 args = (typehints.Tuple[K, typehints.Iterable[args[0]]],) + args[1:]
1115 hints.set_input_types(*args, **kwargs)
1116 else:
1117 K = typehints.Any
1118 if hints.output_types:
1119 main_output_type = hints.simple_output_type('')
1120 hints.set_output_types(typehints.Tuple[K, main_output_type])
1121 return hints
1122
1123
1124 @typehints.with_input_types(typehints.KV[K, V])
1125 @typehints.with_output_types(typehints.KV[K, typehints.Iterable[V]])
1126 class GroupByKey(PTransform):
1127 """A group by key transform.
1128
1129 Processes an input PCollection consisting of key/value pairs represented as a
1130 tuple pair. The result is a PCollection where values having a common key are
1131 grouped together. For example (a, 1), (b, 2), (a, 3) will result into
1132 (a, [1, 3]), (b, [2]).
1133
1134 The implementation here is used only when run on the local direct runner.
1135 """
1136
1137 class ReifyWindows(DoFn):
1138
1139 def process(self, element, window=DoFn.WindowParam,
1140 timestamp=DoFn.TimestampParam):
1141 try:
1142 k, v = element
1143 except TypeError:
1144 raise TypeCheckError('Input to GroupByKey must be a PCollection with '
1145 'elements compatible with KV[A, B]')
1146
1147 return [(k, WindowedValue(v, timestamp, [window]))]
1148
1149 def infer_output_type(self, input_type):
1150 key_type, value_type = trivial_inference.key_value_types(input_type)
1151 return Iterable[KV[key_type, typehints.WindowedValue[value_type]]]
1152
1153 def expand(self, pcoll):
1154 # This code path is only used in the local direct runner. For Dataflow
1155 # runner execution, the GroupByKey transform is expanded on the service.
1156 input_type = pcoll.element_type
1157 if input_type is not None:
1158 # Initialize type-hints used below to enforce type-checking and to pass
1159 # downstream to further PTransforms.
1160 key_type, value_type = trivial_inference.key_value_types(input_type)
1161 typecoders.registry.verify_deterministic(
1162 typecoders.registry.get_coder(key_type),
1163 'GroupByKey operation "%s"' % self.label)
1164
1165 reify_output_type = KV[key_type, typehints.WindowedValue[value_type]]
1166 gbk_input_type = (
1167 KV[key_type, Iterable[typehints.WindowedValue[value_type]]])
1168 gbk_output_type = KV[key_type, Iterable[value_type]]
1169
1170 # pylint: disable=bad-continuation
1171 return (pcoll
1172 | 'ReifyWindows' >> (ParDo(self.ReifyWindows())
1173 .with_output_types(reify_output_type))
1174 | 'GroupByKey' >> (_GroupByKeyOnly()
1175 .with_input_types(reify_output_type)
1176 .with_output_types(gbk_input_type))
1177 | ('GroupByWindow' >> _GroupAlsoByWindow(pcoll.windowing)
1178 .with_input_types(gbk_input_type)
1179 .with_output_types(gbk_output_type)))
1180 else:
1181 # The input_type is None, run the default
1182 return (pcoll
1183 | 'ReifyWindows' >> ParDo(self.ReifyWindows())
1184 | 'GroupByKey' >> _GroupByKeyOnly()
1185 | 'GroupByWindow' >> _GroupAlsoByWindow(pcoll.windowing))
1186
1187 def to_runner_api_parameter(self, unused_context):
1188 return urns.GROUP_BY_KEY_TRANSFORM, None
1189
1190 @PTransform.register_urn(urns.GROUP_BY_KEY_TRANSFORM, None)
1191 def from_runner_api_parameter(unused_payload, unused_context):
1192 return GroupByKey()
1193
1194
1195 @typehints.with_input_types(typehints.KV[K, V])
1196 @typehints.with_output_types(typehints.KV[K, typehints.Iterable[V]])
1197 class _GroupByKeyOnly(PTransform):
1198 """A group by key transform, ignoring windows."""
1199 def infer_output_type(self, input_type):
1200 key_type, value_type = trivial_inference.key_value_types(input_type)
1201 return KV[key_type, Iterable[value_type]]
1202
1203 def expand(self, pcoll):
1204 self._check_pcollection(pcoll)
1205 return pvalue.PCollection(pcoll.pipeline)
1206
1207 def to_runner_api_parameter(self, unused_context):
1208 return urns.GROUP_BY_KEY_ONLY_TRANSFORM, None
1209
1210 @PTransform.register_urn(urns.GROUP_BY_KEY_ONLY_TRANSFORM, None)
1211 def from_runner_api_parameter(unused_payload, unused_context):
1212 return _GroupByKeyOnly()
1213
1214
1215 @typehints.with_input_types(typehints.KV[K, typehints.Iterable[V]])
1216 @typehints.with_output_types(typehints.KV[K, typehints.Iterable[V]])
1217 class _GroupAlsoByWindow(ParDo):
1218 """The GroupAlsoByWindow transform."""
1219 def __init__(self, windowing):
1220 super(_GroupAlsoByWindow, self).__init__(
1221 _GroupAlsoByWindowDoFn(windowing))
1222 self.windowing = windowing
1223
1224 def expand(self, pcoll):
1225 self._check_pcollection(pcoll)
1226 return pvalue.PCollection(pcoll.pipeline)
1227
1228 def to_runner_api_parameter(self, context):
1229 return (
1230 urns.GROUP_ALSO_BY_WINDOW_TRANSFORM,
1231 wrappers_pb2.BytesValue(value=context.windowing_strategies.get_id(
1232 self.windowing)))
1233
1234 @PTransform.register_urn(
1235 urns.GROUP_ALSO_BY_WINDOW_TRANSFORM, wrappers_pb2.BytesValue)
1236 def from_runner_api_parameter(payload, context):
1237 return _GroupAlsoByWindow(
1238 context.windowing_strategies.get_by_id(payload.value))
1239
1240
1241 class _GroupAlsoByWindowDoFn(DoFn):
1242 # TODO(robertwb): Support combiner lifting.
1243
1244 def __init__(self, windowing):
1245 super(_GroupAlsoByWindowDoFn, self).__init__()
1246 self.windowing = windowing
1247
1248 def infer_output_type(self, input_type):
1249 key_type, windowed_value_iter_type = trivial_inference.key_value_types(
1250 input_type)
1251 value_type = windowed_value_iter_type.inner_type.inner_type
1252 return Iterable[KV[key_type, Iterable[value_type]]]
1253
1254 def start_bundle(self):
1255 # pylint: disable=wrong-import-order, wrong-import-position
1256 from apache_beam.transforms.trigger import InMemoryUnmergedState
1257 from apache_beam.transforms.trigger import create_trigger_driver
1258 # pylint: enable=wrong-import-order, wrong-import-position
1259 self.driver = create_trigger_driver(self.windowing, True)
1260 self.state_type = InMemoryUnmergedState
1261
1262 def process(self, element):
1263 k, vs = element
1264 state = self.state_type()
1265 # TODO(robertwb): Conditionally process in smaller chunks.
1266 for wvalue in self.driver.process_elements(state, vs, MIN_TIMESTAMP):
1267 yield wvalue.with_value((k, wvalue.value))
1268 while state.timers:
1269 fired = state.get_and_clear_timers()
1270 for timer_window, (name, time_domain, fire_time) in fired:
1271 for wvalue in self.driver.process_timer(
1272 timer_window, name, time_domain, fire_time, state):
1273 yield wvalue.with_value((k, wvalue.value))
1274
1275
1276 class Partition(PTransformWithSideInputs):
1277 """Split a PCollection into several partitions.
1278
1279 Uses the specified PartitionFn to separate an input PCollection into the
1280 specified number of sub-PCollections.
1281
1282 When apply()d, a Partition() PTransform requires the following:
1283
1284 Args:
1285 partitionfn: a PartitionFn, or a callable with the signature described in
1286 CallableWrapperPartitionFn.
1287 n: number of output partitions.
1288
1289 The result of this PTransform is a simple list of the output PCollections
1290 representing each of n partitions, in order.
1291 """
1292
1293 class ApplyPartitionFnFn(DoFn):
1294 """A DoFn that applies a PartitionFn."""
1295
1296 def process(self, element, partitionfn, n, *args, **kwargs):
1297 partition = partitionfn.partition_for(element, n, *args, **kwargs)
1298 if not 0 <= partition < n:
1299 raise ValueError(
1300 'PartitionFn specified out-of-bounds partition index: '
1301 '%d not in [0, %d)' % (partition, n))
1302 # Each input is directed into the output that corresponds to the
1303 # selected partition.
1304 yield pvalue.TaggedOutput(str(partition), element)
1305
1306 def make_fn(self, fn):
1307 return fn if isinstance(fn, PartitionFn) else CallableWrapperPartitionFn(fn)
1308
1309 def expand(self, pcoll):
1310 n = int(self.args[0])
1311 return pcoll | ParDo(
1312 self.ApplyPartitionFnFn(), self.fn, *self.args,
1313 **self.kwargs).with_outputs(*[str(t) for t in range(n)])
1314
1315
1316 class Windowing(object):
1317
1318 def __init__(self, windowfn, triggerfn=None, accumulation_mode=None,
1319 timestamp_combiner=None):
1320 global AccumulationMode, DefaultTrigger # pylint: disable=global-variable-not-assigned
1321 # pylint: disable=wrong-import-order, wrong-import-position
1322 from apache_beam.transforms.trigger import AccumulationMode, DefaultTrigger
1323 # pylint: enable=wrong-import-order, wrong-import-position
1324 if triggerfn is None:
1325 triggerfn = DefaultTrigger()
1326 if accumulation_mode is None:
1327 if triggerfn == DefaultTrigger():
1328 accumulation_mode = AccumulationMode.DISCARDING
1329 else:
1330 raise ValueError(
1331 'accumulation_mode must be provided for non-trivial triggers')
1332 if not windowfn.get_window_coder().is_deterministic():
1333 raise ValueError(
1334 'window fn (%s) does not have a determanistic coder (%s)' % (
1335 window_fn, windowfn.get_window_coder()))
1336 self.windowfn = windowfn
1337 self.triggerfn = triggerfn
1338 self.accumulation_mode = accumulation_mode
1339 self.timestamp_combiner = (
1340 timestamp_combiner or TimestampCombiner.OUTPUT_AT_EOW)
1341 self._is_default = (
1342 self.windowfn == GlobalWindows() and
1343 self.triggerfn == DefaultTrigger() and
1344 self.accumulation_mode == AccumulationMode.DISCARDING and
1345 self.timestamp_combiner == TimestampCombiner.OUTPUT_AT_EOW)
1346
1347 def __repr__(self):
1348 return "Windowing(%s, %s, %s, %s)" % (self.windowfn, self.triggerfn,
1349 self.accumulation_mode,
1350 self.timestamp_combiner)
1351
1352 def __eq__(self, other):
1353 if type(self) == type(other):
1354 if self._is_default and other._is_default:
1355 return True
1356 return (
1357 self.windowfn == other.windowfn
1358 and self.triggerfn == other.triggerfn
1359 and self.accumulation_mode == other.accumulation_mode
1360 and self.timestamp_combiner == other.timestamp_combiner)
1361 return False
1362
1363 def is_default(self):
1364 return self._is_default
1365
1366 def to_runner_api(self, context):
1367 return beam_runner_api_pb2.WindowingStrategy(
1368 window_fn=self.windowfn.to_runner_api(context),
1369 # TODO(robertwb): Prohibit implicit multi-level merging.
1370 merge_status=(beam_runner_api_pb2.NEEDS_MERGE
1371 if self.windowfn.is_merging()
1372 else beam_runner_api_pb2.NON_MERGING),
1373 window_coder_id=context.coders.get_id(
1374 self.windowfn.get_window_coder()),
1375 trigger=self.triggerfn.to_runner_api(context),
1376 accumulation_mode=self.accumulation_mode,
1377 output_time=self.timestamp_combiner,
1378 # TODO(robertwb): Support EMIT_IF_NONEMPTY
1379 closing_behavior=beam_runner_api_pb2.EMIT_ALWAYS,
1380 allowed_lateness=0)
1381
1382 @staticmethod
1383 def from_runner_api(proto, context):
1384 # pylint: disable=wrong-import-order, wrong-import-position
1385 from apache_beam.transforms.trigger import TriggerFn
1386 return Windowing(
1387 windowfn=WindowFn.from_runner_api(proto.window_fn, context),
1388 triggerfn=TriggerFn.from_runner_api(proto.trigger, context),
1389 accumulation_mode=proto.accumulation_mode,
1390 timestamp_combiner=proto.output_time)
1391
1392
1393 @typehints.with_input_types(T)
1394 @typehints.with_output_types(T)
1395 class WindowInto(ParDo):
1396 """A window transform assigning windows to each element of a PCollection.
1397
1398 Transforms an input PCollection by applying a windowing function to each
1399 element. Each transformed element in the result will be a WindowedValue
1400 element with the same input value and timestamp, with its new set of windows
1401 determined by the windowing function.
1402 """
1403
1404 class WindowIntoFn(DoFn):
1405 """A DoFn that applies a WindowInto operation."""
1406
1407 def __init__(self, windowing):
1408 self.windowing = windowing
1409
1410 def process(self, element, timestamp=DoFn.TimestampParam):
1411 context = WindowFn.AssignContext(timestamp, element=element)
1412 new_windows = self.windowing.windowfn.assign(context)
1413 yield WindowedValue(element, context.timestamp, new_windows)
1414
1415 def __init__(self, windowfn, **kwargs):
1416 """Initializes a WindowInto transform.
1417
1418 Args:
1419 windowfn: Function to be used for windowing
1420 """
1421 triggerfn = kwargs.pop('trigger', None)
1422 accumulation_mode = kwargs.pop('accumulation_mode', None)
1423 timestamp_combiner = kwargs.pop('timestamp_combiner', None)
1424 self.windowing = Windowing(windowfn, triggerfn, accumulation_mode,
1425 timestamp_combiner)
1426 super(WindowInto, self).__init__(self.WindowIntoFn(self.windowing))
1427
1428 def get_windowing(self, unused_inputs):
1429 return self.windowing
1430
1431 def infer_output_type(self, input_type):
1432 return input_type
1433
1434 def expand(self, pcoll):
1435 input_type = pcoll.element_type
1436
1437 if input_type is not None:
1438 output_type = input_type
1439 self.with_input_types(input_type)
1440 self.with_output_types(output_type)
1441 return super(WindowInto, self).expand(pcoll)
1442
1443 def to_runner_api_parameter(self, context):
1444 return (
1445 urns.WINDOW_INTO_TRANSFORM,
1446 self.windowing.to_runner_api(context))
1447
1448 @staticmethod
1449 def from_runner_api_parameter(proto, context):
1450 windowing = Windowing.from_runner_api(proto, context)
1451 return WindowInto(
1452 windowing.windowfn,
1453 trigger=windowing.triggerfn,
1454 accumulation_mode=windowing.accumulation_mode,
1455 timestamp_combiner=windowing.timestamp_combiner)
1456
1457
1458 PTransform.register_urn(
1459 urns.WINDOW_INTO_TRANSFORM,
1460 # TODO(robertwb): Update WindowIntoPayload to include the full strategy.
1461 # (Right now only WindowFn is used, but we need this to reconstitute the
1462 # WindowInto transform, and in the future will need it at runtime to
1463 # support meta-data driven triggers.)
1464 # TODO(robertwb): Use a reference rather than embedding?
1465 beam_runner_api_pb2.WindowingStrategy,
1466 WindowInto.from_runner_api_parameter)
1467
1468
1469 # Python's pickling is broken for nested classes.
1470 WindowIntoFn = WindowInto.WindowIntoFn
1471
1472
1473 class Flatten(PTransform):
1474 """Merges several PCollections into a single PCollection.
1475
1476 Copies all elements in 0 or more PCollections into a single output
1477 PCollection. If there are no input PCollections, the resulting PCollection
1478 will be empty (but see also kwargs below).
1479
1480 Args:
1481 **kwargs: Accepts a single named argument "pipeline", which specifies the
1482 pipeline that "owns" this PTransform. Ordinarily Flatten can obtain this
1483 information from one of the input PCollections, but if there are none (or
1484 if there's a chance there may be none), this argument is the only way to
1485 provide pipeline information and should be considered mandatory.
1486 """
1487
1488 def __init__(self, **kwargs):
1489 super(Flatten, self).__init__()
1490 self.pipeline = kwargs.pop('pipeline', None)
1491 if kwargs:
1492 raise ValueError('Unexpected keyword arguments: %s' % kwargs.keys())
1493
1494 def _extract_input_pvalues(self, pvalueish):
1495 try:
1496 pvalueish = tuple(pvalueish)
1497 except TypeError:
1498 raise ValueError('Input to Flatten must be an iterable.')
1499 return pvalueish, pvalueish
1500
1501 def expand(self, pcolls):
1502 for pcoll in pcolls:
1503 self._check_pcollection(pcoll)
1504 result = pvalue.PCollection(self.pipeline)
1505 result.element_type = typehints.Union[
1506 tuple(pcoll.element_type for pcoll in pcolls)]
1507 return result
1508
1509 def get_windowing(self, inputs):
1510 if not inputs:
1511 # TODO(robertwb): Return something compatible with every windowing?
1512 return Windowing(GlobalWindows())
1513 return super(Flatten, self).get_windowing(inputs)
1514
1515 def to_runner_api_parameter(self, context):
1516 return urns.FLATTEN_TRANSFORM, None
1517
1518 @staticmethod
1519 def from_runner_api_parameter(unused_parameter, unused_context):
1520 return Flatten()
1521
1522
1523 PTransform.register_urn(
1524 urns.FLATTEN_TRANSFORM, None, Flatten.from_runner_api_parameter)
1525
1526
1527 class Create(PTransform):
1528 """A transform that creates a PCollection from an iterable."""
1529
1530 def __init__(self, value):
1531 """Initializes a Create transform.
1532
1533 Args:
1534 value: An object of values for the PCollection
1535 """
1536 super(Create, self).__init__()
1537 if isinstance(value, basestring):
1538 raise TypeError('PTransform Create: Refusing to treat string as '
1539 'an iterable. (string=%r)' % value)
1540 elif isinstance(value, dict):
1541 value = value.items()
1542 self.value = tuple(value)
1543
1544 def infer_output_type(self, unused_input_type):
1545 if not self.value:
1546 return Any
1547 return Union[[trivial_inference.instance_to_type(v) for v in self.value]]
1548
1549 def get_output_type(self):
1550 return (self.get_type_hints().simple_output_type(self.label) or
1551 self.infer_output_type(None))
1552
1553 def expand(self, pbegin):
1554 from apache_beam.io import iobase
1555 assert isinstance(pbegin, pvalue.PBegin)
1556 self.pipeline = pbegin.pipeline
1557 coder = typecoders.registry.get_coder(self.get_output_type())
1558 source = self._create_source_from_iterable(self.value, coder)
1559 return (pbegin.pipeline
1560 | iobase.Read(source).with_output_types(self.get_output_type()))
1561
1562 def get_windowing(self, unused_inputs):
1563 return Windowing(GlobalWindows())
1564
1565 @staticmethod
1566 def _create_source_from_iterable(values, coder):
1567 return Create._create_source(map(coder.encode, values), coder)
1568
1569 @staticmethod
1570 def _create_source(serialized_values, coder):
1571 from apache_beam.io import iobase
1572
1573 class _CreateSource(iobase.BoundedSource):
1574 def __init__(self, serialized_values, coder):
1575 self._coder = coder
1576 self._serialized_values = []
1577 self._total_size = 0
1578 self._serialized_values = serialized_values
1579 self._total_size = sum(map(len, self._serialized_values))
1580
1581 def read(self, range_tracker):
1582 start_position = range_tracker.start_position()
1583 current_position = start_position
1584
1585 def split_points_unclaimed(stop_position):
1586 if current_position >= stop_position:
1587 return 0
1588 return stop_position - current_position - 1
1589
1590 range_tracker.set_split_points_unclaimed_callback(
1591 split_points_unclaimed)
1592 element_iter = iter(self._serialized_values[start_position:])
1593 for i in range(start_position, range_tracker.stop_position()):
1594 if not range_tracker.try_claim(i):
1595 return
1596 current_position = i
1597 yield self._coder.decode(next(element_iter))
1598
1599 def split(self, desired_bundle_size, start_position=None,
1600 stop_position=None):
1601 from apache_beam.io import iobase
1602
1603 if len(self._serialized_values) < 2:
1604 yield iobase.SourceBundle(
1605 weight=0, source=self, start_position=0,
1606 stop_position=len(self._serialized_values))
1607 else:
1608 if start_position is None:
1609 start_position = 0
1610 if stop_position is None:
1611 stop_position = len(self._serialized_values)
1612
1613 avg_size_per_value = self._total_size / len(self._serialized_values)
1614 num_values_per_split = max(
1615 int(desired_bundle_size / avg_size_per_value), 1)
1616
1617 start = start_position
1618 while start < stop_position:
1619 end = min(start + num_values_per_split, stop_position)
1620 remaining = stop_position - end
1621 # Avoid having a too small bundle at the end.
1622 if remaining < (num_values_per_split / 4):
1623 end = stop_position
1624
1625 sub_source = Create._create_source(
1626 self._serialized_values[start:end], self._coder)
1627
1628 yield iobase.SourceBundle(weight=(end - start),
1629 source=sub_source,
1630 start_position=0,
1631 stop_position=(end - start))
1632
1633 start = end
1634
1635 def get_range_tracker(self, start_position, stop_position):
1636 if start_position is None:
1637 start_position = 0
1638 if stop_position is None:
1639 stop_position = len(self._serialized_values)
1640
1641 from apache_beam import io
1642 return io.OffsetRangeTracker(start_position, stop_position)
1643
1644 def estimate_size(self):
1645 return self._total_size
1646
1647 return _CreateSource(serialized_values, coder)