Use bytes instead of Any in RunnerApi.FunctionSpec
[beam.git] / sdks / python / apache_beam / transforms / ptransform.py
1 #
2 # Licensed to the Apache Software Foundation (ASF) under one or more
3 # contributor license agreements. See the NOTICE file distributed with
4 # this work for additional information regarding copyright ownership.
5 # The ASF licenses this file to You under the Apache License, Version 2.0
6 # (the "License"); you may not use this file except in compliance with
7 # the License. You may obtain a copy of the License at
8 #
9 # http://www.apache.org/licenses/LICENSE-2.0
10 #
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
16 #
17
18 """PTransform and descendants.
19
20 A PTransform is an object describing (not executing) a computation. The actual
21 execution semantics for a transform is captured by a runner object. A transform
22 object always belongs to a pipeline object.
23
24 A PTransform derived class needs to define the expand() method that describes
25 how one or more PValues are created by the transform.
26
27 The module defines a few standard transforms: FlatMap (parallel do),
28 GroupByKey (group by key), etc. Note that the expand() methods for these
29 classes contain code that will add nodes to the processing graph associated
30 with a pipeline.
31
32 As support for the FlatMap transform, the module also defines a DoFn
33 class and wrapper class that allows lambda functions to be used as
34 FlatMap processing functions.
35 """
36
37 from __future__ import absolute_import
38
39 import copy
40 import inspect
41 import operator
42 import os
43 import sys
44
45 from google.protobuf import wrappers_pb2
46
47 from apache_beam import error
48 from apache_beam import pvalue
49 from apache_beam.internal import pickler
50 from apache_beam.internal import util
51 from apache_beam.transforms.display import HasDisplayData
52 from apache_beam.transforms.display import DisplayDataItem
53 from apache_beam.typehints import typehints
54 from apache_beam.typehints.decorators import getcallargs_forhints
55 from apache_beam.typehints.decorators import TypeCheckError
56 from apache_beam.typehints.decorators import WithTypeHints
57 from apache_beam.typehints.trivial_inference import instance_to_type
58 from apache_beam.typehints.typehints import validate_composite_type_param
59 from apache_beam.utils import proto_utils
60 from apache_beam.utils import urns
61
62
63 __all__ = [
64 'PTransform',
65 'ptransform_fn',
66 'label_from_callable',
67 ]
68
69
70 class _PValueishTransform(object):
71 """Visitor for PValueish objects.
72
73 A PValueish is a PValue, or list, tuple, dict of PValuesish objects.
74
75 This visits a PValueish, contstructing a (possibly mutated) copy.
76 """
77 def visit(self, node, *args):
78 return getattr(
79 self,
80 'visit_' + node.__class__.__name__,
81 lambda x, *args: x)(node, *args)
82
83 def visit_list(self, node, *args):
84 return [self.visit(x, *args) for x in node]
85
86 def visit_tuple(self, node, *args):
87 return tuple(self.visit(x, *args) for x in node)
88
89 def visit_dict(self, node, *args):
90 return {key: self.visit(value, *args) for (key, value) in node.items()}
91
92
93 class _SetInputPValues(_PValueishTransform):
94 def visit(self, node, replacements):
95 if id(node) in replacements:
96 return replacements[id(node)]
97 return super(_SetInputPValues, self).visit(node, replacements)
98
99
100 class _MaterializedDoOutputsTuple(pvalue.DoOutputsTuple):
101 def __init__(self, deferred, pvalue_cache):
102 super(_MaterializedDoOutputsTuple, self).__init__(
103 None, None, deferred._tags, deferred._main_tag)
104 self._deferred = deferred
105 self._pvalue_cache = pvalue_cache
106
107 def __getitem__(self, tag):
108 return self._pvalue_cache.get_unwindowed_pvalue(self._deferred[tag])
109
110
111 class _MaterializePValues(_PValueishTransform):
112 def __init__(self, pvalue_cache):
113 self._pvalue_cache = pvalue_cache
114
115 def visit(self, node):
116 if isinstance(node, pvalue.PValue):
117 return self._pvalue_cache.get_unwindowed_pvalue(node)
118 elif isinstance(node, pvalue.DoOutputsTuple):
119 return _MaterializedDoOutputsTuple(node, self._pvalue_cache)
120 return super(_MaterializePValues, self).visit(node)
121
122
123 class GetPValues(_PValueishTransform):
124 def visit(self, node, pvalues=None):
125 if pvalues is None:
126 pvalues = []
127 self.visit(node, pvalues)
128 return pvalues
129 elif isinstance(node, (pvalue.PValue, pvalue.DoOutputsTuple)):
130 pvalues.append(node)
131 else:
132 super(GetPValues, self).visit(node, pvalues)
133
134
135 class _ZipPValues(_PValueishTransform):
136 """Pairs each PValue in a pvalueish with a value in a parallel out sibling.
137
138 Sibling should have the same nested structure as pvalueish. Leaves in
139 sibling are expanded across nested pvalueish lists, tuples, and dicts.
140 For example
141
142 ZipPValues().visit({'a': pc1, 'b': (pc2, pc3)},
143 {'a': 'A', 'b', 'B'})
144
145 will return
146
147 [('a', pc1, 'A'), ('b', pc2, 'B'), ('b', pc3, 'B')]
148 """
149
150 def visit(self, pvalueish, sibling, pairs=None, context=None):
151 if pairs is None:
152 pairs = []
153 self.visit(pvalueish, sibling, pairs, context)
154 return pairs
155 elif isinstance(pvalueish, (pvalue.PValue, pvalue.DoOutputsTuple)):
156 pairs.append((context, pvalueish, sibling))
157 else:
158 super(_ZipPValues, self).visit(pvalueish, sibling, pairs, context)
159
160 def visit_list(self, pvalueish, sibling, pairs, context):
161 if isinstance(sibling, (list, tuple)):
162 for ix, (p, s) in enumerate(zip(
163 pvalueish, list(sibling) + [None] * len(pvalueish))):
164 self.visit(p, s, pairs, 'position %s' % ix)
165 else:
166 for p in pvalueish:
167 self.visit(p, sibling, pairs, context)
168
169 def visit_tuple(self, pvalueish, sibling, pairs, context):
170 self.visit_list(pvalueish, sibling, pairs, context)
171
172 def visit_dict(self, pvalueish, sibling, pairs, context):
173 if isinstance(sibling, dict):
174 for key, p in pvalueish.items():
175 self.visit(p, sibling.get(key), pairs, key)
176 else:
177 for p in pvalueish.values():
178 self.visit(p, sibling, pairs, context)
179
180
181 class PTransform(WithTypeHints, HasDisplayData):
182 """A transform object used to modify one or more PCollections.
183
184 Subclasses must define an expand() method that will be used when the transform
185 is applied to some arguments. Typical usage pattern will be:
186
187 input | CustomTransform(...)
188
189 The expand() method of the CustomTransform object passed in will be called
190 with input as an argument.
191 """
192 # By default, transforms don't have any side inputs.
193 side_inputs = ()
194
195 # Used for nullary transforms.
196 pipeline = None
197
198 # Default is unset.
199 _user_label = None
200
201 def __init__(self, label=None):
202 super(PTransform, self).__init__()
203 self.label = label
204
205 @property
206 def label(self):
207 return self._user_label or self.default_label()
208
209 @label.setter
210 def label(self, value):
211 self._user_label = value
212
213 def default_label(self):
214 return self.__class__.__name__
215
216 def with_input_types(self, input_type_hint):
217 """Annotates the input type of a PTransform with a type-hint.
218
219 Args:
220 input_type_hint: An instance of an allowed built-in type, a custom class,
221 or an instance of a typehints.TypeConstraint.
222
223 Raises:
224 TypeError: If 'type_hint' is not a valid type-hint. See
225 typehints.validate_composite_type_param for further details.
226
227 Returns:
228 A reference to the instance of this particular PTransform object. This
229 allows chaining type-hinting related methods.
230 """
231 validate_composite_type_param(input_type_hint,
232 'Type hints for a PTransform')
233 return super(PTransform, self).with_input_types(input_type_hint)
234
235 def with_output_types(self, type_hint):
236 """Annotates the output type of a PTransform with a type-hint.
237
238 Args:
239 type_hint: An instance of an allowed built-in type, a custom class, or a
240 typehints.TypeConstraint.
241
242 Raises:
243 TypeError: If 'type_hint' is not a valid type-hint. See
244 typehints.validate_composite_type_param for further details.
245
246 Returns:
247 A reference to the instance of this particular PTransform object. This
248 allows chaining type-hinting related methods.
249 """
250 validate_composite_type_param(type_hint, 'Type hints for a PTransform')
251 return super(PTransform, self).with_output_types(type_hint)
252
253 def type_check_inputs(self, pvalueish):
254 self.type_check_inputs_or_outputs(pvalueish, 'input')
255
256 def infer_output_type(self, unused_input_type):
257 return self.get_type_hints().simple_output_type(self.label) or typehints.Any
258
259 def type_check_outputs(self, pvalueish):
260 self.type_check_inputs_or_outputs(pvalueish, 'output')
261
262 def type_check_inputs_or_outputs(self, pvalueish, input_or_output):
263 hints = getattr(self.get_type_hints(), input_or_output + '_types')
264 if not hints:
265 return
266 arg_hints, kwarg_hints = hints
267 if arg_hints and kwarg_hints:
268 raise TypeCheckError(
269 'PTransform cannot have both positional and keyword type hints '
270 'without overriding %s._type_check_%s()' % (
271 self.__class__, input_or_output))
272 root_hint = (
273 arg_hints[0] if len(arg_hints) == 1 else arg_hints or kwarg_hints)
274 for context, pvalue_, hint in _ZipPValues().visit(pvalueish, root_hint):
275 if pvalue_.element_type is None:
276 # TODO(robertwb): It's a bug that we ever get here. (typecheck)
277 continue
278 if hint and not typehints.is_consistent_with(pvalue_.element_type, hint):
279 at_context = ' %s %s' % (input_or_output, context) if context else ''
280 raise TypeCheckError(
281 '%s type hint violation at %s%s: expected %s, got %s' % (
282 input_or_output.title(), self.label, at_context, hint,
283 pvalue_.element_type))
284
285 def _infer_output_coder(self, input_type=None, input_coder=None):
286 """Returns the output coder to use for output of this transform.
287
288 Note: this API is experimental and is subject to change; please do not rely
289 on behavior induced by this method.
290
291 The Coder returned here should not be wrapped in a WindowedValueCoder
292 wrapper.
293
294 Args:
295 input_type: An instance of an allowed built-in type, a custom class, or a
296 typehints.TypeConstraint for the input type, or None if not available.
297 input_coder: Coder object for encoding input to this PTransform, or None
298 if not available.
299
300 Returns:
301 Coder object for encoding output of this PTransform or None if unknown.
302 """
303 # TODO(ccy): further refine this API.
304 return None
305
306 def _clone(self, new_label):
307 """Clones the current transform instance under a new label."""
308 transform = copy.copy(self)
309 transform.label = new_label
310 return transform
311
312 def expand(self, input_or_inputs):
313 raise NotImplementedError
314
315 def __str__(self):
316 return '<%s>' % self._str_internal()
317
318 def __repr__(self):
319 return '<%s at %s>' % (self._str_internal(), hex(id(self)))
320
321 def _str_internal(self):
322 return '%s(PTransform)%s%s%s' % (
323 self.__class__.__name__,
324 ' label=[%s]' % self.label if (hasattr(self, 'label') and
325 self.label) else '',
326 ' inputs=%s' % str(self.inputs) if (hasattr(self, 'inputs') and
327 self.inputs) else '',
328 ' side_inputs=%s' % str(self.side_inputs) if self.side_inputs else '')
329
330 def _check_pcollection(self, pcoll):
331 if not isinstance(pcoll, pvalue.PCollection):
332 raise error.TransformError('Expecting a PCollection argument.')
333 if not pcoll.pipeline:
334 raise error.TransformError('PCollection not part of a pipeline.')
335
336 def get_windowing(self, inputs):
337 """Returns the window function to be associated with transform's output.
338
339 By default most transforms just return the windowing function associated
340 with the input PCollection (or the first input if several).
341 """
342 # TODO(robertwb): Assert all input WindowFns compatible.
343 return inputs[0].windowing
344
345 def __rrshift__(self, label):
346 return _NamedPTransform(self, label)
347
348 def __or__(self, right):
349 """Used to compose PTransforms, e.g., ptransform1 | ptransform2."""
350 if isinstance(right, PTransform):
351 return _ChainedPTransform(self, right)
352 return NotImplemented
353
354 def __ror__(self, left, label=None):
355 """Used to apply this PTransform to non-PValues, e.g., a tuple."""
356 pvalueish, pvalues = self._extract_input_pvalues(left)
357 pipelines = [v.pipeline for v in pvalues if isinstance(v, pvalue.PValue)]
358 if pvalues and not pipelines:
359 deferred = False
360 # pylint: disable=wrong-import-order, wrong-import-position
361 from apache_beam import pipeline
362 from apache_beam.options.pipeline_options import PipelineOptions
363 # pylint: enable=wrong-import-order, wrong-import-position
364 p = pipeline.Pipeline(
365 'DirectRunner', PipelineOptions(sys.argv))
366 else:
367 if not pipelines:
368 if self.pipeline is not None:
369 p = self.pipeline
370 else:
371 raise ValueError('"%s" requires a pipeline to be specified '
372 'as there are no deferred inputs.'% self.label)
373 else:
374 p = self.pipeline or pipelines[0]
375 for pp in pipelines:
376 if p != pp:
377 raise ValueError(
378 'Mixing value from different pipelines not allowed.')
379 deferred = not getattr(p.runner, 'is_eager', False)
380 # pylint: disable=wrong-import-order, wrong-import-position
381 from apache_beam.transforms.core import Create
382 # pylint: enable=wrong-import-order, wrong-import-position
383 replacements = {id(v): p | 'CreatePInput%s' % ix >> Create(v)
384 for ix, v in enumerate(pvalues)
385 if not isinstance(v, pvalue.PValue) and v is not None}
386 pvalueish = _SetInputPValues().visit(pvalueish, replacements)
387 self.pipeline = p
388 result = p.apply(self, pvalueish, label)
389 if deferred:
390 return result
391 # Get a reference to the runners internal cache, otherwise runner may
392 # clean it after run.
393 cache = p.runner.cache
394 p.run().wait_until_finish()
395 return _MaterializePValues(cache).visit(result)
396
397 def _extract_input_pvalues(self, pvalueish):
398 """Extract all the pvalues contained in the input pvalueish.
399
400 Returns pvalueish as well as the flat inputs list as the input may have to
401 be copied as inspection may be destructive.
402
403 By default, recursively extracts tuple components and dict values.
404
405 Generally only needs to be overriden for multi-input PTransforms.
406 """
407 # pylint: disable=wrong-import-order
408 from apache_beam import pipeline
409 # pylint: enable=wrong-import-order
410 if isinstance(pvalueish, pipeline.Pipeline):
411 pvalueish = pvalue.PBegin(pvalueish)
412
413 def _dict_tuple_leaves(pvalueish):
414 if isinstance(pvalueish, tuple):
415 for a in pvalueish:
416 for p in _dict_tuple_leaves(a):
417 yield p
418 elif isinstance(pvalueish, dict):
419 for a in pvalueish.values():
420 for p in _dict_tuple_leaves(a):
421 yield p
422 else:
423 yield pvalueish
424 return pvalueish, tuple(_dict_tuple_leaves(pvalueish))
425
426 _known_urns = {}
427
428 @classmethod
429 def register_urn(cls, urn, parameter_type, constructor=None):
430 def register(constructor):
431 cls._known_urns[urn] = parameter_type, constructor
432 return staticmethod(constructor)
433 if constructor:
434 # Used as a statement.
435 register(constructor)
436 else:
437 # Used as a decorator.
438 return register
439
440 def to_runner_api(self, context):
441 from apache_beam.portability.api import beam_runner_api_pb2
442 urn, typed_param = self.to_runner_api_parameter(context)
443 return beam_runner_api_pb2.FunctionSpec(
444 urn=urn,
445 any_param=proto_utils.pack_Any(typed_param),
446 payload=typed_param.SerializeToString()
447 if typed_param is not None else None)
448
449 @classmethod
450 def from_runner_api(cls, proto, context):
451 if proto is None or not proto.urn:
452 return None
453 parameter_type, constructor = cls._known_urns[proto.urn]
454 return constructor(
455 proto_utils.parse_Bytes(proto.payload, parameter_type),
456 context)
457
458 def to_runner_api_parameter(self, context):
459 return (urns.PICKLED_TRANSFORM,
460 wrappers_pb2.BytesValue(value=pickler.dumps(self)))
461
462 @staticmethod
463 def from_runner_api_parameter(spec_parameter, unused_context):
464 return pickler.loads(spec_parameter.value)
465
466
467 PTransform.register_urn(
468 urns.PICKLED_TRANSFORM,
469 wrappers_pb2.BytesValue,
470 PTransform.from_runner_api_parameter)
471
472
473 class _ChainedPTransform(PTransform):
474
475 def __init__(self, *parts):
476 super(_ChainedPTransform, self).__init__(label=self._chain_label(parts))
477 self._parts = parts
478
479 def _chain_label(self, parts):
480 return '|'.join(p.label for p in parts)
481
482 def __or__(self, right):
483 if isinstance(right, PTransform):
484 # Create a flat list rather than a nested tree of composite
485 # transforms for better monitoring, etc.
486 return _ChainedPTransform(*(self._parts + (right,)))
487 return NotImplemented
488
489 def expand(self, pval):
490 return reduce(operator.or_, self._parts, pval)
491
492
493 class PTransformWithSideInputs(PTransform):
494 """A superclass for any PTransform (e.g. FlatMap or Combine)
495 invoking user code.
496
497 PTransforms like FlatMap invoke user-supplied code in some kind of
498 package (e.g. a DoFn) and optionally provide arguments and side inputs
499 to that code. This internal-use-only class contains common functionality
500 for PTransforms that fit this model.
501 """
502
503 def __init__(self, fn, *args, **kwargs):
504 if isinstance(fn, type) and issubclass(fn, WithTypeHints):
505 # Don't treat Fn class objects as callables.
506 raise ValueError('Use %s() not %s.' % (fn.__name__, fn.__name__))
507 self.fn = self.make_fn(fn)
508 # Now that we figure out the label, initialize the super-class.
509 super(PTransformWithSideInputs, self).__init__()
510
511 if (any([isinstance(v, pvalue.PCollection) for v in args]) or
512 any([isinstance(v, pvalue.PCollection) for v in kwargs.itervalues()])):
513 raise error.SideInputError(
514 'PCollection used directly as side input argument. Specify '
515 'AsIter(pcollection) or AsSingleton(pcollection) to indicate how the '
516 'PCollection is to be used.')
517 self.args, self.kwargs, self.side_inputs = util.remove_objects_from_args(
518 args, kwargs, pvalue.AsSideInput)
519 self.raw_side_inputs = args, kwargs
520
521 # Prevent name collisions with fns of the form '<function <lambda> at ...>'
522 self._cached_fn = self.fn
523
524 # Ensure fn and side inputs are picklable for remote execution.
525 self.fn = pickler.loads(pickler.dumps(self.fn))
526 self.args = pickler.loads(pickler.dumps(self.args))
527 self.kwargs = pickler.loads(pickler.dumps(self.kwargs))
528
529 # For type hints, because loads(dumps(class)) != class.
530 self.fn = self._cached_fn
531
532 def with_input_types(
533 self, input_type_hint, *side_inputs_arg_hints, **side_input_kwarg_hints):
534 """Annotates the types of main inputs and side inputs for the PTransform.
535
536 Args:
537 input_type_hint: An instance of an allowed built-in type, a custom class,
538 or an instance of a typehints.TypeConstraint.
539 *side_inputs_arg_hints: A variable length argument composed of
540 of an allowed built-in type, a custom class, or a
541 typehints.TypeConstraint.
542 **side_input_kwarg_hints: A dictionary argument composed of
543 of an allowed built-in type, a custom class, or a
544 typehints.TypeConstraint.
545
546 Example of annotating the types of side-inputs:
547 FlatMap().with_input_types(int, int, bool)
548
549 Raises:
550 TypeError: If 'type_hint' is not a valid type-hint. See
551 typehints.validate_composite_type_param for further details.
552
553 Returns:
554 A reference to the instance of this particular PTransform object. This
555 allows chaining type-hinting related methods.
556 """
557 super(PTransformWithSideInputs, self).with_input_types(input_type_hint)
558
559 for si in side_inputs_arg_hints:
560 validate_composite_type_param(si, 'Type hints for a PTransform')
561 for si in side_input_kwarg_hints.values():
562 validate_composite_type_param(si, 'Type hints for a PTransform')
563
564 self.side_inputs_types = side_inputs_arg_hints
565 return WithTypeHints.with_input_types(
566 self, input_type_hint, *side_inputs_arg_hints, **side_input_kwarg_hints)
567
568 def type_check_inputs(self, pvalueish):
569 type_hints = self.get_type_hints().input_types
570 if type_hints:
571 args, kwargs = self.raw_side_inputs
572
573 def element_type(side_input):
574 if isinstance(side_input, pvalue.AsSideInput):
575 return side_input.element_type
576 return instance_to_type(side_input)
577
578 arg_types = [pvalueish.element_type] + [element_type(v) for v in args]
579 kwargs_types = {k: element_type(v) for (k, v) in kwargs.items()}
580 argspec_fn = self._process_argspec_fn()
581 bindings = getcallargs_forhints(argspec_fn, *arg_types, **kwargs_types)
582 hints = getcallargs_forhints(argspec_fn, *type_hints[0], **type_hints[1])
583 for arg, hint in hints.items():
584 if arg.startswith('%unknown%'):
585 continue
586 if hint is None:
587 continue
588 if not typehints.is_consistent_with(
589 bindings.get(arg, typehints.Any), hint):
590 raise TypeCheckError(
591 'Type hint violation for \'%s\': requires %s but got %s for %s'
592 % (self.label, hint, bindings[arg], arg))
593
594 def _process_argspec_fn(self):
595 """Returns an argspec of the function actually consuming the data.
596 """
597 raise NotImplementedError
598
599 def make_fn(self, fn):
600 # TODO(silviuc): Add comment describing that this is meant to be overriden
601 # by methods detecting callables and wrapping them in DoFns.
602 return fn
603
604 def default_label(self):
605 return '%s(%s)' % (self.__class__.__name__, self.fn.default_label())
606
607
608 class _PTransformFnPTransform(PTransform):
609 """A class wrapper for a function-based transform."""
610
611 def __init__(self, fn, *args, **kwargs):
612 super(_PTransformFnPTransform, self).__init__()
613 self._fn = fn
614 self._args = args
615 self._kwargs = kwargs
616
617 def display_data(self):
618 res = {'fn': (self._fn.__name__
619 if hasattr(self._fn, '__name__')
620 else self._fn.__class__),
621 'args': DisplayDataItem(str(self._args)).drop_if_default('()'),
622 'kwargs': DisplayDataItem(str(self._kwargs)).drop_if_default('{}')}
623 return res
624
625 def expand(self, pcoll):
626 # Since the PTransform will be implemented entirely as a function
627 # (once called), we need to pass through any type-hinting information that
628 # may have been annotated via the .with_input_types() and
629 # .with_output_types() methods.
630 kwargs = dict(self._kwargs)
631 args = tuple(self._args)
632 try:
633 if 'type_hints' in inspect.getargspec(self._fn).args:
634 args = (self.get_type_hints(),) + args
635 except TypeError:
636 # Might not be a function.
637 pass
638 return self._fn(pcoll, *args, **kwargs)
639
640 def default_label(self):
641 if self._args:
642 return '%s(%s)' % (
643 label_from_callable(self._fn), label_from_callable(self._args[0]))
644 return label_from_callable(self._fn)
645
646
647 def ptransform_fn(fn):
648 """A decorator for a function-based PTransform.
649
650 Experimental; no backwards-compatibility guarantees.
651
652 Args:
653 fn: A function implementing a custom PTransform.
654
655 Returns:
656 A CallablePTransform instance wrapping the function-based PTransform.
657
658 This wrapper provides an alternative, simpler way to define a PTransform.
659 The standard method is to subclass from PTransform and override the expand()
660 method. An equivalent effect can be obtained by defining a function that
661 an input PCollection and additional optional arguments and returns a
662 resulting PCollection. For example::
663
664 @ptransform_fn
665 def CustomMapper(pcoll, mapfn):
666 return pcoll | ParDo(mapfn)
667
668 The equivalent approach using PTransform subclassing::
669
670 class CustomMapper(PTransform):
671
672 def __init__(self, mapfn):
673 super(CustomMapper, self).__init__()
674 self.mapfn = mapfn
675
676 def expand(self, pcoll):
677 return pcoll | ParDo(self.mapfn)
678
679 With either method the custom PTransform can be used in pipelines as if
680 it were one of the "native" PTransforms::
681
682 result_pcoll = input_pcoll | 'Label' >> CustomMapper(somefn)
683
684 Note that for both solutions the underlying implementation of the pipe
685 operator (i.e., `|`) will inject the pcoll argument in its proper place
686 (first argument if no label was specified and second argument otherwise).
687 """
688 # TODO(robertwb): Consider removing staticmethod to allow for self parameter.
689
690 def callable_ptransform_factory(*args, **kwargs):
691 return _PTransformFnPTransform(fn, *args, **kwargs)
692 return callable_ptransform_factory
693
694
695 def label_from_callable(fn):
696 if hasattr(fn, 'default_label'):
697 return fn.default_label()
698 elif hasattr(fn, '__name__'):
699 if fn.__name__ == '<lambda>':
700 return '<lambda at %s:%s>' % (
701 os.path.basename(fn.func_code.co_filename),
702 fn.func_code.co_firstlineno)
703 return fn.__name__
704 return str(fn)
705
706
707 class _NamedPTransform(PTransform):
708
709 def __init__(self, transform, label):
710 super(_NamedPTransform, self).__init__(label)
711 self.transform = transform
712
713 def __ror__(self, pvalueish, _unused=None):
714 return self.transform.__ror__(pvalueish, self.label)
715
716 def expand(self, pvalue):
717 raise RuntimeError("Should never be expanded directly.")