Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions elasticdata/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import copy
from importlib import import_module
from django.conf import settings
from elasticsearch import Elasticsearch, helpers, TransportError
from elasticsearch import Elasticsearch, helpers, TransportError, NotFoundError
from datetime import datetime

from .repository import BaseRepository
Expand Down Expand Up @@ -194,11 +194,19 @@ def flush(self, refresh=False):
if persisted_entity.is_action_needed():
actions.append(persisted_entity)
self._execute_callbacks(actions, 'pre')
bulk_results = helpers.streaming_bulk(self.es, [a.stmt for a in actions], refresh=refresh)
# TODO: checking exceptions in bulk_results
for persisted_entity, result in zip(actions, bulk_results):
if 'create' in result[1]:
persisted_entity.set_id(result[1]['create']['_id'])
try:
bulk_results = helpers.streaming_bulk(self.es, [a.stmt for a in actions], refresh=refresh, raise_on_error=False)
except TransportError as e:
raise RepositoryError('Transport returned error', cause=e)
for persisted_entity, (success, result) in zip(actions, bulk_results):
entity = persisted_entity._entity
if not success:
if 'delete' in result:
if result['delete']['found'] == False:
raise EntityNotFound(self.entity_not_found_message(entity.type, entity['id']))
raise RepositoryError(six.text_type(result))
if 'create' in result:
persisted_entity.set_id(result['create']['_id'])
for action in actions:
action.reset_state()
self._execute_callbacks(actions, 'post')
Expand All @@ -210,10 +218,10 @@ def find(self, _id, _type, scope=None, **kwargs):
params.update(kwargs)
try:
_data = self.es.get(**params)
except TransportError as e: # TODO: the might be other errors like server unavaliable
raise EntityNotFound(self.entity_not_found_message(_type.get_type(), _id), e)
if not _data['found']:
except NotFoundError:
raise EntityNotFound(self.entity_not_found_message(_type.get_type(), _id))
except TransportError as e:
raise RepositoryError('Transport returned error', cause=e)
source = _data['_source']
source['id'] = _data['_id']
entity = _type(source, scope)
Expand All @@ -232,8 +240,8 @@ def find_many(self, _ids, _type, scope=None, complete_data=True, **kwargs):
params.update(kwargs)
try:
_data = self.es.mget(**params)
except TransportError as e: # TODO: the might be other errors like server unavaliable
raise EntityNotFound(self.entity_not_found_message(_type.get_type(), ', '.join(_ids)), e)
except TransportError as e:
raise RepositoryError('Transport returned error', cause=e)
entities = []
if complete_data:
invalid_items = [elem['_id'] for elem in _data['docs'] if not elem['found']]
Expand Down
88 changes: 54 additions & 34 deletions tests/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,11 @@ def tearDown(self):
es = Elasticsearch()
es.indices.delete(index=self._index, ignore=[404])

@property
def em(self):
return EntityManager(index=self._index)

def test_persist(self):
em = self.em
em = self.em()
e = ManagerTestType({'foo': 'bar'})
em.persist(e)
self.assertEqual(len(em._registry), 1)
Expand All @@ -171,15 +170,15 @@ def test_persist(self):
self.assertRaises(TypeError, em.persist, dict())

def test_remove(self):
em = self.em
em = self.em()
e = ManagerTestType({'foo': 'bar'})
em.persist(e)
self.assertEqual(list(em._registry.values())[0].state, ADD)
em.remove(e)
self.assertEqual(list(em._registry.values())[0].state, REMOVE)

def test_flush(self):
em = self.em
em = self.em()
e = ManagerTestType({'foo': 'bar'})
em.persist(e)
e2 = ManagerTestType({'bar': 'baz'})
Expand All @@ -192,9 +191,30 @@ def test_flush(self):
self.assertTrue(all(map(lambda pe: pe.state == UPDATE, em._registry.values())))
em.flush()

def test_delete_error(self):
em = self.em()
em2 = self.em()
e = ManagerTestType({'foo': 'bar'})
em.persist(e)
em.flush()
fe = em2.find(e['id'], ManagerTestType)
em2.remove(fe)
em2.flush()
em.remove(fe)
self.assertRaises(EntityNotFound, em.flush)

def test_create_error(self):
em = self.em()
e1 = ManagerTestType({'foo': 123})
em.persist(e1)
em.flush()
e2 = ManagerTestType({'foo': 'bar'})
em.persist(e2)
self.assertRaises(RepositoryError, em.flush)

def test_find(self):
em = self.em
em2 = self.em
em = self.em()
em2 = self.em()
e = ManagerTestType({'foo': 'bar'})
em.persist(e)
em.flush()
Expand All @@ -203,8 +223,8 @@ def test_find(self):
self.assertRaises(EntityNotFound, em2.find, 'non-exists', ManagerTestType)

def test_find_updated(self):
em = self.em
em2 = self.em
em = self.em()
em2 = self.em()
e = ManagerTestType({'foo': 'bar'})
em.persist(e)
em.flush()
Expand All @@ -214,8 +234,8 @@ def test_find_updated(self):
self.assertDictEqual(e.to_representation(), fe.to_representation())

def test_find_many(self):
em = self.em
em2 = self.em
em = self.em()
em2 = self.em()
e = ManagerTestType({'foo': 'bar'})
e2 = ManagerTestType({'bar': 'baz'})
em.persist(e)
Expand All @@ -230,8 +250,8 @@ def _ids_set(ents):
self.assertEqual(_ids_set(fe), _ids_set(fe2))

