Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions comtypes/_post_coinit/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,12 @@ def CoGetClassObject(
) -> _T_IUnknown: ...


def CoGetClassObject(clsid, clsctx=None, pServerInfo=None, interface=None):
# type: (GUID, Optional[int], Optional[COSERVERINFO], Optional[Type[IUnknown]]) -> IUnknown
def CoGetClassObject(
clsid: GUID,
clsctx: Optional[int] = None,
pServerInfo: "Optional[COSERVERINFO]" = None,
interface: Optional[type[IUnknown]] = None,
) -> IUnknown:
if clsctx is None:
clsctx = CLSCTX_SERVER
if interface is None:
Expand Down
21 changes: 19 additions & 2 deletions comtypes/server/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import ctypes
from ctypes import HRESULT, POINTER, byref
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar, overload

import comtypes
import comtypes.client
Expand All @@ -15,6 +15,9 @@
from comtypes import hints # type: ignore


_T_IUnknown = TypeVar("_T_IUnknown", bound=IUnknown)


################################################################
# Interfaces
class IClassFactory(IUnknown):
Expand All @@ -28,9 +31,23 @@ class IClassFactory(IUnknown):
STDMETHOD(HRESULT, "LockServer", [ctypes.c_int]),
]

@overload
def CreateInstance(
self,
punkouter: Optional["_Pointer[IUnknown]"] = None,
interface: type[_T_IUnknown] = IUnknown,
dynamic: Literal[False] = False,
) -> _T_IUnknown: ...
@overload
def CreateInstance(
self,
punkouter: Optional["_Pointer[IUnknown]"] = None,
interface: None = None,
dynamic: Literal[True] = True,
) -> Any: ...
def CreateInstance(
self,
punkouter: Optional[type["_Pointer[IUnknown]"]] = None,
punkouter: Optional["_Pointer[IUnknown]"] = None,
interface: Optional[type[IUnknown]] = None,
dynamic: bool = False,
) -> Any:
Expand Down
43 changes: 43 additions & 0 deletions comtypes/test/test_classfactory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import unittest as ut

from comtypes import GUID, CoGetClassObject, IUnknown, shelllink
from comtypes.server import IClassFactory

CLSID_ShellLink = GUID("{00021401-0000-0000-C000-000000000046}")

REGDB_E_CLASSNOTREG = -2147221164 # 0x80040154


class Test_CreateInstance(ut.TestCase):
def test_returns_specified_interface_type_instance(self):
class_factory = CoGetClassObject(CLSID_ShellLink)
self.assertIsInstance(class_factory, IClassFactory)
shlnk = class_factory.CreateInstance(interface=shelllink.IShellLinkW)
self.assertIsInstance(shlnk, shelllink.IShellLinkW)
shlnk.SetDescription("sample")
self.assertEqual(shlnk.GetDescription(), "sample")

def test_returns_iunknown_type_instance(self):
class_factory = CoGetClassObject(CLSID_ShellLink)
self.assertIsInstance(class_factory, IClassFactory)
punk = class_factory.CreateInstance()
self.assertIsInstance(punk, IUnknown)
self.assertNotIsInstance(punk, shelllink.IShellLinkW)
shlnk = punk.QueryInterface(shelllink.IShellLinkW)
shlnk.SetDescription("sample")
self.assertEqual(shlnk.GetDescription(), "sample")

def test_raises_valueerror_if_takes_dynamic_true_and_interface_explicitly(self):
class_factory = CoGetClassObject(CLSID_ShellLink)
self.assertIsInstance(class_factory, IClassFactory)
with self.assertRaises(ValueError):
class_factory.CreateInstance( # type: ignore
interface=shelllink.IShellLinkW,
dynamic=True, # type: ignore
)

def test_raises_class_not_reg_error_if_non_existent_clsid(self):
# calling `CoGetClassObject` with a non-existent CLSID raises an `OSError`.
with self.assertRaises(OSError) as cm:
CoGetClassObject(GUID.create_new())
self.assertEqual(cm.exception.winerror, REGDB_E_CLASSNOTREG)