ARROW-439: [Python] Add option in "to_pandas" conversions to yield Categorical from...
[arrow.git] / python / pyarrow / pandas_compat.py
1 # Licensed to the Apache Software Foundation (ASF) under one
2 # or more contributor license agreements. See the NOTICE file
3 # distributed with this work for additional information
4 # regarding copyright ownership. The ASF licenses this file
5 # to you under the Apache License, Version 2.0 (the
6 # "License"); you may not use this file except in compliance
7 # with the License. You may obtain a copy of the License at
8 #
9 # http://www.apache.org/licenses/LICENSE-2.0
10 #
11 # Unless required by applicable law or agreed to in writing,
12 # software distributed under the License is distributed on an
13 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 # KIND, either express or implied. See the License for the
15 # specific language governing permissions and limitations
16 # under the License.
17
18 import re
19 import json
20 import numpy as np
21 import pandas as pd
22
23 import six
24
25 import pyarrow as pa
26 from pyarrow.compat import PY2 # noqa
27
28
29 INDEX_LEVEL_NAME_REGEX = re.compile(r'^__index_level_\d+__$')
30
31
32 def is_unnamed_index_level(name):
33 return INDEX_LEVEL_NAME_REGEX.match(name) is not None
34
35
36 def infer_dtype(column):
37 try:
38 return pd.api.types.infer_dtype(column)
39 except AttributeError:
40 return pd.lib.infer_dtype(column)
41
42
43 _logical_type_map = {}
44
45
46 def get_logical_type_map():
47 global _logical_type_map
48
49 if not _logical_type_map:
50 _logical_type_map.update({
51 pa.lib.Type_NA: 'float64', # NaNs
52 pa.lib.Type_BOOL: 'bool',
53 pa.lib.Type_INT8: 'int8',
54 pa.lib.Type_INT16: 'int16',
55 pa.lib.Type_INT32: 'int32',
56 pa.lib.Type_INT64: 'int64',
57 pa.lib.Type_UINT8: 'uint8',
58 pa.lib.Type_UINT16: 'uint16',
59 pa.lib.Type_UINT32: 'uint32',
60 pa.lib.Type_UINT64: 'uint64',
61 pa.lib.Type_HALF_FLOAT: 'float16',
62 pa.lib.Type_FLOAT: 'float32',
63 pa.lib.Type_DOUBLE: 'float64',
64 pa.lib.Type_DATE32: 'date',
65 pa.lib.Type_DATE64: 'date',
66 pa.lib.Type_TIME32: 'time',
67 pa.lib.Type_TIME64: 'time',
68 pa.lib.Type_BINARY: 'bytes',
69 pa.lib.Type_FIXED_SIZE_BINARY: 'bytes',
70 pa.lib.Type_STRING: 'unicode',
71 })
72 return _logical_type_map
73
74
75 def get_logical_type(arrow_type):
76 logical_type_map = get_logical_type_map()
77
78 try:
79 return logical_type_map[arrow_type.id]
80 except KeyError:
81 if isinstance(arrow_type, pa.lib.DictionaryType):
82 return 'categorical'
83 elif isinstance(arrow_type, pa.lib.ListType):
84 return 'list[{}]'.format(get_logical_type(arrow_type.value_type))
85 elif isinstance(arrow_type, pa.lib.TimestampType):
86 return 'datetimetz' if arrow_type.tz is not None else 'datetime'
87 elif isinstance(arrow_type, pa.lib.DecimalType):
88 return 'decimal'
89 raise NotImplementedError(str(arrow_type))
90
91
92 def get_column_metadata(column, name, arrow_type):
93 """Construct the metadata for a given column
94
95 Parameters
96 ----------
97 column : pandas.Series
98 name : str
99 arrow_type : pyarrow.DataType
100
101 Returns
102 -------
103 dict
104 """
105 dtype = column.dtype
106 logical_type = get_logical_type(arrow_type)
107
108 if hasattr(dtype, 'categories'):
109 assert logical_type == 'categorical'
110 extra_metadata = {
111 'num_categories': len(column.cat.categories),
112 'ordered': column.cat.ordered,
113 }
114 elif hasattr(dtype, 'tz'):
115 assert logical_type == 'datetimetz'
116 extra_metadata = {'timezone': str(dtype.tz)}
117 elif logical_type == 'decimal':
118 extra_metadata = {
119 'precision': arrow_type.precision,
120 'scale': arrow_type.scale,
121 }
122 else:
123 extra_metadata = None
124
125 if not isinstance(name, six.string_types):
126 raise TypeError(
127 'Column name must be a string. Got column {} of type {}'.format(
128 name, type(name).__name__
129 )
130 )
131
132 return {
133 'name': name,
134 'pandas_type': logical_type,
135 'numpy_type': str(dtype),
136 'metadata': extra_metadata,
137 }
138
139
140 def index_level_name(index, i):
141 """Return the name of an index level or a default name if `index.name` is
142 None.
143
144 Parameters
145 ----------
146 index : pandas.Index
147 i : int
148
149 Returns
150 -------
151 name : str
152 """
153 if index.name is not None:
154 return index.name
155 else:
156 return '__index_level_{:d}__'.format(i)
157
158
159 def construct_metadata(df, column_names, index_levels, preserve_index, types):
160 """Returns a dictionary containing enough metadata to reconstruct a pandas
161 DataFrame as an Arrow Table, including index columns.
162
163 Parameters
164 ----------
165 df : pandas.DataFrame
166 index_levels : List[pd.Index]
167 presere_index : bool
168 types : List[pyarrow.DataType]
169
170 Returns
171 -------
172 dict
173 """
174 ncolumns = len(column_names)
175 df_types = types[:ncolumns]
176 index_types = types[ncolumns:ncolumns + len(index_levels)]
177
178 column_metadata = [
179 get_column_metadata(df[col_name], name=sanitized_name,
180 arrow_type=arrow_type)
181 for col_name, sanitized_name, arrow_type in
182 zip(df.columns, column_names, df_types)
183 ]
184
185 if preserve_index:
186 index_column_names = [index_level_name(level, i)
187 for i, level in enumerate(index_levels)]
188 index_column_metadata = [
189 get_column_metadata(level, name=index_level_name(level, i),
190 arrow_type=arrow_type)
191 for i, (level, arrow_type) in enumerate(zip(index_levels,
192 index_types))
193 ]
194 else:
195 index_column_names = index_column_metadata = []
196
197 return {
198 b'pandas': json.dumps({
199 'index_columns': index_column_names,
200 'columns': column_metadata + index_column_metadata,
201 'pandas_version': pd.__version__
202 }).encode('utf8')
203 }
204
205
206 def dataframe_to_arrays(df, timestamps_to_ms, schema, preserve_index):
207 names = []
208 arrays = []
209 index_columns = []
210 types = []
211 type = None
212
213 if preserve_index:
214 n = len(getattr(df.index, 'levels', [df.index]))
215 index_columns.extend(df.index.get_level_values(i) for i in range(n))
216
217 for name in df.columns:
218 col = df[name]
219 if not isinstance(name, six.string_types):
220 name = str(name)
221
222 if schema is not None:
223 field = schema.field_by_name(name)
224 type = getattr(field, "type", None)
225
226 array = pa.Array.from_pandas(
227 col, type=type, timestamps_to_ms=timestamps_to_ms
228 )
229 arrays.append(array)
230 names.append(name)
231 types.append(array.type)
232
233 for i, column in enumerate(index_columns):
234 array = pa.Array.from_pandas(column, timestamps_to_ms=timestamps_to_ms)
235 arrays.append(array)
236 names.append(index_level_name(column, i))
237 types.append(array.type)
238
239 metadata = construct_metadata(
240 df, names, index_columns, preserve_index, types
241 )
242 return names, arrays, metadata
243
244
245 def maybe_coerce_datetime64(values, dtype, type_, timestamps_to_ms=False):
246 if timestamps_to_ms:
247 import warnings
248 warnings.warn('timestamps_to_ms=True is deprecated', FutureWarning)
249
250 from pyarrow.compat import DatetimeTZDtype
251
252 if values.dtype.type != np.datetime64:
253 return values, type_
254
255 coerce_ms = timestamps_to_ms and values.dtype != 'datetime64[ms]'
256
257 if coerce_ms:
258 values = values.astype('datetime64[ms]')
259 type_ = pa.timestamp('ms')
260
261 if isinstance(dtype, DatetimeTZDtype):
262 tz = dtype.tz
263 unit = 'ms' if coerce_ms else dtype.unit
264 type_ = pa.timestamp(unit, tz)
265 elif type_ is None:
266 # Trust the NumPy dtype
267 type_ = pa.from_numpy_dtype(values.dtype)
268
269 return values, type_
270
271
272 def table_to_blockmanager(options, table, memory_pool, nthreads=1):
273 import pandas.core.internals as _int
274 from pyarrow.compat import DatetimeTZDtype
275 import pyarrow.lib as lib
276
277 block_table = table
278
279 index_columns = []
280 index_arrays = []
281 index_names = []
282 schema = table.schema
283 row_count = table.num_rows
284 metadata = schema.metadata
285
286 if metadata is not None and b'pandas' in metadata:
287 pandas_metadata = json.loads(metadata[b'pandas'].decode('utf8'))
288 index_columns = pandas_metadata['index_columns']
289
290 for name in index_columns:
291 i = schema.get_field_index(name)
292 if i != -1:
293 col = table.column(i)
294 index_name = (None if is_unnamed_index_level(name)
295 else name)
296 values = col.to_pandas().values
297 if not values.flags.writeable:
298 # ARROW-1054: in pandas 0.19.2, factorize will reject
299 # non-writeable arrays when calling MultiIndex.from_arrays
300 values = values.copy()
301
302 index_arrays.append(values)
303 index_names.append(index_name)
304 block_table = block_table.remove_column(
305 block_table.schema.get_field_index(name)
306 )
307
308 result = lib.table_to_blocks(options, block_table, nthreads, memory_pool)
309
310 blocks = []
311 for item in result:
312 block_arr = item['block']
313 placement = item['placement']
314 if 'dictionary' in item:
315 cat = pd.Categorical(block_arr,
316 categories=item['dictionary'],
317 ordered=item['ordered'], fastpath=True)
318 block = _int.make_block(cat, placement=placement,
319 klass=_int.CategoricalBlock,
320 fastpath=True)
321 elif 'timezone' in item:
322 dtype = DatetimeTZDtype('ns', tz=item['timezone'])
323 block = _int.make_block(block_arr, placement=placement,
324 klass=_int.DatetimeTZBlock,
325 dtype=dtype, fastpath=True)
326 else:
327 block = _int.make_block(block_arr, placement=placement)
328 blocks.append(block)
329
330 if len(index_arrays) > 1:
331 index = pd.MultiIndex.from_arrays(index_arrays, names=index_names)
332 elif len(index_arrays) == 1:
333 index = pd.Index(index_arrays[0], name=index_names[0])
334 else:
335 index = pd.RangeIndex(row_count)
336
337 axes = [
338 [column.name for column in block_table.itercolumns()],
339 index
340 ]
341
342 return _int.BlockManager(blocks, axes)