def test_query(self):
em = self.em
em2 = self.em
em = self.em()
em2 = self.em()
e = ManagerTestType({'foo': 'value', 'bar': 'baz', 'baz': 'foo'})
e2 = ManagerTestType({'foo': 'value', 'bar': 'baz', 'baz': 'foo'})
e3 = ManagerTestType({'foo': 'value', 'bar': 'baz', 'baz': 'foo'})
Expand All @@ -244,8 +264,8 @@ def test_query(self):
self.assertEqual(len(fe), 3)

def test_query_one(self):
em = self.em
em2 = self.em
em = self.em()
em2 = self.em()
e = ManagerTestType({'foo': 'bar'})
em.persist(e)
em.flush()
Expand All @@ -259,8 +279,8 @@ def test_query_one(self):
self.assertRaises(RepositoryError, em2.query_one, {'query': {'term': {'foo': {'value': 'bar'}}}}, ManagerTestType)

def test_find_scope(self):
em = self.em
em2 = self.em
em = self.em()
em2 = self.em()
e = ManagerTestType({'foo': 'bar', 'bar': 'baz'})
em.persist(e)
em.flush()
Expand All @@ -273,8 +293,8 @@ def test_find_scope(self):
self.assertFalse(key in fe2.keys())

def test_find_many_scope(self):
em = self.em
em2 = self.em
em = self.em()
em2 = self.em()
e = ManagerTestType({'foo': 'bar', 'bar': 'baz'})
e2 = ManagerTestType({'foo': 'bar', 'bar': 'baz'})
em.persist(e)
Expand All @@ -289,8 +309,8 @@ def test_find_many_scope(self):
self.assertFalse(key in e.keys())

def test_query_scope(self):
em = self.em
em2 = self.em
em = self.em()
em2 = self.em()
e = ManagerTestType({'foo': 'value', 'bar': 'baz', 'baz': 'foo'})
e2 = ManagerTestType({'foo': 'value', 'bar': 'baz', 'baz': 'foo'})
e3 = ManagerTestType({'foo': 'value', 'bar': 'baz', 'baz': 'foo'})
Expand All @@ -309,8 +329,8 @@ def test_query_scope(self):
self.assertFalse(key in e.keys())

def test_query_one_scope(self):
em = self.em
em2 = self.em
em = self.em()
em2 = self.em()
e = ManagerTestType({'foo': 'value', 'bar': 'baz', 'baz': 'foo'})
em.persist(e)
em.flush()
Expand All @@ -322,7 +342,7 @@ def test_query_one_scope(self):
self.assertFalse(key in fe.keys())

def test_timestamps(self):
em = self.em
em = self.em()
e = TimestampedType({'foo': 'bar'})
em.persist(e)
em.flush()
Expand All @@ -337,36 +357,36 @@ def test_timestamps(self):
self.assertIsInstance(values['updated_at'], datetime)

def test_pre_create_callback(self):
em = self.em
em = self.em()
e = ManagerCallbacksTestType({'foo': 'bar'})
em.persist(e)
em.flush()
em2 = self.em
em2 = self.em()
e2 = em2.find(e['id'], ManagerCallbacksTestType)
self.assertEqual(e2['pre_create'], 'bar')

def test_post_create_callback(self):
with patch.object(ManagerCallbacksTestType, 'post_create') as mock:
em = self.em
em = self.em()
e = ManagerCallbacksTestType({'foo': 'bar'})
em.persist(e)
em.flush()
mock.assert_called_with(em)

def test_pre_update_callback(self):
em = self.em
em = self.em()
e = ManagerCallbacksTestType({'foo': 'bar'})
em.persist(e)
em.flush()
e['bar'] = 'baz'
em.flush()
em2 = self.em
em2 = self.em()
e2 = em2.find(e['id'], ManagerCallbacksTestType)
self.assertEqual(e2['pre_update'], 'bar')

def test_post_update_callback(self):
with patch.object(ManagerCallbacksTestType, 'post_update') as mock:
em = self.em
em = self.em()
e = ManagerCallbacksTestType({'foo': 'bar'})
em.persist(e)
em.flush()
Expand All @@ -376,7 +396,7 @@ def test_post_update_callback(self):

def test_pre_delete_callback(self):
with patch.object(ManagerCallbacksTestType, 'pre_delete') as mock:
em = self.em
em = self.em()
e = ManagerCallbacksTestType({'foo': 'bar'})
em.persist(e)
em.flush()
Expand All @@ -386,7 +406,7 @@ def test_pre_delete_callback(self):

def test_post_delete_callback(self):
with patch.object(ManagerCallbacksTestType, 'post_delete') as mock:
em = self.em
em = self.em()
e = ManagerCallbacksTestType({'foo': 'bar'})
em.persist(e)
em.flush()
Expand All @@ -395,16 +415,16 @@ def test_post_delete_callback(self):
mock.assert_called_with(em)

def test_clear(self):
em = self.em
em = self.em()
e = ManagerTestType({'foo': 'bar'})
em.persist(e)
self.assertEqual(len(em._registry), 1)
em.clear()
self.assertEqual(len(em._registry), 0)

def test_highlight_query(self):
em = self.em
em2 = self.em
em = self.em()
em2 = self.em()
e = ManagerTestType({'foo': 'bar foo'})
em.persist(e)
em.flush()
Expand Down