diff --git a/pandas/core/frame.py b/pandas/core/frame.py index fad348aed0..4811aff3a3 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -3003,10 +3003,9 @@ def combiner(x, y, needs_i8_conversion=False): return self.combine(other, combiner, overwrite=False) def update(self, other, join='left', overwrite=True, filter_func=None, - raise_conflict=False): + raise_conflict=False, on=None): """ - Modify DataFrame in place using non-NA values from passed - DataFrame. Aligns on indices + Modify DataFrame in place using non-NA values from passed DataFrame. Parameters ---------- @@ -3020,6 +3019,10 @@ def update(self, other, join='left', overwrite=True, filter_func=None, raise_conflict : boolean If True, will raise an error if the DataFrame and other both contain data in the same place. + on : label or list, optional + Identify the column to should match up observations in other and + self. If None, other.reindex_like(self) is called so the index + must match to get a meaningful result. """ # TODO: Support other joins if join != 'left': # pragma: no cover @@ -3028,31 +3031,55 @@ def update(self, other, join='left', overwrite=True, filter_func=None, if not isinstance(other, DataFrame): other = DataFrame(other) - other = other.reindex_like(self) + if on is None: + other = other.reindex(index=self.index) + else: + try: + old_index = self.index + col_order = self.columns + self.set_index(on, inplace=True) + other.set_index(on, inplace=True) + other = other.reindex(index=self.index) + except Exception, err: + self.reset_index(inplace=True) + self.set_index(old_index) + raise(err) - for col in self.columns: - this = self[col].values - that = other[col].values - if filter_func is not None: - mask = -filter_func(this) | isnull(that) - else: - if raise_conflict: - mask_this = notnull(that) - mask_that = notnull(this) - if any(mask_this & mask_that): - raise ValueError("Data overlaps.") + try: + for col in other.columns: + if col not in self: # don't update what doesn't exist + continue + this = self[col].values + that = other[col].values + if filter_func is not None: + mask = -filter_func(this) | isnull(that) + else: + if raise_conflict: + mask_this = notnull(that) + mask_that = notnull(this) + if any(mask_this & mask_that): + raise ValueError("Data overlaps.") + + if overwrite: + mask = isnull(that) + + # don't overwrite columns unecessarily + if mask.all(): + continue + else: + mask = notnull(this) - if overwrite: - mask = isnull(that) + self[col] = expressions.where( + mask, this, that, raise_on_error=True) - # don't overwrite columns unecessarily - if mask.all(): - continue - else: - mask = notnull(this) + except Exception, err: + raise(err) - self[col] = expressions.where( - mask, this, that, raise_on_error=True) + finally: + if on is not None: + self.reset_index(inplace=True) + self.set_index(old_index) + self = self[col_order] #---------------------------------------------------------------------- # Misc methods diff --git a/pandas/tests/test_frame.py b/pandas/tests/test_frame.py index 3c39d610c1..0326a73e52 100644 --- a/pandas/tests/test_frame.py +++ b/pandas/tests/test_frame.py @@ -24,7 +24,7 @@ from numpy.random import randn import numpy as np import numpy.ma as ma -from numpy.testing import assert_array_equal +from numpy.testing import assert_array_equal, assert_ import numpy.ma.mrecords as mrecords import pandas.core.nanops as nanops @@ -9974,6 +9974,43 @@ def test_update(self): [1.5, nan, 7.]]) assert_frame_equal(df, expected) + def test_update_on(self): + df = DataFrame([[np.nan, 'A'], + [np.nan, 'A'], + [np.nan, 'A'], + [1.5, 'B'], + [2.2, 'C'], + [3.1, 'C'], + [1.2, 'B']], columns=['number', 'name']) + + df2 = DataFrame([[3.5, 'A']], columns=['number', 'name']) + + expected = DataFrame([[3.5, 'A'], + [3.5, 'A'], + [3.5, 'A'], + [1.5, 'B'], + [2.2, 'C'], + [3.1, 'C'], + [1.2, 'B']], columns=['number', 'name']) + df.update(df2, on='name') + assert_frame_equal(df, expected) + + df = DataFrame([[np.nan, 'A'], + [np.nan, 'A'], + [np.nan, 'A'], + [1.5, 'B'], + [2.2, 'C'], + [3.1, 'C'], + [1.2, 'B']], columns=['number', 'name']) + + df2 = DataFrame([[3.5, 'A'], [2.5, 'A']], + columns=['number', 'name']) + + assertRaises(ValueError, df.update, df2, on='name') + + ## and the index should be reset + assert_(df.index.equals(pd.Index(range(7)))) + def test_update_dtypes(self): # gh 3016