diff --git a/doc/source/v0.14.1.txt b/doc/source/v0.14.1.txt index 3159bbfc34..a92cb54b90 100644 --- a/doc/source/v0.14.1.txt +++ b/doc/source/v0.14.1.txt @@ -280,3 +280,5 @@ Bug Fixes - Bug in ``pandas.core.strings.str_contains`` does not properly match in a case insensitive fashion when ``regex=False`` and ``case=False`` (:issue:`7505`) - Bug in ``expanding_cov``, ``expanding_corr``, ``rolling_cov``, and ``rolling_corr`` for two arguments with mismatched index (:issue:`7512`) +- Bug in grouped `hist` doesn't handle `rot` kw and `sharex` kw properly (:issue:`7234`) + diff --git a/pandas/tests/test_graphics.py b/pandas/tests/test_graphics.py index 2f631e28bf..ddd0477a46 100644 --- a/pandas/tests/test_graphics.py +++ b/pandas/tests/test_graphics.py @@ -2272,6 +2272,8 @@ def test_time_series_plot_color_with_empty_kwargs(self): def test_grouped_hist(self): df = DataFrame(randn(500, 2), columns=['A', 'B']) df['C'] = np.random.randint(0, 4, 500) + df['D'] = ['X'] * 500 + axes = plotting.grouped_hist(df.A, by=df.C) self._check_axes_shape(axes, axes_num=4, layout=(2, 2)) @@ -2280,14 +2282,24 @@ def test_grouped_hist(self): self._check_axes_shape(axes, axes_num=4, layout=(2, 2)) tm.close() + # group by a key with single value + axes = df.hist(by='D', rot=30) + self._check_axes_shape(axes, axes_num=1, layout=(1, 1)) + self._check_ticks_props(axes, xrot=30) + + tm.close() # make sure kwargs to hist are handled + xf, yf = 20, 18 + xrot, yrot = 30, 40 axes = plotting.grouped_hist(df.A, by=df.C, normed=True, - cumulative=True, bins=4) - + cumulative=True, bins=4, + xlabelsize=xf, xrot=xrot, ylabelsize=yf, yrot=yrot) # height of last bin (index 5) must be 1.0 for ax in axes.ravel(): height = ax.get_children()[5].get_height() self.assertAlmostEqual(height, 1.0) + self._check_ticks_props(axes, xlabelsize=xf, xrot=xrot, + ylabelsize=yf, yrot=yrot) tm.close() axes = plotting.grouped_hist(df.A, by=df.C, log=True) diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index 98c802ac10..f4e9b1a0f7 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -2501,23 +2501,12 @@ def hist_frame(data, column=None, by=None, grid=True, xlabelsize=None, kwds : other plotting keyword arguments To be passed to hist function """ - import matplotlib.pyplot as plt if by is not None: axes = grouped_hist(data, column=column, by=by, ax=ax, grid=grid, figsize=figsize, sharex=sharex, sharey=sharey, layout=layout, bins=bins, + xlabelsize=xlabelsize, xrot=xrot, ylabelsize=ylabelsize, yrot=yrot, **kwds) - - for ax in axes.ravel(): - if xlabelsize is not None: - plt.setp(ax.get_xticklabels(), fontsize=xlabelsize) - if xrot is not None: - plt.setp(ax.get_xticklabels(), rotation=xrot) - if ylabelsize is not None: - plt.setp(ax.get_yticklabels(), fontsize=ylabelsize) - if yrot is not None: - plt.setp(ax.get_yticklabels(), rotation=yrot) - return axes if column is not None: @@ -2533,21 +2522,12 @@ def hist_frame(data, column=None, by=None, grid=True, xlabelsize=None, for i, col in enumerate(com._try_sort(data.columns)): ax = axes[i // ncols, i % ncols] - ax.xaxis.set_visible(True) - ax.yaxis.set_visible(True) ax.hist(data[col].dropna().values, bins=bins, **kwds) ax.set_title(col) ax.grid(grid) - if xlabelsize is not None: - plt.setp(ax.get_xticklabels(), fontsize=xlabelsize) - if xrot is not None: - plt.setp(ax.get_xticklabels(), rotation=xrot) - if ylabelsize is not None: - plt.setp(ax.get_yticklabels(), fontsize=ylabelsize) - if yrot is not None: - plt.setp(ax.get_yticklabels(), rotation=yrot) - + _set_ticks_props(axes, xlabelsize=xlabelsize, xrot=xrot, + ylabelsize=ylabelsize, yrot=yrot) fig.subplots_adjust(wspace=0.3, hspace=0.3) return axes @@ -2607,23 +2587,18 @@ def hist_series(self, by=None, ax=None, grid=True, xlabelsize=None, ax.hist(values, bins=bins, **kwds) ax.grid(grid) axes = np.array([ax]) + + _set_ticks_props(axes, xlabelsize=xlabelsize, xrot=xrot, + ylabelsize=ylabelsize, yrot=yrot) + else: if 'figure' in kwds: raise ValueError("Cannot pass 'figure' when using the " "'by' argument, since a new 'Figure' instance " "will be created") - axes = grouped_hist(self, by=by, ax=ax, grid=grid, figsize=figsize, - bins=bins, **kwds) - - for ax in axes.ravel(): - if xlabelsize is not None: - plt.setp(ax.get_xticklabels(), fontsize=xlabelsize) - if xrot is not None: - plt.setp(ax.get_xticklabels(), rotation=xrot) - if ylabelsize is not None: - plt.setp(ax.get_yticklabels(), fontsize=ylabelsize) - if yrot is not None: - plt.setp(ax.get_yticklabels(), rotation=yrot) + axes = grouped_hist(self, by=by, ax=ax, grid=grid, figsize=figsize, bins=bins, + xlabelsize=xlabelsize, xrot=xrot, ylabelsize=ylabelsize, yrot=yrot, + **kwds) if axes.ndim == 1 and len(axes) == 1: return axes[0] @@ -2632,6 +2607,7 @@ def hist_series(self, by=None, ax=None, grid=True, xlabelsize=None, def grouped_hist(data, column=None, by=None, ax=None, bins=50, figsize=None, layout=None, sharex=False, sharey=False, rot=90, grid=True, + xlabelsize=None, xrot=None, ylabelsize=None, yrot=None, **kwargs): """ Grouped histogram @@ -2658,9 +2634,15 @@ def grouped_hist(data, column=None, by=None, ax=None, bins=50, figsize=None, def plot_group(group, ax): ax.hist(group.dropna().values, bins=bins, **kwargs) + xrot = xrot or rot + fig, axes = _grouped_plot(plot_group, data, column=column, by=by, sharex=sharex, sharey=sharey, figsize=figsize, layout=layout, rot=rot) + + _set_ticks_props(axes, xlabelsize=xlabelsize, xrot=xrot, + ylabelsize=ylabelsize, yrot=yrot) + fig.subplots_adjust(bottom=0.15, top=0.9, left=0.1, right=0.9, hspace=0.5, wspace=0.3) return axes @@ -3094,6 +3076,22 @@ def _get_xlim(lines): return left, right +def _set_ticks_props(axes, xlabelsize=None, xrot=None, + ylabelsize=None, yrot=None): + import matplotlib.pyplot as plt + + for ax in _flatten(axes): + if xlabelsize is not None: + plt.setp(ax.get_xticklabels(), fontsize=xlabelsize) + if xrot is not None: + plt.setp(ax.get_xticklabels(), rotation=xrot) + if ylabelsize is not None: + plt.setp(ax.get_yticklabels(), fontsize=ylabelsize) + if yrot is not None: + plt.setp(ax.get_yticklabels(), rotation=yrot) + return axes + + if __name__ == '__main__': # import pandas.rpy.common as com # sales = com.load_data('sanfrancisco.home.sales', package='nutshell')