import logging
from copy import deepcopy
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import sessionmaker
from sqlalchemy.engine.url import make_url
from sqlalchemy import and_, insert, create_engine
from indra.statements import stmts_from_json, stmts_to_json
from indra.util import batch_iter
from . import schema as wms_schema
logger = logging.getLogger(__name__)
[docs]class DbManager:
"""Manages transactions with the assembly database and exposes an API
for various operations."""
def __init__(self, url):
self.url = make_url(url)
logger.info('Starting DB manager with URL: %s' % str(self.url))
self.engine = create_engine(self.url)
self.session = None
[docs] def get_session(self):
"""Return the current active session or create one if not available."""
if self.session is None:
session_maker = sessionmaker(bind=self.engine)
self.session = session_maker()
return self.session
[docs] def create_all(self):
"""Create all the database tables in the schema."""
wms_schema.Base.metadata.create_all(self.engine)
self.engine.execute('create index record_key_idx on '
'prepared_statements (record_key)')
[docs] def query(self, *query_args):
"""Run and return results of a generic query."""
session = self.get_session()
return session.query(*query_args)
[docs] def sql_query(self, query_str):
"""Run and return results of a generic SQL query."""
return self.engine.execute(query_str)
[docs] def execute(self, operation):
"""Execute an insert operation on the current session and return
results."""
session = self.get_session()
try:
res = session.execute(operation)
session.commit()
return {'rowcount': res.rowcount,
'inserted_primary_key': res.inserted_primary_key}
except SQLAlchemyError as e:
logger.error(e)
session.rollback()
return None
[docs] def add_project(self, project_id, name, corpus_id=None):
"""Add a new project.
Parameters
----------
project_id : str
The project ID.
name : str
The project name
corpus_id : Optional[str]
The corpus ID from which the project was derived, if
available.
"""
op = insert(wms_schema.Projects).values(id=project_id,
name=name,
corpus_id=corpus_id)
return self.execute(op)
[docs] def add_records_for_project(self, project_id, record_keys):
"""Add document IDs for a project with the given ID."""
op = insert(wms_schema.ProjectRecords).values(
[
{'project_id': project_id,
'record_key': rec_key}
for rec_key in record_keys
]
)
return self.execute(op)
def get_records_for_project(self, project_id):
qfilter = and_(wms_schema.ProjectRecords.project_id.like(project_id))
q = self.query(wms_schema.ProjectRecords.record_key).filter(qfilter)
record_keys = [r[0] for r in q.all()]
return record_keys
def get_documents_for_project(self, project_id):
qfilter = and_(
wms_schema.ProjectRecords.project_id.like(project_id),
(wms_schema.DartRecords.storage_key ==
wms_schema.ProjectRecords.record_key))
q = self.query(wms_schema.DartRecords.document_id).filter(qfilter)
doc_ids = sorted(set(r[0] for r in q.all()))
return doc_ids
[docs] def get_projects(self):
"""Retyurn a list of all projects."""
q = self.query(wms_schema.Projects)
projects = [{'id': p.id, 'name': p.name} for p in q.all()]
return projects
[docs] def get_corpus_for_project(self, project_id):
"""Return the corpus ID that a project was derived from, if available."""
q = self.query(wms_schema.Projects.corpus_id).filter(
wms_schema.Projects.id.like(project_id))
res = list(q.all())
if res:
return res[0][0]
else:
return None
[docs] def get_tenant_for_corpus(self, corpus_id):
"""Return the tenant for a given corpus, if available."""
q = self.query(wms_schema.Corpora.meta_data).filter(
wms_schema.Corpora.id.like(corpus_id))
res = list(q.all())
if res:
metadata = res[0][0]
return metadata.get('tenant')
else:
return None
def add_corpus(self, corpus_id, metadata):
op = insert(wms_schema.Corpora).values(id=corpus_id,
meta_data=metadata)
return self.execute(op)
def add_records_for_corpus(self, corpus_id, record_keys):
op = insert(wms_schema.CorpusRecords).values(
[
{'corpus_id': corpus_id,
'record_key': rec_key}
for rec_key in record_keys
]
)
return self.execute(op)
def get_records_for_corpus(self, corpus_id):
qfilter = and_(wms_schema.CorpusRecords.corpus_id.like(corpus_id))
q = self.query(wms_schema.CorpusRecords.record_key).filter(qfilter)
record_keys = [r[0] for r in q.all()]
return record_keys
def get_documents_for_corpus(self, corpus_id):
qfilter = and_(
wms_schema.CorpusRecords.corpus_id.like(corpus_id),
(wms_schema.DartRecords.storage_key ==
wms_schema.CorpusRecords.record_key))
q = self.query(wms_schema.DartRecords.document_id).filter(qfilter)
doc_ids = sorted(set(r[0] for r in q.all()))
return doc_ids
[docs] def add_statements_for_record(self, record_key, stmts, indra_version):
"""Add a set of prepared statements for a given document."""
if not stmts:
return None
op = insert(wms_schema.PreparedStatements).values(
[
{
'record_key': record_key,
'indra_version': indra_version,
'stmt': stmt
}
# Note: the deepcopy here is done because when dumping
# statements into JSON, the hash is overwritten, potentially
# with an inadequate one (due to a custom matches_fun not being
# given here).
for stmt in stmts_to_json(deepcopy(stmts))
]
)
return self.execute(op)
[docs] def add_curation_for_project(self, project_id, stmt_hash, curation):
"""Add curations for a given project."""
op = insert(wms_schema.Curations).values(project_id=project_id,
stmt_hash=stmt_hash,
curation=curation)
return self.execute(op)
[docs] def get_statements_for_record(self, record_key):
"""Return prepared statements for given record key."""
qfilter = wms_schema.PreparedStatements.record_key.like(record_key)
q = self.query(wms_schema.PreparedStatements.stmt).filter(qfilter)
stmts = stmts_from_json([r[0] for r in q.all()])
return stmts
[docs] def get_statements_for_records(self, record_keys, batch_size=1000):
"""Return prepared statements for given list of record keys."""
stmts = []
for record_key_batch in batch_iter(record_keys, batch_size, list):
qfilter = wms_schema.PreparedStatements.record_key.in_(
record_key_batch)
q = self.query(wms_schema.PreparedStatements.stmt).filter(qfilter)
stmts += stmts_from_json([r[0] for r in q.all()])
return stmts
[docs] def get_statements(self):
"""Return all prepared statements in the DB."""
q = self.query(wms_schema.PreparedStatements.stmt)
stmts = stmts_from_json([r[0] for r in q.all()])
return stmts
[docs] def get_statements_for_document(self, document_id, reader=None,
reader_version=None, indra_version=None):
"""Return prepared statements for a given document."""
qfilter = and_(
wms_schema.DartRecords.document_id.like(document_id),
wms_schema.DartRecords.storage_key ==
wms_schema.PreparedStatements.record_key)
if reader:
qfilter = and_(
qfilter,
wms_schema.DartRecords.reader.like(reader)
)
if reader_version:
qfilter = and_(
qfilter,
wms_schema.DartRecords.reader_version.like(reader_version)
)
if indra_version:
qfilter = and_(
qfilter,
wms_schema.PreparedStatements.indra_version.like(indra_version)
)
q = self.query(wms_schema.PreparedStatements.stmt).filter(qfilter)
stmts = stmts_from_json([r[0] for r in q.all()])
return stmts
[docs] def get_curations_for_project(self, project_id):
"""Return curations for a given project"""
qfilter = wms_schema.Curations.project_id.like(project_id)
q = self.query(wms_schema.Curations).filter(qfilter)
# Build a dict of stmt_hash: curation records
curations = {res.stmt_hash: res.curation for res in q.all()}
return curations
[docs] def add_dart_record(self, reader, reader_version, document_id, storage_key,
date, output_version=None, labels=None, tenants=None):
"""Insert a DART record into the database."""
op = insert(wms_schema.DartRecords).values(
**{
'reader': reader,
'reader_version': reader_version,
'document_id': document_id,
'storage_key': storage_key,
'date': date,
'output_version': output_version,
'labels': labels,
'tenants': tenants
}
)
return self.execute(op)
[docs] def get_dart_records(self, reader=None, document_id=None,
reader_version=None, output_version=None, labels=None,
tenants=None):
"""Return storage keys for DART records given constraints."""
records = self.get_full_dart_records(
reader=reader, document_id=document_id,
reader_version=reader_version,
output_version=output_version,
labels=labels, tenants=tenants)
return [r['storage_key'] for r in records]
[docs] def get_full_dart_records(self, reader=None, document_id=None,
reader_version=None, output_version=None,
labels=None, tenants=None):
"""Return full DART records given constraints."""
qfilter = None
if document_id:
qfilter = extend_filter(
qfilter,
wms_schema.DartRecords.document_id.like(document_id))
if reader:
qfilter = extend_filter(qfilter,
wms_schema.DartRecords.reader.like(reader))
if reader_version:
qfilter = extend_filter(qfilter, wms_schema.DartRecords.
reader_version.like(reader_version))
if output_version:
qfilter = extend_filter(qfilter, wms_schema.DartRecords.
output_version.like(output_version))
if qfilter is not None:
q = self.query(wms_schema.DartRecords).filter(qfilter)
else:
q = self.query(wms_schema.DartRecords)
record_keys = ['reader', 'reader_version', 'document_id', 'storage_key',
'date', 'output_version', 'labels', 'tenants']
records = [{k: r.__dict__.get(k) for k in record_keys} for r in q.all()]
# If some labels are given, we retain opnly records where there are
# some labels given and all the given labels are contained in those
if labels:
records = [r for r in records if r['labels'] and
set(labels) <= set(r['labels'].split('|'))]
if tenants:
records = [r for r in records if r['tenants'] and
set(tenants) <= set(r['tenants'].split('|'))]
return records
def extend_filter(qfilter, constraint):
if qfilter is None:
return constraint
else:
return and_(qfilter, constraint)