summaryrefslogtreecommitdiff
path: root/ext/lightgbm/compat.py
blob: bd1b29a1e802a7b3877f372cc97e772333d9de02 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
# coding: utf-8
"""Compatibility library."""

from typing import List

"""pandas"""
try:
    from pandas import DataFrame as pd_DataFrame
    from pandas import Series as pd_Series
    from pandas import concat
    try:
        from pandas import CategoricalDtype as pd_CategoricalDtype
    except ImportError:
        from pandas.api.types import CategoricalDtype as pd_CategoricalDtype
    PANDAS_INSTALLED = True
except ImportError:
    PANDAS_INSTALLED = False

    class pd_Series:  # type: ignore
        """Dummy class for pandas.Series."""

        def __init__(self, *args, **kwargs):
            pass

    class pd_DataFrame:  # type: ignore
        """Dummy class for pandas.DataFrame."""

        def __init__(self, *args, **kwargs):
            pass

    class pd_CategoricalDtype:  # type: ignore
        """Dummy class for pandas.CategoricalDtype."""

        def __init__(self, *args, **kwargs):
            pass

    concat = None

"""numpy"""
try:
    from numpy.random import Generator as np_random_Generator
except ImportError:
    class np_random_Generator:  # type: ignore
        """Dummy class for np.random.Generator."""

        def __init__(self, *args, **kwargs):
            pass

"""matplotlib"""
try:
    import matplotlib  # noqa: F401
    MATPLOTLIB_INSTALLED = True
except ImportError:
    MATPLOTLIB_INSTALLED = False

"""graphviz"""
try:
    import graphviz  # noqa: F401
    GRAPHVIZ_INSTALLED = True
except ImportError:
    GRAPHVIZ_INSTALLED = False

"""datatable"""
try:
    import datatable
    if hasattr(datatable, "Frame"):
        dt_DataTable = datatable.Frame
    else:
        dt_DataTable = datatable.DataTable
    DATATABLE_INSTALLED = True
except ImportError:
    DATATABLE_INSTALLED = False

    class dt_DataTable:  # type: ignore
        """Dummy class for datatable.DataTable."""

        def __init__(self, *args, **kwargs):
            pass


"""sklearn"""
try:
    from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
    from sklearn.preprocessing import LabelEncoder
    from sklearn.utils.class_weight import compute_sample_weight
    from sklearn.utils.multiclass import check_classification_targets
    from sklearn.utils.validation import assert_all_finite, check_array, check_X_y
    try:
        from sklearn.exceptions import NotFittedError
        from sklearn.model_selection import BaseCrossValidator, GroupKFold, StratifiedKFold
    except ImportError:
        from sklearn.cross_validation import BaseCrossValidator, GroupKFold, StratifiedKFold
        from sklearn.utils.validation import NotFittedError
    try:
        from sklearn.utils.validation import _check_sample_weight
    except ImportError:
        from sklearn.utils.validation import check_consistent_length

        # dummy function to support older version of scikit-learn
        def _check_sample_weight(sample_weight, X, dtype=None):
            check_consistent_length(sample_weight, X)
            return sample_weight

    SKLEARN_INSTALLED = True
    _LGBMBaseCrossValidator = BaseCrossValidator
    _LGBMModelBase = BaseEstimator
    _LGBMRegressorBase = RegressorMixin
    _LGBMClassifierBase = ClassifierMixin
    _LGBMLabelEncoder = LabelEncoder
    LGBMNotFittedError = NotFittedError
    _LGBMStratifiedKFold = StratifiedKFold
    _LGBMGroupKFold = GroupKFold
    _LGBMCheckXY = check_X_y
    _LGBMCheckArray = check_array
    _LGBMCheckSampleWeight = _check_sample_weight
    _LGBMAssertAllFinite = assert_all_finite
    _LGBMCheckClassificationTargets = check_classification_targets
    _LGBMComputeSampleWeight = compute_sample_weight
