diff options
author | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-03-06 08:05:12 +0100 |
---|---|---|
committer | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-03-06 08:05:12 +0100 |
commit | 08636303e6bef99ea5f997a8ba4a41256aabee6b (patch) | |
tree | b8480c4d5289dedb06b3e3338953666900a849ee /ext/lightgbm/compat.py | |
parent | 112ebbe21ff330c68faf883b125ba4932e007544 (diff) |
import lightgbm
Diffstat (limited to 'ext/lightgbm/compat.py')
-rw-r--r-- | ext/lightgbm/compat.py | 269 |
1 files changed, 269 insertions, 0 deletions
diff --git a/ext/lightgbm/compat.py b/ext/lightgbm/compat.py new file mode 100644 index 0000000..bd1b29a --- /dev/null +++ b/ext/lightgbm/compat.py @@ -0,0 +1,269 @@ +# coding: utf-8 +"""Compatibility library.""" + +from typing import List + +"""pandas""" +try: + from pandas import DataFrame as pd_DataFrame + from pandas import Series as pd_Series + from pandas import concat + try: + from pandas import CategoricalDtype as pd_CategoricalDtype + except ImportError: + from pandas.api.types import CategoricalDtype as pd_CategoricalDtype + PANDAS_INSTALLED = True +except ImportError: + PANDAS_INSTALLED = False + + class pd_Series: # type: ignore + """Dummy class for pandas.Series.""" + + def __init__(self, *args, **kwargs): + pass + + class pd_DataFrame: # type: ignore + """Dummy class for pandas.DataFrame.""" + + def __init__(self, *args, **kwargs): + pass + + class pd_CategoricalDtype: # type: ignore + """Dummy class for pandas.CategoricalDtype.""" + + def __init__(self, *args, **kwargs): + pass + + concat = None + +"""numpy""" +try: + from numpy.random import Generator as np_random_Generator +except ImportError: + class np_random_Generator: # type: ignore + """Dummy class for np.random.Generator.""" + + def __init__(self, *args, **kwargs): + pass + +"""matplotlib""" +try: + import matplotlib # noqa: F401 + MATPLOTLIB_INSTALLED = True +except ImportError: + MATPLOTLIB_INSTALLED = False + +"""graphviz""" +try: + import graphviz # noqa: F401 + GRAPHVIZ_INSTALLED = True +except ImportError: + GRAPHVIZ_INSTALLED = False + +"""datatable""" +try: + import datatable + if hasattr(datatable, "Frame"): + dt_DataTable = datatable.Frame + else: + dt_DataTable = datatable.DataTable + DATATABLE_INSTALLED = True +except ImportError: + DATATABLE_INSTALLED = False + + class dt_DataTable: # type: ignore + """Dummy class for datatable.DataTable.""" + + def __init__(self, *args, **kwargs): + pass + + +"""sklearn""" +try: + from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin + from sklearn.preprocessing import LabelEncoder + from sklearn.utils.class_weight import compute_sample_weight + from sklearn.utils.multiclass import check_classification_targets + from sklearn.utils.validation import assert_all_finite, check_array, check_X_y + try: + from sklearn.exceptions import NotFittedError + from sklearn.model_selection import BaseCrossValidator, GroupKFold, StratifiedKFold + except ImportError: + from sklearn.cross_validation import BaseCrossValidator, GroupKFold, StratifiedKFold + from sklearn.utils.validation import NotFittedError + try: + from sklearn.utils.validation import _check_sample_weight + except ImportError: + from sklearn.utils.validation import check_consistent_length + + # dummy function to support older version of scikit-learn + def _check_sample_weight(sample_weight, X, dtype=None): + check_consistent_length(sample_weight, X) + return sample_weight + + SKLEARN_INSTALLED = True + _LGBMBaseCrossValidator = BaseCrossValidator + _LGBMModelBase = BaseEstimator + _LGBMRegressorBase = RegressorMixin + _LGBMClassifierBase = ClassifierMixin + _LGBMLabelEncoder = LabelEncoder + LGBMNotFittedError = NotFittedError + _LGBMStratifiedKFold = StratifiedKFold + _LGBMGroupKFold = GroupKFold + _LGBMCheckXY = check_X_y + _LGBMCheckArray = check_array + _LGBMCheckSampleWeight = _check_sample_weight + _LGBMAssertAllFinite = assert_all_finite + _LGBMCheckClassificationTargets = check_classification_targets + _LGBMComputeSampleWeight = compute_sample_weight +except ImportError: + SKLEARN_INSTALLED = False + + class _LGBMModelBase: # type: ignore + """Dummy class for sklearn.base.BaseEstimator.""" + + pass + + class _LGBMClassifierBase: # type: ignore + """Dummy class for sklearn.base.ClassifierMixin.""" + + pass + + class _LGBMRegressorBase: # type: ignore + """Dummy class for sklearn.base.RegressorMixin.""" + + pass + + _LGBMBaseCrossValidator = None + _LGBMLabelEncoder = None + LGBMNotFittedError = ValueError + _LGBMStratifiedKFold = None + _LGBMGroupKFold = None + _LGBMCheckXY = None + _LGBMCheckArray = None + _LGBMCheckSampleWeight = None + _LGBMAssertAllFinite = None + _LGBMCheckClassificationTargets = None + _LGBMComputeSampleWeight = None + +"""dask""" +try: + from dask import delayed + from dask.array import Array as dask_Array + from dask.array import from_delayed as dask_array_from_delayed + from dask.bag import from_delayed as dask_bag_from_delayed + from dask.dataframe import DataFrame as dask_DataFrame + from dask.dataframe import Series as dask_Series + from dask.distributed import Client, Future, default_client, wait + DASK_INSTALLED = True +except ImportError: + DASK_INSTALLED = False + + dask_array_from_delayed = None # type: ignore[assignment] + dask_bag_from_delayed = None # type: ignore[assignment] + delayed = None + default_client = None # type: ignore[assignment] + wait = None # type: ignore[assignment] + + class Client: # type: ignore + """Dummy class for dask.distributed.Client.""" + + def __init__(self, *args, **kwargs): + pass + + class Future: # type: ignore + """Dummy class for dask.distributed.Future.""" + + def __init__(self, *args, **kwargs): + pass + + class dask_Array: # type: ignore + """Dummy class for dask.array.Array.""" + + def __init__(self, *args, **kwargs): + pass + + class dask_DataFrame: # type: ignore + """Dummy class for dask.dataframe.DataFrame.""" + + def __init__(self, *args, **kwargs): + pass + + class dask_Series: # type: ignore + """Dummy class for dask.dataframe.Series.""" + + def __init__(self, *args, **kwargs): + pass + +"""pyarrow""" +try: + import pyarrow.compute as pa_compute + from pyarrow import Array as pa_Array + from pyarrow import ChunkedArray as pa_ChunkedArray + from pyarrow import Table as pa_Table + from pyarrow import chunked_array as pa_chunked_array + from pyarrow.cffi import ffi as arrow_cffi + from pyarrow.types import is_floating as arrow_is_floating + from pyarrow.types import is_integer as arrow_is_integer + PYARROW_INSTALLED = True +except ImportError: + PYARROW_INSTALLED = False + + class pa_Array: # type: ignore + """Dummy class for pa.Array.""" + + def __init__(self, *args, **kwargs): + pass + + class pa_ChunkedArray: # type: ignore + """Dummy class for pa.ChunkedArray.""" + + def __init__(self, *args, **kwargs): + pass + + class pa_Table: # type: ignore + """Dummy class for pa.Table.""" + + def __init__(self, *args, **kwargs): + pass + + class arrow_cffi: # type: ignore + """Dummy class for pyarrow.cffi.ffi.""" + + CData = None + addressof = None + cast = None + new = None + + def __init__(self, *args, **kwargs): + pass + + class pa_compute: # type: ignore + """Dummy class for pyarrow.compute.""" + + all = None + equal = None + + pa_chunked_array = None + arrow_is_integer = None + arrow_is_floating = None + +"""cpu_count()""" +try: + from joblib import cpu_count + + def _LGBMCpuCount(only_physical_cores: bool = True) -> int: + return cpu_count(only_physical_cores=only_physical_cores) +except ImportError: + try: + from psutil import cpu_count + + def _LGBMCpuCount(only_physical_cores: bool = True) -> int: + return cpu_count(logical=not only_physical_cores) or 1 + except ImportError: + from multiprocessing import cpu_count + + def _LGBMCpuCount(only_physical_cores: bool = True) -> int: + return cpu_count() + +__all__: List[str] = [] |