Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions openml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
# License: BSD 3-Clause
from __future__ import annotations

from openml._get import get

from . import (
_api_calls,
config,
Expand Down Expand Up @@ -120,4 +122,5 @@ def populate_cache(
"utils",
"_api_calls",
"__version__",
"get",
]
11 changes: 11 additions & 0 deletions openml/_get.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Global get dispatch utility."""

# currently just a forward to models
# to discuss and possibly
# todo: add global get utility here
# in general, e.g., datasets will not have same name as models etc
from __future__ import annotations

from openml.models import get

__all__ = ["get"]
6 changes: 6 additions & 0 deletions openml/base/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Module of base classes."""

from openml.base._base import OpenMLBase
from openml.base._base_pkg import _BasePkg

__all__ = ["_BasePkg", "OpenMLBase"]
3 changes: 1 addition & 2 deletions openml/base.py → openml/base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@

import openml._api_calls
import openml.config

from .utils import _get_rest_api_type_alias, _tag_openml_base
from openml.utils import _get_rest_api_type_alias, _tag_openml_base


class OpenMLBase(ABC):
Expand Down
117 changes: 117 additions & 0 deletions openml/base/_base_pkg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""Base Packager class."""

from __future__ import annotations

import inspect
import sys
import textwrap
from pathlib import Path

from skbase.base import BaseObject
from skbase.utils.dependencies import _check_estimator_deps


class _BasePkg(BaseObject):
_tags = {
"python_dependencies": None,
"python_version": None,
# package register and manifest
"pkg_id": None, # object id contained, "__multiple" if multiple
"pkg_obj": "reference", # or "code"
"pkg_obj_type": None, # openml API type
"pkg_compression": "zlib", # compression
}

def __init__(self):
super().__init__()

def materialize(self):
try:
_check_estimator_deps(obj=self)
except ModuleNotFoundError as e:
# prettier message, so the reference is to the pkg_id
# currently, we cannot simply pass the object name to skbase
# in the error message, so this is a hack
# todo: fix this in scikit-base
msg = str(e)
if len(msg) > 11:
msg = msg[11:]
raise ModuleNotFoundError(msg) from e

return self._materialize()

def _materialize(self):
raise RuntimeError("abstract method")

def serialize(self):
cls_str = class_to_source(type(self))
compress_method = self.get_tag("pkg_compression")
if compress_method in [None, "None"]:
return cls_str

cls_str = cls_str.encode("utf-8")
exec(f"import {compress_method}")
return eval(f"{compress_method}.compress(cls_str)")


def _has_source(obj) -> bool:
"""Return True if inspect.getsource(obj) should succeed."""
module_name = getattr(obj, "__module__", None)
if not module_name or module_name not in sys.modules:
return False

module = sys.modules[module_name]
file = getattr(module, "__file__", None)
if not file:
return False

return Path(file).suffix == ".py"


def class_to_source(cls) -> str:
"""Return full source definition of python class as string.

Parameters
----------
cls : class to serialize

Returns
-------
str : complete definition of cls, as str.
Imports are not contained or serialized.
""" ""

# Fast path: class has retrievable source
if _has_source(cls):
source = inspect.getsource(cls)
return textwrap.dedent(source)

# Fallback for dynamically created classes
lines = []

bases = [base.__name__ for base in cls.__bases__ if base is not object]
base_str = f"({', '.join(bases)})" if bases else ""
lines.append(f"class {cls.__name__}{base_str}:")

body_added = False

for name, value in cls.__dict__.items():
if name.startswith("__") and name.endswith("__"):
continue

if inspect.isfunction(value):
if _has_source(value):
method_src = inspect.getsource(value)
method_src = textwrap.indent(textwrap.dedent(method_src), " ")
lines.append(method_src)
else:
lines.append(f" def {name}(self): ...")
body_added = True
else:
lines.append(f" {name} = {value!r}")
body_added = True

if not body_added:
lines.append(" pass")

return "\n".join(lines)
5 changes: 5 additions & 0 deletions openml/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Module with packaging adapters."""

from openml.models._get import get

__all__ = ["get"]
54 changes: 54 additions & 0 deletions openml/models/_get.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""Model retrieval utility."""

from __future__ import annotations

from functools import lru_cache


def get(id: str):
"""Retrieve model object with unique identifier.

Parameter
---------
id : str
unique identifier of object to retrieve

Returns
-------
class
retrieved object

Raises
------
ModuleNotFoundError
if dependencies of object to retrieve are not satisfied
"""
id_lookup = _id_lookup()
obj = id_lookup.get(id)
if obj is None:
raise ValueError(f"Error in openml.get, object with package id {id} " "does not exist.")
return obj().materialize()


# todo: need to generalize this later to more types
# currently intentionally retrieves only classifiers
# todo: replace this, optionally, by database backend
def _id_lookup(obj_type=None):
return _id_lookup_cached(obj_type=obj_type).copy()


@lru_cache
def _id_lookup_cached(obj_type=None):
all_objs = _all_objects(obj_type=obj_type)

# todo: generalize that pkg can contain more than one object
return {obj.get_class_tag("pkg_id"): obj for obj in all_objs}


@lru_cache
def _all_objects(obj_type=None):
from skbase.lookup import all_objects

from openml.models.apis._classifier import _ModelPkgClassifier

return all_objects(object_types=_ModelPkgClassifier, package_name="openml", return_names=False)
5 changes: 5 additions & 0 deletions openml/models/apis/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Module with packaging adapters."""

from openml.models.apis._classifier import _ModelPkgClassifier

__all__ = ["_ModelPkgClassifier"]
25 changes: 25 additions & 0 deletions openml/models/apis/_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""Base package for sklearn classifiers."""

from __future__ import annotations

from openml.models.base import _OpenmlModelPkg


class _ModelPkgClassifier(_OpenmlModelPkg):
_tags = {
# tags specific to API type
"pkg_obj_type": "classifier",
}

def get_obj_tags(self):
"""Return tags of the object as a dictionary."""
return {} # this needs to be implemented

def get_obj_param_names(self):
"""Return parameter names of the object as a list.

Returns
-------
list: names of object parameters
"""
return list(self.materialize()().get_params().keys())
5 changes: 5 additions & 0 deletions openml/models/base/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Module with packaging adapters."""

from openml.models.base._base import _OpenmlModelPkg

__all__ = ["_OpenmlModelPkg"]
40 changes: 40 additions & 0 deletions openml/models/base/_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Base model package class."""

from __future__ import annotations

from openml.base import _BasePkg


class _OpenmlModelPkg(_BasePkg):
_obj = None

def _materialize(self):
pkg_obj = self.get_tag("pkg_obj")

_obj = self._obj

if _obj is None:
raise ValueError(
"Error in materialize."
"Either _materialize must be implemented, or"
"the _obj attribute must be not None."
)

if pkg_obj == "reference":
from skbase.utils.dependencies import _safe_import

return _safe_import(self._obj)

if pkg_obj == "code":
exec(self._obj)

return obj

# elif pkg_obj == "craft":
# identify and call appropriate craft method

raise ValueError(
'Error in package tag "pkg_obj", '
'must be one of "reference", "code", "craft", '
f"but found value {pkg_obj}, of type {type(pkg_obj)}"
)
1 change: 1 addition & 0 deletions openml/models/classification/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Sklearn classification models."""
14 changes: 14 additions & 0 deletions openml/models/classification/auto_sklearn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""Auto-sklearn classifier."""

from __future__ import annotations

from openml.models.apis import _ModelPkgClassifier


class OpenmlPkg__AutoSklearnClassifier(_ModelPkgClassifier):
_tags = {
"pkg_id": "AutoSklearnClassifier",
"python_dependencies": "auto-sklearn",
}

_obj = "autosklearn.classification.AutoSklearnClassifier"
14 changes: 14 additions & 0 deletions openml/models/classification/xgboost.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""Xgboost classifier."""

from __future__ import annotations

from openml.models.apis import _ModelPkgClassifier


class OpenmlPkg__XGBClassifier(_ModelPkgClassifier):
_tags = {
"pkg_id": "XGBClassifier",
"python_dependencies": "xgboost",
}

_obj = "xgboost.XGBClassifier"
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ dependencies = [
"minio",
"pyarrow",
"tqdm", # For MinIO download progress bars
"scikit-base",
]
requires-python = ">=3.8"
maintainers = [
Expand Down