except ImportError:
    SKLEARN_INSTALLED = False

    class _LGBMModelBase:  # type: ignore
        """Dummy class for sklearn.base.BaseEstimator."""

        pass

    class _LGBMClassifierBase:  # type: ignore
        """Dummy class for sklearn.base.ClassifierMixin."""

        pass

    class _LGBMRegressorBase:  # type: ignore
        """Dummy class for sklearn.base.RegressorMixin."""

        pass

    _LGBMBaseCrossValidator = None
    _LGBMLabelEncoder = None
    LGBMNotFittedError = ValueError
    _LGBMStratifiedKFold = None
    _LGBMGroupKFold = None
    _LGBMCheckXY = None
    _LGBMCheckArray = None
    _LGBMCheckSampleWeight = None
    _LGBMAssertAllFinite = None
    _LGBMCheckClassificationTargets = None
    _LGBMComputeSampleWeight = None

"""dask"""
try:
    from dask import delayed
    from dask.array import Array as dask_Array
    from dask.array import from_delayed as dask_array_from_delayed
    from dask.bag import from_delayed as dask_bag_from_delayed
    from dask.dataframe import DataFrame as dask_DataFrame
    from dask.dataframe import Series as dask_Series
    from dask.distributed import Client, Future, default_client, wait
    DASK_INSTALLED = True
except ImportError:
    DASK_INSTALLED = False

    dask_array_from_delayed = None  # type: ignore[assignment]
    dask_bag_from_delayed = None  # type: ignore[assignment]
    delayed = None
    default_client = None  # type: ignore[assignment]
    wait = None  # type: ignore[assignment]

    class Client:  # type: ignore
        """Dummy class for dask.distributed.Client."""

        def __init__(self, *args, **kwargs):
            pass

    class Future:  # type: ignore
        """Dummy class for dask.distributed.Future."""

        def __init__(self, *args, **kwargs):
            pass

    class dask_Array:  # type: ignore
        """Dummy class for dask.array.Array."""

        def __init__(self, *args, **kwargs):
            pass

    class dask_DataFrame:  # type: ignore
        """Dummy class for dask.dataframe.DataFrame."""

        def __init__(self, *args, **kwargs):
            pass

    class dask_Series:  # type: ignore
        """Dummy class for dask.dataframe.Series."""

        def __init__(self, *args, **kwargs):
            pass

"""pyarrow"""
try:
    import pyarrow.compute as pa_compute
    from pyarrow import Array as pa_Array
    from pyarrow import ChunkedArray as pa_ChunkedArray
    from pyarrow import Table as pa_Table
    from pyarrow import chunked_array as pa_chunked_array
    from pyarrow.cffi import ffi as arrow_cffi
    from pyarrow.types import is_floating as arrow_is_floating
    from pyarrow.types import is_integer as arrow_is_integer
    PYARROW_INSTALLED = True
except ImportError:
    PYARROW_INSTALLED = False

    class pa_Array:  # type: ignore
        """Dummy class for pa.Array."""

        def __init__(self, *args, **kwargs):
            pass

    class pa_ChunkedArray:  # type: ignore
        """Dummy class for pa.ChunkedArray."""

        def __init__(self, *args, **kwargs):
            pass

    class pa_Table:  # type: ignore
        """Dummy class for pa.Table."""

        def __init__(self, *args, **kwargs):
            pass

    class arrow_cffi:  # type: ignore
        """Dummy class for pyarrow.cffi.ffi."""

        CData = None
        addressof = None
        cast = None
        new = None

        def __init__(self, *args, **kwargs):
            pass

    class pa_compute:  # type: ignore
        """Dummy class for pyarrow.compute."""

        all = None
        equal = None

    pa_chunked_array = None
    arrow_is_integer = None
    arrow_is_floating = None

"""cpu_count()"""
try:
    from joblib import cpu_count

    def _LGBMCpuCount(only_physical_cores: bool = True) -> int:
        return cpu_count(only_physical_cores=only_physical_cores)
except ImportError:
    try:
        from psutil import cpu_count

        def _LGBMCpuCount(only_physical_cores: bool = True) -> int:
            return cpu_count(logical=not only_physical_cores) or 1
    except ImportError:
        from multiprocessing import cpu_count

        def _LGBMCpuCount(only_physical_cores: bool = True) -> int:
            return cpu_count()

__all__: List[str] = []