diff --git a/comtypes/_post_coinit/misc.py b/comtypes/_post_coinit/misc.py index 39b1b0d3..1607f80b 100644 --- a/comtypes/_post_coinit/misc.py +++ b/comtypes/_post_coinit/misc.py @@ -160,8 +160,12 @@ def CoGetClassObject( ) -> _T_IUnknown: ... -def CoGetClassObject(clsid, clsctx=None, pServerInfo=None, interface=None): - # type: (GUID, Optional[int], Optional[COSERVERINFO], Optional[Type[IUnknown]]) -> IUnknown +def CoGetClassObject( + clsid: GUID, + clsctx: Optional[int] = None, + pServerInfo: "Optional[COSERVERINFO]" = None, + interface: Optional[type[IUnknown]] = None, +) -> IUnknown: if clsctx is None: clsctx = CLSCTX_SERVER if interface is None: diff --git a/comtypes/server/__init__.py b/comtypes/server/__init__.py index 02fbd868..4f9473dc 100644 --- a/comtypes/server/__init__.py +++ b/comtypes/server/__init__.py @@ -1,6 +1,6 @@ import ctypes from ctypes import HRESULT, POINTER, byref -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar, overload import comtypes import comtypes.client @@ -15,6 +15,9 @@ from comtypes import hints # type: ignore +_T_IUnknown = TypeVar("_T_IUnknown", bound=IUnknown) + + ################################################################ # Interfaces class IClassFactory(IUnknown): @@ -28,9 +31,23 @@ class IClassFactory(IUnknown): STDMETHOD(HRESULT, "LockServer", [ctypes.c_int]), ] + @overload + def CreateInstance( + self, + punkouter: Optional["_Pointer[IUnknown]"] = None, + interface: type[_T_IUnknown] = IUnknown, + dynamic: Literal[False] = False, + ) -> _T_IUnknown: ... + @overload + def CreateInstance( + self, + punkouter: Optional["_Pointer[IUnknown]"] = None, + interface: None = None, + dynamic: Literal[True] = True, + ) -> Any: ... def CreateInstance( self, - punkouter: Optional[type["_Pointer[IUnknown]"]] = None, + punkouter: Optional["_Pointer[IUnknown]"] = None, interface: Optional[type[IUnknown]] = None, dynamic: bool = False, ) -> Any: diff --git a/comtypes/test/test_classfactory.py b/comtypes/test/test_classfactory.py new file mode 100644 index 00000000..5bd247dd --- /dev/null +++ b/comtypes/test/test_classfactory.py @@ -0,0 +1,43 @@ +import unittest as ut + +from comtypes import GUID, CoGetClassObject, IUnknown, shelllink +from comtypes.server import IClassFactory + +CLSID_ShellLink = GUID("{00021401-0000-0000-C000-000000000046}") + +REGDB_E_CLASSNOTREG = -2147221164 # 0x80040154 + + +class Test_CreateInstance(ut.TestCase): + def test_returns_specified_interface_type_instance(self): + class_factory = CoGetClassObject(CLSID_ShellLink) + self.assertIsInstance(class_factory, IClassFactory) + shlnk = class_factory.CreateInstance(interface=shelllink.IShellLinkW) + self.assertIsInstance(shlnk, shelllink.IShellLinkW) + shlnk.SetDescription("sample") + self.assertEqual(shlnk.GetDescription(), "sample") + + def test_returns_iunknown_type_instance(self): + class_factory = CoGetClassObject(CLSID_ShellLink) + self.assertIsInstance(class_factory, IClassFactory) + punk = class_factory.CreateInstance() + self.assertIsInstance(punk, IUnknown) + self.assertNotIsInstance(punk, shelllink.IShellLinkW) + shlnk = punk.QueryInterface(shelllink.IShellLinkW) + shlnk.SetDescription("sample") + self.assertEqual(shlnk.GetDescription(), "sample") + + def test_raises_valueerror_if_takes_dynamic_true_and_interface_explicitly(self): + class_factory = CoGetClassObject(CLSID_ShellLink) + self.assertIsInstance(class_factory, IClassFactory) + with self.assertRaises(ValueError): + class_factory.CreateInstance( # type: ignore + interface=shelllink.IShellLinkW, + dynamic=True, # type: ignore + ) + + def test_raises_class_not_reg_error_if_non_existent_clsid(self): + # calling `CoGetClassObject` with a non-existent CLSID raises an `OSError`. + with self.assertRaises(OSError) as cm: + CoGetClassObject(GUID.create_new()) + self.assertEqual(cm.exception.winerror, REGDB_E_CLASSNOTREG)