diff --git a/doc/source/v0.14.1.txt b/doc/source/v0.14.1.txt index 850e7e13db..6188f182f1 100644 --- a/doc/source/v0.14.1.txt +++ b/doc/source/v0.14.1.txt @@ -176,6 +176,8 @@ Bug Fixes - Bug in ``to_timedelta`` that accepted invalid units and misinterpreted 'm/h' (:issue:`7611`, :issue: `6423`) - Bug in grouped ``hist`` and ``scatter`` plots use old ``figsize`` default (:issue:`7394`) +- Bug in plotting subplots with ``DataFrame.plot``, ``hist`` clears passed ``ax`` even if the number of subplots is one (:issue:`7391`). +- Bug in plotting subplots with ``DataFrame.boxplot`` with ``by`` kw raises ``ValueError`` if the number of subplots exceeds 1 (:issue:`7391`). - Bug in ``Panel.apply`` with a multi-index as an axis (:issue:`7469`) diff --git a/pandas/tests/test_graphics.py b/pandas/tests/test_graphics.py index d19d071833..729aa83647 100644 --- a/pandas/tests/test_graphics.py +++ b/pandas/tests/test_graphics.py @@ -859,6 +859,13 @@ def test_plot(self): axes = _check_plot_works(df.plot, kind='bar', subplots=True) self._check_axes_shape(axes, axes_num=1, layout=(1, 1)) + # When ax is supplied and required number of axes is 1, + # passed ax should be used: + fig, ax = self.plt.subplots() + axes = df.plot(kind='bar', subplots=True, ax=ax) + self.assertEqual(len(axes), 1) + self.assertIs(ax.get_axes(), axes[0]) + def test_nonnumeric_exclude(self): df = DataFrame({'A': ["x", "y", "z"], 'B': [1, 2, 3]}) ax = df.plot() @@ -1419,17 +1426,23 @@ def test_boxplot(self): df = DataFrame(np.random.rand(10, 2), columns=['Col1', 'Col2']) df['X'] = Series(['A', 'A', 'A', 'A', 'A', 'B', 'B', 'B', 'B', 'B']) + df['Y'] = Series(['A'] * 10) _check_plot_works(df.boxplot, by='X') - # When ax is supplied, existing axes should be used: + # When ax is supplied and required number of axes is 1, + # passed ax should be used: fig, ax = self.plt.subplots() axes = df.boxplot('Col1', by='X', ax=ax) self.assertIs(ax.get_axes(), axes) - # Multiple columns with an ax argument is not supported fig, ax = self.plt.subplots() - with tm.assertRaisesRegexp(ValueError, 'existing axis'): - df.boxplot(column=['Col1', 'Col2'], by='X', ax=ax) + axes = df.groupby('Y').boxplot(ax=ax, return_type='axes') + self.assertIs(ax.get_axes(), axes['A']) + + # Multiple columns with an ax argument should use same figure + fig, ax = self.plt.subplots() + axes = df.boxplot(column=['Col1', 'Col2'], by='X', ax=ax, return_type='axes') + self.assertIs(axes['Col1'].get_figure(), fig) # When by is None, check that all relevant lines are present in the dict fig, ax = self.plt.subplots() @@ -2180,32 +2193,32 @@ class TestDataFrameGroupByPlots(TestPlotBase): @slow def test_boxplot(self): grouped = self.hist_df.groupby(by='gender') - box = _check_plot_works(grouped.boxplot, return_type='dict') - self._check_axes_shape(self.plt.gcf().axes, axes_num=2, layout=(1, 2)) + axes = _check_plot_works(grouped.boxplot, return_type='axes') + self._check_axes_shape(axes.values(), axes_num=2, layout=(1, 2)) - box = _check_plot_works(grouped.boxplot, subplots=False, - return_type='dict') - self._check_axes_shape(self.plt.gcf().axes, axes_num=2, layout=(1, 2)) + axes = _check_plot_works(grouped.boxplot, subplots=False, + return_type='axes') + self._check_axes_shape(axes, axes_num=1, layout=(1, 1)) tuples = lzip(string.ascii_letters[:10], range(10)) df = DataFrame(np.random.rand(10, 3), index=MultiIndex.from_tuples(tuples)) grouped = df.groupby(level=1) - box = _check_plot_works(grouped.boxplot, return_type='dict') - self._check_axes_shape(self.plt.gcf().axes, axes_num=10, layout=(4, 3)) + axes = _check_plot_works(grouped.boxplot, return_type='axes') + self._check_axes_shape(axes.values(), axes_num=10, layout=(4, 3)) - box = _check_plot_works(grouped.boxplot, subplots=False, - return_type='dict') - self._check_axes_shape(self.plt.gcf().axes, axes_num=10, layout=(4, 3)) + axes = _check_plot_works(grouped.boxplot, subplots=False, + return_type='axes') + self._check_axes_shape(axes, axes_num=1, layout=(1, 1)) grouped = df.unstack(level=1).groupby(level=0, axis=1) - box = _check_plot_works(grouped.boxplot, return_type='dict') - self._check_axes_shape(self.plt.gcf().axes, axes_num=3, layout=(2, 2)) + axes = _check_plot_works(grouped.boxplot, return_type='axes') + self._check_axes_shape(axes.values(), axes_num=3, layout=(2, 2)) - box = _check_plot_works(grouped.boxplot, subplots=False, - return_type='dict') - self._check_axes_shape(self.plt.gcf().axes, axes_num=3, layout=(2, 2)) + axes = _check_plot_works(grouped.boxplot, subplots=False, + return_type='axes') + self._check_axes_shape(axes, axes_num=1, layout=(1, 1)) def test_series_plot_color_kwargs(self): # GH1890 diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index 2b02523c14..779aa328e8 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -2665,7 +2665,8 @@ def plot_group(group, ax): def boxplot_frame_groupby(grouped, subplots=True, column=None, fontsize=None, - rot=0, grid=True, figsize=None, layout=None, **kwds): + rot=0, grid=True, ax=None, figsize=None, + layout=None, **kwds): """ Make box plots from DataFrameGroupBy data. @@ -2712,7 +2713,7 @@ def boxplot_frame_groupby(grouped, subplots=True, column=None, fontsize=None, naxes = len(grouped) nrows, ncols = _get_layout(naxes, layout=layout) fig, axes = _subplots(nrows=nrows, ncols=ncols, naxes=naxes, squeeze=False, - sharex=False, sharey=True) + ax=ax, sharex=False, sharey=True, figsize=figsize) axes = _flatten(axes) ret = compat.OrderedDict() @@ -2733,7 +2734,7 @@ def boxplot_frame_groupby(grouped, subplots=True, column=None, fontsize=None, else: df = frames[0] ret = df.boxplot(column=column, fontsize=fontsize, rot=rot, - grid=grid, figsize=figsize, layout=layout, **kwds) + grid=grid, ax=ax, figsize=figsize, layout=layout, **kwds) return ret @@ -2779,17 +2780,10 @@ def _grouped_plot_by_column(plotf, data, columns=None, by=None, by = [by] columns = data._get_numeric_data().columns - by naxes = len(columns) - - if ax is None: - nrows, ncols = _get_layout(naxes, layout=layout) - fig, axes = _subplots(nrows=nrows, ncols=ncols, naxes=naxes, - sharex=True, sharey=True, - figsize=figsize, ax=ax) - else: - if naxes > 1: - raise ValueError("Using an existing axis is not supported when plotting multiple columns.") - fig = ax.get_figure() - axes = ax.get_axes() + nrows, ncols = _get_layout(naxes, layout=layout) + fig, axes = _subplots(nrows=nrows, ncols=ncols, naxes=naxes, + sharex=True, sharey=True, + figsize=figsize, ax=ax) ravel_axes = _flatten(axes) @@ -2974,12 +2968,6 @@ def _subplots(nrows=1, ncols=1, naxes=None, sharex=False, sharey=False, squeeze= if subplot_kw is None: subplot_kw = {} - if ax is None: - fig = plt.figure(**fig_kw) - else: - fig = ax.get_figure() - fig.clear() - # Create empty object array to hold all axes. It's easiest to make it 1-d # so we can just append subplots upon creation, and then nplots = nrows * ncols @@ -2989,6 +2977,21 @@ def _subplots(nrows=1, ncols=1, naxes=None, sharex=False, sharey=False, squeeze= elif nplots < naxes: raise ValueError("naxes {0} is larger than layour size defined by nrows * ncols".format(naxes)) + if ax is None: + fig = plt.figure(**fig_kw) + else: + fig = ax.get_figure() + # if ax is passed and a number of subplots is 1, return ax as it is + if naxes == 1: + if squeeze: + return fig, ax + else: + return fig, _flatten(ax) + else: + warnings.warn("To output multiple subplots, the figure containing the passed axes " + "is being cleared", UserWarning) + fig.clear() + axarr = np.empty(nplots, dtype=object) def on_right(i):