from collections import Mapping, Sequence
from numbers import Number
import pandas
from .space import Space
class Connection(object):
"""Abstract connection class that defines the database connection API.
"""
def lock(self):
raise NotImplementedError
def all_results(self):
raise NotImplementedError
def find_results(self, filter):
raise NotImplementedError
def insert_result(self, entry):
raise NotImplementedError
def update_result(self, entry, value):
raise NotImplementedError
def count_results(self):
raise NotImplementedError
def all_complementary(self):
raise NotImplementedError
def insert_complementary(self, document):
raise NotImplementedError
def find_complementary(self, filter):
raise NotImplementedError
def get_space(self):
raise NotImplementedError
def insert_space(self, space):
raise NotImplementedError
def clear(self):
raise NotImplementedError
def pop_id(self, document):
raise NotImplementedError
def results_as_dataframe(self):
"""Compile all the results and transform them using the space specified in the database. It is safe to
use this method while other experiments are still writing to the database.
Returns:
A :class:`pandas.DataFrame` containing all results with its ``"_chocolate_id"`` as ``"id"``,
their parameters and its loss. Pending results have a loss of :data:`None`.
"""
with self.lock():
s = self.get_space()
results = self.all_results()
all_results = []
for r in results:
result = s([r[k] for k in s.names()])
# Find all losses
losses = {k: v for k, v in r.items() if k.startswith("_loss")}
if all(l is not None for l in losses.values()):
result.update(losses)
result["id"] = r["_chocolate_id"]
all_results.append(result)
df = pandas.DataFrame.from_dict(all_results)
df.index = df.id
df.drop("id", inplace=True, axis=1)
return df
class SearchAlgorithm(object):
"""Base class for search algorithms. Other than providing the :meth:`update` method
it ensures the provided space fits with the one int the database.
"""
def __init__(self, connection, space=None, crossvalidation=None, clear_db=False):
if space is not None and not isinstance(space, Space):
space = Space(space)
self.conn = connection
with self.conn.lock():
db_space = self.conn.get_space()
if space is None and db_space is None:
raise RuntimeError("The database does not contain any space, please provide one through"
"the 'space' argument")
elif space is not None and db_space is not None:
if space != db_space and clear_db is False:
raise RuntimeError("The provided space and database space are different. To overwrite "
"the space contained in the database set the 'clear_db' argument")
elif space != db_space and clear_db is True:
self.conn.clear()
self.conn.insert_space(space)
elif space is not None and db_space is None:
self.conn.insert_space(space)
elif space is None and db_space is not None:
space = db_space
self.space = space
self.crossvalidation = crossvalidation
if self.crossvalidation is not None:
self.crossvalidation = crossvalidation
self.crossvalidation.wrap_connection(connection)
def update(self, token, values):
"""Update the loss of the parameters associated with *token*.
Args:
token: A token generated by the sampling algorithm for the current
parameters
values: The loss of the current parameter set. The values can be a
single :class:`Number`, a :class:`Sequence` or a :class:`Mapping`.
When a sequence is given, the column name is set to "_loss_i" where
"i" is the index of the value. When a mapping is given, each key
is prefixed with the string "_loss_".
"""
if isinstance(values, Sequence):
values = {"_loss_{}".format(i): v for i, v in enumerate(values)}
elif isinstance(values, Mapping):
values = {"_loss_{}".format(k): v for k, v in values.items()}
elif isinstance(values, Number):
values = {"_loss": values}
with self.conn.lock():
self.conn.update_result(token, values)
def next(self):
"""Retrieve the next point to evaluate based on available data in the
database.
Returns:
A tuple containing a unique token and a fully qualified parameter set.
"""
with self.conn.lock():
if self.crossvalidation is not None:
reps_token, params = self.crossvalidation.next()
if reps_token is not None and params is not None:
return reps_token, params
elif reps_token is not None and params is None:
token, params = self._next(reps_token)
return token, params
return self._next()
def _next(self, token=None):
raise NotImplementedError