Use bytes instead of Any in RunnerApi.FunctionSpec
[beam.git] / sdks / python / apache_beam / runners / worker / bundle_processor.py
1 #
2 # Licensed to the Apache Software Foundation (ASF) under one or more
3 # contributor license agreements. See the NOTICE file distributed with
4 # this work for additional information regarding copyright ownership.
5 # The ASF licenses this file to You under the Apache License, Version 2.0
6 # (the "License"); you may not use this file except in compliance with
7 # the License. You may obtain a copy of the License at
8 #
9 # http://www.apache.org/licenses/LICENSE-2.0
10 #
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
16 #
17
18 """SDK harness for executing Python Fns via the Fn API."""
19
20 from __future__ import absolute_import
21 from __future__ import division
22 from __future__ import print_function
23
24 import base64
25 import collections
26 import json
27 import logging
28
29 from google.protobuf import wrappers_pb2
30
31 import apache_beam as beam
32 from apache_beam.coders import coder_impl
33 from apache_beam.coders import WindowedValueCoder
34 from apache_beam.internal import pickler
35 from apache_beam.io import iobase
36 from apache_beam.portability.api import beam_fn_api_pb2
37 from apache_beam.portability.api import beam_runner_api_pb2
38 from apache_beam.runners.dataflow.native_io import iobase as native_iobase
39 from apache_beam.runners import pipeline_context
40 from apache_beam.runners.worker import operation_specs
41 from apache_beam.runners.worker import operations
42 from apache_beam.utils import counters
43 from apache_beam.utils import proto_utils
44 from apache_beam.utils import urns
45
46 # This module is experimental. No backwards-compatibility guarantees.
47
48
49 try:
50 from apache_beam.runners.worker import statesampler
51 except ImportError:
52 from apache_beam.runners.worker import statesampler_fake as statesampler
53
54
55 DATA_INPUT_URN = 'urn:org.apache.beam:source:runner:0.1'
56 DATA_OUTPUT_URN = 'urn:org.apache.beam:sink:runner:0.1'
57 IDENTITY_DOFN_URN = 'urn:org.apache.beam:dofn:identity:0.1'
58 PYTHON_ITERABLE_VIEWFN_URN = 'urn:org.apache.beam:viewfn:iterable:python:0.1'
59 PYTHON_CODER_URN = 'urn:org.apache.beam:coder:python:0.1'
60 # TODO(vikasrk): Fix this once runner sends appropriate python urns.
61 PYTHON_DOFN_URN = 'urn:org.apache.beam:dofn:java:0.1'
62 PYTHON_SOURCE_URN = 'urn:org.apache.beam:source:java:0.1'
63
64
65 def side_input_tag(transform_id, tag):
66 return str("%d[%s][%s]" % (len(transform_id), transform_id, tag))
67
68
69 class RunnerIOOperation(operations.Operation):
70 """Common baseclass for runner harness IO operations."""
71
72 def __init__(self, operation_name, step_name, consumers, counter_factory,
73 state_sampler, windowed_coder, target, data_channel):
74 super(RunnerIOOperation, self).__init__(
75 operation_name, None, counter_factory, state_sampler)
76 self.windowed_coder = windowed_coder
77 self.step_name = step_name
78 # target represents the consumer for the bytes in the data plane for a
79 # DataInputOperation or a producer of these bytes for a DataOutputOperation.
80 self.target = target
81 self.data_channel = data_channel
82 for _, consumer_ops in consumers.items():
83 for consumer in consumer_ops:
84 self.add_receiver(consumer, 0)
85
86
87 class DataOutputOperation(RunnerIOOperation):
88 """A sink-like operation that gathers outputs to be sent back to the runner.
89 """
90
91 def set_output_stream(self, output_stream):
92 self.output_stream = output_stream
93
94 def process(self, windowed_value):
95 self.windowed_coder.get_impl().encode_to_stream(
96 windowed_value, self.output_stream, True)
97
98 def finish(self):
99 self.output_stream.close()
100 super(DataOutputOperation, self).finish()
101
102
103 class DataInputOperation(RunnerIOOperation):
104 """A source-like operation that gathers input from the runner.
105 """
106
107 def __init__(self, operation_name, step_name, consumers, counter_factory,
108 state_sampler, windowed_coder, input_target, data_channel):
109 super(DataInputOperation, self).__init__(
110 operation_name, step_name, consumers, counter_factory, state_sampler,
111 windowed_coder, target=input_target, data_channel=data_channel)
112 # We must do this manually as we don't have a spec or spec.output_coders.
113 self.receivers = [
114 operations.ConsumerSet(self.counter_factory, self.step_name, 0,
115 consumers.itervalues().next(),
116 self.windowed_coder)]
117
118 def process(self, windowed_value):
119 self.output(windowed_value)
120
121 def process_encoded(self, encoded_windowed_values):
122 input_stream = coder_impl.create_InputStream(encoded_windowed_values)
123 while input_stream.size() > 0:
124 decoded_value = self.windowed_coder.get_impl().decode_from_stream(
125 input_stream, True)
126 self.output(decoded_value)
127
128
129 # TODO(robertwb): Revise side input API to not be in terms of native sources.
130 # This will enable lookups, but there's an open question as to how to handle
131 # custom sources without forcing intermediate materialization. This seems very
132 # related to the desire to inject key and window preserving [Splittable]DoFns
133 # into the view computation.
134 class SideInputSource(native_iobase.NativeSource,
135 native_iobase.NativeSourceReader):
136 """A 'source' for reading side inputs via state API calls.
137 """
138
139 def __init__(self, state_handler, state_key, coder):
140 self._state_handler = state_handler
141 self._state_key = state_key
142 self._coder = coder
143
144 def reader(self):
145 return self
146
147 @property
148 def returns_windowed_values(self):
149 return True
150
151 def __enter__(self):
152 return self
153
154 def __exit__(self, *exn_info):
155 pass
156
157 def __iter__(self):
158 # TODO(robertwb): Support pagination.
159 input_stream = coder_impl.create_InputStream(
160 self._state_handler.Get(self._state_key).data)
161 while input_stream.size() > 0:
162 yield self._coder.get_impl().decode_from_stream(input_stream, True)
163
164
165 def memoize(func):
166 cache = {}
167 missing = object()
168
169 def wrapper(*args):
170 result = cache.get(args, missing)
171 if result is missing:
172 result = cache[args] = func(*args)
173 return result
174 return wrapper
175
176
177 def only_element(iterable):
178 element, = iterable
179 return element
180
181
182 class BundleProcessor(object):
183 """A class for processing bundles of elements.
184 """
185 def __init__(
186 self, process_bundle_descriptor, state_handler, data_channel_factory):
187 self.process_bundle_descriptor = process_bundle_descriptor
188 self.state_handler = state_handler
189 self.data_channel_factory = data_channel_factory
190
191 def create_execution_tree(self, descriptor):
192 # TODO(robertwb): Figure out the correct prefix to use for output counters
193 # from StateSampler.
194 counter_factory = counters.CounterFactory()
195 state_sampler = statesampler.StateSampler(
196 'fnapi-step%s-' % descriptor.id, counter_factory)
197
198 transform_factory = BeamTransformFactory(
199 descriptor, self.data_channel_factory, counter_factory, state_sampler,
200 self.state_handler)
201
202 pcoll_consumers = collections.defaultdict(list)
203 for transform_id, transform_proto in descriptor.transforms.items():
204 for pcoll_id in transform_proto.inputs.values():
205 pcoll_consumers[pcoll_id].append(transform_id)
206
207 @memoize
208 def get_operation(transform_id):
209 transform_consumers = {
210 tag: [get_operation(op) for op in pcoll_consumers[pcoll_id]]
211 for tag, pcoll_id
212 in descriptor.transforms[transform_id].outputs.items()
213 }
214 return transform_factory.create_operation(
215 transform_id, transform_consumers)
216
217 # Operations must be started (hence returned) in order.
218 @memoize
219 def topological_height(transform_id):
220 return 1 + max(
221 [0] +
222 [topological_height(consumer)
223 for pcoll in descriptor.transforms[transform_id].outputs.values()
224 for consumer in pcoll_consumers[pcoll]])
225
226 return [get_operation(transform_id)
227 for transform_id in sorted(
228 descriptor.transforms, key=topological_height, reverse=True)]
229
230 def process_bundle(self, instruction_id):
231 ops = self.create_execution_tree(self.process_bundle_descriptor)
232
233 expected_inputs = []
234 for op in ops:
235 if isinstance(op, DataOutputOperation):
236 # TODO(robertwb): Is there a better way to pass the instruction id to
237 # the operation?
238 op.set_output_stream(op.data_channel.output_stream(
239 instruction_id, op.target))
240 elif isinstance(op, DataInputOperation):
241 # We must wait until we receive "end of stream" for each of these ops.
242 expected_inputs.append(op)
243
244 # Start all operations.
245 for op in reversed(ops):
246 logging.info('start %s', op)
247 op.start()
248
249 # Inject inputs from data plane.
250 for input_op in expected_inputs:
251 for data in input_op.data_channel.input_elements(
252 instruction_id, [input_op.target]):
253 # ignores input name
254 input_op.process_encoded(data.data)
255
256 # Finish all operations.
257 for op in ops:
258 logging.info('finish %s', op)
259 op.finish()
260
261
262 class BeamTransformFactory(object):
263 """Factory for turning transform_protos into executable operations."""
264 def __init__(self, descriptor, data_channel_factory, counter_factory,
265 state_sampler, state_handler):
266 self.descriptor = descriptor
267 self.data_channel_factory = data_channel_factory
268 self.counter_factory = counter_factory
269 self.state_sampler = state_sampler
270 self.state_handler = state_handler
271 self.context = pipeline_context.PipelineContext(descriptor)
272
273 _known_urns = {}
274
275 @classmethod
276 def register_urn(cls, urn, parameter_type):
277 def wrapper(func):
278 cls._known_urns[urn] = func, parameter_type
279 return func
280 return wrapper
281
282 def create_operation(self, transform_id, consumers):
283 transform_proto = self.descriptor.transforms[transform_id]
284 creator, parameter_type = self._known_urns[transform_proto.spec.urn]
285 payload = proto_utils.parse_Bytes(
286 transform_proto.spec.payload, parameter_type)
287 return creator(self, transform_id, transform_proto, payload, consumers)
288
289 def get_coder(self, coder_id):
290 coder_proto = self.descriptor.coders[coder_id]
291 if coder_proto.spec.spec.urn:
292 return self.context.coders.get_by_id(coder_id)
293 else:
294 # No URN, assume cloud object encoding json bytes.
295 return operation_specs.get_coder_from_spec(
296 json.loads(coder_proto.spec.spec.payload))
297
298 def get_output_coders(self, transform_proto):
299 return {
300 tag: self.get_coder(self.descriptor.pcollections[pcoll_id].coder_id)
301 for tag, pcoll_id in transform_proto.outputs.items()
302 }
303
304 def get_only_output_coder(self, transform_proto):
305 return only_element(self.get_output_coders(transform_proto).values())
306
307 def get_input_coders(self, transform_proto):
308 return {
309 tag: self.get_coder(self.descriptor.pcollections[pcoll_id].coder_id)
310 for tag, pcoll_id in transform_proto.inputs.items()
311 }
312
313 def get_only_input_coder(self, transform_proto):
314 return only_element(self.get_input_coders(transform_proto).values())
315
316 # TODO(robertwb): Update all operations to take these in the constructor.
317 @staticmethod
318 def augment_oldstyle_op(op, step_name, consumers, tag_list=None):
319 op.step_name = step_name
320 for tag, op_consumers in consumers.items():
321 for consumer in op_consumers:
322 op.add_receiver(consumer, tag_list.index(tag) if tag_list else 0)
323 return op
324
325
326 @BeamTransformFactory.register_urn(
327 DATA_INPUT_URN, beam_fn_api_pb2.RemoteGrpcPort)
328 def create(factory, transform_id, transform_proto, grpc_port, consumers):
329 target = beam_fn_api_pb2.Target(
330 primitive_transform_reference=transform_id,
331 name=only_element(transform_proto.outputs.keys()))
332 return DataInputOperation(
333 transform_proto.unique_name,
334 transform_proto.unique_name,
335 consumers,
336 factory.counter_factory,
337 factory.state_sampler,
338 factory.get_only_output_coder(transform_proto),
339 input_target=target,
340 data_channel=factory.data_channel_factory.create_data_channel(grpc_port))
341
342
343 @BeamTransformFactory.register_urn(
344 DATA_OUTPUT_URN, beam_fn_api_pb2.RemoteGrpcPort)
345 def create(factory, transform_id, transform_proto, grpc_port, consumers):
346 target = beam_fn_api_pb2.Target(
347 primitive_transform_reference=transform_id,
348 name=only_element(transform_proto.inputs.keys()))
349 return DataOutputOperation(
350 transform_proto.unique_name,
351 transform_proto.unique_name,
352 consumers,
353 factory.counter_factory,
354 factory.state_sampler,
355 # TODO(robertwb): Perhaps this could be distinct from the input coder?
356 factory.get_only_input_coder(transform_proto),
357 target=target,
358 data_channel=factory.data_channel_factory.create_data_channel(grpc_port))
359
360
361 @BeamTransformFactory.register_urn(PYTHON_SOURCE_URN, None)
362 def create(factory, transform_id, transform_proto, parameter, consumers):
363 # The Dataflow runner harness strips the base64 encoding.
364 source = pickler.loads(base64.b64encode(parameter))
365 spec = operation_specs.WorkerRead(
366 iobase.SourceBundle(1.0, source, None, None),
367 [WindowedValueCoder(source.default_output_coder())])
368 return factory.augment_oldstyle_op(
369 operations.ReadOperation(
370 transform_proto.unique_name,
371 spec,
372 factory.counter_factory,
373 factory.state_sampler),
374 transform_proto.unique_name,
375 consumers)
376
377
378 @BeamTransformFactory.register_urn(
379 urns.READ_TRANSFORM, beam_runner_api_pb2.ReadPayload)
380 def create(factory, transform_id, transform_proto, parameter, consumers):
381 # The Dataflow runner harness strips the base64 encoding.
382 source = iobase.SourceBase.from_runner_api(parameter.source, factory.context)
383 spec = operation_specs.WorkerRead(
384 iobase.SourceBundle(1.0, source, None, None),
385 [WindowedValueCoder(source.default_output_coder())])
386 return factory.augment_oldstyle_op(
387 operations.ReadOperation(
388 transform_proto.unique_name,
389 spec,
390 factory.counter_factory,
391 factory.state_sampler),
392 transform_proto.unique_name,
393 consumers)
394
395
396 @BeamTransformFactory.register_urn(PYTHON_DOFN_URN, None)
397 def create(factory, transform_id, transform_proto, parameter, consumers):
398 dofn_data = pickler.loads(parameter)
399 if len(dofn_data) == 2:
400 # Has side input data.
401 serialized_fn, side_input_data = dofn_data
402 else:
403 # No side input data.
404 serialized_fn, side_input_data = parameter.value, []
405 return _create_pardo_operation(
406 factory, transform_id, transform_proto, consumers,
407 serialized_fn, side_input_data)
408
409
410 @BeamTransformFactory.register_urn(
411 urns.PARDO_TRANSFORM, beam_runner_api_pb2.ParDoPayload)
412 def create(factory, transform_id, transform_proto, parameter, consumers):
413 assert parameter.do_fn.spec.urn == urns.PICKLED_DO_FN_INFO
414 serialized_fn = parameter.do_fn.spec.payload
415 dofn_data = pickler.loads(serialized_fn)
416 if len(dofn_data) == 2:
417 # Has side input data.
418 serialized_fn, side_input_data = dofn_data
419 else:
420 # No side input data.
421 side_input_data = []
422 return _create_pardo_operation(
423 factory, transform_id, transform_proto, consumers,
424 serialized_fn, side_input_data)
425
426
427 def _create_pardo_operation(
428 factory, transform_id, transform_proto, consumers,
429 serialized_fn, side_input_data):
430 def create_side_input(tag, coder):
431 # TODO(robertwb): Extract windows (and keys) out of element data.
432 # TODO(robertwb): Extract state key from ParDoPayload.
433 return operation_specs.WorkerSideInputSource(
434 tag=tag,
435 source=SideInputSource(
436 factory.state_handler,
437 beam_fn_api_pb2.StateKey.MultimapSideInput(
438 key=side_input_tag(transform_id, tag)),
439 coder=coder))
440 output_tags = list(transform_proto.outputs.keys())
441
442 # Hack to match out prefix injected by dataflow runner.
443 def mutate_tag(tag):
444 if 'None' in output_tags:
445 if tag == 'None':
446 return 'out'
447 else:
448 return 'out_' + tag
449 else:
450 return tag
451 dofn_data = pickler.loads(serialized_fn)
452 if not dofn_data[-1]:
453 # Windowing not set.
454 pcoll_id, = transform_proto.inputs.values()
455 windowing = factory.context.windowing_strategies.get_by_id(
456 factory.descriptor.pcollections[pcoll_id].windowing_strategy_id)
457 serialized_fn = pickler.dumps(dofn_data[:-1] + (windowing,))
458 output_coders = factory.get_output_coders(transform_proto)
459 spec = operation_specs.WorkerDoFn(
460 serialized_fn=serialized_fn,
461 output_tags=[mutate_tag(tag) for tag in output_tags],
462 input=None,
463 side_inputs=[
464 create_side_input(tag, coder) for tag, coder in side_input_data],
465 output_coders=[output_coders[tag] for tag in output_tags])
466 return factory.augment_oldstyle_op(
467 operations.DoOperation(
468 transform_proto.unique_name,
469 spec,
470 factory.counter_factory,
471 factory.state_sampler),
472 transform_proto.unique_name,
473 consumers,
474 output_tags)
475
476
477 def _create_simple_pardo_operation(
478 factory, transform_id, transform_proto, consumers, dofn):
479 serialized_fn = pickler.dumps((dofn, (), {}, [], None))
480 side_input_data = []
481 return _create_pardo_operation(
482 factory, transform_id, transform_proto, consumers,
483 serialized_fn, side_input_data)
484
485
486 @BeamTransformFactory.register_urn(
487 urns.GROUP_ALSO_BY_WINDOW_TRANSFORM, wrappers_pb2.BytesValue)
488 def create(factory, transform_id, transform_proto, parameter, consumers):
489 # Perhaps this hack can go away once all apply overloads are gone.
490 from apache_beam.transforms.core import _GroupAlsoByWindowDoFn
491 return _create_simple_pardo_operation(
492 factory, transform_id, transform_proto, consumers,
493 _GroupAlsoByWindowDoFn(
494 factory.context.windowing_strategies.get_by_id(parameter.value)))
495
496
497 @BeamTransformFactory.register_urn(
498 urns.WINDOW_INTO_TRANSFORM, beam_runner_api_pb2.WindowingStrategy)
499 def create(factory, transform_id, transform_proto, parameter, consumers):
500 class WindowIntoDoFn(beam.DoFn):
501 def __init__(self, windowing):
502 self.windowing = windowing
503
504 def process(self, element, timestamp=beam.DoFn.TimestampParam):
505 new_windows = self.windowing.windowfn.assign(
506 WindowFn.AssignContext(timestamp, element=element))
507 yield WindowedValue(element, timestamp, new_windows)
508 from apache_beam.transforms.core import Windowing
509 from apache_beam.transforms.window import WindowFn, WindowedValue
510 windowing = Windowing.from_runner_api(parameter, factory.context)
511 return _create_simple_pardo_operation(
512 factory, transform_id, transform_proto, consumers,
513 WindowIntoDoFn(windowing))
514
515
516 @BeamTransformFactory.register_urn(IDENTITY_DOFN_URN, None)
517 def create(factory, transform_id, transform_proto, unused_parameter, consumers):
518 return factory.augment_oldstyle_op(
519 operations.FlattenOperation(
520 transform_proto.unique_name,
521 operation_specs.WorkerFlatten(
522 None, [factory.get_only_output_coder(transform_proto)]),
523 factory.counter_factory,
524 factory.state_sampler),
525 transform_proto.unique_name,
526 consumers)