From aa39ef02e0bf7d4bf47349cd59ddf4d92de03df8 Mon Sep 17 00:00:00 2001 From: 苏向夜 Date: Sat, 27 Jan 2024 16:25:17 +0800 Subject: :sparkles: feat(internal): return a Register class instead of ModuleType in require function --- src/infini/const.py | 3 --- src/infini/internal.py | 28 +++++++++++++++++----------- tests/test_internal.py | 11 +++++++++++ 3 files changed, 28 insertions(+), 14 deletions(-) delete mode 100644 src/infini/const.py create mode 100644 tests/test_internal.py diff --git a/src/infini/const.py b/src/infini/const.py deleted file mode 100644 index 2ec42488..00000000 --- a/src/infini/const.py +++ /dev/null @@ -1,3 +0,0 @@ -from pathlib import Path - -SRC_HOME = Path.home() / ".ipm" / "src" \ No newline at end of file diff --git a/src/infini/internal.py b/src/infini/internal.py index 4f800297..460407a8 100644 --- a/src/infini/internal.py +++ b/src/infini/internal.py @@ -1,18 +1,24 @@ -from infini.typing import List, ModuleType -from infini.const import SRC_HOME +from infini.loader import Loader +from infini.register import Register +from infini.typing import List +from pathlib import Path -import importlib import sys +import inspect -def require(name: str, paths: List | None = None) -> ModuleType: +def require(name: str, paths: List | None = None) -> Register: + caller_frame = inspect.stack()[1] + caller_file = caller_frame[0].f_globals["__file__"] + + default_paths = [Path(caller_file).resolve().parent] paths = [ str(path) - for path in ( - (list(paths) + [str(SRC_HOME / name)]) if paths else [str(SRC_HOME / name)] - ) + for path in ((list(paths) + default_paths) if paths else default_paths) ] - sys.path.extend(paths) - module = importlib.import_module(name) - (sys.path.remove(path) for path in paths) - return module + (sys.path.insert(0, path) for path in paths) + + with Loader() as loader: + loader.load(name) + sys.path = sys.path[len(paths) - 1 :] + return loader.into_register() diff --git a/tests/test_internal.py b/tests/test_internal.py new file mode 100644 index 00000000..321c9f5c --- /dev/null +++ b/tests/test_internal.py @@ -0,0 +1,11 @@ +from infini.internal import require +from ipm import api + +import shutil + + +def test_internal(): + api.new("test_ipk") + registers = require("test_ipk") + assert registers + shutil.rmtree("test_ipk") -- cgit v1.2.3-70-g09d2